summaryrefslogtreecommitdiff
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to 'net')
-rw-r--r--net/psp/Kconfig1
-rw-r--r--net/psp/Makefile2
-rw-r--r--net/psp/psp-nl-gen.c39
-rw-r--r--net/psp/psp-nl-gen.h7
-rw-r--r--net/psp/psp.h22
-rw-r--r--net/psp/psp_main.c26
-rw-r--r--net/psp/psp_nl.c232
-rw-r--r--net/psp/psp_sock.c274
8 files changed, 601 insertions, 2 deletions
diff --git a/net/psp/Kconfig b/net/psp/Kconfig
index 5e3908a40945..a7d24691a7e1 100644
--- a/net/psp/Kconfig
+++ b/net/psp/Kconfig
@@ -6,6 +6,7 @@ config INET_PSP
bool "PSP Security Protocol support"
depends on INET
select SKB_DECRYPTED
+ select SOCK_VALIDATE_XMIT
help
Enable kernel support for the PSP protocol.
For more information see:
diff --git a/net/psp/Makefile b/net/psp/Makefile
index 41b51d06e560..eb5ff3c5bfb2 100644
--- a/net/psp/Makefile
+++ b/net/psp/Makefile
@@ -2,4 +2,4 @@
obj-$(CONFIG_INET_PSP) += psp.o
-psp-y := psp_main.o psp_nl.o psp-nl-gen.o
+psp-y := psp_main.o psp_nl.o psp_sock.o psp-nl-gen.o
diff --git a/net/psp/psp-nl-gen.c b/net/psp/psp-nl-gen.c
index 7f49577ac72f..9fdd6f831803 100644
--- a/net/psp/psp-nl-gen.c
+++ b/net/psp/psp-nl-gen.c
@@ -10,6 +10,12 @@
#include <uapi/linux/psp.h>
+/* Common nested types */
+const struct nla_policy psp_keys_nl_policy[PSP_A_KEYS_SPI + 1] = {
+ [PSP_A_KEYS_KEY] = { .type = NLA_BINARY, },
+ [PSP_A_KEYS_SPI] = { .type = NLA_U32, },
+};
+
/* PSP_CMD_DEV_GET - do */
static const struct nla_policy psp_dev_get_nl_policy[PSP_A_DEV_ID + 1] = {
[PSP_A_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1),
@@ -26,6 +32,21 @@ static const struct nla_policy psp_key_rotate_nl_policy[PSP_A_DEV_ID + 1] = {
[PSP_A_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1),
};
+/* PSP_CMD_RX_ASSOC - do */
+static const struct nla_policy psp_rx_assoc_nl_policy[PSP_A_ASSOC_SOCK_FD + 1] = {
+ [PSP_A_ASSOC_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1),
+ [PSP_A_ASSOC_VERSION] = NLA_POLICY_MAX(NLA_U32, 3),
+ [PSP_A_ASSOC_SOCK_FD] = { .type = NLA_U32, },
+};
+
+/* PSP_CMD_TX_ASSOC - do */
+static const struct nla_policy psp_tx_assoc_nl_policy[PSP_A_ASSOC_SOCK_FD + 1] = {
+ [PSP_A_ASSOC_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1),
+ [PSP_A_ASSOC_VERSION] = NLA_POLICY_MAX(NLA_U32, 3),
+ [PSP_A_ASSOC_TX_KEY] = NLA_POLICY_NESTED(psp_keys_nl_policy),
+ [PSP_A_ASSOC_SOCK_FD] = { .type = NLA_U32, },
+};
+
/* Ops table for psp */
static const struct genl_split_ops psp_nl_ops[] = {
{
@@ -60,6 +81,24 @@ static const struct genl_split_ops psp_nl_ops[] = {
.maxattr = PSP_A_DEV_ID,
.flags = GENL_CMD_CAP_DO,
},
+ {
+ .cmd = PSP_CMD_RX_ASSOC,
+ .pre_doit = psp_assoc_device_get_locked,
+ .doit = psp_nl_rx_assoc_doit,
+ .post_doit = psp_device_unlock,
+ .policy = psp_rx_assoc_nl_policy,
+ .maxattr = PSP_A_ASSOC_SOCK_FD,
+ .flags = GENL_CMD_CAP_DO,
+ },
+ {
+ .cmd = PSP_CMD_TX_ASSOC,
+ .pre_doit = psp_assoc_device_get_locked,
+ .doit = psp_nl_tx_assoc_doit,
+ .post_doit = psp_device_unlock,
+ .policy = psp_tx_assoc_nl_policy,
+ .maxattr = PSP_A_ASSOC_SOCK_FD,
+ .flags = GENL_CMD_CAP_DO,
+ },
};
static const struct genl_multicast_group psp_nl_mcgrps[] = {
diff --git a/net/psp/psp-nl-gen.h b/net/psp/psp-nl-gen.h
index 00a2d4ec59e4..25268ed11fb5 100644
--- a/net/psp/psp-nl-gen.h
+++ b/net/psp/psp-nl-gen.h
@@ -11,8 +11,13 @@
#include <uapi/linux/psp.h>
+/* Common nested types */
+extern const struct nla_policy psp_keys_nl_policy[PSP_A_KEYS_SPI + 1];
+
int psp_device_get_locked(const struct genl_split_ops *ops,
struct sk_buff *skb, struct genl_info *info);
+int psp_assoc_device_get_locked(const struct genl_split_ops *ops,
+ struct sk_buff *skb, struct genl_info *info);
void
psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb,
struct genl_info *info);
@@ -21,6 +26,8 @@ int psp_nl_dev_get_doit(struct sk_buff *skb, struct genl_info *info);
int psp_nl_dev_get_dumpit(struct sk_buff *skb, struct netlink_callback *cb);
int psp_nl_dev_set_doit(struct sk_buff *skb, struct genl_info *info);
int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info);
+int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info);
+int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info);
enum {
PSP_NLGRP_MGMT,
diff --git a/net/psp/psp.h b/net/psp/psp.h
index 94d0cc31a61f..defd3e3fd5e7 100644
--- a/net/psp/psp.h
+++ b/net/psp/psp.h
@@ -4,6 +4,7 @@
#define __PSP_PSP_H
#include <linux/list.h>
+#include <linux/lockdep.h>
#include <linux/mutex.h>
#include <net/netns/generic.h>
#include <net/psp.h>
@@ -17,15 +18,36 @@ int psp_dev_check_access(struct psp_dev *psd, struct net *net);
void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd);
+struct psp_assoc *psp_assoc_create(struct psp_dev *psd);
+struct psp_dev *psp_dev_get_for_sock(struct sock *sk);
+void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas);
+int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas,
+ struct psp_key_parsed *key,
+ struct netlink_ext_ack *extack);
+int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd,
+ u32 version, struct psp_key_parsed *key,
+ struct netlink_ext_ack *extack);
+
static inline void psp_dev_get(struct psp_dev *psd)
{
refcount_inc(&psd->refcnt);
}
+static inline bool psp_dev_tryget(struct psp_dev *psd)
+{
+ return refcount_inc_not_zero(&psd->refcnt);
+}
+
static inline void psp_dev_put(struct psp_dev *psd)
{
if (refcount_dec_and_test(&psd->refcnt))
psp_dev_destroy(psd);
}
+static inline bool psp_dev_is_registered(struct psp_dev *psd)
+{
+ lockdep_assert_held(&psd->lock);
+ return !!psd->ops;
+}
+
#endif /* __PSP_PSP_H */
diff --git a/net/psp/psp_main.c b/net/psp/psp_main.c
index f60155493afc..a1ae3c8920c3 100644
--- a/net/psp/psp_main.c
+++ b/net/psp/psp_main.c
@@ -55,7 +55,10 @@ psp_dev_create(struct net_device *netdev,
if (WARN_ON(!psd_caps->versions ||
!psd_ops->set_config ||
- !psd_ops->key_rotate))
+ !psd_ops->key_rotate ||
+ !psd_ops->rx_spi_alloc ||
+ !psd_ops->tx_key_add ||
+ !psd_ops->tx_key_del))
return ERR_PTR(-EINVAL);
psd = kzalloc(sizeof(*psd), GFP_KERNEL);
@@ -68,6 +71,7 @@ psp_dev_create(struct net_device *netdev,
psd->drv_priv = priv_ptr;
mutex_init(&psd->lock);
+ INIT_LIST_HEAD(&psd->active_assocs);
refcount_set(&psd->refcnt, 1);
mutex_lock(&psp_devs_lock);
@@ -107,6 +111,8 @@ void psp_dev_destroy(struct psp_dev *psd)
*/
void psp_dev_unregister(struct psp_dev *psd)
{
+ struct psp_assoc *pas, *next;
+
mutex_lock(&psp_devs_lock);
mutex_lock(&psd->lock);
@@ -119,6 +125,9 @@ void psp_dev_unregister(struct psp_dev *psd)
xa_store(&psp_devs, psd->id, NULL, GFP_KERNEL);
mutex_unlock(&psp_devs_lock);
+ list_for_each_entry_safe(pas, next, &psd->active_assocs, assocs_list)
+ psp_dev_tx_key_del(psd, pas);
+
rcu_assign_pointer(psd->main_netdev->psp_dev, NULL);
psd->ops = NULL;
@@ -130,6 +139,21 @@ void psp_dev_unregister(struct psp_dev *psd)
}
EXPORT_SYMBOL(psp_dev_unregister);
+unsigned int psp_key_size(u32 version)
+{
+ switch (version) {
+ case PSP_VERSION_HDR0_AES_GCM_128:
+ case PSP_VERSION_HDR0_AES_GMAC_128:
+ return 16;
+ case PSP_VERSION_HDR0_AES_GCM_256:
+ case PSP_VERSION_HDR0_AES_GMAC_256:
+ return 32;
+ default:
+ return 0;
+ }
+}
+EXPORT_SYMBOL(psp_key_size);
+
static int __init psp_init(void)
{
mutex_init(&psp_devs_lock);
diff --git a/net/psp/psp_nl.c b/net/psp/psp_nl.c
index 75f2702c1029..1b1d08fce637 100644
--- a/net/psp/psp_nl.c
+++ b/net/psp/psp_nl.c
@@ -79,9 +79,12 @@ void
psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb,
struct genl_info *info)
{
+ struct socket *socket = info->user_ptr[1];
struct psp_dev *psd = info->user_ptr[0];
mutex_unlock(&psd->lock);
+ if (socket)
+ sockfd_put(socket);
}
static int
@@ -261,3 +264,232 @@ err_free_rsp:
nlmsg_free(rsp);
return err;
}
+
+/* Key etc. */
+
+int psp_assoc_device_get_locked(const struct genl_split_ops *ops,
+ struct sk_buff *skb, struct genl_info *info)
+{
+ struct socket *socket;
+ struct psp_dev *psd;
+ struct nlattr *id;
+ int fd, err;
+
+ if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_SOCK_FD))
+ return -EINVAL;
+
+ fd = nla_get_u32(info->attrs[PSP_A_ASSOC_SOCK_FD]);
+ socket = sockfd_lookup(fd, &err);
+ if (!socket)
+ return err;
+
+ if (!sk_is_tcp(socket->sk)) {
+ NL_SET_ERR_MSG_ATTR(info->extack,
+ info->attrs[PSP_A_ASSOC_SOCK_FD],
+ "Unsupported socket family and type");
+ err = -EOPNOTSUPP;
+ goto err_sock_put;
+ }
+
+ psd = psp_dev_get_for_sock(socket->sk);
+ if (psd) {
+ err = psp_dev_check_access(psd, genl_info_net(info));
+ if (err) {
+ psp_dev_put(psd);
+ psd = NULL;
+ }
+ }
+
+ if (!psd && GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_DEV_ID)) {
+ err = -EINVAL;
+ goto err_sock_put;
+ }
+
+ id = info->attrs[PSP_A_ASSOC_DEV_ID];
+ if (psd) {
+ mutex_lock(&psd->lock);
+ if (id && psd->id != nla_get_u32(id)) {
+ mutex_unlock(&psd->lock);
+ NL_SET_ERR_MSG_ATTR(info->extack, id,
+ "Device id vs socket mismatch");
+ err = -EINVAL;
+ goto err_psd_put;
+ }
+
+ psp_dev_put(psd);
+ } else {
+ psd = psp_device_get_and_lock(genl_info_net(info), id);
+ if (IS_ERR(psd)) {
+ err = PTR_ERR(psd);
+ goto err_sock_put;
+ }
+ }
+
+ info->user_ptr[0] = psd;
+ info->user_ptr[1] = socket;
+
+ return 0;
+
+err_psd_put:
+ psp_dev_put(psd);
+err_sock_put:
+ sockfd_put(socket);
+ return err;
+}
+
+static int
+psp_nl_parse_key(struct genl_info *info, u32 attr, struct psp_key_parsed *key,
+ unsigned int key_sz)
+{
+ struct nlattr *nest = info->attrs[attr];
+ struct nlattr *tb[PSP_A_KEYS_SPI + 1];
+ u32 spi;
+ int err;
+
+ err = nla_parse_nested(tb, ARRAY_SIZE(tb) - 1, nest,
+ psp_keys_nl_policy, info->extack);
+ if (err)
+ return err;
+
+ if (NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_KEY) ||
+ NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_SPI))
+ return -EINVAL;
+
+ if (nla_len(tb[PSP_A_KEYS_KEY]) != key_sz) {
+ NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY],
+ "incorrect key length");
+ return -EINVAL;
+ }
+
+ spi = nla_get_u32(tb[PSP_A_KEYS_SPI]);
+ if (!(spi & PSP_SPI_KEY_ID)) {
+ NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY],
+ "invalid SPI: lower 31b must be non-zero");
+ return -EINVAL;
+ }
+
+ key->spi = cpu_to_be32(spi);
+ memcpy(key->key, nla_data(tb[PSP_A_KEYS_KEY]), key_sz);
+
+ return 0;
+}
+
+static int
+psp_nl_put_key(struct sk_buff *skb, u32 attr, u32 version,
+ struct psp_key_parsed *key)
+{
+ int key_sz = psp_key_size(version);
+ void *nest;
+
+ nest = nla_nest_start(skb, attr);
+
+ if (nla_put_u32(skb, PSP_A_KEYS_SPI, be32_to_cpu(key->spi)) ||
+ nla_put(skb, PSP_A_KEYS_KEY, key_sz, key->key)) {
+ nla_nest_cancel(skb, nest);
+ return -EMSGSIZE;
+ }
+
+ nla_nest_end(skb, nest);
+
+ return 0;
+}
+
+int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info)
+{
+ struct socket *socket = info->user_ptr[1];
+ struct psp_dev *psd = info->user_ptr[0];
+ struct psp_key_parsed key;
+ struct psp_assoc *pas;
+ struct sk_buff *rsp;
+ u32 version;
+ int err;
+
+ if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION))
+ return -EINVAL;
+
+ version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]);
+ if (!(psd->caps->versions & (1 << version))) {
+ NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]);
+ return -EOPNOTSUPP;
+ }
+
+ rsp = psp_nl_reply_new(info);
+ if (!rsp)
+ return -ENOMEM;
+
+ pas = psp_assoc_create(psd);
+ if (!pas) {
+ err = -ENOMEM;
+ goto err_free_rsp;
+ }
+ pas->version = version;
+
+ err = psd->ops->rx_spi_alloc(psd, version, &key, info->extack);
+ if (err)
+ goto err_free_pas;
+
+ if (nla_put_u32(rsp, PSP_A_ASSOC_DEV_ID, psd->id) ||
+ psp_nl_put_key(rsp, PSP_A_ASSOC_RX_KEY, version, &key)) {
+ err = -EMSGSIZE;
+ goto err_free_pas;
+ }
+
+ err = psp_sock_assoc_set_rx(socket->sk, pas, &key, info->extack);
+ if (err) {
+ NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_SOCK_FD]);
+ goto err_free_pas;
+ }
+ psp_assoc_put(pas);
+
+ return psp_nl_reply_send(rsp, info);
+
+err_free_pas:
+ psp_assoc_put(pas);
+err_free_rsp:
+ nlmsg_free(rsp);
+ return err;
+}
+
+int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info)
+{
+ struct socket *socket = info->user_ptr[1];
+ struct psp_dev *psd = info->user_ptr[0];
+ struct psp_key_parsed key;
+ struct sk_buff *rsp;
+ unsigned int key_sz;
+ u32 version;
+ int err;
+
+ if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION) ||
+ GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_TX_KEY))
+ return -EINVAL;
+
+ version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]);
+ if (!(psd->caps->versions & (1 << version))) {
+ NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]);
+ return -EOPNOTSUPP;
+ }
+
+ key_sz = psp_key_size(version);
+ if (!key_sz)
+ return -EINVAL;
+
+ err = psp_nl_parse_key(info, PSP_A_ASSOC_TX_KEY, &key, key_sz);
+ if (err < 0)
+ return err;
+
+ rsp = psp_nl_reply_new(info);
+ if (!rsp)
+ return -ENOMEM;
+
+ err = psp_sock_assoc_set_tx(socket->sk, psd, version, &key,
+ info->extack);
+ if (err)
+ goto err_free_msg;
+
+ return psp_nl_reply_send(rsp, info);
+
+err_free_msg:
+ nlmsg_free(rsp);
+ return err;
+}
diff --git a/net/psp/psp_sock.c b/net/psp/psp_sock.c
new file mode 100644
index 000000000000..8ebccee94593
--- /dev/null
+++ b/net/psp/psp_sock.c
@@ -0,0 +1,274 @@
+// SPDX-License-Identifier: GPL-2.0-only
+
+#include <linux/file.h>
+#include <linux/net.h>
+#include <linux/rcupdate.h>
+#include <linux/tcp.h>
+
+#include <net/ip.h>
+#include <net/psp.h>
+#include "psp.h"
+
+struct psp_dev *psp_dev_get_for_sock(struct sock *sk)
+{
+ struct dst_entry *dst;
+ struct psp_dev *psd;
+
+ dst = sk_dst_get(sk);
+ if (!dst)
+ return NULL;
+
+ rcu_read_lock();
+ psd = rcu_dereference(dst->dev->psp_dev);
+ if (psd && !psp_dev_tryget(psd))
+ psd = NULL;
+ rcu_read_unlock();
+
+ dst_release(dst);
+
+ return psd;
+}
+
+static struct sk_buff *
+psp_validate_xmit(struct sock *sk, struct net_device *dev, struct sk_buff *skb)
+{
+ struct psp_assoc *pas;
+ bool good;
+
+ rcu_read_lock();
+ pas = psp_skb_get_assoc_rcu(skb);
+ good = !pas || rcu_access_pointer(dev->psp_dev) == pas->psd;
+ rcu_read_unlock();
+ if (!good) {
+ kfree_skb_reason(skb, SKB_DROP_REASON_PSP_OUTPUT);
+ return NULL;
+ }
+
+ return skb;
+}
+
+struct psp_assoc *psp_assoc_create(struct psp_dev *psd)
+{
+ struct psp_assoc *pas;
+
+ lockdep_assert_held(&psd->lock);
+
+ pas = kzalloc(struct_size(pas, drv_data, psd->caps->assoc_drv_spc),
+ GFP_KERNEL_ACCOUNT);
+ if (!pas)
+ return NULL;
+
+ pas->psd = psd;
+ pas->dev_id = psd->id;
+ psp_dev_get(psd);
+ refcount_set(&pas->refcnt, 1);
+
+ list_add_tail(&pas->assocs_list, &psd->active_assocs);
+
+ return pas;
+}
+
+static struct psp_assoc *psp_assoc_dummy(struct psp_assoc *pas)
+{
+ struct psp_dev *psd = pas->psd;
+ size_t sz;
+
+ lockdep_assert_held(&psd->lock);
+
+ sz = struct_size(pas, drv_data, psd->caps->assoc_drv_spc);
+ return kmemdup(pas, sz, GFP_KERNEL);
+}
+
+static int psp_dev_tx_key_add(struct psp_dev *psd, struct psp_assoc *pas,
+ struct netlink_ext_ack *extack)
+{
+ return psd->ops->tx_key_add(psd, pas, extack);
+}
+
+void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas)
+{
+ if (pas->tx.spi)
+ psd->ops->tx_key_del(psd, pas);
+ list_del(&pas->assocs_list);
+}
+
+static void psp_assoc_free(struct work_struct *work)
+{
+ struct psp_assoc *pas = container_of(work, struct psp_assoc, work);
+ struct psp_dev *psd = pas->psd;
+
+ mutex_lock(&psd->lock);
+ if (psd->ops)
+ psp_dev_tx_key_del(psd, pas);
+ mutex_unlock(&psd->lock);
+ psp_dev_put(psd);
+ kfree(pas);
+}
+
+static void psp_assoc_free_queue(struct rcu_head *head)
+{
+ struct psp_assoc *pas = container_of(head, struct psp_assoc, rcu);
+
+ INIT_WORK(&pas->work, psp_assoc_free);
+ schedule_work(&pas->work);
+}
+
+/**
+ * psp_assoc_put() - release a reference on a PSP association
+ * @pas: association to release
+ */
+void psp_assoc_put(struct psp_assoc *pas)
+{
+ if (pas && refcount_dec_and_test(&pas->refcnt))
+ call_rcu(&pas->rcu, psp_assoc_free_queue);
+}
+
+void psp_sk_assoc_free(struct sock *sk)
+{
+ struct psp_assoc *pas = rcu_dereference_protected(sk->psp_assoc, 1);
+
+ rcu_assign_pointer(sk->psp_assoc, NULL);
+ psp_assoc_put(pas);
+}
+
+int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas,
+ struct psp_key_parsed *key,
+ struct netlink_ext_ack *extack)
+{
+ int err;
+
+ memcpy(&pas->rx, key, sizeof(*key));
+
+ lock_sock(sk);
+
+ if (psp_sk_assoc(sk)) {
+ NL_SET_ERR_MSG(extack, "Socket already has PSP state");
+ err = -EBUSY;
+ goto exit_unlock;
+ }
+
+ refcount_inc(&pas->refcnt);
+ rcu_assign_pointer(sk->psp_assoc, pas);
+ err = 0;
+
+exit_unlock:
+ release_sock(sk);
+
+ return err;
+}
+
+static int psp_sock_recv_queue_check(struct sock *sk, struct psp_assoc *pas)
+{
+ struct psp_skb_ext *pse;
+ struct sk_buff *skb;
+
+ skb_rbtree_walk(skb, &tcp_sk(sk)->out_of_order_queue) {
+ pse = skb_ext_find(skb, SKB_EXT_PSP);
+ if (!psp_pse_matches_pas(pse, pas))
+ return -EBUSY;
+ }
+
+ skb_queue_walk(&sk->sk_receive_queue, skb) {
+ pse = skb_ext_find(skb, SKB_EXT_PSP);
+ if (!psp_pse_matches_pas(pse, pas))
+ return -EBUSY;
+ }
+ return 0;
+}
+
+int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd,
+ u32 version, struct psp_key_parsed *key,
+ struct netlink_ext_ack *extack)
+{
+ struct psp_assoc *pas, *dummy;
+ int err;
+
+ lock_sock(sk);
+
+ pas = psp_sk_assoc(sk);
+ if (!pas) {
+ NL_SET_ERR_MSG(extack, "Socket has no Rx key");
+ err = -EINVAL;
+ goto exit_unlock;
+ }
+ if (pas->psd != psd) {
+ NL_SET_ERR_MSG(extack, "Rx key from different device");
+ err = -EINVAL;
+ goto exit_unlock;
+ }
+ if (pas->version != version) {
+ NL_SET_ERR_MSG(extack,
+ "PSP version mismatch with existing state");
+ err = -EINVAL;
+ goto exit_unlock;
+ }
+ if (pas->tx.spi) {
+ NL_SET_ERR_MSG(extack, "Tx key already set");
+ err = -EBUSY;
+ goto exit_unlock;
+ }
+
+ err = psp_sock_recv_queue_check(sk, pas);
+ if (err) {
+ NL_SET_ERR_MSG(extack, "Socket has incompatible segments already in the recv queue");
+ goto exit_unlock;
+ }
+
+ /* Pass a fake association to drivers to make sure they don't
+ * try to store pointers to it. For re-keying we'll need to
+ * re-allocate the assoc structures.
+ */
+ dummy = psp_assoc_dummy(pas);
+ if (!dummy) {
+ err = -ENOMEM;
+ goto exit_unlock;
+ }
+
+ memcpy(&dummy->tx, key, sizeof(*key));
+ err = psp_dev_tx_key_add(psd, dummy, extack);
+ if (err)
+ goto exit_free_dummy;
+
+ memcpy(pas->drv_data, dummy->drv_data, psd->caps->assoc_drv_spc);
+ memcpy(&pas->tx, key, sizeof(*key));
+
+ WRITE_ONCE(sk->sk_validate_xmit_skb, psp_validate_xmit);
+ tcp_write_collapse_fence(sk);
+ pas->upgrade_seq = tcp_sk(sk)->rcv_nxt;
+
+exit_free_dummy:
+ kfree(dummy);
+exit_unlock:
+ release_sock(sk);
+ return err;
+}
+
+void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk)
+{
+ struct psp_assoc *pas = psp_sk_assoc(sk);
+
+ if (pas)
+ refcount_inc(&pas->refcnt);
+ rcu_assign_pointer(tw->psp_assoc, pas);
+ tw->tw_validate_xmit_skb = psp_validate_xmit;
+}
+
+void psp_twsk_assoc_free(struct inet_timewait_sock *tw)
+{
+ struct psp_assoc *pas = rcu_dereference_protected(tw->psp_assoc, 1);
+
+ rcu_assign_pointer(tw->psp_assoc, NULL);
+ psp_assoc_put(pas);
+}
+
+void psp_reply_set_decrypted(struct sk_buff *skb)
+{
+ struct psp_assoc *pas;
+
+ rcu_read_lock();
+ pas = psp_sk_get_assoc_rcu(skb->sk);
+ if (pas && pas->tx.spi)
+ skb->decrypted = 1;
+ rcu_read_unlock();
+}
+EXPORT_IPV6_MOD_GPL(psp_reply_set_decrypted);