bpf: sockmap/tls, close can race with map free

When a map free is called and in parallel a socket is closed we
have two paths that can potentially reset the socket prot ops, the
bpf close() path and the map free path. This creates a problem
with which prot ops should be used from the socket closed side.

If the map_free side completes first then we want to call the
original lowest level ops. However, if the tls path runs first
we want to call the sockmap ops. Additionally there was no locking
around prot updates in TLS code paths so the prot ops could
be changed multiple times once from TLS path and again from sockmap
side potentially leaving ops pointed at either TLS or sockmap
when psock and/or tls context have already been destroyed.

To fix this race first only update ops inside callback lock
so that TLS, sockmap and lowest level all agree on prot state.
Second and a ULP callback update() so that lower layers can
inform the upper layer when they are being removed allowing the
upper layer to reset prot ops.

This gets us close to allowing sockmap and tls to be stacked
in arbitrary order but will save that patch for *next trees.

v4:
 - make sure we don't free things for device;
 - remove the checks which swap the callbacks back
   only if TLS is at the top.

Reported-by: syzbot+06537213db7ba2745c4a@syzkaller.appspotmail.com
Fixes: 02c558b2d5d6 ("bpf: sockmap, support for msg_peek in sk_msg with redirect ingress")
Signed-off-by: John Fastabend <john.fastabend@gmail.com>
Signed-off-by: Jakub Kicinski <jakub.kicinski@netronome.com>
Reviewed-by: Dirk van der Merwe <dirk.vandermerwe@netronome.com>
Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
This commit is contained in:
John Fastabend 2019-07-19 10:29:22 -07:00 committed by Daniel Borkmann
parent 0e858739c2
commit 95fa145479
5 changed files with 53 additions and 8 deletions

View File

@ -354,7 +354,13 @@ static inline void sk_psock_restore_proto(struct sock *sk,
sk->sk_write_space = psock->saved_write_space; sk->sk_write_space = psock->saved_write_space;
if (psock->sk_proto) { if (psock->sk_proto) {
sk->sk_prot = psock->sk_proto; struct inet_connection_sock *icsk = inet_csk(sk);
bool has_ulp = !!icsk->icsk_ulp_data;
if (has_ulp)
tcp_update_ulp(sk, psock->sk_proto);
else
sk->sk_prot = psock->sk_proto;
psock->sk_proto = NULL; psock->sk_proto = NULL;
} }
} }

View File

