diff --git a/net/core/sock_map.c b/net/core/sock_map.c index b0e96337a269..54bd5e3378c7 100644 --- a/net/core/sock_map.c +++ b/net/core/sock_map.c @@ -200,15 +200,25 @@ static struct sk_psock *sock_map_psock_get_checked(struct sock *sk) rcu_read_lock(); psock = sk_psock(sk); - if (psock) { - if (sk->sk_prot->close != sock_map_close) { - psock = ERR_PTR(-EBUSY); + + if (READ_ONCE(sk->sk_prot)->close != sock_map_close) { + if (likely(!psock)) goto out; - } - if (!refcount_inc_not_zero(&psock->refcnt)) - psock = ERR_PTR(-EBUSY); + /* sock_map_init_proto() has not finished yet. */ + psock = ERR_PTR(-EBUSY); + goto out; } + + if (likely(psock && refcount_inc_not_zero(&psock->refcnt))) + goto out; + + /* sk_psock() being NULL or psock->refcnt being 0 does not + * guarantee that sk_psock_restore_proto() has finished. + */ + read_lock_bh(&sk->sk_callback_lock); + read_unlock_bh(&sk->sk_callback_lock); + psock = ERR_PTR(-EBUSY); out: rcu_read_unlock(); return psock; @@ -1631,14 +1641,15 @@ void sock_map_unhash(struct sock *sk) struct sk_psock *psock; rcu_read_lock(); - psock = sk_psock(sk); - if (unlikely(!psock)) { + psock = sock_map_psock_get_checked(sk); + if (IS_ERR_OR_NULL(psock)) { rcu_read_unlock(); saved_unhash = READ_ONCE(sk->sk_prot)->unhash; } else { saved_unhash = psock->saved_unhash; sock_map_remove_links(sk, psock); rcu_read_unlock(); + sk_psock_put(sk, psock); } if (WARN_ON_ONCE(saved_unhash == sock_map_unhash)) return; @@ -1653,8 +1664,8 @@ void sock_map_destroy(struct sock *sk) struct sk_psock *psock; rcu_read_lock(); - psock = sk_psock_get(sk); - if (unlikely(!psock)) { + psock = sock_map_psock_get_checked(sk); + if (IS_ERR_OR_NULL(psock)) { rcu_read_unlock(); saved_destroy = READ_ONCE(sk->sk_prot)->destroy; } else { @@ -1678,13 +1689,10 @@ void sock_map_close(struct sock *sk, long timeout) lock_sock(sk); rcu_read_lock(); - psock = sk_psock(sk); - if (likely(psock)) { + psock = sock_map_psock_get_checked(sk); + if (!IS_ERR_OR_NULL(psock)) { saved_close = psock->saved_close; sock_map_remove_links(sk, psock); - psock = sk_psock_get(sk); - if (unlikely(!psock)) - goto no_psock; rcu_read_unlock(); sk_psock_stop(psock); release_sock(sk); @@ -1692,7 +1700,6 @@ void sock_map_close(struct sock *sk, long timeout) sk_psock_put(sk, psock); } else { saved_close = READ_ONCE(sk->sk_prot)->close; -no_psock: rcu_read_unlock(); release_sock(sk); }