Use READ/WRITE_ONCE() for IP local_port_range.

Commit 227b60f510 added a seqlock to ensure that the low and high
port numbers were always updated together.
This is overkill because the two 16bit port numbers can be held in
a u32 and read/written in a single instruction.

More recently 91d0b78c51 added support for finer per-socket limits.
The user-supplied value is 'high << 16 | low' but they are held
separately and the socket options protected by the socket lock.

Use a u32 containing 'high << 16 | low' for both the 'net' and 'sk'
fields and use READ_ONCE()/WRITE_ONCE() to ensure both values are
always updated together.

Change (the now trival) inet_get_local_port_range() to a static inline
to optimise the calling code.
(In particular avoiding returning integers by reference.)

Signed-off-by: David Laight <david.laight@aculab.com>
Reviewed-by: Eric Dumazet <edumazet@google.com>
Reviewed-by: David Ahern <dsahern@kernel.org>
Acked-by: Mat Martineau <martineau@kernel.org>
Reviewed-by: Kuniyuki Iwashima <kuniyu@amazon.com>
Link: https://lore.kernel.org/r/4e505d4198e946a8be03fb1b4c3072b0@AcuMS.aculab.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
This commit is contained in:
David Laight 2023-12-06 13:44:20 +00:00 committed by Jakub Kicinski
parent 36b0bdb6d3
commit d9f28735af
7 changed files with 43 additions and 57 deletions

View File

@ -234,10 +234,7 @@ struct inet_sock {
int uc_index;
int mc_index;
__be32 mc_addr;
struct {
__u16 lo;
__u16 hi;
} local_port_range;
u32 local_port_range; /* high << 16 | low */
struct ip_mc_socklist __rcu *mc_list;
struct inet_cork_full cork;

View File

@ -349,7 +349,13 @@ static inline u64 snmp_fold_field64(void __percpu *mib, int offt, size_t syncp_o
} \
}
void inet_get_local_port_range(const struct net *net, int *low, int *high);
static inline void inet_get_local_port_range(const struct net *net, int *low, int *high)
{
u32 range = READ_ONCE(net->ipv4.ip_local_ports.range);
*low = range & 0xffff;
*high = range >> 16;
}
void inet_sk_get_local_port_range(const struct sock *sk, int *low, int *high);
#ifdef CONFIG_SYSCTL

View File

@ -19,8 +19,7 @@ struct hlist_head;
struct fib_table;
struct sock;
struct local_ports {
seqlock_t lock;
int range[2];
u32 range; /* high << 16 | low */
bool warned;
};

View File

