diff --git a/include/net/9p/transport.h b/include/net/9p/transport.h index 2a25dec30211..5122b5e40f78 100644 --- a/include/net/9p/transport.h +++ b/include/net/9p/transport.h @@ -61,7 +61,7 @@ struct p9_trans_module { int (*cancel) (struct p9_client *, struct p9_req_t *req); int (*cancelled)(struct p9_client *, struct p9_req_t *req); int (*zc_request)(struct p9_client *, struct p9_req_t *, - char *, char *, int , int, int, int); + struct iov_iter *, struct iov_iter *, int , int, int); }; void v9fs_register_trans(struct p9_trans_module *m); diff --git a/net/9p/client.c b/net/9p/client.c index e86a9bea1d16..9ef5d85f082f 100644 --- a/net/9p/client.c +++ b/net/9p/client.c @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -555,7 +556,7 @@ out_err: */ static int p9_check_zc_errors(struct p9_client *c, struct p9_req_t *req, - char *uidata, int in_hdrlen, int kern_buf) + struct iov_iter *uidata, int in_hdrlen) { int err; int ecode; @@ -591,16 +592,11 @@ static int p9_check_zc_errors(struct p9_client *c, struct p9_req_t *req, ename = &req->rc->sdata[req->rc->offset]; if (len > inline_len) { /* We have error in external buffer */ - if (kern_buf) { - memcpy(ename + inline_len, uidata, - len - inline_len); - } else { - err = copy_from_user(ename + inline_len, - uidata, len - inline_len); - if (err) { - err = -EFAULT; - goto out_err; - } + err = copy_from_iter(ename + inline_len, + len - inline_len, uidata); + if (err != len - inline_len) { + err = -EFAULT; + goto out_err; } } ename = NULL; @@ -806,8 +802,8 @@ reterr: * p9_client_zc_rpc - issue a request and wait for a response * @c: client session * @type: type of request - * @uidata: user bffer that should be ued for zero copy read - * @uodata: user buffer that shoud be user for zero copy write + * @uidata: destination for zero copy read + * @uodata: source for zero copy write * @inlen: read buffer size * @olen: write buffer size * @hdrlen: reader header size, This is the size of response protocol data @@ -816,9 +812,10 @@ reterr: * Returns request structure (which client must free using p9_free_req) */ static struct p9_req_t *p9_client_zc_rpc(struct p9_client *c, int8_t type, - char *uidata, char *uodata, + struct iov_iter *uidata, + struct iov_iter *uodata, int inlen, int olen, int in_hdrlen, - int kern_buf, const char *fmt, ...) + const char *fmt, ...) { va_list ap; int sigpending, err; @@ -841,12 +838,8 @@ static struct p9_req_t *p9_client_zc_rpc(struct p9_client *c, int8_t type, } else sigpending = 0; - /* If we are called with KERNEL_DS force kern_buf */ - if (segment_eq(get_fs(), KERNEL_DS)) - kern_buf = 1; - err = c->trans_mod->zc_request(c, req, uidata, uodata, - inlen, olen, in_hdrlen, kern_buf); + inlen, olen, in_hdrlen); if (err < 0) { if (err == -EIO) c->status = Disconnected; @@ -876,7 +869,7 @@ static struct p9_req_t *p9_client_zc_rpc(struct p9_client *c, int8_t type, if (err < 0) goto reterr; - err = p9_check_zc_errors(c, req, uidata, in_hdrlen, kern_buf); + err = p9_check_zc_errors(c, req, uidata, in_hdrlen); trace_9p_client_res(c, type, req->rc->tag, err); if (!err) return req; @@ -1545,11 +1538,24 @@ p9_client_read(struct p9_fid *fid, char *data, char __user *udata, u64 offset, u32 count) { char *dataptr; - int kernel_buf = 0; struct p9_req_t *req; struct p9_client *clnt; int err, rsize, non_zc = 0; + struct iov_iter to; + union { + struct kvec kv; + struct iovec iov; + } v; + if (data) { + v.kv.iov_base = data; + v.kv.iov_len = count; + iov_iter_kvec(&to, ITER_KVEC | READ, &v.kv, 1, count); + } else { + v.iov.iov_base = udata; + v.iov.iov_len = count; + iov_iter_init(&to, READ, &v.iov, 1, count); + } p9_debug(P9_DEBUG_9P, ">>> TREAD fid %d offset %llu %d\n", fid->fid, (unsigned long long) offset, count); @@ -1565,18 +1571,12 @@ p9_client_read(struct p9_fid *fid, char *data, char __user *udata, u64 offset, /* Don't bother zerocopy for small IO (< 1024) */ if (clnt->trans_mod->zc_request && rsize > 1024) { - char *indata; - if (data) { - kernel_buf = 1; - indata = data; - } else - indata = (__force char *)udata; /* * response header len is 11 * PDU Header(7) + IO Size (4) */ - req = p9_client_zc_rpc(clnt, P9_TREAD, indata, NULL, rsize, 0, - 11, kernel_buf, "dqd", fid->fid, + req = p9_client_zc_rpc(clnt, P9_TREAD, &to, NULL, rsize, 0, + 11, "dqd", fid->fid, offset, rsize); } else { non_zc = 1; @@ -1596,16 +1596,9 @@ p9_client_read(struct p9_fid *fid, char *data, char __user *udata, u64 offset, p9_debug(P9_DEBUG_9P, "<<< RREAD count %d\n", count); - if (non_zc) { - if (data) { - memmove(data, dataptr, count); - } else { - err = copy_to_user(udata, dataptr, count); - if (err) { - err = -EFAULT; - goto free_and_error; - } - } + if (non_zc && copy_to_iter(dataptr, count, &to) != count) { + err = -EFAULT; + goto free_and_error; } p9_free_req(clnt, req); return count; @@ -1622,9 +1615,23 @@ p9_client_write(struct p9_fid *fid, char *data, const char __user *udata, u64 offset, u32 count) { int err, rsize; - int kernel_buf = 0; struct p9_client *clnt; struct p9_req_t *req; + struct iov_iter from; + union { + struct kvec kv; + struct iovec iov; + } v; + + if (data) { + v.kv.iov_base = data; + v.kv.iov_len = count; + iov_iter_kvec(&from, ITER_KVEC | WRITE, &v.kv, 1, count); + } else { + v.iov.iov_base = udata; + v.iov.iov_len = count; + iov_iter_init(&from, WRITE, &v.iov, 1, count); + } p9_debug(P9_DEBUG_9P, ">>> TWRITE fid %d offset %llu count %d\n", fid->fid, (unsigned long long) offset, count); @@ -1640,22 +1647,12 @@ p9_client_write(struct p9_fid *fid, char *data, const char __user *udata, /* Don't bother zerocopy for small IO (< 1024) */ if (clnt->trans_mod->zc_request && rsize > 1024) { - char *odata; - if (data) { - kernel_buf = 1; - odata = data; - } else - odata = (char *)udata; - req = p9_client_zc_rpc(clnt, P9_TWRITE, NULL, odata, 0, rsize, - P9_ZC_HDR_SZ, kernel_buf, "dqd", + req = p9_client_zc_rpc(clnt, P9_TWRITE, NULL, &from, 0, rsize, + P9_ZC_HDR_SZ, "dqd", fid->fid, offset, rsize); } else { - if (data) - req = p9_client_rpc(clnt, P9_TWRITE, "dqD", fid->fid, - offset, rsize, data); - else - req = p9_client_rpc(clnt, P9_TWRITE, "dqU", fid->fid, - offset, rsize, udata); + req = p9_client_rpc(clnt, P9_TWRITE, "dqV", fid->fid, + offset, rsize, &from); } if (IS_ERR(req)) { err = PTR_ERR(req); @@ -2068,6 +2065,10 @@ int p9_client_readdir(struct p9_fid *fid, char *data, u32 count, u64 offset) struct p9_client *clnt; struct p9_req_t *req; char *dataptr; + struct kvec kv = {.iov_base = data, .iov_len = count}; + struct iov_iter to; + + iov_iter_kvec(&to, READ | ITER_KVEC, &kv, 1, count); p9_debug(P9_DEBUG_9P, ">>> TREADDIR fid %d offset %llu count %d\n", fid->fid, (unsigned long long) offset, count); @@ -2088,8 +2089,8 @@ int p9_client_readdir(struct p9_fid *fid, char *data, u32 count, u64 offset) * response header len is 11 * PDU Header(7) + IO Size (4) */ - req = p9_client_zc_rpc(clnt, P9_TREADDIR, data, NULL, rsize, 0, - 11, 1, "dqd", fid->fid, offset, rsize); + req = p9_client_zc_rpc(clnt, P9_TREADDIR, &to, NULL, rsize, 0, + 11, "dqd", fid->fid, offset, rsize); } else { non_zc = 1; req = p9_client_rpc(clnt, P9_TREADDIR, "dqd", fid->fid, diff --git a/net/9p/protocol.c b/net/9p/protocol.c index ab9127ec5b7a..e9d0f0c1a048 100644 --- a/net/9p/protocol.c +++ b/net/9p/protocol.c @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include "protocol.h" @@ -69,10 +70,11 @@ static size_t pdu_write(struct p9_fcall *pdu, const void *data, size_t size) } static size_t -pdu_write_u(struct p9_fcall *pdu, const char __user *udata, size_t size) +pdu_write_u(struct p9_fcall *pdu, struct iov_iter *from, size_t size) { size_t len = min(pdu->capacity - pdu->size, size); - if (copy_from_user(&pdu->sdata[pdu->size], udata, len)) + struct iov_iter i = *from; + if (copy_from_iter(&pdu->sdata[pdu->size], len, &i) != len) len = 0; pdu->size += len; @@ -437,23 +439,13 @@ p9pdu_vwritef(struct p9_fcall *pdu, int proto_version, const char *fmt, stbuf->extension, stbuf->n_uid, stbuf->n_gid, stbuf->n_muid); } break; - case 'D':{ - uint32_t count = va_arg(ap, uint32_t); - const void *data = va_arg(ap, const void *); - - errcode = p9pdu_writef(pdu, proto_version, "d", - count); - if (!errcode && pdu_write(pdu, data, count)) - errcode = -EFAULT; - } - break; - case 'U':{ + case 'V':{ int32_t count = va_arg(ap, int32_t); - const char __user *udata = - va_arg(ap, const void __user *); + struct iov_iter *from = + va_arg(ap, struct iov_iter *); errcode = p9pdu_writef(pdu, proto_version, "d", count); - if (!errcode && pdu_write_u(pdu, udata, count)) + if (!errcode && pdu_write_u(pdu, from, count)) errcode = -EFAULT; } break; diff --git a/net/9p/trans_virtio.c b/net/9p/trans_virtio.c index 36a1a739ad68..e62bcbbabb5e 100644 --- a/net/9p/trans_virtio.c +++ b/net/9p/trans_virtio.c @@ -217,15 +217,15 @@ static int p9_virtio_cancel(struct p9_client *client, struct p9_req_t *req) * @start: which segment of the sg_list to start at * @pdata: a list of pages to add into sg. * @nr_pages: number of pages to pack into the scatter/gather list - * @data: data to pack into scatter/gather list + * @offs: amount of data in the beginning of first page _not_ to pack * @count: amount of data to pack into the scatter/gather list */ static int pack_sg_list_p(struct scatterlist *sg, int start, int limit, - struct page **pdata, int nr_pages, char *data, int count) + struct page **pdata, int nr_pages, size_t offs, int count) { int i = 0, s; - int data_off; + int data_off = offs; int index = start; BUG_ON(nr_pages > (limit - start)); @@ -233,16 +233,14 @@ pack_sg_list_p(struct scatterlist *sg, int start, int limit, * if the first page doesn't start at * page boundary find the offset */ - data_off = offset_in_page(data); while (nr_pages) { - s = rest_of_page(data); + s = PAGE_SIZE - data_off; if (s > count) s = count; /* Make sure we don't terminate early. */ sg_unmark_end(&sg[index]); sg_set_page(&sg[index++], pdata[i++], s, data_off); data_off = 0; - data += s; count -= s; nr_pages--; } @@ -314,11 +312,20 @@ req_retry: } static int p9_get_mapped_pages(struct virtio_chan *chan, - struct page **pages, char *data, - int nr_pages, int write, int kern_buf) + struct page ***pages, + struct iov_iter *data, + int count, + size_t *offs, + int *need_drop) { + int nr_pages; int err; - if (!kern_buf) { + + if (!iov_iter_count(data)) + return 0; + + if (!(data->type & ITER_KVEC)) { + int n; /* * We allow only p9_max_pages pinned. We wait for the * Other zc request to finish here @@ -329,26 +336,49 @@ static int p9_get_mapped_pages(struct virtio_chan *chan, if (err == -ERESTARTSYS) return err; } - err = p9_payload_gup(data, &nr_pages, pages, write); - if (err < 0) - return err; + n = iov_iter_get_pages_alloc(data, pages, count, offs); + if (n < 0) + return n; + *need_drop = 1; + nr_pages = DIV_ROUND_UP(n + *offs, PAGE_SIZE); atomic_add(nr_pages, &vp_pinned); + return n; } else { /* kernel buffer, no need to pin pages */ - int s, index = 0; - int count = nr_pages; - while (nr_pages) { - s = rest_of_page(data); - if (is_vmalloc_addr(data)) - pages[index++] = vmalloc_to_page(data); - else - pages[index++] = kmap_to_page(data); - data += s; - nr_pages--; + int index; + size_t len; + void *p; + + /* we'd already checked that it's non-empty */ + while (1) { + len = iov_iter_single_seg_count(data); + if (likely(len)) { + p = data->kvec->iov_base + data->iov_offset; + break; + } + iov_iter_advance(data, 0); } - nr_pages = count; + if (len > count) + len = count; + + nr_pages = DIV_ROUND_UP((unsigned long)p + len, PAGE_SIZE) - + (unsigned long)p / PAGE_SIZE; + + *pages = kmalloc(sizeof(struct page *) * nr_pages, GFP_NOFS); + if (!*pages) + return -ENOMEM; + + *need_drop = 0; + p -= (*offs = (unsigned long)p % PAGE_SIZE); + for (index = 0; index < nr_pages; index++) { + if (is_vmalloc_addr(p)) + (*pages)[index] = vmalloc_to_page(p); + else + (*pages)[index] = kmap_to_page(p); + p += PAGE_SIZE; + } + return len; } - return nr_pages; } /** @@ -364,8 +394,8 @@ static int p9_get_mapped_pages(struct virtio_chan *chan, */ static int p9_virtio_zc_request(struct p9_client *client, struct p9_req_t *req, - char *uidata, char *uodata, int inlen, - int outlen, int in_hdr_len, int kern_buf) + struct iov_iter *uidata, struct iov_iter *uodata, + int inlen, int outlen, int in_hdr_len) { int in, out, err, out_sgs, in_sgs; unsigned long flags; @@ -373,41 +403,32 @@ p9_virtio_zc_request(struct p9_client *client, struct p9_req_t *req, struct page **in_pages = NULL, **out_pages = NULL; struct virtio_chan *chan = client->trans; struct scatterlist *sgs[4]; + size_t offs; + int need_drop = 0; p9_debug(P9_DEBUG_TRANS, "virtio request\n"); if (uodata) { - out_nr_pages = p9_nr_pages(uodata, outlen); - out_pages = kmalloc(sizeof(struct page *) * out_nr_pages, - GFP_NOFS); - if (!out_pages) { - err = -ENOMEM; - goto err_out; + int n = p9_get_mapped_pages(chan, &out_pages, uodata, + outlen, &offs, &need_drop); + if (n < 0) + return n; + out_nr_pages = DIV_ROUND_UP(n + offs, PAGE_SIZE); + if (n != outlen) { + __le32 v = cpu_to_le32(n); + memcpy(&req->tc->sdata[req->tc->size - 4], &v, 4); + outlen = n; } - out_nr_pages = p9_get_mapped_pages(chan, out_pages, uodata, - out_nr_pages, 0, kern_buf); - if (out_nr_pages < 0) { - err = out_nr_pages; - kfree(out_pages); - out_pages = NULL; - goto err_out; - } - } - if (uidata) { - in_nr_pages = p9_nr_pages(uidata, inlen); - in_pages = kmalloc(sizeof(struct page *) * in_nr_pages, - GFP_NOFS); - if (!in_pages) { - err = -ENOMEM; - goto err_out; - } - in_nr_pages = p9_get_mapped_pages(chan, in_pages, uidata, - in_nr_pages, 1, kern_buf); - if (in_nr_pages < 0) { - err = in_nr_pages; - kfree(in_pages); - in_pages = NULL; - goto err_out; + } else if (uidata) { + int n = p9_get_mapped_pages(chan, &in_pages, uidata, + inlen, &offs, &need_drop); + if (n < 0) + return n; + in_nr_pages = DIV_ROUND_UP(n + offs, PAGE_SIZE); + if (n != inlen) { + __le32 v = cpu_to_le32(n); + memcpy(&req->tc->sdata[req->tc->size - 4], &v, 4); + inlen = n; } } req->status = REQ_STATUS_SENT; @@ -426,7 +447,7 @@ req_retry_pinned: if (out_pages) { sgs[out_sgs++] = chan->sg + out; out += pack_sg_list_p(chan->sg, out, VIRTQUEUE_NUM, - out_pages, out_nr_pages, uodata, outlen); + out_pages, out_nr_pages, offs, outlen); } /* @@ -444,7 +465,7 @@ req_retry_pinned: if (in_pages) { sgs[out_sgs + in_sgs++] = chan->sg + out + in; in += pack_sg_list_p(chan->sg, out + in, VIRTQUEUE_NUM, - in_pages, in_nr_pages, uidata, inlen); + in_pages, in_nr_pages, offs, inlen); } BUG_ON(out_sgs + in_sgs > ARRAY_SIZE(sgs)); @@ -478,7 +499,7 @@ req_retry_pinned: * Non kernel buffers are pinned, unpin them */ err_out: - if (!kern_buf) { + if (need_drop) { if (in_pages) { p9_release_pages(in_pages, in_nr_pages); atomic_sub(in_nr_pages, &vp_pinned);