io_uring/rsrc: cleanup io_pin_pages()

This function is overly convoluted with a goto error path, and checks
under the mmap_read_lock() that don't need to be at all. Rearrange it
a bit so the checks and errors fall out naturally, rather than needing
to jump around for it.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
This commit is contained in:
Jens Axboe 2023-10-02 18:25:23 -06:00
parent 93b8cc60c3
commit 922a2c78f1

View File

@ -1037,39 +1037,36 @@ struct page **io_pin_pages(unsigned long ubuf, unsigned long len, int *npages)
{ {
unsigned long start, end, nr_pages; unsigned long start, end, nr_pages;
struct page **pages = NULL; struct page **pages = NULL;
int pret, ret = -ENOMEM; int ret;
end = (ubuf + len + PAGE_SIZE - 1) >> PAGE_SHIFT; end = (ubuf + len + PAGE_SIZE - 1) >> PAGE_SHIFT;
start = ubuf >> PAGE_SHIFT; start = ubuf >> PAGE_SHIFT;
nr_pages = end - start; nr_pages = end - start;
WARN_ON(!nr_pages);
pages = kvmalloc_array(nr_pages, sizeof(struct page *), GFP_KERNEL); pages = kvmalloc_array(nr_pages, sizeof(struct page *), GFP_KERNEL);
if (!pages) if (!pages)
goto done; return ERR_PTR(-ENOMEM);
ret = 0;
mmap_read_lock(current->mm); mmap_read_lock(current->mm);
pret = pin_user_pages(ubuf, nr_pages, FOLL_WRITE | FOLL_LONGTERM, ret = pin_user_pages(ubuf, nr_pages, FOLL_WRITE | FOLL_LONGTERM, pages);
pages);
if (pret == nr_pages)
*npages = nr_pages;
else
ret = pret < 0 ? pret : -EFAULT;
mmap_read_unlock(current->mm); mmap_read_unlock(current->mm);
if (ret) {
/* success, mapped all pages */
if (ret == nr_pages) {
*npages = nr_pages;
return pages;
}
/* partial map, or didn't map anything */
if (ret >= 0) {
/* if we did partial map, release any pages we did get */ /* if we did partial map, release any pages we did get */
if (pret > 0) if (ret)
unpin_user_pages(pages, pret); unpin_user_pages(pages, ret);
goto done; ret = -EFAULT;
} }
ret = 0; kvfree(pages);
done: return ERR_PTR(ret);
if (ret < 0) {
kvfree(pages);
pages = ERR_PTR(ret);
}
return pages;
} }
static int io_sqe_buffer_register(struct io_ring_ctx *ctx, struct iovec *iov, static int io_sqe_buffer_register(struct io_ring_ctx *ctx, struct iovec *iov,