bpf: generalize is_scalar_branch_taken() logic

Generalize is_branch_taken logic for SCALAR_VALUE register to handle
cases when both registers are not constants. Previously supported
<range> vs <scalar> cases are a natural subset of more generic <range>
vs <range> set of cases.

Generalized logic relies on straightforward segment intersection checks.

Acked-by: Eduard Zingerman <eddyz87@gmail.com>
Signed-off-by: Andrii Nakryiko <andrii@kernel.org>
Acked-by: Shung-Hsi Yu <shung-hsi.yu@suse.com>
Link: https://lore.kernel.org/r/20231112010609.848406-3-andrii@kernel.org
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
This commit is contained in:
Andrii Nakryiko 2023-11-11 17:05:58 -08:00 committed by Alexei Starovoitov
parent 67420501e8
commit 96381879a3

View File

@ -14261,82 +14261,99 @@ static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_sta
u8 opcode, bool is_jmp32) u8 opcode, bool is_jmp32)
{ {
struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off; struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off;
struct tnum t2 = is_jmp32 ? tnum_subreg(reg2->var_off) : reg2->var_off;
u64 umin1 = is_jmp32 ? (u64)reg1->u32_min_value : reg1->umin_value; u64 umin1 = is_jmp32 ? (u64)reg1->u32_min_value : reg1->umin_value;
u64 umax1 = is_jmp32 ? (u64)reg1->u32_max_value : reg1->umax_value; u64 umax1 = is_jmp32 ? (u64)reg1->u32_max_value : reg1->umax_value;
s64 smin1 = is_jmp32 ? (s64)reg1->s32_min_value : reg1->smin_value; s64 smin1 = is_jmp32 ? (s64)reg1->s32_min_value : reg1->smin_value;
s64 smax1 = is_jmp32 ? (s64)reg1->s32_max_value : reg1->smax_value; s64 smax1 = is_jmp32 ? (s64)reg1->s32_max_value : reg1->smax_value;
u64 uval = is_jmp32 ? (u32)tnum_subreg(reg2->var_off).value : reg2->var_off.value; u64 umin2 = is_jmp32 ? (u64)reg2->u32_min_value : reg2->umin_value;
s64 sval = is_jmp32 ? (s32)uval : (s64)uval; u64 umax2 = is_jmp32 ? (u64)reg2->u32_max_value : reg2->umax_value;
s64 smin2 = is_jmp32 ? (s64)reg2->s32_min_value : reg2->smin_value;
s64 smax2 = is_jmp32 ? (s64)reg2->s32_max_value : reg2->smax_value;
switch (opcode) { switch (opcode) {
case BPF_JEQ: case BPF_JEQ:
if (tnum_is_const(t1)) /* constants, umin/umax and smin/smax checks would be
return !!tnum_equals_const(t1, uval); * redundant in this case because they all should match
else if (uval < umin1 || uval > umax1) */
if (tnum_is_const(t1) && tnum_is_const(t2))
return t1.value == t2.value;
/* non-overlapping ranges */
if (umin1 > umax2 || umax1 < umin2)
return 0; return 0;
else if (sval < smin1 || sval > smax1) if (smin1 > smax2 || smax1 < smin2)
return 0; return 0;
break; break;
case BPF_JNE: case BPF_JNE:
if (tnum_is_const(t1)) /* constants, umin/umax and smin/smax checks would be
return !tnum_equals_const(t1, uval); * redundant in this case because they all should match
else if (uval < umin1 || uval > umax1) */
if (tnum_is_const(t1) && tnum_is_const(t2))
return t1.value != t2.value;
/* non-overlapping ranges */
if (umin1 > umax2 || umax1 < umin2)
return 1; return 1;
else if (sval < smin1 || sval > smax1) if (smin1 > smax2 || smax1 < smin2)
return 1; return 1;
break; break;
case BPF_JSET: case BPF_JSET:
if ((~t1.mask & t1.value) & uval) if (!is_reg_const(reg2, is_jmp32)) {
swap(reg1, reg2);
swap(t1, t2);
}
if (!is_reg_const(reg2, is_jmp32))
return -1;
if ((~t1.mask & t1.value) & t2.value)
return 1; return 1;
if (!((t1.mask | t1.value) & uval)) if (!((t1.mask | t1.value) & t2.value))
return 0; return 0;
break; break;
case BPF_JGT: case BPF_JGT:
if (umin1 > uval ) if (umin1 > umax2)
return 1; return 1;
else if (umax1 <= uval) else if (umax1 <= umin2)
return 0; return 0;
break; break;
case BPF_JSGT: case BPF_JSGT:
if (smin1 > sval) if (smin1 > smax2)
return 1; return 1;
else if (smax1 <= sval) else if (smax1 <= smin2)
return 0; return 0;
break; break;
case BPF_JLT: case BPF_JLT:
if (umax1 < uval) if (umax1 < umin2)
return 1; return 1;
else if (umin1 >= uval) else if (umin1 >= umax2)
return 0; return 0;
break; break;
case BPF_JSLT: case BPF_JSLT:
if (smax1 < sval) if (smax1 < smin2)
return 1; return 1;
else if (smin1 >= sval) else if (smin1 >= smax2)
return 0; return 0;
break; break;
case BPF_JGE: case BPF_JGE:
if (umin1 >= uval) if (umin1 >= umax2)
return 1; return 1;
else if (umax1 < uval) else if (umax1 < umin2)
return 0; return 0;
break; break;
case BPF_JSGE: case BPF_JSGE:
if (smin1 >= sval) if (smin1 >= smax2)
return 1; return 1;
else if (smax1 < sval) else if (smax1 < smin2)
return 0; return 0;
break; break;
case BPF_JLE: case BPF_JLE:
if (umax1 <= uval) if (umax1 <= umin2)
return 1; return 1;
else if (umin1 > uval) else if (umin1 > umax2)
return 0; return 0;
break; break;
case BPF_JSLE: case BPF_JSLE:
if (smax1 <= sval) if (smax1 <= smin2)
return 1; return 1;
else if (smin1 > sval) else if (smin1 > smax2)
return 0; return 0;
break; break;
} }
@ -14415,28 +14432,28 @@ static int is_pkt_ptr_branch_taken(struct bpf_reg_state *dst_reg,
static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2,
u8 opcode, bool is_jmp32) u8 opcode, bool is_jmp32)
{ {
u64 val;
if (reg_is_pkt_pointer_any(reg1) && reg_is_pkt_pointer_any(reg2) && !is_jmp32) if (reg_is_pkt_pointer_any(reg1) && reg_is_pkt_pointer_any(reg2) && !is_jmp32)
return is_pkt_ptr_branch_taken(reg1, reg2, opcode); return is_pkt_ptr_branch_taken(reg1, reg2, opcode);
/* try to make sure reg2 is a constant SCALAR_VALUE */ if (__is_pointer_value(false, reg1) || __is_pointer_value(false, reg2)) {
u64 val;
/* arrange that reg2 is a scalar, and reg1 is a pointer */
if (!is_reg_const(reg2, is_jmp32)) { if (!is_reg_const(reg2, is_jmp32)) {
opcode = flip_opcode(opcode); opcode = flip_opcode(opcode);
swap(reg1, reg2); swap(reg1, reg2);
} }
/* for now we expect reg2 to be a constant to make any useful decisions */ /* and ensure that reg2 is a constant */
if (!is_reg_const(reg2, is_jmp32)) if (!is_reg_const(reg2, is_jmp32))
return -1; return -1;
val = reg_const_value(reg2, is_jmp32);
if (__is_pointer_value(false, reg1)) {
if (!reg_not_null(reg1)) if (!reg_not_null(reg1))
return -1; return -1;
/* If pointer is valid tests against zero will fail so we can /* If pointer is valid tests against zero will fail so we can
* use this to direct branch taken. * use this to direct branch taken.
*/ */
val = reg_const_value(reg2, is_jmp32);
if (val != 0) if (val != 0)
return -1; return -1;
@ -14450,6 +14467,7 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg
} }
} }
/* now deal with two scalars, but not necessarily constants */
return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32); return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32);
} }