diff --git a/net/l2tp/l2tp_core.c b/net/l2tp/l2tp_core.c index c80ab3f26084..c97cd0fd8514 100644 --- a/net/l2tp/l2tp_core.c +++ b/net/l2tp/l2tp_core.c @@ -150,15 +150,43 @@ static void l2tp_session_free(struct l2tp_session *session) kfree(session); } -struct l2tp_tunnel *l2tp_sk_to_tunnel(struct sock *sk) +static struct l2tp_tunnel *__l2tp_sk_to_tunnel(const struct sock *sk) { - struct l2tp_tunnel *tunnel = sk->sk_user_data; + const struct net *net = sock_net(sk); + unsigned long tunnel_id, tmp; + struct l2tp_tunnel *tunnel; + struct l2tp_net *pn; - if (tunnel) - if (WARN_ON(tunnel->magic != L2TP_TUNNEL_MAGIC)) - return NULL; + WARN_ON_ONCE(!rcu_read_lock_bh_held()); + pn = l2tp_pernet(net); + idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) { + if (tunnel && tunnel->sock == sk) + return tunnel; + } - return tunnel; + return NULL; +} + +struct l2tp_tunnel *l2tp_sk_to_tunnel(const struct sock *sk) +{ + const struct net *net = sock_net(sk); + unsigned long tunnel_id, tmp; + struct l2tp_tunnel *tunnel; + struct l2tp_net *pn; + + rcu_read_lock_bh(); + pn = l2tp_pernet(net); + idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) { + if (tunnel && + tunnel->sock == sk && + refcount_inc_not_zero(&tunnel->ref_count)) { + rcu_read_unlock_bh(); + return tunnel; + } + } + rcu_read_unlock_bh(); + + return NULL; } EXPORT_SYMBOL_GPL(l2tp_sk_to_tunnel); @@ -1213,8 +1241,10 @@ EXPORT_SYMBOL_GPL(l2tp_xmit_skb); */ static void l2tp_tunnel_destruct(struct sock *sk) { - struct l2tp_tunnel *tunnel = l2tp_sk_to_tunnel(sk); + struct l2tp_tunnel *tunnel; + rcu_read_lock_bh(); + tunnel = __l2tp_sk_to_tunnel(sk); if (!tunnel) goto end; @@ -1242,6 +1272,7 @@ static void l2tp_tunnel_destruct(struct sock *sk) kfree_rcu(tunnel, rcu); end: + rcu_read_unlock_bh(); return; } @@ -1308,10 +1339,13 @@ static void l2tp_tunnel_closeall(struct l2tp_tunnel *tunnel) /* Tunnel socket destroy hook for UDP encapsulation */ static void l2tp_udp_encap_destroy(struct sock *sk) { - struct l2tp_tunnel *tunnel = l2tp_sk_to_tunnel(sk); + struct l2tp_tunnel *tunnel; - if (tunnel) + tunnel = l2tp_sk_to_tunnel(sk); + if (tunnel) { l2tp_tunnel_delete(tunnel); + l2tp_tunnel_dec_refcount(tunnel); + } } static void l2tp_tunnel_remove(struct net *net, struct l2tp_tunnel *tunnel) diff --git a/net/l2tp/l2tp_core.h b/net/l2tp/l2tp_core.h index 8ac81bc1bc6f..a41cf6795df0 100644 --- a/net/l2tp/l2tp_core.h +++ b/net/l2tp/l2tp_core.h @@ -273,10 +273,7 @@ void l2tp_nl_unregister_ops(enum l2tp_pwtype pw_type); /* IOCTL helper for IP encap modules. */ int l2tp_ioctl(struct sock *sk, int cmd, int *karg); -/* Extract the tunnel structure from a socket's sk_user_data pointer, - * validating the tunnel magic feather. - */ -struct l2tp_tunnel *l2tp_sk_to_tunnel(struct sock *sk); +struct l2tp_tunnel *l2tp_sk_to_tunnel(const struct sock *sk); static inline int l2tp_get_l2specific_len(struct l2tp_session *session) { diff --git a/net/l2tp/l2tp_ip.c b/net/l2tp/l2tp_ip.c index e48aa177d74c..78243f993cda 100644 --- a/net/l2tp/l2tp_ip.c +++ b/net/l2tp/l2tp_ip.c @@ -235,14 +235,17 @@ static void l2tp_ip_close(struct sock *sk, long timeout) static void l2tp_ip_destroy_sock(struct sock *sk) { - struct l2tp_tunnel *tunnel = l2tp_sk_to_tunnel(sk); + struct l2tp_tunnel *tunnel; struct sk_buff *skb; while ((skb = __skb_dequeue_tail(&sk->sk_write_queue)) != NULL) kfree_skb(skb); - if (tunnel) + tunnel = l2tp_sk_to_tunnel(sk); + if (tunnel) { l2tp_tunnel_delete(tunnel); + l2tp_tunnel_dec_refcount(tunnel); + } } static int l2tp_ip_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len) diff --git a/net/l2tp/l2tp_ip6.c b/net/l2tp/l2tp_ip6.c index d217ff1f229e..3b0465f2d60d 100644 --- a/net/l2tp/l2tp_ip6.c +++ b/net/l2tp/l2tp_ip6.c @@ -246,14 +246,17 @@ static void l2tp_ip6_close(struct sock *sk, long timeout) static void l2tp_ip6_destroy_sock(struct sock *sk) { - struct l2tp_tunnel *tunnel = l2tp_sk_to_tunnel(sk); + struct l2tp_tunnel *tunnel; lock_sock(sk); ip6_flush_pending_frames(sk); release_sock(sk); - if (tunnel) + tunnel = l2tp_sk_to_tunnel(sk); + if (tunnel) { l2tp_tunnel_delete(tunnel); + l2tp_tunnel_dec_refcount(tunnel); + } } static int l2tp_ip6_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len)