diff --git a/drivers/iommu/iommufd/main.c b/drivers/iommu/iommufd/main.c index e71523cbd0de..6b75f9fab6ce 100644 --- a/drivers/iommu/iommufd/main.c +++ b/drivers/iommu/iommufd/main.c @@ -208,11 +208,14 @@ static int iommufd_destroy(struct iommufd_ucmd *ucmd) { struct iommu_destroy *cmd = ucmd->cmd; struct iommufd_object *obj; + struct iommufd_ioas *ioas; obj = iommufd_object_remove(ucmd->ictx, cmd->id, false); if (IS_ERR(obj)) return PTR_ERR(obj); iommufd_object_ops[obj->type].destroy(obj); + ioas = container_of(obj, struct iommufd_ioas, obj); + ioas->obj = NULL; kfree(obj); return 0; } diff --git a/drivers/iommu/iommufd/vfio_compat.c b/drivers/iommu/iommufd/vfio_compat.c index 6c810bf80f99..06317d0bd95e 100644 --- a/drivers/iommu/iommufd/vfio_compat.c +++ b/drivers/iommu/iommufd/vfio_compat.c @@ -140,6 +140,8 @@ int iommufd_vfio_ioas(struct iommufd_ucmd *ucmd) ioas = iommufd_get_ioas(ucmd->ictx, cmd->ioas_id); if (IS_ERR(ioas)) return PTR_ERR(ioas); + if (!ioas->obj) + return -EINVAL; xa_lock(&ucmd->ictx->objects); ucmd->ictx->vfio_ioas = ioas; xa_unlock(&ucmd->ictx->objects);