diff --git a/io_uring/io_uring.c b/io_uring/io_uring.c index 1b53a2ab0a27..e8096d502a7c 100644 --- a/io_uring/io_uring.c +++ b/io_uring/io_uring.c @@ -149,7 +149,6 @@ static bool io_uring_try_cancel_requests(struct io_ring_ctx *ctx, static void io_queue_sqe(struct io_kiocb *req); static void io_move_task_work_from_local(struct io_ring_ctx *ctx); static void __io_submit_flush_completions(struct io_ring_ctx *ctx); -static __cold void io_fallback_tw(struct io_uring_task *tctx); struct kmem_cache *req_cachep; @@ -1238,6 +1237,34 @@ static inline struct llist_node *io_llist_cmpxchg(struct llist_head *head, return cmpxchg(&head->first, old, new); } +static __cold void io_fallback_tw(struct io_uring_task *tctx, bool sync) +{ + struct llist_node *node = llist_del_all(&tctx->task_list); + struct io_ring_ctx *last_ctx = NULL; + struct io_kiocb *req; + + while (node) { + req = container_of(node, struct io_kiocb, io_task_work.node); + node = node->next; + if (sync && last_ctx != req->ctx) { + if (last_ctx) { + flush_delayed_work(&last_ctx->fallback_work); + percpu_ref_put(&last_ctx->refs); + } + last_ctx = req->ctx; + percpu_ref_get(&last_ctx->refs); + } + if (llist_add(&req->io_task_work.node, + &req->ctx->fallback_llist)) + schedule_delayed_work(&req->ctx->fallback_work, 1); + } + + if (last_ctx) { + flush_delayed_work(&last_ctx->fallback_work); + percpu_ref_put(&last_ctx->refs); + } +} + void tctx_task_work(struct callback_head *cb) { struct io_tw_state ts = {}; @@ -1250,7 +1277,7 @@ void tctx_task_work(struct callback_head *cb) unsigned int count = 0; if (unlikely(current->flags & PF_EXITING)) { - io_fallback_tw(tctx); + io_fallback_tw(tctx, true); return; } @@ -1279,20 +1306,6 @@ void tctx_task_work(struct callback_head *cb) trace_io_uring_task_work_run(tctx, count, loops); } -static __cold void io_fallback_tw(struct io_uring_task *tctx) -{ - struct llist_node *node = llist_del_all(&tctx->task_list); - struct io_kiocb *req; - - while (node) { - req = container_of(node, struct io_kiocb, io_task_work.node); - node = node->next; - if (llist_add(&req->io_task_work.node, - &req->ctx->fallback_llist)) - schedule_delayed_work(&req->ctx->fallback_work, 1); - } -} - static inline void io_req_local_work_add(struct io_kiocb *req, unsigned flags) { struct io_ring_ctx *ctx = req->ctx; @@ -1359,7 +1372,7 @@ static void io_req_normal_work_add(struct io_kiocb *req) if (likely(!task_work_add(req->task, &tctx->task_work, ctx->notify_method))) return; - io_fallback_tw(tctx); + io_fallback_tw(tctx, false); } void __io_req_task_work_add(struct io_kiocb *req, unsigned flags) @@ -3109,6 +3122,8 @@ static __cold void io_ring_ctx_wait_and_kill(struct io_ring_ctx *ctx) if (ctx->rings) io_kill_timeouts(ctx, NULL, true); + flush_delayed_work(&ctx->fallback_work); + INIT_WORK(&ctx->exit_work, io_ring_exit_work); /* * Use system_unbound_wq to avoid spawning tons of event kworkers diff --git a/io_uring/net.c b/io_uring/net.c index d7e6efe89f48..eb1f51ddcb23 100644 --- a/io_uring/net.c +++ b/io_uring/net.c @@ -631,7 +631,7 @@ static inline bool io_recv_finish(struct io_kiocb *req, int *ret, unsigned int cflags; cflags = io_put_kbuf(req, issue_flags); - if (msg->msg_inq && msg->msg_inq != -1U) + if (msg->msg_inq && msg->msg_inq != -1) cflags |= IORING_CQE_F_SOCK_NONEMPTY; if (!(req->flags & REQ_F_APOLL_MULTISHOT)) { @@ -646,7 +646,7 @@ static inline bool io_recv_finish(struct io_kiocb *req, int *ret, io_recv_prep_retry(req); /* Known not-empty or unknown state, retry */ if (cflags & IORING_CQE_F_SOCK_NONEMPTY || - msg->msg_inq == -1U) + msg->msg_inq == -1) return false; if (issue_flags & IO_URING_F_MULTISHOT) *ret = IOU_ISSUE_SKIP_COMPLETE; @@ -805,7 +805,7 @@ int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags) flags |= MSG_DONTWAIT; kmsg->msg.msg_get_inq = 1; - kmsg->msg.msg_inq = -1U; + kmsg->msg.msg_inq = -1; if (req->flags & REQ_F_APOLL_MULTISHOT) { ret = io_recvmsg_multishot(sock, sr, kmsg, flags, &mshot_finished); @@ -903,7 +903,7 @@ int io_recv(struct io_kiocb *req, unsigned int issue_flags) if (unlikely(ret)) goto out_free; - msg.msg_inq = -1U; + msg.msg_inq = -1; msg.msg_flags = 0; flags = sr->msg_flags;