@ -1847,9 +1847,7 @@ static __net_init int inet_init_net(struct net *net)
/*
* Set defaults for local port range
*/
seqlock_init(&net->ipv4.ip_local_ports.lock);
net->ipv4.ip_local_ports.range[0] = 32768;
net->ipv4.ip_local_ports.range[1] = 60999;
net->ipv4.ip_local_ports.range = 60999u << 16 | 32768u;
seqlock_init(&net->ipv4.ping_group_range.lock);
/*

View File

@ -117,34 +117,25 @@ bool inet_rcv_saddr_any(const struct sock *sk)
return !sk->sk_rcv_saddr;
}
void inet_get_local_port_range(const struct net *net, int *low, int *high)
{
unsigned int seq;
do {
seq = read_seqbegin(&net->ipv4.ip_local_ports.lock);
*low = net->ipv4.ip_local_ports.range[0];
*high = net->ipv4.ip_local_ports.range[1];
} while (read_seqretry(&net->ipv4.ip_local_ports.lock, seq));
}
EXPORT_SYMBOL(inet_get_local_port_range);
void inet_sk_get_local_port_range(const struct sock *sk, int *low, int *high)
{
const struct inet_sock *inet = inet_sk(sk);
const struct net *net = sock_net(sk);
int lo, hi, sk_lo, sk_hi;
u32 sk_range;
inet_get_local_port_range(net, &lo, &hi);
sk_lo = inet->local_port_range.lo;
sk_hi = inet->local_port_range.hi;
sk_range = READ_ONCE(inet->local_port_range);
if (unlikely(sk_range)) {
sk_lo = sk_range & 0xffff;
sk_hi = sk_range >> 16;
if (unlikely(lo <= sk_lo && sk_lo <= hi))
lo = sk_lo;
if (unlikely(lo <= sk_hi && sk_hi <= hi))
hi = sk_hi;
if (lo <= sk_lo && sk_lo <= hi)
lo = sk_lo;
if (lo <= sk_hi && sk_hi <= hi)
hi = sk_hi;
}
*low = lo;
*high = hi;

View File

@ -1055,6 +1055,19 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname,
case IP_TOS: /* This sets both TOS and Precedence */
ip_sock_set_tos(sk, val);
return 0;
case IP_LOCAL_PORT_RANGE:
{
u16 lo = val;
u16 hi = val >> 16;
if (optlen != sizeof(u32))
return -EINVAL;
if (lo != 0 && hi != 0 && lo > hi)
return -EINVAL;
WRITE_ONCE(inet->local_port_range, val);
return 0;
}
}
err = 0;
@ -1332,20 +1345,6 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname,
err = xfrm_user_policy(sk, optname, optval, optlen);
break;
case IP_LOCAL_PORT_RANGE:
{
const __u16 lo = val;
const __u16 hi = val >> 16;
if (optlen != sizeof(__u32))
goto e_inval;
if (lo != 0 && hi != 0 && lo > hi)
goto e_inval;
inet->local_port_range.lo = lo;
inet->local_port_range.hi = hi;
break;
}
default:
err = -ENOPROTOOPT;
break;
@ -1692,6 +1691,9 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
return -EFAULT;
return 0;
}
case IP_LOCAL_PORT_RANGE:
val = READ_ONCE(inet->local_port_range);
goto copyval;
}
if (needs_rtnl)
@ -1721,9 +1723,6 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
else
err = ip_get_mcast_msfilter(sk, optval, optlen, len);
goto out;
case IP_LOCAL_PORT_RANGE:
val = inet->local_port_range.hi << 16 | inet->local_port_range.lo;
break;
case IP_PROTOCOL:
val = inet_sk(sk)->inet_num;
break;

View File

@ -50,26 +50,22 @@ static int tcp_plb_max_cong_thresh = 256;
static int sysctl_tcp_low_latency __read_mostly;
/* Update system visible IP port range */
static void set_local_port_range(struct net *net, int range[2])
static void set_local_port_range(struct net *net, unsigned int low, unsigned int high)
{
bool same_parity = !((range[0] ^ range[1]) & 1);
bool same_parity = !((low ^ high) & 1);
write_seqlock_bh(&net->ipv4.ip_local_ports.lock);
if (same_parity && !net->ipv4.ip_local_ports.warned) {
net->ipv4.ip_local_ports.warned = true;
pr_err_ratelimited("ip_local_port_range: prefer different parity for start/end values.\n");
}
net->ipv4.ip_local_ports.range[0] = range[0];
net->ipv4.ip_local_ports.range[1] = range[1];
write_sequnlock_bh(&net->ipv4.ip_local_ports.lock);
WRITE_ONCE(net->ipv4.ip_local_ports.range, high << 16 | low);
}
/* Validate changes from /proc interface. */
static int ipv4_local_port_range(struct ctl_table *table, int write,
void *buffer, size_t *lenp, loff_t *ppos)
{
struct net *net =
container_of(table->data, struct net, ipv4.ip_local_ports.range);
struct net *net = table->data;
int ret;
int range[2];
struct ctl_table tmp = {
@ -93,7 +89,7 @@ static int ipv4_local_port_range(struct ctl_table *table, int write,
(range[0] < READ_ONCE(net->ipv4.sysctl_ip_prot_sock)))
ret = -EINVAL;
else
set_local_port_range(net, range);
set_local_port_range(net, range[0], range[1]);
}
return ret;
@ -733,8 +729,8 @@ static struct ctl_table ipv4_net_table[] = {
},
{
.procname = "ip_local_port_range",
.maxlen = sizeof(init_net.ipv4.ip_local_ports.range),
.data = &init_net.ipv4.ip_local_ports.range,
.maxlen = 0,
.data = &init_net,
.mode = 0644,
.proc_handler = ipv4_local_port_range,
},