diff --git a/drivers/iommu/iommufd/main.c b/drivers/iommu/iommufd/main.c index e71523cbd0de..8d2dba522baf 100644 --- a/drivers/iommu/iommufd/main.c +++ b/drivers/iommu/iommufd/main.c @@ -137,6 +137,7 @@ static struct iommufd_object *iommufd_object_remove(struct iommufd_ctx *ictx, u32 id, bool extra_put) { struct iommufd_object *obj; + struct iommufd_ioas *ioas; XA_STATE(xas, &ictx->objects, id); xa_lock(&ictx->objects); @@ -159,7 +160,9 @@ static struct iommufd_object *iommufd_object_remove(struct iommufd_ctx *ictx, } xas_store(&xas, NULL); - if (ictx->vfio_ioas == container_of(obj, struct iommufd_ioas, obj)) + ioas = container_of(obj, struct iommufd_ioas, obj); + ioas->obj = NULL; + if (ictx->vfio_ioas == ioas) ictx->vfio_ioas = NULL; out_xa: diff --git a/drivers/iommu/iommufd/vfio_compat.c b/drivers/iommu/iommufd/vfio_compat.c index 6c810bf80f99..06317d0bd95e 100644 --- a/drivers/iommu/iommufd/vfio_compat.c +++ b/drivers/iommu/iommufd/vfio_compat.c @@ -140,6 +140,8 @@ int iommufd_vfio_ioas(struct iommufd_ucmd *ucmd) ioas = iommufd_get_ioas(ucmd->ictx, cmd->ioas_id); if (IS_ERR(ioas)) return PTR_ERR(ioas); + if (!ioas->obj) + return -EINVAL; xa_lock(&ucmd->ictx->objects); ucmd->ictx->vfio_ioas = ioas; xa_unlock(&ucmd->ictx->objects);