summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJakub Kicinski <kuba@kernel.org>2025-10-17 17:20:41 -0700
committerJakub Kicinski <kuba@kernel.org>2025-10-17 17:20:42 -0700
commite90576829ce47a46838685153494bc12cd1bc333 (patch)
tree438eaa3d0ebae810f84532006c3eafd8ac580536
parent37a183d3b7cdb873e7f5f9daef1ad6d8f7c95fb7 (diff)
parent03de843bd0806184505f1e8099ff4ca9a8665dbb (diff)
Merge tag 'for-netdev' of https://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf-next
Martin KaFai Lau says: ==================== pull-request: bpf-next 2025-10-16 We've added 6 non-merge commits during the last 1 day(s) which contain a total of 18 files changed, 577 insertions(+), 38 deletions(-). The main changes are: 1) Bypass the global per-protocol memory accounting either by setting a netns sysctl or using bpf_setsockopt in a bpf program, from Kuniyuki Iwashima. * tag 'for-netdev' of https://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf-next: selftests/bpf: Add test for sk->sk_bypass_prot_mem. bpf: Introduce SK_BPF_BYPASS_PROT_MEM. bpf: Support bpf_setsockopt() for BPF_CGROUP_INET_SOCK_CREATE. net: Introduce net.core.bypass_prot_mem sysctl. net: Allow opt-out from global protocol memory accounting. tcp: Save lock_sock() for memcg in inet_csk_accept(). ==================== Link: https://patch.msgid.link/20251016204539.773707-1-martin.lau@linux.dev Signed-off-by: Jakub Kicinski <kuba@kernel.org>
-rw-r--r--Documentation/admin-guide/sysctl/net.rst8
-rw-r--r--include/net/netns/core.h1
-rw-r--r--include/net/proto_memory.h3
-rw-r--r--include/net/sock.h3
-rw-r--r--include/net/tcp.h3
-rw-r--r--include/uapi/linux/bpf.h2
-rw-r--r--net/core/filter.c85
-rw-r--r--net/core/sock.c37
-rw-r--r--net/core/sysctl_net_core.c9
-rw-r--r--net/ipv4/af_inet.c22
-rw-r--r--net/ipv4/inet_connection_sock.c25
-rw-r--r--net/ipv4/tcp.c3
-rw-r--r--net/ipv4/tcp_output.c7
-rw-r--r--net/mptcp/protocol.c7
-rw-r--r--net/tls/tls_device.c3
-rw-r--r--tools/include/uapi/linux/bpf.h1
-rw-r--r--tools/testing/selftests/bpf/prog_tests/sk_bypass_prot_mem.c292
-rw-r--r--tools/testing/selftests/bpf/progs/sk_bypass_prot_mem.c104
18 files changed, 577 insertions, 38 deletions
diff --git a/Documentation/admin-guide/sysctl/net.rst b/Documentation/admin-guide/sysctl/net.rst
index 40749b3cd356..991773dcb9cf 100644
--- a/Documentation/admin-guide/sysctl/net.rst
+++ b/Documentation/admin-guide/sysctl/net.rst
@@ -212,6 +212,14 @@ mem_pcpu_rsv
Per-cpu reserved forward alloc cache size in page units. Default 1MB per CPU.
+bypass_prot_mem
+---------------
+
+Skip charging socket buffers to the global per-protocol memory
+accounting controlled by net.ipv4.tcp_mem, net.ipv4.udp_mem, etc.
+
+Default: 0 (off)
+
rmem_default
------------
diff --git a/include/net/netns/core.h b/include/net/netns/core.h
index cb9c3e4cd738..9ef3d70e5e9c 100644
--- a/include/net/netns/core.h
+++ b/include/net/netns/core.h
@@ -17,6 +17,7 @@ struct netns_core {
int sysctl_optmem_max;
u8 sysctl_txrehash;
u8 sysctl_tstamp_allow_data;
+ u8 sysctl_bypass_prot_mem;
#ifdef CONFIG_PROC_FS
struct prot_inuse __percpu *prot_inuse;
diff --git a/include/net/proto_memory.h b/include/net/proto_memory.h
index 8e91a8fa31b5..ad6d703ce6fe 100644
--- a/include/net/proto_memory.h
+++ b/include/net/proto_memory.h
@@ -35,6 +35,9 @@ static inline bool sk_under_memory_pressure(const struct sock *sk)
mem_cgroup_sk_under_memory_pressure(sk))
return true;
+ if (sk->sk_bypass_prot_mem)
+ return false;
+
return !!READ_ONCE(*sk->sk_prot->memory_pressure);
}
diff --git a/include/net/sock.h b/include/net/sock.h
index 335d0da82d79..5c564f114ae9 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -118,6 +118,7 @@ typedef __u64 __bitwise __addrpair;
* @skc_reuseport: %SO_REUSEPORT setting
* @skc_ipv6only: socket is IPV6 only
* @skc_net_refcnt: socket is using net ref counting
+ * @skc_bypass_prot_mem: bypass the per-protocol memory accounting for skb
* @skc_bound_dev_if: bound device index if != 0
* @skc_bind_node: bind hash linkage for various protocol lookup tables
* @skc_portaddr_node: second hash linkage for UDP/UDP-Lite protocol
@@ -174,6 +175,7 @@ struct sock_common {
unsigned char skc_reuseport:1;
unsigned char skc_ipv6only:1;
unsigned char skc_net_refcnt:1;
+ unsigned char skc_bypass_prot_mem:1;
int skc_bound_dev_if;
union {
struct hlist_node skc_bind_node;
@@ -381,6 +383,7 @@ struct sock {
#define sk_reuseport __sk_common.skc_reuseport
#define sk_ipv6only __sk_common.skc_ipv6only
#define sk_net_refcnt __sk_common.skc_net_refcnt
+#define sk_bypass_prot_mem __sk_common.skc_bypass_prot_mem
#define sk_bound_dev_if __sk_common.skc_bound_dev_if
#define sk_bind_node __sk_common.skc_bind_node
#define sk_prot __sk_common.skc_prot
diff --git a/include/net/tcp.h b/include/net/tcp.h
index 67fdd2523d92..190b3714e93b 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -303,6 +303,9 @@ static inline bool tcp_under_memory_pressure(const struct sock *sk)
mem_cgroup_sk_under_memory_pressure(sk))
return true;
+ if (sk->sk_bypass_prot_mem)
+ return false;
+
return READ_ONCE(tcp_memory_pressure);
}
/*
diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
index 6829936d33f5..6eb75ad900b1 100644
--- a/include/uapi/linux/bpf.h
+++ b/include/uapi/linux/bpf.h
@@ -7200,6 +7200,8 @@ enum {
TCP_BPF_SYN_MAC = 1007, /* Copy the MAC, IP[46], and TCP header */
TCP_BPF_SOCK_OPS_CB_FLAGS = 1008, /* Get or Set TCP sock ops flags */
SK_BPF_CB_FLAGS = 1009, /* Get or set sock ops flags in socket */
+ SK_BPF_BYPASS_PROT_MEM = 1010, /* Get or Set sk->sk_bypass_prot_mem */
+
};
enum {
diff --git a/net/core/filter.c b/net/core/filter.c
index 76628df1fc82..16105f52927d 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -5733,6 +5733,77 @@ static const struct bpf_func_proto bpf_sock_addr_getsockopt_proto = {
.arg5_type = ARG_CONST_SIZE,
};
+static int sk_bpf_set_get_bypass_prot_mem(struct sock *sk,
+ char *optval, int optlen,
+ bool getopt)
+{
+ int val;
+
+ if (optlen != sizeof(int))
+ return -EINVAL;
+
+ if (!sk_has_account(sk))
+ return -EOPNOTSUPP;
+
+ if (getopt) {
+ *(int *)optval = sk->sk_bypass_prot_mem;
+ return 0;
+ }
+
+ val = *(int *)optval;
+ if (val < 0 || val > 1)
+ return -EINVAL;
+
+ sk->sk_bypass_prot_mem = val;
+ return 0;
+}
+
+BPF_CALL_5(bpf_sock_create_setsockopt, struct sock *, sk, int, level,
+ int, optname, char *, optval, int, optlen)
+{
+ if (level == SOL_SOCKET && optname == SK_BPF_BYPASS_PROT_MEM)
+ return sk_bpf_set_get_bypass_prot_mem(sk, optval, optlen, false);
+
+ return __bpf_setsockopt(sk, level, optname, optval, optlen);
+}
+
+static const struct bpf_func_proto bpf_sock_create_setsockopt_proto = {
+ .func = bpf_sock_create_setsockopt,
+ .gpl_only = false,
+ .ret_type = RET_INTEGER,
+ .arg1_type = ARG_PTR_TO_CTX,
+ .arg2_type = ARG_ANYTHING,
+ .arg3_type = ARG_ANYTHING,
+ .arg4_type = ARG_PTR_TO_MEM | MEM_RDONLY,
+ .arg5_type = ARG_CONST_SIZE,
+};
+
+BPF_CALL_5(bpf_sock_create_getsockopt, struct sock *, sk, int, level,
+ int, optname, char *, optval, int, optlen)
+{
+ if (level == SOL_SOCKET && optname == SK_BPF_BYPASS_PROT_MEM) {
+ int err = sk_bpf_set_get_bypass_prot_mem(sk, optval, optlen, true);
+
+ if (err)
+ memset(optval, 0, optlen);
+
+ return err;
+ }
+
+ return __bpf_getsockopt(sk, level, optname, optval, optlen);
+}
+
+static const struct bpf_func_proto bpf_sock_create_getsockopt_proto = {
+ .func = bpf_sock_create_getsockopt,
+ .gpl_only = false,
+ .ret_type = RET_INTEGER,
+ .arg1_type = ARG_PTR_TO_CTX,
+ .arg2_type = ARG_ANYTHING,
+ .arg3_type = ARG_ANYTHING,
+ .arg4_type = ARG_PTR_TO_UNINIT_MEM,
+ .arg5_type = ARG_CONST_SIZE,
+};
+
BPF_CALL_5(bpf_sock_ops_setsockopt, struct bpf_sock_ops_kern *, bpf_sock,
int, level, int, optname, char *, optval, int, optlen)
{
@@ -8062,6 +8133,20 @@ sock_filter_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
return &bpf_sk_storage_get_cg_sock_proto;
case BPF_FUNC_ktime_get_coarse_ns:
return &bpf_ktime_get_coarse_ns_proto;
+ case BPF_FUNC_setsockopt:
+ switch (prog->expected_attach_type) {
+ case BPF_CGROUP_INET_SOCK_CREATE:
+ return &bpf_sock_create_setsockopt_proto;
+ default:
+ return NULL;
+ }
+ case BPF_FUNC_getsockopt:
+ switch (prog->expected_attach_type) {
+ case BPF_CGROUP_INET_SOCK_CREATE:
+ return &bpf_sock_create_getsockopt_proto;
+ default:
+ return NULL;
+ }
default:
return bpf_base_func_proto(func_id, prog);
}
diff --git a/net/core/sock.c b/net/core/sock.c
index 08ae20069b6d..b78533fb9268 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -1046,9 +1046,13 @@ static int sock_reserve_memory(struct sock *sk, int bytes)
if (!charged)
return -ENOMEM;
+ if (sk->sk_bypass_prot_mem)
+ goto success;
+
/* pre-charge to forward_alloc */
sk_memory_allocated_add(sk, pages);
allocated = sk_memory_allocated(sk);
+
/* If the system goes into memory pressure with this
* precharge, give up and return error.
*/
@@ -1057,6 +1061,8 @@ static int sock_reserve_memory(struct sock *sk, int bytes)
mem_cgroup_sk_uncharge(sk, pages);
return -ENOMEM;
}
+
+success:
sk_forward_alloc_add(sk, pages << PAGE_SHIFT);
WRITE_ONCE(sk->sk_reserved_mem,
@@ -2300,8 +2306,13 @@ struct sock *sk_alloc(struct net *net, int family, gfp_t priority,
* why we need sk_prot_creator -acme
*/
sk->sk_prot = sk->sk_prot_creator = prot;
+
+ if (READ_ONCE(net->core.sysctl_bypass_prot_mem))
+ sk->sk_bypass_prot_mem = 1;
+
sk->sk_kern_sock = kern;
sock_lock_init(sk);
+
sk->sk_net_refcnt = kern ? 0 : 1;
if (likely(sk->sk_net_refcnt)) {
get_net_track(net, &sk->ns_tracker, priority);
@@ -3145,8 +3156,11 @@ bool sk_page_frag_refill(struct sock *sk, struct page_frag *pfrag)
if (likely(skb_page_frag_refill(32U, pfrag, sk->sk_allocation)))
return true;
- sk_enter_memory_pressure(sk);
+ if (!sk->sk_bypass_prot_mem)
+ sk_enter_memory_pressure(sk);
+
sk_stream_moderate_sndbuf(sk);
+
return false;
}
EXPORT_SYMBOL(sk_page_frag_refill);
@@ -3263,10 +3277,12 @@ int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind)
{
bool memcg_enabled = false, charged = false;
struct proto *prot = sk->sk_prot;
- long allocated;
+ long allocated = 0;
- sk_memory_allocated_add(sk, amt);
- allocated = sk_memory_allocated(sk);
+ if (!sk->sk_bypass_prot_mem) {
+ sk_memory_allocated_add(sk, amt);
+ allocated = sk_memory_allocated(sk);
+ }
if (mem_cgroup_sk_enabled(sk)) {
memcg_enabled = true;
@@ -3275,6 +3291,9 @@ int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind)
goto suppress_allocation;
}
+ if (!allocated)
+ return 1;
+
/* Under limit. */
if (allocated <= sk_prot_mem_limits(sk, 0)) {
sk_leave_memory_pressure(sk);
@@ -3353,7 +3372,8 @@ suppress_allocation:
trace_sock_exceed_buf_limit(sk, prot, allocated, kind);
- sk_memory_allocated_sub(sk, amt);
+ if (allocated)
+ sk_memory_allocated_sub(sk, amt);
if (charged)
mem_cgroup_sk_uncharge(sk, amt);
@@ -3392,11 +3412,14 @@ EXPORT_SYMBOL(__sk_mem_schedule);
*/
void __sk_mem_reduce_allocated(struct sock *sk, int amount)
{
- sk_memory_allocated_sub(sk, amount);
-
if (mem_cgroup_sk_enabled(sk))
mem_cgroup_sk_uncharge(sk, amount);
+ if (sk->sk_bypass_prot_mem)
+ return;
+
+ sk_memory_allocated_sub(sk, amount);
+
if (sk_under_global_memory_pressure(sk) &&
(sk_memory_allocated(sk) < sk_prot_mem_limits(sk, 0)))
sk_leave_memory_pressure(sk);
diff --git a/net/core/sysctl_net_core.c b/net/core/sysctl_net_core.c
index f79137826d7f..8d4decb2606f 100644
--- a/net/core/sysctl_net_core.c
+++ b/net/core/sysctl_net_core.c
@@ -683,6 +683,15 @@ static struct ctl_table netns_core_table[] = {
.extra1 = SYSCTL_ZERO,
.extra2 = SYSCTL_ONE
},
+ {
+ .procname = "bypass_prot_mem",
+ .data = &init_net.core.sysctl_bypass_prot_mem,
+ .maxlen = sizeof(u8),
+ .mode = 0644,
+ .proc_handler = proc_dou8vec_minmax,
+ .extra1 = SYSCTL_ZERO,
+ .extra2 = SYSCTL_ONE
+ },
/* sysctl_core_net_init() will set the values after this
* to readonly in network namespaces
*/
diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
index 3109c5ec38f3..e8771faa5bbf 100644
--- a/net/ipv4/af_inet.c
+++ b/net/ipv4/af_inet.c
@@ -755,6 +755,28 @@ EXPORT_SYMBOL(inet_stream_connect);
void __inet_accept(struct socket *sock, struct socket *newsock, struct sock *newsk)
{
+ /* TODO: use sk_clone_lock() in SCTP and remove protocol checks */
+ if (mem_cgroup_sockets_enabled &&
+ (!IS_ENABLED(CONFIG_IP_SCTP) || sk_is_tcp(newsk))) {
+ gfp_t gfp = GFP_KERNEL | __GFP_NOFAIL;
+
+ mem_cgroup_sk_alloc(newsk);
+
+ if (mem_cgroup_from_sk(newsk)) {
+ int amt;
+
+ /* The socket has not been accepted yet, no need
+ * to look at newsk->sk_wmem_queued.
+ */
+ amt = sk_mem_pages(newsk->sk_forward_alloc +
+ atomic_read(&newsk->sk_rmem_alloc));
+ if (amt)
+ mem_cgroup_sk_charge(newsk, amt, gfp);
+ }
+
+ kmem_cache_charge(newsk, gfp);
+ }
+
sock_rps_record_flow(newsk);
WARN_ON(!((1 << newsk->sk_state) &
(TCPF_ESTABLISHED | TCPF_SYN_RECV |
diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c
index cdd1e12aac8c..3b83b66b2284 100644
--- a/net/ipv4/inet_connection_sock.c
+++ b/net/ipv4/inet_connection_sock.c
@@ -712,31 +712,6 @@ struct sock *inet_csk_accept(struct sock *sk, struct proto_accept_arg *arg)
release_sock(sk);
- if (mem_cgroup_sockets_enabled) {
- gfp_t gfp = GFP_KERNEL | __GFP_NOFAIL;
- int amt = 0;
-
- /* atomically get the memory usage, set and charge the
- * newsk->sk_memcg.
- */
- lock_sock(newsk);
-
- mem_cgroup_sk_alloc(newsk);
- if (mem_cgroup_from_sk(newsk)) {
- /* The socket has not been accepted yet, no need
- * to look at newsk->sk_wmem_queued.
- */
- amt = sk_mem_pages(newsk->sk_forward_alloc +
- atomic_read(&newsk->sk_rmem_alloc));
- }
-
- if (amt)
- mem_cgroup_sk_charge(newsk, amt, gfp);
- kmem_cache_charge(newsk, gfp);
-
- release_sock(newsk);
- }
-
if (req)
reqsk_put(req);
diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index 0ccc5405e740..e15b38f6bd2d 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -926,7 +926,8 @@ struct sk_buff *tcp_stream_alloc_skb(struct sock *sk, gfp_t gfp,
}
__kfree_skb(skb);
} else {
- sk->sk_prot->enter_memory_pressure(sk);
+ if (!sk->sk_bypass_prot_mem)
+ tcp_enter_memory_pressure(sk);
sk_stream_moderate_sndbuf(sk);
}
return NULL;
diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c
index b94efb3050d2..7f5df7a71f62 100644
--- a/net/ipv4/tcp_output.c
+++ b/net/ipv4/tcp_output.c
@@ -3743,12 +3743,17 @@ void sk_forced_mem_schedule(struct sock *sk, int size)
delta = size - sk->sk_forward_alloc;
if (delta <= 0)
return;
+
amt = sk_mem_pages(delta);
sk_forward_alloc_add(sk, amt << PAGE_SHIFT);
- sk_memory_allocated_add(sk, amt);
if (mem_cgroup_sk_enabled(sk))
mem_cgroup_sk_charge(sk, amt, gfp_memcg_charge() | __GFP_NOFAIL);
+
+ if (sk->sk_bypass_prot_mem)
+ return;
+
+ sk_memory_allocated_add(sk, amt);
}
/* Send a FIN. The caller locks the socket for us.
diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c
index 0292162a14ee..94a5f6dcc577 100644
--- a/net/mptcp/protocol.c
+++ b/net/mptcp/protocol.c
@@ -1065,11 +1065,12 @@ static void mptcp_enter_memory_pressure(struct sock *sk)
mptcp_for_each_subflow(msk, subflow) {
struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
- if (first)
+ if (first && !ssk->sk_bypass_prot_mem) {
tcp_enter_memory_pressure(ssk);
- sk_stream_moderate_sndbuf(ssk);
+ first = false;
+ }
- first = false;
+ sk_stream_moderate_sndbuf(ssk);
}
__mptcp_sync_sndbuf(sk);
}
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
index a64ae15b1a60..caa2b5d24622 100644
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -373,7 +373,8 @@ static int tls_do_allocation(struct sock *sk,
if (!offload_ctx->open_record) {
if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
sk->sk_allocation))) {
- READ_ONCE(sk->sk_prot)->enter_memory_pressure(sk);
+ if (!sk->sk_bypass_prot_mem)
+ READ_ONCE(sk->sk_prot)->enter_memory_pressure(sk);
sk_stream_moderate_sndbuf(sk);
return -ENOMEM;
}
diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h
index 6829936d33f5..9b17d937edf7 100644
--- a/tools/include/uapi/linux/bpf.h
+++ b/tools/include/uapi/linux/bpf.h
@@ -7200,6 +7200,7 @@ enum {
TCP_BPF_SYN_MAC = 1007, /* Copy the MAC, IP[46], and TCP header */
TCP_BPF_SOCK_OPS_CB_FLAGS = 1008, /* Get or Set TCP sock ops flags */
SK_BPF_CB_FLAGS = 1009, /* Get or set sock ops flags in socket */
+ SK_BPF_BYPASS_PROT_MEM = 1010, /* Get or Set sk->sk_bypass_prot_mem */
};
enum {
diff --git a/tools/testing/selftests/bpf/prog_tests/sk_bypass_prot_mem.c b/tools/testing/selftests/bpf/prog_tests/sk_bypass_prot_mem.c
new file mode 100644
index 000000000000..e4940583924b
--- /dev/null
+++ b/tools/testing/selftests/bpf/prog_tests/sk_bypass_prot_mem.c
@@ -0,0 +1,292 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Copyright 2025 Google LLC */
+
+#include <test_progs.h>
+#include "sk_bypass_prot_mem.skel.h"
+#include "network_helpers.h"
+
+#define NR_PAGES 32
+#define NR_SOCKETS 2
+#define BUF_TOTAL (NR_PAGES * 4096 / NR_SOCKETS)
+#define BUF_SINGLE 1024
+#define NR_SEND (BUF_TOTAL / BUF_SINGLE)
+
+struct test_case {
+ char name[8];
+ int family;
+ int type;
+ int (*create_sockets)(struct test_case *test_case, int sk[], int len);
+ long (*get_memory_allocated)(struct test_case *test_case, struct sk_bypass_prot_mem *skel);
+};
+
+static int tcp_create_sockets(struct test_case *test_case, int sk[], int len)
+{
+ int server, i, err = 0;
+
+ server = start_server(test_case->family, test_case->type, NULL, 0, 0);
+ if (!ASSERT_GE(server, 0, "start_server_str"))
+ return server;
+
+ /* Keep for-loop so we can change NR_SOCKETS easily. */
+ for (i = 0; i < len; i += 2) {
+ sk[i] = connect_to_fd(server, 0);
+ if (sk[i] < 0) {
+ ASSERT_GE(sk[i], 0, "connect_to_fd");
+ err = sk[i];
+ break;
+ }
+
+ sk[i + 1] = accept(server, NULL, NULL);
+ if (sk[i + 1] < 0) {
+ ASSERT_GE(sk[i + 1], 0, "accept");
+ err = sk[i + 1];
+ break;
+ }
+ }
+
+ close(server);
+
+ return err;
+}
+
+static int udp_create_sockets(struct test_case *test_case, int sk[], int len)
+{
+ int i, j, err, rcvbuf = BUF_TOTAL;
+
+ /* Keep for-loop so we can change NR_SOCKETS easily. */
+ for (i = 0; i < len; i += 2) {
+ sk[i] = start_server(test_case->family, test_case->type, NULL, 0, 0);
+ if (sk[i] < 0) {
+ ASSERT_GE(sk[i], 0, "start_server");
+ return sk[i];
+ }
+
+ sk[i + 1] = connect_to_fd(sk[i], 0);
+ if (sk[i + 1] < 0) {
+ ASSERT_GE(sk[i + 1], 0, "connect_to_fd");
+ return sk[i + 1];
+ }
+
+ err = connect_fd_to_fd(sk[i], sk[i + 1], 0);
+ if (err) {
+ ASSERT_EQ(err, 0, "connect_fd_to_fd");
+ return err;
+ }
+
+ for (j = 0; j < 2; j++) {
+ err = setsockopt(sk[i + j], SOL_SOCKET, SO_RCVBUF, &rcvbuf, sizeof(int));
+ if (err) {
+ ASSERT_EQ(err, 0, "setsockopt(SO_RCVBUF)");
+ return err;
+ }
+ }
+ }
+
+ return 0;
+}
+
+static long get_memory_allocated(struct test_case *test_case,
+ bool *activated, long *memory_allocated)
+{
+ int sk;
+
+ *activated = true;
+
+ /* AF_INET and AF_INET6 share the same memory_allocated.
+ * tcp_init_sock() is called by AF_INET and AF_INET6,
+ * but udp_lib_init_sock() is inline.
+ */
+ sk = socket(AF_INET, test_case->type, 0);
+ if (!ASSERT_GE(sk, 0, "get_memory_allocated"))
+ return -1;
+
+ close(sk);
+
+ return *memory_allocated;
+}
+
+static long tcp_get_memory_allocated(struct test_case *test_case, struct sk_bypass_prot_mem *skel)
+{
+ return get_memory_allocated(test_case,
+ &skel->bss->tcp_activated,
+ &skel->bss->tcp_memory_allocated);
+}
+
+static long udp_get_memory_allocated(struct test_case *test_case, struct sk_bypass_prot_mem *skel)
+{
+ return get_memory_allocated(test_case,
+ &skel->bss->udp_activated,
+ &skel->bss->udp_memory_allocated);
+}
+
+static int check_bypass(struct test_case *test_case,
+ struct sk_bypass_prot_mem *skel, bool bypass)
+{
+ char buf[BUF_SINGLE] = {};
+ long memory_allocated[2];
+ int sk[NR_SOCKETS];
+ int err, i, j;
+
+ for (i = 0; i < ARRAY_SIZE(sk); i++)
+ sk[i] = -1;
+
+ err = test_case->create_sockets(test_case, sk, ARRAY_SIZE(sk));
+ if (err)
+ goto close;
+
+ memory_allocated[0] = test_case->get_memory_allocated(test_case, skel);
+
+ /* allocate pages >= NR_PAGES */
+ for (i = 0; i < ARRAY_SIZE(sk); i++) {
+ for (j = 0; j < NR_SEND; j++) {
+ int bytes = send(sk[i], buf, sizeof(buf), 0);
+
+ /* Avoid too noisy logs when something failed. */
+ if (bytes != sizeof(buf)) {
+ ASSERT_EQ(bytes, sizeof(buf), "send");
+ if (bytes < 0) {
+ err = bytes;
+ goto drain;
+ }
+ }
+ }
+ }
+
+ memory_allocated[1] = test_case->get_memory_allocated(test_case, skel);
+
+ if (bypass)
+ ASSERT_LE(memory_allocated[1], memory_allocated[0] + 10, "bypass");
+ else
+ ASSERT_GT(memory_allocated[1], memory_allocated[0] + NR_PAGES, "no bypass");
+
+drain:
+ if (test_case->type == SOCK_DGRAM) {
+ /* UDP starts purging sk->sk_receive_queue after one RCU
+ * grace period, then udp_memory_allocated goes down,
+ * so drain the queue before close().
+ */
+ for (i = 0; i < ARRAY_SIZE(sk); i++) {
+ for (j = 0; j < NR_SEND; j++) {
+ int bytes = recv(sk[i], buf, 1, MSG_DONTWAIT | MSG_TRUNC);
+
+ if (bytes == sizeof(buf))
+ continue;
+ if (bytes != -1 || errno != EAGAIN)
+ PRINT_FAIL("bytes: %d, errno: %s\n", bytes, strerror(errno));
+ break;
+ }
+ }
+ }
+
+close:
+ for (i = 0; i < ARRAY_SIZE(sk); i++) {
+ if (sk[i] < 0)
+ break;
+
+ close(sk[i]);
+ }
+
+ return err;
+}
+
+static void run_test(struct test_case *test_case)
+{
+ struct sk_bypass_prot_mem *skel;
+ struct nstoken *nstoken;
+ int cgroup, err;
+
+ skel = sk_bypass_prot_mem__open_and_load();
+ if (!ASSERT_OK_PTR(skel, "open_and_load"))
+ return;
+
+ skel->bss->nr_cpus = libbpf_num_possible_cpus();
+
+ err = sk_bypass_prot_mem__attach(skel);
+ if (!ASSERT_OK(err, "attach"))
+ goto destroy_skel;
+
+ cgroup = test__join_cgroup("/sk_bypass_prot_mem");
+ if (!ASSERT_GE(cgroup, 0, "join_cgroup"))
+ goto destroy_skel;
+
+ err = make_netns("sk_bypass_prot_mem");
+ if (!ASSERT_EQ(err, 0, "make_netns"))
+ goto close_cgroup;
+
+ nstoken = open_netns("sk_bypass_prot_mem");
+ if (!ASSERT_OK_PTR(nstoken, "open_netns"))
+ goto remove_netns;
+
+ err = check_bypass(test_case, skel, false);
+ if (!ASSERT_EQ(err, 0, "test_bypass(false)"))
+ goto close_netns;
+
+ err = write_sysctl("/proc/sys/net/core/bypass_prot_mem", "1");
+ if (!ASSERT_EQ(err, 0, "write_sysctl(1)"))
+ goto close_netns;
+
+ err = check_bypass(test_case, skel, true);
+ if (!ASSERT_EQ(err, 0, "test_bypass(true by sysctl)"))
+ goto close_netns;
+
+ err = write_sysctl("/proc/sys/net/core/bypass_prot_mem", "0");
+ if (!ASSERT_EQ(err, 0, "write_sysctl(0)"))
+ goto close_netns;
+
+ skel->links.sock_create = bpf_program__attach_cgroup(skel->progs.sock_create, cgroup);
+ if (!ASSERT_OK_PTR(skel->links.sock_create, "attach_cgroup(sock_create)"))
+ goto close_netns;
+
+ err = check_bypass(test_case, skel, true);
+ ASSERT_EQ(err, 0, "test_bypass(true by bpf)");
+
+close_netns:
+ close_netns(nstoken);
+remove_netns:
+ remove_netns("sk_bypass_prot_mem");
+close_cgroup:
+ close(cgroup);
+destroy_skel:
+ sk_bypass_prot_mem__destroy(skel);
+}
+
+static struct test_case test_cases[] = {
+ {
+ .name = "TCP ",
+ .family = AF_INET,
+ .type = SOCK_STREAM,
+ .create_sockets = tcp_create_sockets,
+ .get_memory_allocated = tcp_get_memory_allocated,
+ },
+ {
+ .name = "UDP ",
+ .family = AF_INET,
+ .type = SOCK_DGRAM,
+ .create_sockets = udp_create_sockets,
+ .get_memory_allocated = udp_get_memory_allocated,
+ },
+ {
+ .name = "TCPv6",
+ .family = AF_INET6,
+ .type = SOCK_STREAM,
+ .create_sockets = tcp_create_sockets,
+ .get_memory_allocated = tcp_get_memory_allocated,
+ },
+ {
+ .name = "UDPv6",
+ .family = AF_INET6,
+ .type = SOCK_DGRAM,
+ .create_sockets = udp_create_sockets,
+ .get_memory_allocated = udp_get_memory_allocated,
+ },
+};
+
+void serial_test_sk_bypass_prot_mem(void)
+{
+ int i;
+
+ for (i = 0; i < ARRAY_SIZE(test_cases); i++) {
+ if (test__start_subtest(test_cases[i].name))
+ run_test(&test_cases[i]);
+ }
+}
diff --git a/tools/testing/selftests/bpf/progs/sk_bypass_prot_mem.c b/tools/testing/selftests/bpf/progs/sk_bypass_prot_mem.c
new file mode 100644
index 000000000000..09a00d11ffcc
--- /dev/null
+++ b/tools/testing/selftests/bpf/progs/sk_bypass_prot_mem.c
@@ -0,0 +1,104 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Copyright 2025 Google LLC */
+
+#include "bpf_tracing_net.h"
+#include <bpf/bpf_helpers.h>
+#include <bpf/bpf_tracing.h>
+#include <errno.h>
+
+extern int tcp_memory_per_cpu_fw_alloc __ksym;
+extern int udp_memory_per_cpu_fw_alloc __ksym;
+
+int nr_cpus;
+bool tcp_activated, udp_activated;
+long tcp_memory_allocated, udp_memory_allocated;
+
+struct sk_prot {
+ long *memory_allocated;
+ int *memory_per_cpu_fw_alloc;
+};
+
+static int drain_memory_per_cpu_fw_alloc(__u32 i, struct sk_prot *sk_prot_ctx)
+{
+ int *memory_per_cpu_fw_alloc;
+
+ memory_per_cpu_fw_alloc = bpf_per_cpu_ptr(sk_prot_ctx->memory_per_cpu_fw_alloc, i);
+ if (memory_per_cpu_fw_alloc)
+ *sk_prot_ctx->memory_allocated += *memory_per_cpu_fw_alloc;
+
+ return 0;
+}
+
+static long get_memory_allocated(struct sock *_sk, int *memory_per_cpu_fw_alloc)
+{
+ struct sock *sk = bpf_core_cast(_sk, struct sock);
+ struct sk_prot sk_prot_ctx;
+ long memory_allocated;
+
+ /* net_aligned_data.{tcp,udp}_memory_allocated was not available. */
+ memory_allocated = sk->__sk_common.skc_prot->memory_allocated->counter;
+
+ sk_prot_ctx.memory_allocated = &memory_allocated;
+ sk_prot_ctx.memory_per_cpu_fw_alloc = memory_per_cpu_fw_alloc;
+
+ bpf_loop(nr_cpus, drain_memory_per_cpu_fw_alloc, &sk_prot_ctx, 0);
+
+ return memory_allocated;
+}
+
+static void fentry_init_sock(struct sock *sk, bool *activated,
+ long *memory_allocated, int *memory_per_cpu_fw_alloc)
+{
+ if (!*activated)
+ return;
+
+ *memory_allocated = get_memory_allocated(sk, memory_per_cpu_fw_alloc);
+ *activated = false;
+}
+
+SEC("fentry/tcp_init_sock")
+int BPF_PROG(fentry_tcp_init_sock, struct sock *sk)
+{
+ fentry_init_sock(sk, &tcp_activated,
+ &tcp_memory_allocated, &tcp_memory_per_cpu_fw_alloc);
+ return 0;
+}
+
+SEC("fentry/udp_init_sock")
+int BPF_PROG(fentry_udp_init_sock, struct sock *sk)
+{
+ fentry_init_sock(sk, &udp_activated,
+ &udp_memory_allocated, &udp_memory_per_cpu_fw_alloc);
+ return 0;
+}
+
+SEC("cgroup/sock_create")
+int sock_create(struct bpf_sock *ctx)
+{
+ int err, val = 1;
+
+ err = bpf_setsockopt(ctx, SOL_SOCKET, SK_BPF_BYPASS_PROT_MEM,
+ &val, sizeof(val));
+ if (err)
+ goto err;
+
+ val = 0;
+
+ err = bpf_getsockopt(ctx, SOL_SOCKET, SK_BPF_BYPASS_PROT_MEM,
+ &val, sizeof(val));
+ if (err)
+ goto err;
+
+ if (val != 1) {
+ err = -EINVAL;
+ goto err;
+ }
+
+ return 1;
+
+err:
+ bpf_set_retval(err);
+ return 0;
+}
+
+char LICENSE[] SEC("license") = "GPL";