diff --git a/drivers/iommu/apple-dart.c b/drivers/iommu/apple-dart.c index 2566cf3111c7..126da0d89f0d 100644 --- a/drivers/iommu/apple-dart.c +++ b/drivers/iommu/apple-dart.c @@ -568,10 +568,9 @@ apple_dart_setup_translation(struct apple_dart_domain *domain, stream_map->dart->hw->invalidate_tlb(stream_map); } -static int apple_dart_finalize_domain(struct iommu_domain *domain, +static int apple_dart_finalize_domain(struct apple_dart_domain *dart_domain, struct apple_dart_master_cfg *cfg) { - struct apple_dart_domain *dart_domain = to_dart_domain(domain); struct apple_dart *dart = cfg->stream_maps[0].dart; struct io_pgtable_cfg pgtbl_cfg; int ret = 0; @@ -597,17 +596,18 @@ static int apple_dart_finalize_domain(struct iommu_domain *domain, .iommu_dev = dart->dev, }; - dart_domain->pgtbl_ops = - alloc_io_pgtable_ops(dart->hw->fmt, &pgtbl_cfg, domain); + dart_domain->pgtbl_ops = alloc_io_pgtable_ops(dart->hw->fmt, &pgtbl_cfg, + &dart_domain->domain); if (!dart_domain->pgtbl_ops) { ret = -ENOMEM; goto done; } - domain->pgsize_bitmap = pgtbl_cfg.pgsize_bitmap; - domain->geometry.aperture_start = 0; - domain->geometry.aperture_end = (dma_addr_t)DMA_BIT_MASK(dart->ias); - domain->geometry.force_aperture = true; + dart_domain->domain.pgsize_bitmap = pgtbl_cfg.pgsize_bitmap; + dart_domain->domain.geometry.aperture_start = 0; + dart_domain->domain.geometry.aperture_end = + (dma_addr_t)DMA_BIT_MASK(dart->ias); + dart_domain->domain.geometry.force_aperture = true; dart_domain->finalized = true; @@ -662,7 +662,7 @@ static int apple_dart_attach_dev_paging(struct iommu_domain *domain, if (cfg->stream_maps[0].dart->force_bypass) return -EINVAL; - ret = apple_dart_finalize_domain(domain, cfg); + ret = apple_dart_finalize_domain(dart_domain, cfg); if (ret) return ret; @@ -758,6 +758,16 @@ static struct iommu_domain *apple_dart_domain_alloc_paging(struct device *dev) mutex_init(&dart_domain->init_lock); + if (dev) { + struct apple_dart_master_cfg *cfg = dev_iommu_priv_get(dev); + int ret; + + ret = apple_dart_finalize_domain(dart_domain, cfg); + if (ret) { + kfree(dart_domain); + return ERR_PTR(ret); + } + } return &dart_domain->domain; }