um: generalize os_rcv_fd

Change os_rcv_fd() to os_rcv_fd_msg() that can more generally
receive any number of FDs in any kind of message.

Link: https://patch.msgid.link/20240702192118.40b78b2bfe4e.Ic6ec12d72630e5bcae1e597d6bd5c6f29f441563@changeid
Signed-off-by: Johannes Berg <johannes.berg@intel.com>
This commit is contained in:
Johannes Berg 2024-07-02 19:21:19 +02:00
parent 6555acdefc
commit 5cde6096a4
6 changed files with 54 additions and 41 deletions

View File

@ -45,15 +45,17 @@ struct connection {
static irqreturn_t pipe_interrupt(int irq, void *data)
{
struct connection *conn = data;
int fd;
int n_fds = 1, fd = -1;
ssize_t ret;
fd = os_rcv_fd(conn->socket[0], &conn->helper_pid);
if (fd < 0) {
if (fd == -EAGAIN)
ret = os_rcv_fd_msg(conn->socket[0], &fd, n_fds, &conn->helper_pid,
sizeof(conn->helper_pid));
if (ret != sizeof(conn->helper_pid)) {
if (ret == -EAGAIN)
return IRQ_NONE;
printk(KERN_ERR "pipe_interrupt : os_rcv_fd returned %d\n",
-fd);
printk(KERN_ERR "pipe_interrupt : os_rcv_fd_msg returned %zd\n",
ret);
os_close_file(conn->fd);
}

View File

@ -156,7 +156,7 @@ static int xterm_open(int input, int output, int primary, void *d,
new = xterm_fd(fd, &data->helper_pid);
if (new < 0) {
err = new;
printk(UM_KERN_ERR "xterm_open : os_rcv_fd failed, err = %d\n",
printk(UM_KERN_ERR "xterm_open : xterm_fd failed, err = %d\n",
-err);
goto out_kill;
}

View File

@ -21,12 +21,19 @@ struct xterm_wait {
static irqreturn_t xterm_interrupt(int irq, void *data)
{
struct xterm_wait *xterm = data;
int fd;
int fd = -1, n_fds = 1;
ssize_t ret;
fd = os_rcv_fd(xterm->fd, &xterm->pid);
if (fd == -EAGAIN)
ret = os_rcv_fd_msg(xterm->fd, &fd, n_fds,
&xterm->pid, sizeof(xterm->pid));
if (ret == -EAGAIN)
return IRQ_NONE;
if (ret < 0)
fd = ret;
else if (ret != sizeof(xterm->pid))
fd = -EMSGSIZE;
xterm->new_fd = fd;
complete(&xterm->ready);

View File

@ -165,7 +165,8 @@ extern int os_create_unix_socket(const char *file, int len, int close_on_exec);
extern int os_shutdown_socket(int fd, int r, int w);
extern int os_dup_file(int fd);
extern void os_close_file(int fd);
extern int os_rcv_fd(int fd, int *helper_pid_out);
ssize_t os_rcv_fd_msg(int fd, int *fds, unsigned int n_fds,
void *data, size_t data_len);
extern int os_connect_socket(const char *name);
extern int os_file_type(char *file);
extern int os_file_mode(const char *file, struct openflags *mode_out);

View File

@ -33,7 +33,7 @@ EXPORT_SYMBOL(os_shutdown_socket);
EXPORT_SYMBOL(os_create_unix_socket);
EXPORT_SYMBOL(os_connect_socket);
EXPORT_SYMBOL(os_accept_connection);
EXPORT_SYMBOL(os_rcv_fd);
EXPORT_SYMBOL(os_rcv_fd_msg);
EXPORT_SYMBOL(run_helper);
EXPORT_SYMBOL(os_major);
EXPORT_SYMBOL(os_minor);

View File

@ -512,44 +512,47 @@ int os_shutdown_socket(int fd, int r, int w)
return 0;
}
int os_rcv_fd(int fd, int *helper_pid_out)
/**
* os_rcv_fd_msg - receive message with (optional) FDs
* @fd: the FD to receive from
* @fds: the array for FDs to write to
* @n_fds: number of FDs to receive (@fds array size)
* @data: the message buffer
* @data_len: the size of the message to receive
*
* Receive a message with FDs.
*
* Returns: the size of the received message, or an error code
*/
ssize_t os_rcv_fd_msg(int fd, int *fds, unsigned int n_fds,
void *data, size_t data_len)
{
int new, n;
char buf[CMSG_SPACE(sizeof(new))];
struct msghdr msg;
char buf[CMSG_SPACE(sizeof(*fds) * n_fds)];
struct cmsghdr *cmsg;
struct iovec iov;
msg.msg_name = NULL;
msg.msg_namelen = 0;
iov = ((struct iovec) { .iov_base = helper_pid_out,
.iov_len = sizeof(*helper_pid_out) });
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = buf;
msg.msg_controllen = sizeof(buf);
msg.msg_flags = 0;
struct iovec iov = {
.iov_base = data,
.iov_len = data_len,
};
struct msghdr msg = {
.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = buf,
.msg_controllen = sizeof(buf),
};
int n;
n = recvmsg(fd, &msg, 0);
if (n < 0)
return -errno;
else if (n != iov.iov_len)
*helper_pid_out = -1;
cmsg = CMSG_FIRSTHDR(&msg);
if (cmsg == NULL) {
printk(UM_KERN_ERR "rcv_fd didn't receive anything, "
"error = %d\n", errno);
return -1;
}
if ((cmsg->cmsg_level != SOL_SOCKET) ||
(cmsg->cmsg_type != SCM_RIGHTS)) {
printk(UM_KERN_ERR "rcv_fd didn't receive a descriptor\n");
return -1;
}
if (!cmsg ||
cmsg->cmsg_level != SOL_SOCKET ||
cmsg->cmsg_type != SCM_RIGHTS)
return n;
new = ((int *) CMSG_DATA(cmsg))[0];
return new;
memcpy(fds, CMSG_DATA(cmsg), cmsg->cmsg_len);
return n;
}
int os_create_unix_socket(const char *file, int len, int close_on_exec)