diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c index f0d118e..b359013 100644 --- a/net/smc/af_smc.c +++ b/net/smc/af_smc.c @@ -235,10 +235,19 @@ struct proto smc_proto6 = { static void smc_restore_fallback_changes(struct smc_sock *smc) { + struct sock *clcsk = smc->clcsock->sk; + if (smc->clcsock->file) { /* non-accepted sockets have no file yet */ smc->clcsock->file->private_data = smc->sk.sk_socket; smc->clcsock->file = NULL; } + + rcu_assign_sk_user_data(clcsk, NULL); + clcsk->sk_state_change = smc->clcsk_state_change; + clcsk->sk_data_ready = smc->clcsk_data_ready; + clcsk->sk_write_space = smc->clcsk_write_space; + clcsk->sk_error_report = smc->clcsk_error_report; + synchronize_rcu(); } static int __smc_release(struct smc_sock *smc) @@ -710,6 +719,7 @@ static int smc_fback_mark_woken(wait_queue_entry_t *wait, return 0; } +/* must be called under rcu_read_lock */ static void smc_fback_forward_wakeup(struct smc_sock *smc, struct sock *clcsk, void (*clcsock_callback)(struct sock *sk)) { @@ -718,58 +728,67 @@ static void smc_fback_forward_wakeup(struct smc_sock *smc, struct sock *clcsk, init_waitqueue_func_entry(&mark.wait_entry, smc_fback_mark_woken); - rcu_read_lock(); wq = rcu_dereference(clcsk->sk_wq); if (!wq) - goto out; + return; add_wait_queue(sk_sleep(clcsk), &mark.wait_entry); clcsock_callback(clcsk); remove_wait_queue(sk_sleep(clcsk), &mark.wait_entry); if (mark.woken) smc_fback_wakeup_waitqueue(smc, mark.key); -out: - rcu_read_unlock(); } static void smc_fback_state_change(struct sock *clcsk) { - struct smc_sock *smc = - smc_clcsock_user_data(clcsk); + struct smc_sock *smc; + rcu_read_lock(); + smc = rcu_dereference_sk_user_data(clcsk); if (!smc) - return; + goto out; smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_state_change); +out: + rcu_read_unlock(); } static void smc_fback_data_ready(struct sock *clcsk) { - struct smc_sock *smc = - smc_clcsock_user_data(clcsk); + struct smc_sock *smc; + rcu_read_lock(); + smc = rcu_dereference_sk_user_data(clcsk); if (!smc) - return; + goto out; smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_data_ready); +out: + rcu_read_unlock(); } static void smc_fback_write_space(struct sock *clcsk) { - struct smc_sock *smc = - smc_clcsock_user_data(clcsk); + struct smc_sock *smc; + rcu_read_lock(); + smc = rcu_dereference_sk_user_data(clcsk); if (!smc) - return; + goto out; smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_write_space); +out: + rcu_read_unlock(); } static void smc_fback_error_report(struct sock *clcsk) { - struct smc_sock *smc = - smc_clcsock_user_data(clcsk); + struct smc_sock *smc; + rcu_read_lock(); + smc = rcu_dereference_sk_user_data(clcsk); if (!smc) - return; + goto out; smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_error_report); +out: + rcu_read_unlock(); } static int smc_switch_to_fallback(struct smc_sock *smc, int reason_code) @@ -810,8 +829,7 @@ static int smc_switch_to_fallback(struct smc_sock *smc, int reason_code) clcsk->sk_write_space = smc_fback_write_space; clcsk->sk_error_report = smc_fback_error_report; - smc->clcsock->sk->sk_user_data = - (void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY); + rcu_assign_sk_user_data_nocopy(smc->clcsock->sk, smc); } out: mutex_unlock(&smc->clcsock_release_lock);