diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 1d89715af89d..0536f8526359 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -299,30 +299,53 @@ static void vhost_vq_meta_reset(struct vhost_dev *d) } #if VHOST_ARCH_CAN_ACCEL_UACCESS -static void __vhost_cleanup_vq_maps(struct vhost_virtqueue *vq) +static void vhost_map_unprefetch(struct vhost_map *map) +{ + kfree(map->pages); + map->pages = NULL; + map->npages = 0; + map->addr = NULL; +} + +static void vhost_uninit_vq_maps(struct vhost_virtqueue *vq) { struct vhost_map *map[VHOST_NUM_ADDRS]; int i; + spin_lock(&vq->mmu_lock); for (i = 0; i < VHOST_NUM_ADDRS; i++) { map[i] = rcu_dereference_protected(vq->maps[i], lockdep_is_held(&vq->mmu_lock)); - if (map[i]) { - if (vq->uaddrs[i].write) { - for (i = 0; i < map[i]->npages; i++) - set_page_dirty(map[i]->pages[i]); - } + if (map[i]) rcu_assign_pointer(vq->maps[i], NULL); - kfree_rcu(map[i], head); - } } + spin_unlock(&vq->mmu_lock); + + synchronize_rcu(); + + for (i = 0; i < VHOST_NUM_ADDRS; i++) + if (map[i]) + vhost_map_unprefetch(map[i]); + } -static void vhost_cleanup_vq_maps(struct vhost_virtqueue *vq) +static void vhost_reset_vq_maps(struct vhost_virtqueue *vq) { - spin_lock(&vq->mmu_lock); - __vhost_cleanup_vq_maps(vq); - spin_unlock(&vq->mmu_lock); + int i; + + vhost_uninit_vq_maps(vq); + for (i = 0; i < VHOST_NUM_ADDRS; i++) + vq->uaddrs[i].size = 0; +} + +static bool vhost_map_range_overlap(struct vhost_uaddr *uaddr, + unsigned long start, + unsigned long end) +{ + if (unlikely(!uaddr->size)) + return false; + + return !(end < uaddr->uaddr || start > uaddr->uaddr - 1 + uaddr->size); } static void vhost_invalidate_vq_start(struct vhost_virtqueue *vq, @@ -330,11 +353,31 @@ static void vhost_invalidate_vq_start(struct vhost_virtqueue *vq, unsigned long start, unsigned long end) { + struct vhost_uaddr *uaddr = &vq->uaddrs[index]; + struct vhost_map *map; + int i; + + if (!vhost_map_range_overlap(uaddr, start, end)) + return; + spin_lock(&vq->mmu_lock); ++vq->invalidate_count; - __vhost_cleanup_vq_maps(vq); + map = rcu_dereference_protected(vq->maps[index], + lockdep_is_held(&vq->mmu_lock)); + if (map) { + if (uaddr->write) { + for (i = 0; i < map->npages; i++) + set_page_dirty(map->pages[i]); + } + rcu_assign_pointer(vq->maps[index], NULL); + } spin_unlock(&vq->mmu_lock); + + if (map) { + synchronize_rcu(); + vhost_map_unprefetch(map); + } } static void vhost_invalidate_vq_end(struct vhost_virtqueue *vq, @@ -342,6 +385,9 @@ static void vhost_invalidate_vq_end(struct vhost_virtqueue *vq, unsigned long start, unsigned long end) { + if (!vhost_map_range_overlap(&vq->uaddrs[index], start, end)) + return; + spin_lock(&vq->mmu_lock); --vq->invalidate_count; spin_unlock(&vq->mmu_lock); @@ -437,7 +483,7 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->invalidate_count = 0; __vhost_vq_meta_reset(vq); #if VHOST_ARCH_CAN_ACCEL_UACCESS - vhost_cleanup_vq_maps(vq); + vhost_reset_vq_maps(vq); #endif } @@ -787,7 +833,6 @@ static void vhost_setup_uaddr(struct vhost_virtqueue *vq, size_t size, bool write) { struct vhost_uaddr *addr = &vq->uaddrs[index]; - spin_lock(&vq->mmu_lock); addr->uaddr = uaddr; addr->size = size; @@ -796,8 +841,6 @@ static void vhost_setup_uaddr(struct vhost_virtqueue *vq, static void vhost_setup_vq_uaddr(struct vhost_virtqueue *vq) { - spin_lock(&vq->mmu_lock); - vhost_setup_uaddr(vq, VHOST_ADDR_DESC, (unsigned long)vq->desc, vhost_get_desc_size(vq, vq->num), @@ -810,8 +853,6 @@ static void vhost_setup_vq_uaddr(struct vhost_virtqueue *vq) (unsigned long)vq->used, vhost_get_used_size(vq, vq->num), true); - - spin_unlock(&vq->mmu_lock); } static int vhost_map_prefetch(struct vhost_virtqueue *vq, @@ -833,11 +874,13 @@ static int vhost_map_prefetch(struct vhost_virtqueue *vq, goto err; err = -ENOMEM; - map = kmalloc(sizeof(*map) + sizeof(*map->pages) * npages, GFP_ATOMIC); + map = kmalloc(sizeof(*map), GFP_ATOMIC); if (!map) goto err; - pages = map->pages; + pages = kmalloc_array(npages, sizeof(struct page *), GFP_ATOMIC); + if (!pages) + goto err_pages; err = EFAULT; npinned = __get_user_pages_fast(uaddr->uaddr, npages, @@ -864,6 +907,7 @@ static int vhost_map_prefetch(struct vhost_virtqueue *vq, map->addr = vaddr + (uaddr->uaddr & (PAGE_SIZE - 1)); map->npages = npages; + map->pages = pages; rcu_assign_pointer(vq->maps[index], map); /* No need for a synchronize_rcu(). This function should be @@ -875,6 +919,8 @@ static int vhost_map_prefetch(struct vhost_virtqueue *vq, return 0; err_gup: + kfree(pages); +err_pages: kfree(map); err: spin_unlock(&vq->mmu_lock); @@ -896,10 +942,6 @@ void vhost_dev_cleanup(struct vhost_dev *dev) vhost_vq_reset(dev, dev->vqs[i]); } vhost_dev_free_iovecs(dev); -#if VHOST_ARCH_CAN_ACCEL_UACCESS - if (dev->mm) - mmu_notifier_unregister(&dev->mmu_notifier, dev->mm); -#endif if (dev->log_ctx) eventfd_ctx_put(dev->log_ctx); dev->log_ctx = NULL; @@ -915,8 +957,16 @@ void vhost_dev_cleanup(struct vhost_dev *dev) kthread_stop(dev->worker); dev->worker = NULL; } - if (dev->mm) + if (dev->mm) { +#if VHOST_ARCH_CAN_ACCEL_UACCESS + mmu_notifier_unregister(&dev->mmu_notifier, dev->mm); +#endif mmput(dev->mm); + } +#if VHOST_ARCH_CAN_ACCEL_UACCESS + for (i = 0; i < dev->nvqs; i++) + vhost_uninit_vq_maps(dev->vqs[i]); +#endif dev->mm = NULL; } EXPORT_SYMBOL_GPL(vhost_dev_cleanup); @@ -1376,7 +1426,7 @@ static inline int vhost_get_used_event(struct vhost_virtqueue *vq, map = rcu_dereference(vq->maps[VHOST_ADDR_AVAIL]); if (likely(map)) { avail = map->addr; - *event = avail->ring[vq->num]; + *event = (__virtio16)avail->ring[vq->num]; rcu_read_unlock(); return 0; } @@ -1780,8 +1830,6 @@ static void vhost_vq_map_prefetch(struct vhost_virtqueue *vq) struct vhost_map __rcu *map; int i; - vhost_setup_vq_uaddr(vq); - for (i = 0; i < VHOST_NUM_ADDRS; i++) { rcu_read_lock(); map = rcu_dereference(vq->maps[i]); @@ -1790,10 +1838,6 @@ static void vhost_vq_map_prefetch(struct vhost_virtqueue *vq) vhost_map_prefetch(vq, i); } } -#else -static void vhost_vq_map_prefetch(struct vhost_virtqueue *vq) -{ -} #endif int vq_meta_prefetch(struct vhost_virtqueue *vq) @@ -1801,7 +1845,9 @@ int vq_meta_prefetch(struct vhost_virtqueue *vq) unsigned int num = vq->num; if (!vq->iotlb) { +#if VHOST_ARCH_CAN_ACCEL_UACCESS vhost_vq_map_prefetch(vq); +#endif return 1; } @@ -2014,6 +2060,16 @@ static long vhost_vring_set_num_addr(struct vhost_dev *d, mutex_lock(&vq->mutex); +#if VHOST_ARCH_CAN_ACCEL_UACCESS + /* Unregister MMU notifer to allow invalidation callback + * can access vq->uaddrs[] without holding a lock. + */ + if (d->mm) + mmu_notifier_unregister(&d->mmu_notifier, d->mm); + + vhost_uninit_vq_maps(vq); +#endif + switch (ioctl) { case VHOST_SET_VRING_NUM: r = vhost_vring_set_num(d, vq, argp); @@ -2025,6 +2081,13 @@ static long vhost_vring_set_num_addr(struct vhost_dev *d, BUG(); } +#if VHOST_ARCH_CAN_ACCEL_UACCESS + vhost_setup_vq_uaddr(vq); + + if (d->mm) + mmu_notifier_register(&d->mmu_notifier, d->mm); +#endif + mutex_unlock(&vq->mutex); return r; diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h index 584bb13c4d6d..819296332913 100644 --- a/drivers/vhost/vhost.h +++ b/drivers/vhost/vhost.h @@ -86,8 +86,7 @@ enum vhost_uaddr_type { struct vhost_map { int npages; void *addr; - struct rcu_head head; - struct page *pages[]; + struct page **pages; }; struct vhost_uaddr {