diff --git a/virt/kvm/guest_memfd.c b/virt/kvm/guest_memfd.c index 0176089be731..bfe437098b79 100644 --- a/virt/kvm/guest_memfd.c +++ b/virt/kvm/guest_memfd.c @@ -528,33 +528,29 @@ void kvm_gmem_unbind(struct kvm_memory_slot *slot) fput(file); } -int kvm_gmem_get_pfn(struct kvm *kvm, struct kvm_memory_slot *slot, - gfn_t gfn, kvm_pfn_t *pfn, int *max_order) +static int __kvm_gmem_get_pfn(struct file *file, struct kvm_memory_slot *slot, + gfn_t gfn, kvm_pfn_t *pfn, int *max_order, bool prepare) { pgoff_t index = gfn - slot->base_gfn + slot->gmem.pgoff; - struct kvm_gmem *gmem; + struct kvm_gmem *gmem = file->private_data; struct folio *folio; struct page *page; - struct file *file; int r; - file = kvm_gmem_get_file(slot); - if (!file) + if (file != slot->gmem.file) { + WARN_ON_ONCE(slot->gmem.file); return -EFAULT; + } gmem = file->private_data; - if (xa_load(&gmem->bindings, index) != slot) { WARN_ON_ONCE(xa_load(&gmem->bindings, index)); - r = -EIO; - goto out_fput; + return -EIO; } - folio = kvm_gmem_get_folio(file_inode(file), index, true); - if (IS_ERR(folio)) { - r = PTR_ERR(folio); - goto out_fput; - } + folio = kvm_gmem_get_folio(file_inode(file), index, prepare); + if (IS_ERR(folio)) + return PTR_ERR(folio); if (folio_test_hwpoison(folio)) { r = -EHWPOISON; @@ -571,9 +567,21 @@ int kvm_gmem_get_pfn(struct kvm *kvm, struct kvm_memory_slot *slot, out_unlock: folio_unlock(folio); -out_fput: - fput(file); return r; } + +int kvm_gmem_get_pfn(struct kvm *kvm, struct kvm_memory_slot *slot, + gfn_t gfn, kvm_pfn_t *pfn, int *max_order) +{ + struct file *file = kvm_gmem_get_file(slot); + int r; + + if (!file) + return -EFAULT; + + r = __kvm_gmem_get_pfn(file, slot, gfn, pfn, max_order, true); + fput(file); + return r; +} EXPORT_SYMBOL_GPL(kvm_gmem_get_pfn);