@ -2103,6 +2103,8 @@ struct tcp_ulp_ops {
/* initialize ulp */ /* initialize ulp */
int (*init)(struct sock *sk); int (*init)(struct sock *sk);
/* update ulp */
void (*update)(struct sock *sk, struct proto *p);
/* cleanup ulp */ /* cleanup ulp */
void (*release)(struct sock *sk); void (*release)(struct sock *sk);
@ -2114,6 +2116,7 @@ void tcp_unregister_ulp(struct tcp_ulp_ops *type);
int tcp_set_ulp(struct sock *sk, const char *name); int tcp_set_ulp(struct sock *sk, const char *name);
void tcp_get_available_ulp(char *buf, size_t len); void tcp_get_available_ulp(char *buf, size_t len);
void tcp_cleanup_ulp(struct sock *sk); void tcp_cleanup_ulp(struct sock *sk);
void tcp_update_ulp(struct sock *sk, struct proto *p);
#define MODULE_ALIAS_TCP_ULP(name) \ #define MODULE_ALIAS_TCP_ULP(name) \
__MODULE_INFO(alias, alias_userspace, name); \ __MODULE_INFO(alias, alias_userspace, name); \

View File

@ -585,12 +585,12 @@ EXPORT_SYMBOL_GPL(sk_psock_destroy);
void sk_psock_drop(struct sock *sk, struct sk_psock *psock) void sk_psock_drop(struct sock *sk, struct sk_psock *psock)
{ {
rcu_assign_sk_user_data(sk, NULL);
sk_psock_cork_free(psock); sk_psock_cork_free(psock);
sk_psock_zap_ingress(psock); sk_psock_zap_ingress(psock);
sk_psock_restore_proto(sk, psock);
write_lock_bh(&sk->sk_callback_lock); write_lock_bh(&sk->sk_callback_lock);
sk_psock_restore_proto(sk, psock);
rcu_assign_sk_user_data(sk, NULL);
if (psock->progs.skb_parser) if (psock->progs.skb_parser)
sk_psock_stop_strp(sk, psock); sk_psock_stop_strp(sk, psock);
write_unlock_bh(&sk->sk_callback_lock); write_unlock_bh(&sk->sk_callback_lock);

View File

@ -96,6 +96,19 @@ void tcp_get_available_ulp(char *buf, size_t maxlen)
rcu_read_unlock(); rcu_read_unlock();
} }
void tcp_update_ulp(struct sock *sk, struct proto *proto)
{
struct inet_connection_sock *icsk = inet_csk(sk);
if (!icsk->icsk_ulp_ops) {
sk->sk_prot = proto;
return;
}
if (icsk->icsk_ulp_ops->update)
icsk->icsk_ulp_ops->update(sk, proto);
}
void tcp_cleanup_ulp(struct sock *sk) void tcp_cleanup_ulp(struct sock *sk)
{ {
struct inet_connection_sock *icsk = inet_csk(sk); struct inet_connection_sock *icsk = inet_csk(sk);

View File

@ -328,7 +328,10 @@ static void tls_sk_proto_unhash(struct sock *sk)
ctx = tls_get_ctx(sk); ctx = tls_get_ctx(sk);
tls_sk_proto_cleanup(sk, ctx, timeo); tls_sk_proto_cleanup(sk, ctx, timeo);
write_lock_bh(&sk->sk_callback_lock);
icsk->icsk_ulp_data = NULL; icsk->icsk_ulp_data = NULL;
sk->sk_prot = ctx->sk_proto;
write_unlock_bh(&sk->sk_callback_lock);
if (ctx->sk_proto->unhash) if (ctx->sk_proto->unhash)
ctx->sk_proto->unhash(sk); ctx->sk_proto->unhash(sk);
@ -337,7 +340,7 @@ static void tls_sk_proto_unhash(struct sock *sk)
static void tls_sk_proto_close(struct sock *sk, long timeout) static void tls_sk_proto_close(struct sock *sk, long timeout)
{ {
void (*sk_proto_close)(struct sock *sk, long timeout); struct inet_connection_sock *icsk = inet_csk(sk);
struct tls_context *ctx = tls_get_ctx(sk); struct tls_context *ctx = tls_get_ctx(sk);
long timeo = sock_sndtimeo(sk, 0); long timeo = sock_sndtimeo(sk, 0);
bool free_ctx; bool free_ctx;
@ -347,12 +350,15 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
lock_sock(sk); lock_sock(sk);
free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW; free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
sk_proto_close = ctx->sk_proto_close;
if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE) if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
tls_sk_proto_cleanup(sk, ctx, timeo); tls_sk_proto_cleanup(sk, ctx, timeo);
write_lock_bh(&sk->sk_callback_lock);
if (free_ctx)
icsk->icsk_ulp_data = NULL;
sk->sk_prot = ctx->sk_proto; sk->sk_prot = ctx->sk_proto;
write_unlock_bh(&sk->sk_callback_lock);
release_sock(sk); release_sock(sk);
if (ctx->tx_conf == TLS_SW) if (ctx->tx_conf == TLS_SW)
tls_sw_free_ctx_tx(ctx); tls_sw_free_ctx_tx(ctx);
@ -360,7 +366,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
tls_sw_strparser_done(ctx); tls_sw_strparser_done(ctx);
if (ctx->rx_conf == TLS_SW) if (ctx->rx_conf == TLS_SW)
tls_sw_free_ctx_rx(ctx); tls_sw_free_ctx_rx(ctx);
sk_proto_close(sk, timeout); ctx->sk_proto_close(sk, timeout);
if (free_ctx) if (free_ctx)
tls_ctx_free(ctx); tls_ctx_free(ctx);
@ -827,7 +833,7 @@ static int tls_init(struct sock *sk)
int rc = 0; int rc = 0;
if (tls_hw_prot(sk)) if (tls_hw_prot(sk))
goto out; return 0;
/* The TLS ulp is currently supported only for TCP sockets /* The TLS ulp is currently supported only for TCP sockets
* in ESTABLISHED state. * in ESTABLISHED state.
@ -838,22 +844,38 @@ static int tls_init(struct sock *sk)
if (sk->sk_state != TCP_ESTABLISHED) if (sk->sk_state != TCP_ESTABLISHED)
return -ENOTSUPP; return -ENOTSUPP;
tls_build_proto(sk);
/* allocate tls context */ /* allocate tls context */
write_lock_bh(&sk->sk_callback_lock);
ctx = create_ctx(sk); ctx = create_ctx(sk);
if (!ctx) { if (!ctx) {
rc = -ENOMEM; rc = -ENOMEM;
goto out; goto out;
} }
tls_build_proto(sk);
ctx->tx_conf = TLS_BASE; ctx->tx_conf = TLS_BASE;
ctx->rx_conf = TLS_BASE; ctx->rx_conf = TLS_BASE;
ctx->sk_proto = sk->sk_prot; ctx->sk_proto = sk->sk_prot;
update_sk_prot(sk, ctx); update_sk_prot(sk, ctx);
out: out:
write_unlock_bh(&sk->sk_callback_lock);
return rc; return rc;
} }
static void tls_update(struct sock *sk, struct proto *p)
{
struct tls_context *ctx;
ctx = tls_get_ctx(sk);
if (likely(ctx)) {
ctx->sk_proto_close = p->close;
ctx->sk_proto = p;
} else {
sk->sk_prot = p;
}
}
void tls_register_device(struct tls_device *device) void tls_register_device(struct tls_device *device)
{ {
spin_lock_bh(&device_spinlock); spin_lock_bh(&device_spinlock);
@ -874,6 +896,7 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
.name = "tls", .name = "tls",
.owner = THIS_MODULE, .owner = THIS_MODULE,
.init = tls_init, .init = tls_init,
.update = tls_update,
}; };
static int __init tls_register(void) static int __init tls_register(void)