diff --git a/drivers/infiniband/core/ucma.c b/drivers/infiniband/core/ucma.c index 6e700b974033..7fa02963a309 100644 --- a/drivers/infiniband/core/ucma.c +++ b/drivers/infiniband/core/ucma.c @@ -109,6 +109,7 @@ struct ucma_multicast { u8 join_state; struct list_head list; struct sockaddr_storage addr; + atomic_t ref; }; struct ucma_event { @@ -257,6 +258,17 @@ static void ucma_copy_ud_event(struct ib_device *device, dst->qkey = src->qkey; } +static bool ucma_get_mc(struct ucma_multicast *mc) +{ + return mc && atomic_inc_not_zero(&mc->ref); +} + +static void ucma_put_mc(struct ucma_multicast *mc) +{ + if (mc && atomic_dec_and_test(&mc->ref)) + kfree(mc); +} + static struct ucma_event *ucma_create_uevent(struct ucma_context *ctx, struct rdma_cm_event *event) { @@ -272,9 +284,12 @@ static struct ucma_event *ucma_create_uevent(struct ucma_context *ctx, case RDMA_CM_EVENT_MULTICAST_ERROR: uevent->mc = (struct ucma_multicast *) event->param.ud.private_data; - uevent->resp.uid = uevent->mc->uid; - uevent->resp.id = uevent->mc->id; - break; + if (ucma_get_mc(uevent->mc)) { + uevent->resp.uid = uevent->mc->uid; + uevent->resp.id = uevent->mc->id; + ucma_put_mc(uevent->mc); + break; + } default: uevent->resp.uid = ctx->uid; uevent->resp.id = ctx->id; @@ -498,7 +513,7 @@ static void ucma_cleanup_multicast(struct ucma_context *ctx) * lock on the reader and this is enough serialization */ __xa_erase(&multicast_table, mc->id); - kfree(mc); + ucma_put_mc(mc); } xa_unlock(&multicast_table); } @@ -1471,6 +1486,7 @@ static ssize_t ucma_process_join(struct ucma_file *file, mc->ctx = ctx; mc->join_state = join_state; mc->uid = cmd->uid; + atomic_set(&mc->ref, 1); memcpy(&mc->addr, addr, cmd->addr_size); xa_lock(&multicast_table); @@ -1490,6 +1506,7 @@ static ssize_t ucma_process_join(struct ucma_file *file, if (ret) goto err_xa_erase; + ucma_get_mc(mc); resp.id = mc->id; if (copy_to_user(u64_to_user_ptr(cmd->response), &resp, sizeof(resp))) { @@ -1506,6 +1523,7 @@ static ssize_t ucma_process_join(struct ucma_file *file, mutex_lock(&ctx->mutex); rdma_leave_multicast(ctx->cm_id, (struct sockaddr *) &mc->addr); mutex_unlock(&ctx->mutex); + ucma_put_mc(mc); ucma_cleanup_mc_events(mc); err_xa_erase: xa_lock(&multicast_table); @@ -1513,7 +1531,7 @@ static ssize_t ucma_process_join(struct ucma_file *file, __xa_erase(&multicast_table, mc->id); err_free_mc: xa_unlock(&multicast_table); - kfree(mc); + ucma_put_mc(mc); err_put_ctx: ucma_put_ctx(ctx); return ret; @@ -1599,7 +1617,7 @@ static ssize_t ucma_leave_multicast(struct ucma_file *file, ucma_put_ctx(mc->ctx); resp.events_reported = mc->events_reported; - kfree(mc); + ucma_put_mc(mc); if (copy_to_user(u64_to_user_ptr(cmd.response), &resp, sizeof(resp)))