--- x/net/core/sock_map.c
+++ y/net/core/sock_map.c
@@ -178,8 +178,10 @@ static void sock_map_del_link(struct soc
 		if (verdict_stop)
 			sk_psock_stop_verdict(sk, psock);
 
-		if (psock->psock_update_sk_prot)
+		if (psock->psock_update_sk_prot) {
 			psock->psock_update_sk_prot(sk, psock, false);
+			WARN_ON_ONCE(sk->sk_prot->close == sock_map_close);
+		}
 		write_unlock_bh(&sk->sk_callback_lock);
 	}
 }
@@ -196,10 +198,13 @@ static void sock_map_unref(struct sock *
 
 static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
 {
+	int rc;
 	if (!sk->sk_prot->psock_update_sk_prot)
 		return -EINVAL;
 	psock->psock_update_sk_prot = sk->sk_prot->psock_update_sk_prot;
-	return sk->sk_prot->psock_update_sk_prot(sk, psock, false);
+	rc = sk->sk_prot->psock_update_sk_prot(sk, psock, false);
+	WARN_ON_ONCE(sk->sk_prot->close == sock_map_close);
+	return rc;
 }
 
 static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)