diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 34c0d970bcbc..058191d5efad 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -630,6 +630,7 @@ void vhost_dev_init(struct vhost_dev *dev, dev->iov_limit = iov_limit; dev->weight = weight; dev->byte_weight = byte_weight; + dev->has_notifier = false; init_llist_head(&dev->work_list); init_waitqueue_head(&dev->wait); INIT_LIST_HEAD(&dev->read_list); @@ -731,6 +732,7 @@ long vhost_dev_set_owner(struct vhost_dev *dev) if (err) goto err_mmu_notifier; #endif + dev->has_notifier = true; return 0; @@ -960,7 +962,11 @@ void vhost_dev_cleanup(struct vhost_dev *dev) } if (dev->mm) { #if VHOST_ARCH_CAN_ACCEL_UACCESS - mmu_notifier_unregister(&dev->mmu_notifier, dev->mm); + if (dev->has_notifier) { + mmu_notifier_unregister(&dev->mmu_notifier, + dev->mm); + dev->has_notifier = false; + } #endif mmput(dev->mm); } @@ -2065,8 +2071,10 @@ static long vhost_vring_set_num_addr(struct vhost_dev *d, /* Unregister MMU notifer to allow invalidation callback * can access vq->uaddrs[] without holding a lock. */ - if (d->mm) + if (d->has_notifier) { mmu_notifier_unregister(&d->mmu_notifier, d->mm); + d->has_notifier = false; + } vhost_uninit_vq_maps(vq); #endif @@ -2086,8 +2094,11 @@ static long vhost_vring_set_num_addr(struct vhost_dev *d, if (r == 0) vhost_setup_vq_uaddr(vq); - if (d->mm) - mmu_notifier_register(&d->mmu_notifier, d->mm); + if (d->mm) { + r = mmu_notifier_register(&d->mmu_notifier, d->mm); + if (!r) + d->has_notifier = true; + } #endif mutex_unlock(&vq->mutex); diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h index c5d950cf7627..859268a09169 100644 --- a/drivers/vhost/vhost.h +++ b/drivers/vhost/vhost.h @@ -211,6 +211,7 @@ struct vhost_dev { int iov_limit; int weight; int byte_weight; + bool has_notifier; }; bool vhost_exceeds_weight(struct vhost_virtqueue *vq, int pkts, int total_len);