diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c index 14ddc40..6e03b1f 100644 --- a/net/smc/af_smc.c +++ b/net/smc/af_smc.c @@ -243,11 +243,14 @@ struct proto smc_proto6 = { }; EXPORT_SYMBOL_GPL(smc_proto6); +static void smc_fback_restore_callbacks(struct smc_sock *smc); + static void smc_restore_fallback_changes(struct smc_sock *smc) { if (smc->clcsock->file) { /* non-accepted sockets have no file yet */ smc->clcsock->file->private_data = smc->sk.sk_socket; smc->clcsock->file = NULL; + smc_fback_restore_callbacks(smc); } } @@ -744,47 +747,104 @@ static void smc_fback_forward_wakeup(struct smc_sock *smc, struct sock *clcsk, static void smc_fback_state_change(struct sock *clcsk) { - struct smc_sock *smc = - smc_clcsock_user_data(clcsk); + struct smc_sock *smc; + read_lock_bh(&clcsk->sk_callback_lock); + smc = smc_clcsock_user_data(clcsk); if (!smc) return; smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_state_change); + read_unlock_bh(&clcsk->sk_callback_lock); } static void smc_fback_data_ready(struct sock *clcsk) { - struct smc_sock *smc = - smc_clcsock_user_data(clcsk); + struct smc_sock *smc; + read_lock_bh(&clcsk->sk_callback_lock); + smc = smc_clcsock_user_data(clcsk); if (!smc) return; smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_data_ready); + read_unlock_bh(&clcsk->sk_callback_lock); } static void smc_fback_write_space(struct sock *clcsk) { - struct smc_sock *smc = - smc_clcsock_user_data(clcsk); + struct smc_sock *smc; + read_lock_bh(&clcsk->sk_callback_lock); + smc = smc_clcsock_user_data(clcsk); if (!smc) return; smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_write_space); + read_unlock_bh(&clcsk->sk_callback_lock); } static void smc_fback_error_report(struct sock *clcsk) { - struct smc_sock *smc = - smc_clcsock_user_data(clcsk); + struct smc_sock *smc; + read_lock_bh(&clcsk->sk_callback_lock); + smc = smc_clcsock_user_data(clcsk); if (!smc) return; smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_error_report); + read_unlock_bh(&clcsk->sk_callback_lock); +} + +static void smc_fback_init_saved_callbacks(struct smc_sock *smc) +{ + smc->clcsk_state_change = NULL; + smc->clcsk_data_ready = NULL; + smc->clcsk_write_space = NULL; + smc->clcsk_error_report = NULL; +} + +static void smc_fback_replace_callbacks(struct smc_sock *smc) +{ + struct sock *clcsk = smc->clcsock->sk; + + write_lock_bh(&clcsk->sk_callback_lock); + clcsk->sk_user_data = (void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY); + /* sk_state_change */ + smc_fback_replace_cb(&smc->clcsk_state_change, &clcsk->sk_state_change, + smc_fback_state_change); + /* sk_data_ready */ + smc_fback_replace_cb(&smc->clcsk_data_ready, &clcsk->sk_data_ready, + smc_fback_data_ready); + /* sk_write_space */ + smc_fback_replace_cb(&smc->clcsk_write_space, &clcsk->sk_write_space, + smc_fback_write_space); + /* sk_error_report */ + smc_fback_replace_cb(&smc->clcsk_error_report, &clcsk->sk_error_report, + smc_fback_error_report); + write_unlock_bh(&clcsk->sk_callback_lock); +} + +static void smc_fback_restore_callbacks(struct smc_sock *smc) +{ + struct sock *clcsk = smc->clcsock->sk; + + write_lock_bh(&clcsk->sk_callback_lock); + clcsk->sk_user_data = NULL; + /* sk_state_change */ + if (smc->clcsk_state_change) + clcsk->sk_state_change = smc->clcsk_state_change; + /* sk_data_ready */ + if (smc->clcsk_data_ready) + clcsk->sk_data_ready = smc->clcsk_data_ready; + /* sk_write_space */ + if (smc->clcsk_write_space) + clcsk->sk_write_space = smc->clcsk_write_space; + /* sk_error_report */ + if (smc->clcsk_error_report) + clcsk->sk_error_report = smc->clcsk_error_report; + write_unlock_bh(&clcsk->sk_callback_lock); } static int smc_switch_to_fallback(struct smc_sock *smc, int reason_code) { - struct sock *clcsk; int rc = 0; mutex_lock(&smc->clcsock_release_lock); @@ -792,10 +852,7 @@ static int smc_switch_to_fallback(struct smc_sock *smc, int reason_code) rc = -EBADF; goto out; } - clcsk = smc->clcsock->sk; - if (smc->use_fallback) - goto out; smc->use_fallback = true; smc->fallback_rsn = reason_code; smc_stat_fallback(smc); @@ -810,18 +867,7 @@ static int smc_switch_to_fallback(struct smc_sock *smc, int reason_code) * in smc sk->sk_wq and they should be woken up * as clcsock's wait queue is woken up. */ - smc->clcsk_state_change = clcsk->sk_state_change; - smc->clcsk_data_ready = clcsk->sk_data_ready; - smc->clcsk_write_space = clcsk->sk_write_space; - smc->clcsk_error_report = clcsk->sk_error_report; - - clcsk->sk_state_change = smc_fback_state_change; - clcsk->sk_data_ready = smc_fback_data_ready; - clcsk->sk_write_space = smc_fback_write_space; - clcsk->sk_error_report = smc_fback_error_report; - - smc->clcsock->sk->sk_user_data = - (void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY); + smc_fback_replace_callbacks(smc); } out: mutex_unlock(&smc->clcsock_release_lock); @@ -2395,10 +2441,12 @@ static int smc_listen(struct socket *sock, int backlog) /* save original sk_data_ready function and establish * smc-specific sk_data_ready function */ - smc->clcsk_data_ready = smc->clcsock->sk->sk_data_ready; - smc->clcsock->sk->sk_data_ready = smc_clcsock_data_ready; + write_lock_bh(&smc->clcsock->sk->sk_callback_lock); smc->clcsock->sk->sk_user_data = (void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY); + smc_fback_replace_cb(&smc->clcsk_data_ready, + &smc->clcsock->sk->sk_data_ready, smc_clcsock_data_ready); + write_unlock_bh(&smc->clcsock->sk->sk_callback_lock); /* save original ops */ smc->ori_af_ops = inet_csk(smc->clcsock->sk)->icsk_af_ops; @@ -2413,7 +2461,9 @@ static int smc_listen(struct socket *sock, int backlog) rc = kernel_listen(smc->clcsock, backlog); if (rc) { + write_lock_bh(&smc->clcsock->sk->sk_callback_lock); smc->clcsock->sk->sk_data_ready = smc->clcsk_data_ready; + write_unlock_bh(&smc->clcsock->sk->sk_callback_lock); goto out; } sk->sk_max_ack_backlog = backlog; @@ -3092,6 +3142,7 @@ static int __smc_create(struct net *net, struct socket *sock, int protocol, smc = smc_sk(sk); smc->use_fallback = false; /* assume rdma capability first */ smc->fallback_rsn = 0; + smc_fback_init_saved_callbacks(smc); /* default behavior from limit_smc_hs in every net namespace */ smc->limit_smc_hs = net->smc.limit_smc_hs; diff --git a/net/smc/smc.h b/net/smc/smc.h index ea06205..a371ec6 100644 --- a/net/smc/smc.h +++ b/net/smc/smc.h @@ -294,6 +294,15 @@ static inline struct smc_sock *smc_clcsock_user_data(const struct sock *clcsk) ((uintptr_t)clcsk->sk_user_data & ~SK_USER_DATA_NOCOPY); } +static inline void smc_fback_replace_cb(void (**saved_cb)(struct sock *sk), + void (**orig_cb)(struct sock *sk), void (*new_cb)(struct sock *sk)) +{ + /* only save once */ + if (!*saved_cb) + *saved_cb = *orig_cb; + *orig_cb = new_cb; +} + extern struct workqueue_struct *smc_hs_wq; /* wq for handshake work */ extern struct workqueue_struct *smc_close_wq; /* wq for close work */ diff --git a/net/smc/smc_close.c b/net/smc/smc_close.c index 676cb23..f5ef3ee 100644 --- a/net/smc/smc_close.c +++ b/net/smc/smc_close.c @@ -214,8 +214,10 @@ int smc_close_active(struct smc_sock *smc) sk->sk_state = SMC_CLOSED; sk->sk_state_change(sk); /* wake up accept */ if (smc->clcsock && smc->clcsock->sk) { + write_lock_bh(&smc->clcsock->sk->sk_callback_lock); smc->clcsock->sk->sk_data_ready = smc->clcsk_data_ready; smc->clcsock->sk->sk_user_data = NULL; + write_unlock_bh(&smc->clcsock->sk->sk_callback_lock); rc = kernel_sock_shutdown(smc->clcsock, SHUT_RDWR); } smc_close_cleanup_listen(sk);