iov_iter: Introduce fault_in_iov_iter_writeable

Introduce a new fault_in_iov_iter_writeable helper for safely faulting
in an iterator for writing.  Uses get_user_pages() to fault in the pages
without actually writing to them, which would be destructive.

We'll use fault_in_iov_iter_writeable in gfs2 once we've determined that
the iterator passed to .read_iter isn't in memory.

Signed-off-by: Andreas Gruenbacher <agruenba@redhat.com>
This commit is contained in:
Andreas Gruenbacher 2021-07-05 17:26:28 +02:00
parent a6294593e8
commit cdd591fc86
4 changed files with 104 additions and 0 deletions

View File

@ -736,6 +736,7 @@ extern void add_page_wait_queue(struct page *page, wait_queue_entry_t *waiter);
* Fault in userspace address range.
*/
size_t fault_in_writeable(char __user *uaddr, size_t size);
size_t fault_in_safe_writeable(const char __user *uaddr, size_t size);
size_t fault_in_readable(const char __user *uaddr, size_t size);
int add_to_page_cache_locked(struct page *page, struct address_space *mapping,

View File

@ -134,6 +134,7 @@ size_t copy_page_from_iter_atomic(struct page *page, unsigned offset,
void iov_iter_advance(struct iov_iter *i, size_t bytes);
void iov_iter_revert(struct iov_iter *i, size_t bytes);
size_t fault_in_iov_iter_readable(const struct iov_iter *i, size_t bytes);
size_t fault_in_iov_iter_writeable(const struct iov_iter *i, size_t bytes);
size_t iov_iter_single_seg_count(const struct iov_iter *i);
size_t copy_page_to_iter(struct page *page, size_t offset, size_t bytes,
struct iov_iter *i);

View File

@ -467,6 +467,45 @@ size_t fault_in_iov_iter_readable(const struct iov_iter *i, size_t size)
}
EXPORT_SYMBOL(fault_in_iov_iter_readable);
/*
* fault_in_iov_iter_writeable - fault in iov iterator for writing
* @i: iterator
* @size: maximum length
*
* Faults in the iterator using get_user_pages(), i.e., without triggering
* hardware page faults. This is primarily useful when we already know that
* some or all of the pages in @i aren't in memory.
*
* Returns the number of bytes not faulted in, like copy_to_user() and
* copy_from_user().
*
* Always returns 0 for non-user-space iterators.
*/
size_t fault_in_iov_iter_writeable(const struct iov_iter *i, size_t size)
{
if (iter_is_iovec(i)) {
size_t count = min(size, iov_iter_count(i));
const struct iovec *p;
size_t skip;
size -= count;
for (p = i->iov, skip = i->iov_offset; count; p++, skip = 0) {
size_t len = min(count, p->iov_len - skip);
size_t ret;
if (unlikely(!len))
continue;
ret = fault_in_safe_writeable(p->iov_base + skip, len);
count -= len - ret;
if (ret)
break;
}
return count + size;
}
return 0;
}
EXPORT_SYMBOL(fault_in_iov_iter_writeable);
void iov_iter_init(struct iov_iter *i, unsigned int direction,
const struct iovec *iov, unsigned long nr_segs,
size_t count)

View File

@ -1691,6 +1691,69 @@ size_t fault_in_writeable(char __user *uaddr, size_t size)
}
EXPORT_SYMBOL(fault_in_writeable);
/*
* fault_in_safe_writeable - fault in an address range for writing
* @uaddr: start of address range
* @size: length of address range
*
* Faults in an address range using get_user_pages, i.e., without triggering
* hardware page faults. This is primarily useful when we already know that
* some or all of the pages in the address range aren't in memory.
*
* Other than fault_in_writeable(), this function is non-destructive.
*
* Note that we don't pin or otherwise hold the pages referenced that we fault
* in. There's no guarantee that they'll stay in memory for any duration of
* time.
*
* Returns the number of bytes not faulted in, like copy_to_user() and
* copy_from_user().
*/
size_t fault_in_safe_writeable(const char __user *uaddr, size_t size)
{
unsigned long start = (unsigned long)untagged_addr(uaddr);
unsigned long end, nstart, nend;
struct mm_struct *mm = current->mm;
struct vm_area_struct *vma = NULL;
int locked = 0;
nstart = start & PAGE_MASK;
end = PAGE_ALIGN(start + size);
if (end < nstart)
end = 0;
for (; nstart != end; nstart = nend) {
unsigned long nr_pages;
long ret;
if (!locked) {
locked = 1;
mmap_read_lock(mm);
vma = find_vma(mm, nstart);
} else if (nstart >= vma->vm_end)
vma = vma->vm_next;
if (!vma || vma->vm_start >= end)
break;
nend = end ? min(end, vma->vm_end) : vma->vm_end;
if (vma->vm_flags & (VM_IO | VM_PFNMAP))
continue;
if (nstart < vma->vm_start)
nstart = vma->vm_start;
nr_pages = (nend - nstart) / PAGE_SIZE;
ret = __get_user_pages_locked(mm, nstart, nr_pages,
NULL, NULL, &locked,
FOLL_TOUCH | FOLL_WRITE);
if (ret <= 0)
break;
nend = nstart + ret * PAGE_SIZE;
}
if (locked)
mmap_read_unlock(mm);
if (nstart == end)
return 0;
return size - min_t(size_t, nstart - start, size);
}
EXPORT_SYMBOL(fault_in_safe_writeable);
/**
* fault_in_readable - fault in userspace address range for reading
* @uaddr: start of user address range