diff --git a/io_uring/net.c b/io_uring/net.c index 021ca2edf44a..fdb69a3fde76 100644 --- a/io_uring/net.c +++ b/io_uring/net.c @@ -59,9 +59,10 @@ struct io_sr_msg { unsigned done_io; unsigned msg_flags; u16 flags; - /* used only for sendzc */ + /* initialised and used only by !msg send variants */ u16 addr_len; void __user *addr; + /* used only for send zerocopy */ struct io_kiocb *notif; }; @@ -180,7 +181,7 @@ static int io_sendmsg_copy_hdr(struct io_kiocb *req, &iomsg->free_iov); } -int io_sendzc_prep_async(struct io_kiocb *req) +int io_send_prep_async(struct io_kiocb *req) { struct io_sr_msg *zc = io_kiocb_to_cmd(req, struct io_sr_msg); struct io_async_msghdr *io; @@ -234,8 +235,14 @@ int io_sendmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe) { struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg); - if (unlikely(sqe->file_index || sqe->addr2)) + if (req->opcode == IORING_OP_SEND) { + if (READ_ONCE(sqe->__pad3[0])) + return -EINVAL; + sr->addr = u64_to_user_ptr(READ_ONCE(sqe->addr2)); + sr->addr_len = READ_ONCE(sqe->addr_len); + } else if (sqe->addr2 || sqe->file_index) { return -EINVAL; + } sr->umsg = u64_to_user_ptr(READ_ONCE(sqe->addr)); sr->len = READ_ONCE(sqe->len); @@ -315,6 +322,7 @@ int io_sendmsg(struct io_kiocb *req, unsigned int issue_flags) int io_send(struct io_kiocb *req, unsigned int issue_flags) { + struct sockaddr_storage __address; struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg); struct msghdr msg; struct iovec iov; @@ -323,9 +331,23 @@ int io_send(struct io_kiocb *req, unsigned int issue_flags) int min_ret = 0; int ret; + if (sr->addr) { + if (req_has_async_data(req)) { + struct io_async_msghdr *io = req->async_data; + + msg.msg_name = &io->addr; + } else { + ret = move_addr_to_kernel(sr->addr, sr->addr_len, &__address); + if (unlikely(ret < 0)) + return ret; + msg.msg_name = (struct sockaddr *)&__address; + } + msg.msg_namelen = sr->addr_len; + } + if (!(req->flags & REQ_F_POLLED) && (sr->flags & IORING_RECVSEND_POLL_FIRST)) - return -EAGAIN; + return io_setup_async_addr(req, &__address, issue_flags); sock = sock_from_file(req->file); if (unlikely(!sock)) @@ -351,13 +373,14 @@ int io_send(struct io_kiocb *req, unsigned int issue_flags) ret = sock_sendmsg(sock, &msg); if (ret < min_ret) { if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK)) - return -EAGAIN; + return io_setup_async_addr(req, &__address, issue_flags); + if (ret > 0 && io_net_retry(sock, flags)) { sr->len -= ret; sr->buf += ret; sr->done_io += ret; req->flags |= REQ_F_PARTIAL_IO; - return -EAGAIN; + return io_setup_async_addr(req, &__address, issue_flags); } if (ret == -ERESTARTSYS) ret = -EINTR; diff --git a/io_uring/net.h b/io_uring/net.h index e7366aac335c..488d4dc7eee2 100644 --- a/io_uring/net.h +++ b/io_uring/net.h @@ -31,12 +31,13 @@ struct io_async_connect { int io_shutdown_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe); int io_shutdown(struct io_kiocb *req, unsigned int issue_flags); -int io_sendzc_prep_async(struct io_kiocb *req); int io_sendmsg_prep_async(struct io_kiocb *req); void io_sendmsg_recvmsg_cleanup(struct io_kiocb *req); int io_sendmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe); int io_sendmsg(struct io_kiocb *req, unsigned int issue_flags); + int io_send(struct io_kiocb *req, unsigned int issue_flags); +int io_send_prep_async(struct io_kiocb *req); int io_recvmsg_prep_async(struct io_kiocb *req); int io_recvmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe); diff --git a/io_uring/opdef.c b/io_uring/opdef.c index 4fbefb7d70c7..849514abd046 100644 --- a/io_uring/opdef.c +++ b/io_uring/opdef.c @@ -316,11 +316,14 @@ const struct io_op_def io_op_defs[] = { .pollout = 1, .audit_skip = 1, .ioprio = 1, + .manual_alloc = 1, .name = "SEND", #if defined(CONFIG_NET) + .async_size = sizeof(struct io_async_msghdr), .prep = io_sendmsg_prep, .issue = io_send, .fail = io_sendrecv_fail, + .prep_async = io_send_prep_async, #else .prep = io_eopnotsupp_prep, #endif @@ -495,7 +498,7 @@ const struct io_op_def io_op_defs[] = { .async_size = sizeof(struct io_async_msghdr), .prep = io_sendzc_prep, .issue = io_sendzc, - .prep_async = io_sendzc_prep_async, + .prep_async = io_send_prep_async, .cleanup = io_sendzc_cleanup, .fail = io_send_zc_fail, #else