diff --git a/drivers/infiniband/core/ucma.c b/drivers/infiniband/core/ucma.c index 6e700b974033..89c444c6f317 100644 --- a/drivers/infiniband/core/ucma.c +++ b/drivers/infiniband/core/ucma.c @@ -109,6 +109,7 @@ struct ucma_multicast { u8 join_state; struct list_head list; struct sockaddr_storage addr; + atomic_t ref; }; struct ucma_event { @@ -257,6 +258,12 @@ static void ucma_copy_ud_event(struct ib_device *device, dst->qkey = src->qkey; } +static void ucma_put_mc(struct ucma_multicast *mc) +{ + if (mc && atomic_dec_and_test(&mc->ref)) + kfree(mc); +} + static struct ucma_event *ucma_create_uevent(struct ucma_context *ctx, struct rdma_cm_event *event) { @@ -274,6 +281,7 @@ static struct ucma_event *ucma_create_uevent(struct ucma_context *ctx, event->param.ud.private_data; uevent->resp.uid = uevent->mc->uid; uevent->resp.id = uevent->mc->id; + ucma_put_mc(uevent->mc); break; default: uevent->resp.uid = ctx->uid; @@ -1471,6 +1479,7 @@ static ssize_t ucma_process_join(struct ucma_file *file, mc->ctx = ctx; mc->join_state = join_state; mc->uid = cmd->uid; + atomic_set(&mc->ref, 1); memcpy(&mc->addr, addr, cmd->addr_size); xa_lock(&multicast_table); @@ -1489,6 +1498,7 @@ static ssize_t ucma_process_join(struct ucma_file *file, mutex_unlock(&ctx->mutex); if (ret) goto err_xa_erase; + atomic_inc(&mc->ref); resp.id = mc->id; if (copy_to_user(u64_to_user_ptr(cmd->response), @@ -1513,7 +1523,7 @@ static ssize_t ucma_process_join(struct ucma_file *file, __xa_erase(&multicast_table, mc->id); err_free_mc: xa_unlock(&multicast_table); - kfree(mc); + ucma_put_mc(mc); err_put_ctx: ucma_put_ctx(ctx); return ret;