diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index dc9301d31f12..67ea39c5d7d9 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -300,53 +300,30 @@ static void vhost_vq_meta_reset(struct vhost_dev *d) } #if VHOST_ARCH_CAN_ACCEL_UACCESS -static void vhost_map_unprefetch(struct vhost_map *map) -{ - kfree(map->pages); - map->pages = NULL; - map->npages = 0; - map->addr = NULL; -} - -static void vhost_uninit_vq_maps(struct vhost_virtqueue *vq) +static void __vhost_cleanup_vq_maps(struct vhost_virtqueue *vq) { struct vhost_map *map[VHOST_NUM_ADDRS]; int i; - spin_lock(&vq->mmu_lock); for (i = 0; i < VHOST_NUM_ADDRS; i++) { map[i] = rcu_dereference_protected(vq->maps[i], lockdep_is_held(&vq->mmu_lock)); - if (map[i]) + if (map[i]) { + if (vq->uaddrs[i].write) { + for (i = 0; i < map[i]->npages; i++) + set_page_dirty(map[i]->pages[i]); + } rcu_assign_pointer(vq->maps[i], NULL); + kfree_rcu(map[i], head); + } } - spin_unlock(&vq->mmu_lock); - - synchronize_rcu(); - - for (i = 0; i < VHOST_NUM_ADDRS; i++) - if (map[i]) - vhost_map_unprefetch(map[i]); - -} - -static void vhost_reset_vq_maps(struct vhost_virtqueue *vq) -{ - int i; - - vhost_uninit_vq_maps(vq); - for (i = 0; i < VHOST_NUM_ADDRS; i++) - vq->uaddrs[i].size = 0; } -static bool vhost_map_range_overlap(struct vhost_uaddr *uaddr, - unsigned long start, - unsigned long end) +static void vhost_cleanup_vq_maps(struct vhost_virtqueue *vq) { - if (unlikely(!uaddr->size)) - return false; - - return !(end < uaddr->uaddr || start > uaddr->uaddr - 1 + uaddr->size); + spin_lock(&vq->mmu_lock); + __vhost_cleanup_vq_maps(vq); + spin_unlock(&vq->mmu_lock); } static void vhost_invalidate_vq_start(struct vhost_virtqueue *vq, @@ -354,31 +331,11 @@ static void vhost_invalidate_vq_start(struct vhost_virtqueue *vq, unsigned long start, unsigned long end) { - struct vhost_uaddr *uaddr = &vq->uaddrs[index]; - struct vhost_map *map; - int i; - - if (!vhost_map_range_overlap(uaddr, start, end)) - return; - spin_lock(&vq->mmu_lock); ++vq->invalidate_count; - map = rcu_dereference_protected(vq->maps[index], - lockdep_is_held(&vq->mmu_lock)); - if (map) { - if (uaddr->write) { - for (i = 0; i < map->npages; i++) - set_page_dirty(map->pages[i]); - } - rcu_assign_pointer(vq->maps[index], NULL); - } + __vhost_cleanup_vq_maps(vq); spin_unlock(&vq->mmu_lock); - - if (map) { - synchronize_rcu(); - vhost_map_unprefetch(map); - } } static void vhost_invalidate_vq_end(struct vhost_virtqueue *vq, @@ -386,9 +343,6 @@ static void vhost_invalidate_vq_end(struct vhost_virtqueue *vq, unsigned long start, unsigned long end) { - if (!vhost_map_range_overlap(&vq->uaddrs[index], start, end)) - return; - spin_lock(&vq->mmu_lock); --vq->invalidate_count; spin_unlock(&vq->mmu_lock); @@ -484,7 +438,7 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->invalidate_count = 0; __vhost_vq_meta_reset(vq); #if VHOST_ARCH_CAN_ACCEL_UACCESS - vhost_reset_vq_maps(vq); + vhost_cleanup_vq_maps(vq); #endif } @@ -834,6 +788,7 @@ static void vhost_setup_uaddr(struct vhost_virtqueue *vq, size_t size, bool write) { struct vhost_uaddr *addr = &vq->uaddrs[index]; + spin_lock(&vq->mmu_lock); addr->uaddr = uaddr; addr->size = size; @@ -842,6 +797,8 @@ static void vhost_setup_uaddr(struct vhost_virtqueue *vq, static void vhost_setup_vq_uaddr(struct vhost_virtqueue *vq) { + spin_lock(&vq->mmu_lock); + vhost_setup_uaddr(vq, VHOST_ADDR_DESC, (unsigned long)vq->desc, vhost_get_desc_size(vq, vq->num), @@ -854,6 +811,8 @@ static void vhost_setup_vq_uaddr(struct vhost_virtqueue *vq) (unsigned long)vq->used, vhost_get_used_size(vq, vq->num), true); + + spin_unlock(&vq->mmu_lock); } static int vhost_map_prefetch(struct vhost_virtqueue *vq, @@ -875,13 +834,11 @@ static int vhost_map_prefetch(struct vhost_virtqueue *vq, goto err; err = -ENOMEM; - map = kmalloc(sizeof(*map), GFP_ATOMIC); + map = kmalloc(sizeof(*map) + sizeof(*map->pages) * npages, GFP_ATOMIC); if (!map) goto err; - pages = kmalloc_array(npages, sizeof(struct page *), GFP_ATOMIC); - if (!pages) - goto err_pages; + pages = map->pages; err = EFAULT; npinned = __get_user_pages_fast(uaddr->uaddr, npages, @@ -908,7 +865,6 @@ static int vhost_map_prefetch(struct vhost_virtqueue *vq, map->addr = vaddr + (uaddr->uaddr & (PAGE_SIZE - 1)); map->npages = npages; - map->pages = pages; rcu_assign_pointer(vq->maps[index], map); /* No need for a synchronize_rcu(). This function should be @@ -920,8 +876,6 @@ static int vhost_map_prefetch(struct vhost_virtqueue *vq, return 0; err_gup: - kfree(pages); -err_pages: kfree(map); err: spin_unlock(&vq->mmu_lock); @@ -943,6 +897,10 @@ void vhost_dev_cleanup(struct vhost_dev *dev) vhost_vq_reset(dev, dev->vqs[i]); } vhost_dev_free_iovecs(dev); +#if VHOST_ARCH_CAN_ACCEL_UACCESS + if (dev->mm) + mmu_notifier_unregister(&dev->mmu_notifier, dev->mm); +#endif if (dev->log_ctx) eventfd_ctx_put(dev->log_ctx); dev->log_ctx = NULL; @@ -958,16 +916,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev) kthread_stop(dev->worker); dev->worker = NULL; } - if (dev->mm) { -#if VHOST_ARCH_CAN_ACCEL_UACCESS - mmu_notifier_unregister(&dev->mmu_notifier, dev->mm); -#endif + if (dev->mm) mmput(dev->mm); - } -#if VHOST_ARCH_CAN_ACCEL_UACCESS - for (i = 0; i < dev->nvqs; i++) - vhost_uninit_vq_maps(dev->vqs[i]); -#endif dev->mm = NULL; } EXPORT_SYMBOL_GPL(vhost_dev_cleanup); @@ -1427,7 +1377,7 @@ static inline int vhost_get_used_event(struct vhost_virtqueue *vq, map = rcu_dereference(vq->maps[VHOST_ADDR_AVAIL]); if (likely(map)) { avail = map->addr; - *event = (__virtio16)avail->ring[vq->num]; + *event = avail->ring[vq->num]; rcu_read_unlock(); return 0; } @@ -1831,6 +1781,8 @@ static void vhost_vq_map_prefetch(struct vhost_virtqueue *vq) struct vhost_map __rcu *map; int i; + vhost_setup_vq_uaddr(vq); + for (i = 0; i < VHOST_NUM_ADDRS; i++) { rcu_read_lock(); map = rcu_dereference(vq->maps[i]); @@ -1839,6 +1791,10 @@ static void vhost_vq_map_prefetch(struct vhost_virtqueue *vq) vhost_map_prefetch(vq, i); } } +#else +static void vhost_vq_map_prefetch(struct vhost_virtqueue *vq) +{ +} #endif int vq_meta_prefetch(struct vhost_virtqueue *vq) @@ -1846,9 +1802,7 @@ int vq_meta_prefetch(struct vhost_virtqueue *vq) unsigned int num = vq->num; if (!vq->iotlb) { -#if VHOST_ARCH_CAN_ACCEL_UACCESS vhost_vq_map_prefetch(vq); -#endif return 1; } @@ -2061,16 +2015,6 @@ static long vhost_vring_set_num_addr(struct vhost_dev *d, mutex_lock(&vq->mutex); -#if VHOST_ARCH_CAN_ACCEL_UACCESS - /* Unregister MMU notifer to allow invalidation callback - * can access vq->uaddrs[] without holding a lock. - */ - if (d->mm) - mmu_notifier_unregister(&d->mmu_notifier, d->mm); - - vhost_uninit_vq_maps(vq); -#endif - switch (ioctl) { case VHOST_SET_VRING_NUM: r = vhost_vring_set_num(d, vq, argp); @@ -2082,13 +2026,6 @@ static long vhost_vring_set_num_addr(struct vhost_dev *d, BUG(); } -#if VHOST_ARCH_CAN_ACCEL_UACCESS - vhost_setup_vq_uaddr(vq); - - if (d->mm) - mmu_notifier_register(&d->mmu_notifier, d->mm); -#endif - mutex_unlock(&vq->mutex); return r; diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h index 819296332913..584bb13c4d6d 100644 --- a/drivers/vhost/vhost.h +++ b/drivers/vhost/vhost.h @@ -86,7 +86,8 @@ enum vhost_uaddr_type { struct vhost_map { int npages; void *addr; - struct page **pages; + struct rcu_head head; + struct page *pages[]; }; struct vhost_uaddr {