// SPDX-License-Identifier: GPL-2.0 /* Copyright 2025 Google LLC */ #include #include "sk_bypass_prot_mem.skel.h" #include "network_helpers.h" #define NR_PAGES 32 #define NR_SOCKETS 2 #define BUF_TOTAL (NR_PAGES * 4096 / NR_SOCKETS) #define BUF_SINGLE 1024 #define NR_SEND (BUF_TOTAL / BUF_SINGLE) struct test_case { char name[8]; int family; int type; int (*create_sockets)(struct test_case *test_case, int sk[], int len); long (*get_memory_allocated)(struct test_case *test_case, struct sk_bypass_prot_mem *skel); }; static int tcp_create_sockets(struct test_case *test_case, int sk[], int len) { int server, i, err = 0; server = start_server(test_case->family, test_case->type, NULL, 0, 0); if (!ASSERT_GE(server, 0, "start_server_str")) return server; /* Keep for-loop so we can change NR_SOCKETS easily. */ for (i = 0; i < len; i += 2) { sk[i] = connect_to_fd(server, 0); if (sk[i] < 0) { ASSERT_GE(sk[i], 0, "connect_to_fd"); err = sk[i]; break; } sk[i + 1] = accept(server, NULL, NULL); if (sk[i + 1] < 0) { ASSERT_GE(sk[i + 1], 0, "accept"); err = sk[i + 1]; break; } } close(server); return err; } static int udp_create_sockets(struct test_case *test_case, int sk[], int len) { int i, j, err, rcvbuf = BUF_TOTAL; /* Keep for-loop so we can change NR_SOCKETS easily. */ for (i = 0; i < len; i += 2) { sk[i] = start_server(test_case->family, test_case->type, NULL, 0, 0); if (sk[i] < 0) { ASSERT_GE(sk[i], 0, "start_server"); return sk[i]; } sk[i + 1] = connect_to_fd(sk[i], 0); if (sk[i + 1] < 0) { ASSERT_GE(sk[i + 1], 0, "connect_to_fd"); return sk[i + 1]; } err = connect_fd_to_fd(sk[i], sk[i + 1], 0); if (err) { ASSERT_EQ(err, 0, "connect_fd_to_fd"); return err; } for (j = 0; j < 2; j++) { err = setsockopt(sk[i + j], SOL_SOCKET, SO_RCVBUF, &rcvbuf, sizeof(int)); if (err) { ASSERT_EQ(err, 0, "setsockopt(SO_RCVBUF)"); return err; } } } return 0; } static long get_memory_allocated(struct test_case *test_case, bool *activated, long *memory_allocated) { int sk; *activated = true; /* AF_INET and AF_INET6 share the same memory_allocated. * tcp_init_sock() is called by AF_INET and AF_INET6, * but udp_lib_init_sock() is inline. */ sk = socket(AF_INET, test_case->type, 0); if (!ASSERT_GE(sk, 0, "get_memory_allocated")) return -1; close(sk); return *memory_allocated; } static long tcp_get_memory_allocated(struct test_case *test_case, struct sk_bypass_prot_mem *skel) { return get_memory_allocated(test_case, &skel->bss->tcp_activated, &skel->bss->tcp_memory_allocated); } static long udp_get_memory_allocated(struct test_case *test_case, struct sk_bypass_prot_mem *skel) { return get_memory_allocated(test_case, &skel->bss->udp_activated, &skel->bss->udp_memory_allocated); } static int check_bypass(struct test_case *test_case, struct sk_bypass_prot_mem *skel, bool bypass) { char buf[BUF_SINGLE] = {}; long memory_allocated[2]; int sk[NR_SOCKETS]; int err, i, j; for (i = 0; i < ARRAY_SIZE(sk); i++) sk[i] = -1; err = test_case->create_sockets(test_case, sk, ARRAY_SIZE(sk)); if (err) goto close; memory_allocated[0] = test_case->get_memory_allocated(test_case, skel); /* allocate pages >= NR_PAGES */ for (i = 0; i < ARRAY_SIZE(sk); i++) { for (j = 0; j < NR_SEND; j++) { int bytes = send(sk[i], buf, sizeof(buf), 0); /* Avoid too noisy logs when something failed. */ if (bytes != sizeof(buf)) { ASSERT_EQ(bytes, sizeof(buf), "send"); if (bytes < 0) { err = bytes; goto drain; } } } } memory_allocated[1] = test_case->get_memory_allocated(test_case, skel); if (bypass) ASSERT_LE(memory_allocated[1], memory_allocated[0] + 10, "bypass"); else ASSERT_GT(memory_allocated[1], memory_allocated[0] + NR_PAGES, "no bypass"); drain: if (test_case->type == SOCK_DGRAM) { /* UDP starts purging sk->sk_receive_queue after one RCU * grace period, then udp_memory_allocated goes down, * so drain the queue before close(). */ for (i = 0; i < ARRAY_SIZE(sk); i++) { for (j = 0; j < NR_SEND; j++) { int bytes = recv(sk[i], buf, 1, MSG_DONTWAIT | MSG_TRUNC); if (bytes == sizeof(buf)) continue; if (bytes != -1 || errno != EAGAIN) PRINT_FAIL("bytes: %d, errno: %s\n", bytes, strerror(errno)); break; } } } close: for (i = 0; i < ARRAY_SIZE(sk); i++) { if (sk[i] < 0) break; close(sk[i]); } return err; } static void run_test(struct test_case *test_case) { struct sk_bypass_prot_mem *skel; struct nstoken *nstoken; int cgroup, err; skel = sk_bypass_prot_mem__open_and_load(); if (!ASSERT_OK_PTR(skel, "open_and_load")) return; skel->bss->nr_cpus = libbpf_num_possible_cpus(); err = sk_bypass_prot_mem__attach(skel); if (!ASSERT_OK(err, "attach")) goto destroy_skel; cgroup = test__join_cgroup("/sk_bypass_prot_mem"); if (!ASSERT_GE(cgroup, 0, "join_cgroup")) goto destroy_skel; err = make_netns("sk_bypass_prot_mem"); if (!ASSERT_EQ(err, 0, "make_netns")) goto close_cgroup; nstoken = open_netns("sk_bypass_prot_mem"); if (!ASSERT_OK_PTR(nstoken, "open_netns")) goto remove_netns; err = check_bypass(test_case, skel, false); if (!ASSERT_EQ(err, 0, "test_bypass(false)")) goto close_netns; err = write_sysctl("/proc/sys/net/core/bypass_prot_mem", "1"); if (!ASSERT_EQ(err, 0, "write_sysctl(1)")) goto close_netns; err = check_bypass(test_case, skel, true); if (!ASSERT_EQ(err, 0, "test_bypass(true by sysctl)")) goto close_netns; err = write_sysctl("/proc/sys/net/core/bypass_prot_mem", "0"); if (!ASSERT_EQ(err, 0, "write_sysctl(0)")) goto close_netns; skel->links.sock_create = bpf_program__attach_cgroup(skel->progs.sock_create, cgroup); if (!ASSERT_OK_PTR(skel->links.sock_create, "attach_cgroup(sock_create)")) goto close_netns; err = check_bypass(test_case, skel, true); ASSERT_EQ(err, 0, "test_bypass(true by bpf)"); close_netns: close_netns(nstoken); remove_netns: remove_netns("sk_bypass_prot_mem"); close_cgroup: close(cgroup); destroy_skel: sk_bypass_prot_mem__destroy(skel); } static struct test_case test_cases[] = { { .name = "TCP ", .family = AF_INET, .type = SOCK_STREAM, .create_sockets = tcp_create_sockets, .get_memory_allocated = tcp_get_memory_allocated, }, { .name = "UDP ", .family = AF_INET, .type = SOCK_DGRAM, .create_sockets = udp_create_sockets, .get_memory_allocated = udp_get_memory_allocated, }, { .name = "TCPv6", .family = AF_INET6, .type = SOCK_STREAM, .create_sockets = tcp_create_sockets, .get_memory_allocated = tcp_get_memory_allocated, }, { .name = "UDPv6", .family = AF_INET6, .type = SOCK_DGRAM, .create_sockets = udp_create_sockets, .get_memory_allocated = udp_get_memory_allocated, }, }; void serial_test_sk_bypass_prot_mem(void) { int i; for (i = 0; i < ARRAY_SIZE(test_cases); i++) { if (test__start_subtest(test_cases[i].name)) run_test(&test_cases[i]); } }