diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h index c5a2d6f50f25..81bfa1a33623 100644 --- a/include/linux/skmsg.h +++ b/include/linux/skmsg.h @@ -277,7 +277,7 @@ static inline void sk_msg_sg_copy_clear(struct sk_msg *msg, u32 start) static inline struct sk_psock *sk_psock(const struct sock *sk) { - return rcu_dereference_sk_user_data(sk); + return rcu_dereference_sk_user_data_psock(sk); } static inline void sk_psock_set_state(struct sk_psock *psock, diff --git a/include/net/sock.h b/include/net/sock.h index 9fa54762e077..316c0313b2bf 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -549,10 +549,17 @@ enum sk_pacing { * when cloning the socket. For instance, it can point to a reference * counted object. sk_user_data bottom bit is set if pointer must not * be copied. + * + * SK_USER_DATA_NOCOPY - test if pointer must not copied + * SK_USER_DATA_BPF - managed by BPF + * SK_USER_DATA_NOTPSOCK - test if pointer points to psock */ #define SK_USER_DATA_NOCOPY 1UL -#define SK_USER_DATA_BPF 2UL /* Managed by BPF */ -#define SK_USER_DATA_PTRMASK ~(SK_USER_DATA_NOCOPY | SK_USER_DATA_BPF) +#define SK_USER_DATA_BPF 2UL +#define SK_USER_DATA_NOTPSOCK 4UL +#define SK_USER_DATA_PTRMASK ~(SK_USER_DATA_NOCOPY | SK_USER_DATA_BPF |\ + SK_USER_DATA_NOTPSOCK) + /** * sk_user_data_is_nocopy - Test if sk_user_data pointer must not be copied @@ -584,6 +591,22 @@ static inline bool sk_user_data_is_nocopy(const struct sock *sk) __tmp | SK_USER_DATA_NOCOPY); \ }) +/** + * rcu_dereference_sk_user_data_psock - return psock if sk_user_data points + * to the psock + * @sk: socket + */ +static inline +struct sk_psock *rcu_dereference_sk_user_data_psock(const struct sock *sk) +{ + uintptr_t __tmp = (uintptr_t)rcu_dereference(__sk_user_data((sk))); + + if (__tmp & SK_USER_DATA_NOTPSOCK) + return NULL; + return (struct sk_psock *)(__tmp & SK_USER_DATA_PTRMASK); +} + + static inline struct net *sock_net(const struct sock *sk) { diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c index 433bb5a7df31..d0feccf824c8 100644 --- a/net/smc/af_smc.c +++ b/net/smc/af_smc.c @@ -812,7 +812,8 @@ static void smc_fback_replace_callbacks(struct smc_sock *smc) struct sock *clcsk = smc->clcsock->sk; write_lock_bh(&clcsk->sk_callback_lock); - clcsk->sk_user_data = (void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY); + clcsk->sk_user_data = (void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY | + SK_USER_DATA_NOTPSOCK); smc_clcsock_replace_cb(&clcsk->sk_state_change, smc_fback_state_change, &smc->clcsk_state_change); @@ -2470,7 +2471,8 @@ static int smc_listen(struct socket *sock, int backlog) */ write_lock_bh(&smc->clcsock->sk->sk_callback_lock); smc->clcsock->sk->sk_user_data = - (void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY); + (void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY | + SK_USER_DATA_NOTPSOCK); smc_clcsock_replace_cb(&smc->clcsock->sk->sk_data_ready, smc_clcsock_data_ready, &smc->clcsk_data_ready); write_unlock_bh(&smc->clcsock->sk->sk_callback_lock); diff --git a/net/smc/smc.h b/net/smc/smc.h index 5ed765ea0c73..c24d0469d267 100644 --- a/net/smc/smc.h +++ b/net/smc/smc.h @@ -299,7 +299,7 @@ static inline void smc_init_saved_callbacks(struct smc_sock *smc) static inline struct smc_sock *smc_clcsock_user_data(const struct sock *clcsk) { return (struct smc_sock *) - ((uintptr_t)clcsk->sk_user_data & ~SK_USER_DATA_NOCOPY); + ((uintptr_t)clcsk->sk_user_data & SK_USER_DATA_PTRMASK); } /* save target_cb in saved_cb, and replace target_cb with new_cb */