diff --git a/net/bpf/test_run.c b/net/bpf/test_run.c index 36ae54f57bf5..df14f4c76186 100644 --- a/net/bpf/test_run.c +++ b/net/bpf/test_run.c @@ -970,7 +970,7 @@ static struct proto bpf_dummy_proto = { int bpf_prog_test_run_skb(struct bpf_prog *prog, const union bpf_attr *kattr, union bpf_attr __user *uattr) { - bool is_l2 = false, is_direct_pkt_access = false; + bool is_l2 = false, is_direct_pkt_access = false, ctx_needed = true; struct net *net = current->nsproxy->net_ns; struct net_device *dev = net->loopback_dev; u32 size = kattr->test.data_size_in; @@ -998,6 +998,35 @@ int bpf_prog_test_run_skb(struct bpf_prog *prog, const union bpf_attr *kattr, return PTR_ERR(ctx); } + switch (prog->type) { + case BPF_PROG_TYPE_SOCKET_FILTER: + case BPF_PROG_TYPE_SCHED_CLS: + case BPF_PROG_TYPE_SCHED_ACT: + case BPF_PROG_TYPE_XDP: + case BPF_PROG_TYPE_CGROUP_SKB: + case BPF_PROG_TYPE_CGROUP_SOCK: + case BPF_PROG_TYPE_SOCK_OPS: + case BPF_PROG_TYPE_SK_SKB: + case BPF_PROG_TYPE_SK_MSG: + case BPF_PROG_TYPE_CGROUP_SOCK_ADDR: + case BPF_PROG_TYPE_LWT_SEG6LOCAL: + case BPF_PROG_TYPE_SK_REUSEPORT: + case BPF_PROG_TYPE_NETFILTER: + case BPF_PROG_TYPE_LWT_IN: + case BPF_PROG_TYPE_LWT_OUT: + case BPF_PROG_TYPE_LWT_XMIT: + ctx_needed = true; + break; + default: + ctx_needed = false; + } + + if (!ctx && ctx_needed) { + kfree(data); + kfree(ctx); + return -EINVAL; + } + switch (prog->type) { case BPF_PROG_TYPE_SCHED_CLS: case BPF_PROG_TYPE_SCHED_ACT: