diff --git a/include/net/xfrm.h b/include/net/xfrm.h index a0bdd58f401c..f5275618e744 100644 --- a/include/net/xfrm.h +++ b/include/net/xfrm.h @@ -188,6 +188,7 @@ struct xfrm_state { refcount_t refcnt; spinlock_t lock; + u32 pcpu_num; struct xfrm_id id; struct xfrm_selector sel; struct xfrm_mark mark; @@ -1684,7 +1685,7 @@ struct xfrmk_spdinfo { u32 spdhmcnt; }; -struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq); +struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq, u32 pcpu_num); int xfrm_state_delete(struct xfrm_state *x); int xfrm_state_flush(struct net *net, u8 proto, bool task_valid, bool sync); int xfrm_dev_state_flush(struct net *net, struct net_device *dev, bool task_valid); @@ -1796,7 +1797,7 @@ int verify_spi_info(u8 proto, u32 min, u32 max, struct netlink_ext_ack *extack); int xfrm_alloc_spi(struct xfrm_state *x, u32 minspi, u32 maxspi, struct netlink_ext_ack *extack); struct xfrm_state *xfrm_find_acq(struct net *net, const struct xfrm_mark *mark, - u8 mode, u32 reqid, u32 if_id, u8 proto, + u8 mode, u32 reqid, u32 if_id, u32 pcpu_num, u8 proto, const xfrm_address_t *daddr, const xfrm_address_t *saddr, int create, unsigned short family); diff --git a/include/uapi/linux/xfrm.h b/include/uapi/linux/xfrm.h index f28701500714..d73a97e3030a 100644 --- a/include/uapi/linux/xfrm.h +++ b/include/uapi/linux/xfrm.h @@ -322,6 +322,7 @@ enum xfrm_attr_type_t { XFRMA_MTIMER_THRESH, /* __u32 in seconds for input SA */ XFRMA_SA_DIR, /* __u8 */ XFRMA_NAT_KEEPALIVE_INTERVAL, /* __u32 in seconds for NAT keepalive */ + XFRMA_SA_PCPU, /* __u32 */ __XFRMA_MAX #define XFRMA_OUTPUT_MARK XFRMA_SET_MARK /* Compatibility */ @@ -437,6 +438,7 @@ struct xfrm_userpolicy_info { #define XFRM_POLICY_LOCALOK 1 /* Allow user to override global policy */ /* Automatically expand selector to include matching ICMP payloads. */ #define XFRM_POLICY_ICMP 2 +#define XFRM_POLICY_CPU_ACQUIRE 4 __u8 share; }; diff --git a/net/key/af_key.c b/net/key/af_key.c index f79fb99271ed..c56bb4f451e6 100644 --- a/net/key/af_key.c +++ b/net/key/af_key.c @@ -1354,7 +1354,7 @@ static int pfkey_getspi(struct sock *sk, struct sk_buff *skb, const struct sadb_ } if (hdr->sadb_msg_seq) { - x = xfrm_find_acq_byseq(net, DUMMY_MARK, hdr->sadb_msg_seq); + x = xfrm_find_acq_byseq(net, DUMMY_MARK, hdr->sadb_msg_seq, UINT_MAX); if (x && !xfrm_addr_equal(&x->id.daddr, xdaddr, family)) { xfrm_state_put(x); x = NULL; @@ -1362,7 +1362,8 @@ static int pfkey_getspi(struct sock *sk, struct sk_buff *skb, const struct sadb_ } if (!x) - x = xfrm_find_acq(net, &dummy_mark, mode, reqid, 0, proto, xdaddr, xsaddr, 1, family); + x = xfrm_find_acq(net, &dummy_mark, mode, reqid, 0, UINT_MAX, + proto, xdaddr, xsaddr, 1, family); if (x == NULL) return -ENOENT; @@ -1417,7 +1418,7 @@ static int pfkey_acquire(struct sock *sk, struct sk_buff *skb, const struct sadb if (hdr->sadb_msg_seq == 0 || hdr->sadb_msg_errno == 0) return 0; - x = xfrm_find_acq_byseq(net, DUMMY_MARK, hdr->sadb_msg_seq); + x = xfrm_find_acq_byseq(net, DUMMY_MARK, hdr->sadb_msg_seq, UINT_MAX); if (x == NULL) return 0; diff --git a/net/xfrm/xfrm_compat.c b/net/xfrm/xfrm_compat.c index 91357ccaf4af..5b9ee63e30b6 100644 --- a/net/xfrm/xfrm_compat.c +++ b/net/xfrm/xfrm_compat.c @@ -132,6 +132,7 @@ static const struct nla_policy compat_policy[XFRMA_MAX+1] = { [XFRMA_MTIMER_THRESH] = { .type = NLA_U32 }, [XFRMA_SA_DIR] = NLA_POLICY_RANGE(NLA_U8, XFRM_SA_DIR_IN, XFRM_SA_DIR_OUT), [XFRMA_NAT_KEEPALIVE_INTERVAL] = { .type = NLA_U32 }, + [XFRMA_SA_PCPU] = { .type = NLA_U32 }, }; static struct nlmsghdr *xfrm_nlmsg_put_compat(struct sk_buff *skb, @@ -282,9 +283,10 @@ static int xfrm_xlate64_attr(struct sk_buff *dst, const struct nlattr *src) case XFRMA_MTIMER_THRESH: case XFRMA_SA_DIR: case XFRMA_NAT_KEEPALIVE_INTERVAL: + case XFRMA_SA_PCPU: return xfrm_nla_cpy(dst, src, nla_len(src)); default: - BUILD_BUG_ON(XFRMA_MAX != XFRMA_NAT_KEEPALIVE_INTERVAL); + BUILD_BUG_ON(XFRMA_MAX != XFRMA_SA_PCPU); pr_warn_once("unsupported nla_type %d\n", src->nla_type); return -EOPNOTSUPP; } @@ -439,7 +441,7 @@ static int xfrm_xlate32_attr(void *dst, const struct nlattr *nla, int err; if (type > XFRMA_MAX) { - BUILD_BUG_ON(XFRMA_MAX != XFRMA_NAT_KEEPALIVE_INTERVAL); + BUILD_BUG_ON(XFRMA_MAX != XFRMA_SA_PCPU); NL_SET_ERR_MSG(extack, "Bad attribute"); return -EOPNOTSUPP; } diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c index 37478d36a8df..ebef07b80afa 100644 --- a/net/xfrm/xfrm_state.c +++ b/net/xfrm/xfrm_state.c @@ -679,6 +679,7 @@ struct xfrm_state *xfrm_state_alloc(struct net *net) x->lft.hard_packet_limit = XFRM_INF; x->replay_maxage = 0; x->replay_maxdiff = 0; + x->pcpu_num = UINT_MAX; spin_lock_init(&x->lock); } return x; @@ -1155,6 +1156,12 @@ static void xfrm_state_look_at(struct xfrm_policy *pol, struct xfrm_state *x, struct xfrm_state **best, int *acq_in_progress, int *error) { + /* We need the cpu id just as a lookup key, + * we don't require it to be stable. + */ + unsigned int pcpu_id = get_cpu(); + put_cpu(); + /* Resolution logic: * 1. There is a valid state with matching selector. Done. * 2. Valid state with inappropriate selector. Skip. @@ -1174,13 +1181,18 @@ static void xfrm_state_look_at(struct xfrm_policy *pol, struct xfrm_state *x, &fl->u.__fl_common)) return; + if (x->pcpu_num != UINT_MAX && x->pcpu_num != pcpu_id) + return; + if (!*best || + ((*best)->pcpu_num == UINT_MAX && x->pcpu_num == pcpu_id) || (*best)->km.dying > x->km.dying || ((*best)->km.dying == x->km.dying && (*best)->curlft.add_time < x->curlft.add_time)) *best = x; } else if (x->km.state == XFRM_STATE_ACQ) { - *acq_in_progress = 1; + if (!*best || x->pcpu_num == pcpu_id) + *acq_in_progress = 1; } else if (x->km.state == XFRM_STATE_ERROR || x->km.state == XFRM_STATE_EXPIRED) { if ((!x->sel.family || @@ -1209,6 +1221,13 @@ xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr, unsigned short encap_family = tmpl->encap_family; unsigned int sequence; struct km_event c; + unsigned int pcpu_id; + + /* We need the cpu id just as a lookup key, + * we don't require it to be stable. + */ + pcpu_id = get_cpu(); + put_cpu(); to_put = NULL; @@ -1282,7 +1301,10 @@ xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr, } found: - x = best; + if (!(pol->flags & XFRM_POLICY_CPU_ACQUIRE) || + (best && (best->pcpu_num == pcpu_id))) + x = best; + if (!x && !error && !acquire_in_progress) { if (tmpl->id.spi && (x0 = __xfrm_state_lookup_all(net, mark, daddr, @@ -1314,6 +1336,8 @@ xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr, xfrm_init_tempstate(x, fl, tmpl, daddr, saddr, family); memcpy(&x->mark, &pol->mark, sizeof(x->mark)); x->if_id = if_id; + if ((pol->flags & XFRM_POLICY_CPU_ACQUIRE) && best) + x->pcpu_num = pcpu_id; error = security_xfrm_state_alloc_acquire(x, pol->security, fl->flowi_secid); if (error) { @@ -1392,6 +1416,11 @@ xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr, x = NULL; error = -ESRCH; } + + /* Use the already installed 'fallback' while the CPU-specific + * SA acquire is handled*/ + if (best) + x = best; } out: if (x) { @@ -1524,12 +1553,14 @@ static void __xfrm_state_bump_genids(struct xfrm_state *xnew) unsigned int h; u32 mark = xnew->mark.v & xnew->mark.m; u32 if_id = xnew->if_id; + u32 cpu_id = xnew->pcpu_num; h = xfrm_dst_hash(net, &xnew->id.daddr, &xnew->props.saddr, reqid, family); hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) { if (x->props.family == family && x->props.reqid == reqid && x->if_id == if_id && + x->pcpu_num == cpu_id && (mark & x->mark.m) == x->mark.v && xfrm_addr_equal(&x->id.daddr, &xnew->id.daddr, family) && xfrm_addr_equal(&x->props.saddr, &xnew->props.saddr, family)) @@ -1552,7 +1583,7 @@ EXPORT_SYMBOL(xfrm_state_insert); static struct xfrm_state *__find_acq_core(struct net *net, const struct xfrm_mark *m, unsigned short family, u8 mode, - u32 reqid, u32 if_id, u8 proto, + u32 reqid, u32 if_id, u32 pcpu_num, u8 proto, const xfrm_address_t *daddr, const xfrm_address_t *saddr, int create) @@ -1569,6 +1600,7 @@ static struct xfrm_state *__find_acq_core(struct net *net, x->id.spi != 0 || x->id.proto != proto || (mark & x->mark.m) != x->mark.v || + x->pcpu_num != pcpu_num || !xfrm_addr_equal(&x->id.daddr, daddr, family) || !xfrm_addr_equal(&x->props.saddr, saddr, family)) continue; @@ -1602,6 +1634,7 @@ static struct xfrm_state *__find_acq_core(struct net *net, break; } + x->pcpu_num = pcpu_num; x->km.state = XFRM_STATE_ACQ; x->id.proto = proto; x->props.family = family; @@ -1630,7 +1663,7 @@ static struct xfrm_state *__find_acq_core(struct net *net, return x; } -static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq); +static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq, u32 pcpu_num); int xfrm_state_add(struct xfrm_state *x) { @@ -1656,7 +1689,7 @@ int xfrm_state_add(struct xfrm_state *x) } if (use_spi && x->km.seq) { - x1 = __xfrm_find_acq_byseq(net, mark, x->km.seq); + x1 = __xfrm_find_acq_byseq(net, mark, x->km.seq, x->pcpu_num); if (x1 && ((x1->id.proto != x->id.proto) || !xfrm_addr_equal(&x1->id.daddr, &x->id.daddr, family))) { to_put = x1; @@ -1666,7 +1699,7 @@ int xfrm_state_add(struct xfrm_state *x) if (use_spi && !x1) x1 = __find_acq_core(net, &x->mark, family, x->props.mode, - x->props.reqid, x->if_id, x->id.proto, + x->props.reqid, x->if_id, x->pcpu_num, x->id.proto, &x->id.daddr, &x->props.saddr, 0); __xfrm_state_bump_genids(x); @@ -1791,6 +1824,7 @@ static struct xfrm_state *xfrm_state_clone(struct xfrm_state *orig, x->props.flags = orig->props.flags; x->props.extra_flags = orig->props.extra_flags; + x->pcpu_num = orig->pcpu_num; x->if_id = orig->if_id; x->tfcpad = orig->tfcpad; x->replay_maxdiff = orig->replay_maxdiff; @@ -2066,13 +2100,14 @@ EXPORT_SYMBOL(xfrm_state_lookup_byaddr); struct xfrm_state * xfrm_find_acq(struct net *net, const struct xfrm_mark *mark, u8 mode, u32 reqid, - u32 if_id, u8 proto, const xfrm_address_t *daddr, + u32 if_id, u32 pcpu_num, u8 proto, const xfrm_address_t *daddr, const xfrm_address_t *saddr, int create, unsigned short family) { struct xfrm_state *x; spin_lock_bh(&net->xfrm.xfrm_state_lock); - x = __find_acq_core(net, mark, family, mode, reqid, if_id, proto, daddr, saddr, create); + x = __find_acq_core(net, mark, family, mode, reqid, if_id, pcpu_num, + proto, daddr, saddr, create); spin_unlock_bh(&net->xfrm.xfrm_state_lock); return x; @@ -2207,7 +2242,7 @@ xfrm_state_sort(struct xfrm_state **dst, struct xfrm_state **src, int n, /* Silly enough, but I'm lazy to build resolution list */ -static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq) +static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq, u32 pcpu_num) { unsigned int h = xfrm_seq_hash(net, seq); struct xfrm_state *x; @@ -2215,6 +2250,7 @@ static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 s hlist_for_each_entry_rcu(x, net->xfrm.state_byseq + h, byseq) { if (x->km.seq == seq && (mark & x->mark.m) == x->mark.v && + x->pcpu_num == pcpu_num && x->km.state == XFRM_STATE_ACQ) { xfrm_state_hold(x); return x; @@ -2224,12 +2260,12 @@ static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 s return NULL; } -struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq) +struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq, u32 pcpu_num) { struct xfrm_state *x; spin_lock_bh(&net->xfrm.xfrm_state_lock); - x = __xfrm_find_acq_byseq(net, mark, seq); + x = __xfrm_find_acq_byseq(net, mark, seq, pcpu_num); spin_unlock_bh(&net->xfrm.xfrm_state_lock); return x; } diff --git a/net/xfrm/xfrm_user.c b/net/xfrm/xfrm_user.c index e3b8ce89831a..e4d448950d05 100644 --- a/net/xfrm/xfrm_user.c +++ b/net/xfrm/xfrm_user.c @@ -460,6 +460,12 @@ static int verify_newsa_info(struct xfrm_usersa_info *p, } } + if (!sa_dir && attrs[XFRMA_SA_PCPU]) { + NL_SET_ERR_MSG(extack, "SA_PCPU only supported with SA_DIR"); + err = -EINVAL; + goto out; + } + out: return err; } @@ -841,6 +847,12 @@ static struct xfrm_state *xfrm_state_construct(struct net *net, x->nat_keepalive_interval = nla_get_u32(attrs[XFRMA_NAT_KEEPALIVE_INTERVAL]); + if (attrs[XFRMA_SA_PCPU]) { + x->pcpu_num = nla_get_u32(attrs[XFRMA_SA_PCPU]); + if (x->pcpu_num >= num_possible_cpus()) + goto error; + } + err = __xfrm_init_state(x, false, attrs[XFRMA_OFFLOAD_DEV], extack); if (err) goto error; @@ -1296,6 +1308,11 @@ static int copy_to_user_state_extra(struct xfrm_state *x, if (ret) goto out; } + if (x->pcpu_num != UINT_MAX) { + ret = nla_put_u32(skb, XFRMA_SA_PCPU, x->pcpu_num); + if (ret) + goto out; + } if (x->dir) ret = nla_put_u8(skb, XFRMA_SA_DIR, x->dir); @@ -1700,6 +1717,7 @@ static int xfrm_alloc_userspi(struct sk_buff *skb, struct nlmsghdr *nlh, u32 mark; struct xfrm_mark m; u32 if_id = 0; + u32 pcpu_num = UINT_MAX; p = nlmsg_data(nlh); err = verify_spi_info(p->info.id.proto, p->min, p->max, extack); @@ -1716,8 +1734,16 @@ static int xfrm_alloc_userspi(struct sk_buff *skb, struct nlmsghdr *nlh, if (attrs[XFRMA_IF_ID]) if_id = nla_get_u32(attrs[XFRMA_IF_ID]); + if (attrs[XFRMA_SA_PCPU]) { + pcpu_num = nla_get_u32(attrs[XFRMA_SA_PCPU]); + if (pcpu_num >= num_possible_cpus()) { + err = -EINVAL; + goto out_noput; + } + } + if (p->info.seq) { - x = xfrm_find_acq_byseq(net, mark, p->info.seq); + x = xfrm_find_acq_byseq(net, mark, p->info.seq, pcpu_num); if (x && !xfrm_addr_equal(&x->id.daddr, daddr, family)) { xfrm_state_put(x); x = NULL; @@ -1726,7 +1752,7 @@ static int xfrm_alloc_userspi(struct sk_buff *skb, struct nlmsghdr *nlh, if (!x) x = xfrm_find_acq(net, &m, p->info.mode, p->info.reqid, - if_id, p->info.id.proto, daddr, + if_id, pcpu_num, p->info.id.proto, daddr, &p->info.saddr, 1, family); err = -ENOENT; @@ -2526,7 +2552,8 @@ static inline unsigned int xfrm_aevent_msgsize(struct xfrm_state *x) + nla_total_size(sizeof(struct xfrm_mark)) + nla_total_size(4) /* XFRM_AE_RTHR */ + nla_total_size(4) /* XFRM_AE_ETHR */ - + nla_total_size(sizeof(x->dir)); /* XFRMA_SA_DIR */ + + nla_total_size(sizeof(x->dir)) /* XFRMA_SA_DIR */ + + nla_total_size(4); /* XFRMA_SA_PCPU */ } static int build_aevent(struct sk_buff *skb, struct xfrm_state *x, const struct km_event *c) @@ -2582,6 +2609,8 @@ static int build_aevent(struct sk_buff *skb, struct xfrm_state *x, const struct err = xfrm_if_id_put(skb, x->if_id); if (err) goto out_cancel; + if (x->pcpu_num != UINT_MAX) + err = nla_put_u32(skb, XFRMA_SA_PCPU, x->pcpu_num); if (x->dir) { err = nla_put_u8(skb, XFRMA_SA_DIR, x->dir); @@ -2852,6 +2881,13 @@ static int xfrm_add_acquire(struct sk_buff *skb, struct nlmsghdr *nlh, xfrm_mark_get(attrs, &mark); + if (attrs[XFRMA_SA_PCPU]) { + x->pcpu_num = nla_get_u32(attrs[XFRMA_SA_PCPU]); + err = -EINVAL; + if (x->pcpu_num >= num_possible_cpus()) + goto free_state; + } + err = verify_newpolicy_info(&ua->policy, extack); if (err) goto free_state; @@ -3182,6 +3218,7 @@ const struct nla_policy xfrma_policy[XFRMA_MAX+1] = { [XFRMA_MTIMER_THRESH] = { .type = NLA_U32 }, [XFRMA_SA_DIR] = NLA_POLICY_RANGE(NLA_U8, XFRM_SA_DIR_IN, XFRM_SA_DIR_OUT), [XFRMA_NAT_KEEPALIVE_INTERVAL] = { .type = NLA_U32 }, + [XFRMA_SA_PCPU] = { .type = NLA_U32 }, }; EXPORT_SYMBOL_GPL(xfrma_policy); @@ -3348,7 +3385,8 @@ static inline unsigned int xfrm_expire_msgsize(void) { return NLMSG_ALIGN(sizeof(struct xfrm_user_expire)) + nla_total_size(sizeof(struct xfrm_mark)) + - nla_total_size(sizeof_field(struct xfrm_state, dir)); + nla_total_size(sizeof_field(struct xfrm_state, dir)) + + nla_total_size(4); /* XFRMA_SA_PCPU */ } static int build_expire(struct sk_buff *skb, struct xfrm_state *x, const struct km_event *c) @@ -3374,6 +3412,11 @@ static int build_expire(struct sk_buff *skb, struct xfrm_state *x, const struct err = xfrm_if_id_put(skb, x->if_id); if (err) return err; + if (x->pcpu_num != UINT_MAX) { + err = nla_put_u32(skb, XFRMA_SA_PCPU, x->pcpu_num); + if (err) + return err; + } if (x->dir) { err = nla_put_u8(skb, XFRMA_SA_DIR, x->dir); @@ -3481,6 +3524,8 @@ static inline unsigned int xfrm_sa_len(struct xfrm_state *x) } if (x->if_id) l += nla_total_size(sizeof(x->if_id)); + if (x->pcpu_num) + l += nla_total_size(sizeof(x->pcpu_num)); /* Must count x->lastused as it may become non-zero behind our back. */ l += nla_total_size_64bit(sizeof(u64)); @@ -3587,6 +3632,7 @@ static inline unsigned int xfrm_acquire_msgsize(struct xfrm_state *x, + nla_total_size(sizeof(struct xfrm_user_tmpl) * xp->xfrm_nr) + nla_total_size(sizeof(struct xfrm_mark)) + nla_total_size(xfrm_user_sec_ctx_size(x->security)) + + nla_total_size(4) /* XFRMA_SA_PCPU */ + userpolicy_type_attrsize(); } @@ -3623,6 +3669,8 @@ static int build_acquire(struct sk_buff *skb, struct xfrm_state *x, err = xfrm_if_id_put(skb, xp->if_id); if (!err && xp->xdo.dev) err = copy_user_offload(&xp->xdo, skb); + if (!err && x->pcpu_num != UINT_MAX) + err = nla_put_u32(skb, XFRMA_SA_PCPU, x->pcpu_num); if (err) { nlmsg_cancel(skb, nlh); return err;