diff --git a/drivers/vdpa/virtio_pci/vp_vdpa.c b/drivers/vdpa/virtio_pci/vp_vdpa.c index 04522077735b..d448db0c4de3 100644 --- a/drivers/vdpa/virtio_pci/vp_vdpa.c +++ b/drivers/vdpa/virtio_pci/vp_vdpa.c @@ -17,6 +17,7 @@ #include #include #include +#include #define VP_VDPA_QUEUE_MAX 256 #define VP_VDPA_DRIVER_NAME "vp_vdpa" @@ -35,6 +36,7 @@ struct vp_vdpa { struct virtio_pci_modern_device *mdev; struct vp_vring *vring; struct vdpa_callback config_cb; + u64 device_features; char msix_name[VP_VDPA_NAME_SIZE]; int config_irq; int queues; @@ -66,9 +68,9 @@ static struct virtio_pci_modern_device *vp_vdpa_to_mdev(struct vp_vdpa *vp_vdpa) static u64 vp_vdpa_get_device_features(struct vdpa_device *vdpa) { - struct virtio_pci_modern_device *mdev = vdpa_to_mdev(vdpa); + struct vp_vdpa *vp_vdpa = vdpa_to_vp(vdpa); - return vp_modern_get_features(mdev); + return vp_vdpa->device_features; } static int vp_vdpa_set_driver_features(struct vdpa_device *vdpa, u64 features) @@ -475,6 +477,7 @@ static int vp_vdpa_dev_add(struct vdpa_mgmt_dev *v_mdev, const char *name, struct pci_dev *pdev = mdev->pci_dev; struct device *dev = &pdev->dev; struct vp_vdpa *vp_vdpa = NULL; + u64 device_features; int ret, i; vp_vdpa = vdpa_alloc_device(struct vp_vdpa, vdpa, @@ -491,6 +494,20 @@ static int vp_vdpa_dev_add(struct vdpa_mgmt_dev *v_mdev, const char *name, vp_vdpa->queues = vp_modern_get_num_queues(mdev); vp_vdpa->mdev = mdev; + device_features = vp_modern_get_features(mdev); + if (add_config->mask & BIT_ULL(VDPA_ATTR_DEV_FEATURES)) { + if (add_config->device_features & ~device_features) { + ret = -EINVAL; + dev_err(&pdev->dev, "Try to provision features " + "that are not supported by the device: " + "device_features 0x%llx provisioned 0x%llx\n", + device_features, add_config->device_features); + goto err; + } + device_features = add_config->device_features; + } + vp_vdpa->device_features = device_features; + ret = devm_add_action_or_reset(dev, vp_vdpa_free_irq_vectors, pdev); if (ret) { dev_err(&pdev->dev, @@ -599,6 +616,7 @@ static int vp_vdpa_probe(struct pci_dev *pdev, const struct pci_device_id *id) mgtdev->id_table = mdev_id; mgtdev->max_supported_vqs = vp_modern_get_num_queues(mdev); mgtdev->supported_features = vp_modern_get_features(mdev); + mgtdev->config_attr_mask = (1 << VDPA_ATTR_DEV_FEATURES); pci_set_master(pdev); pci_set_drvdata(pdev, vp_vdpa_mgtdev);