diff options
Diffstat (limited to 'net')
| -rw-r--r-- | net/psp/Kconfig | 1 | ||||
| -rw-r--r-- | net/psp/Makefile | 2 | ||||
| -rw-r--r-- | net/psp/psp-nl-gen.c | 39 | ||||
| -rw-r--r-- | net/psp/psp-nl-gen.h | 7 | ||||
| -rw-r--r-- | net/psp/psp.h | 22 | ||||
| -rw-r--r-- | net/psp/psp_main.c | 26 | ||||
| -rw-r--r-- | net/psp/psp_nl.c | 232 | ||||
| -rw-r--r-- | net/psp/psp_sock.c | 274 |
8 files changed, 601 insertions, 2 deletions
diff --git a/net/psp/Kconfig b/net/psp/Kconfig index 5e3908a40945..a7d24691a7e1 100644 --- a/net/psp/Kconfig +++ b/net/psp/Kconfig @@ -6,6 +6,7 @@ config INET_PSP bool "PSP Security Protocol support" depends on INET select SKB_DECRYPTED + select SOCK_VALIDATE_XMIT help Enable kernel support for the PSP protocol. For more information see: diff --git a/net/psp/Makefile b/net/psp/Makefile index 41b51d06e560..eb5ff3c5bfb2 100644 --- a/net/psp/Makefile +++ b/net/psp/Makefile @@ -2,4 +2,4 @@ obj-$(CONFIG_INET_PSP) += psp.o -psp-y := psp_main.o psp_nl.o psp-nl-gen.o +psp-y := psp_main.o psp_nl.o psp_sock.o psp-nl-gen.o diff --git a/net/psp/psp-nl-gen.c b/net/psp/psp-nl-gen.c index 7f49577ac72f..9fdd6f831803 100644 --- a/net/psp/psp-nl-gen.c +++ b/net/psp/psp-nl-gen.c @@ -10,6 +10,12 @@ #include <uapi/linux/psp.h> +/* Common nested types */ +const struct nla_policy psp_keys_nl_policy[PSP_A_KEYS_SPI + 1] = { + [PSP_A_KEYS_KEY] = { .type = NLA_BINARY, }, + [PSP_A_KEYS_SPI] = { .type = NLA_U32, }, +}; + /* PSP_CMD_DEV_GET - do */ static const struct nla_policy psp_dev_get_nl_policy[PSP_A_DEV_ID + 1] = { [PSP_A_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1), @@ -26,6 +32,21 @@ static const struct nla_policy psp_key_rotate_nl_policy[PSP_A_DEV_ID + 1] = { [PSP_A_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1), }; +/* PSP_CMD_RX_ASSOC - do */ +static const struct nla_policy psp_rx_assoc_nl_policy[PSP_A_ASSOC_SOCK_FD + 1] = { + [PSP_A_ASSOC_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1), + [PSP_A_ASSOC_VERSION] = NLA_POLICY_MAX(NLA_U32, 3), + [PSP_A_ASSOC_SOCK_FD] = { .type = NLA_U32, }, +}; + +/* PSP_CMD_TX_ASSOC - do */ +static const struct nla_policy psp_tx_assoc_nl_policy[PSP_A_ASSOC_SOCK_FD + 1] = { + [PSP_A_ASSOC_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1), + [PSP_A_ASSOC_VERSION] = NLA_POLICY_MAX(NLA_U32, 3), + [PSP_A_ASSOC_TX_KEY] = NLA_POLICY_NESTED(psp_keys_nl_policy), + [PSP_A_ASSOC_SOCK_FD] = { .type = NLA_U32, }, +}; + /* Ops table for psp */ static const struct genl_split_ops psp_nl_ops[] = { { @@ -60,6 +81,24 @@ static const struct genl_split_ops psp_nl_ops[] = { .maxattr = PSP_A_DEV_ID, .flags = GENL_CMD_CAP_DO, }, + { + .cmd = PSP_CMD_RX_ASSOC, + .pre_doit = psp_assoc_device_get_locked, + .doit = psp_nl_rx_assoc_doit, + .post_doit = psp_device_unlock, + .policy = psp_rx_assoc_nl_policy, + .maxattr = PSP_A_ASSOC_SOCK_FD, + .flags = GENL_CMD_CAP_DO, + }, + { + .cmd = PSP_CMD_TX_ASSOC, + .pre_doit = psp_assoc_device_get_locked, + .doit = psp_nl_tx_assoc_doit, + .post_doit = psp_device_unlock, + .policy = psp_tx_assoc_nl_policy, + .maxattr = PSP_A_ASSOC_SOCK_FD, + .flags = GENL_CMD_CAP_DO, + }, }; static const struct genl_multicast_group psp_nl_mcgrps[] = { diff --git a/net/psp/psp-nl-gen.h b/net/psp/psp-nl-gen.h index 00a2d4ec59e4..25268ed11fb5 100644 --- a/net/psp/psp-nl-gen.h +++ b/net/psp/psp-nl-gen.h @@ -11,8 +11,13 @@ #include <uapi/linux/psp.h> +/* Common nested types */ +extern const struct nla_policy psp_keys_nl_policy[PSP_A_KEYS_SPI + 1]; + int psp_device_get_locked(const struct genl_split_ops *ops, struct sk_buff *skb, struct genl_info *info); +int psp_assoc_device_get_locked(const struct genl_split_ops *ops, + struct sk_buff *skb, struct genl_info *info); void psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb, struct genl_info *info); @@ -21,6 +26,8 @@ int psp_nl_dev_get_doit(struct sk_buff *skb, struct genl_info *info); int psp_nl_dev_get_dumpit(struct sk_buff *skb, struct netlink_callback *cb); int psp_nl_dev_set_doit(struct sk_buff *skb, struct genl_info *info); int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info); +int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info); +int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info); enum { PSP_NLGRP_MGMT, diff --git a/net/psp/psp.h b/net/psp/psp.h index 94d0cc31a61f..defd3e3fd5e7 100644 --- a/net/psp/psp.h +++ b/net/psp/psp.h @@ -4,6 +4,7 @@ #define __PSP_PSP_H #include <linux/list.h> +#include <linux/lockdep.h> #include <linux/mutex.h> #include <net/netns/generic.h> #include <net/psp.h> @@ -17,15 +18,36 @@ int psp_dev_check_access(struct psp_dev *psd, struct net *net); void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd); +struct psp_assoc *psp_assoc_create(struct psp_dev *psd); +struct psp_dev *psp_dev_get_for_sock(struct sock *sk); +void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas); +int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas, + struct psp_key_parsed *key, + struct netlink_ext_ack *extack); +int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd, + u32 version, struct psp_key_parsed *key, + struct netlink_ext_ack *extack); + static inline void psp_dev_get(struct psp_dev *psd) { refcount_inc(&psd->refcnt); } +static inline bool psp_dev_tryget(struct psp_dev *psd) +{ + return refcount_inc_not_zero(&psd->refcnt); +} + static inline void psp_dev_put(struct psp_dev *psd) { if (refcount_dec_and_test(&psd->refcnt)) psp_dev_destroy(psd); } +static inline bool psp_dev_is_registered(struct psp_dev *psd) +{ + lockdep_assert_held(&psd->lock); + return !!psd->ops; +} + #endif /* __PSP_PSP_H */ diff --git a/net/psp/psp_main.c b/net/psp/psp_main.c index f60155493afc..a1ae3c8920c3 100644 --- a/net/psp/psp_main.c +++ b/net/psp/psp_main.c @@ -55,7 +55,10 @@ psp_dev_create(struct net_device *netdev, if (WARN_ON(!psd_caps->versions || !psd_ops->set_config || - !psd_ops->key_rotate)) + !psd_ops->key_rotate || + !psd_ops->rx_spi_alloc || + !psd_ops->tx_key_add || + !psd_ops->tx_key_del)) return ERR_PTR(-EINVAL); psd = kzalloc(sizeof(*psd), GFP_KERNEL); @@ -68,6 +71,7 @@ psp_dev_create(struct net_device *netdev, psd->drv_priv = priv_ptr; mutex_init(&psd->lock); + INIT_LIST_HEAD(&psd->active_assocs); refcount_set(&psd->refcnt, 1); mutex_lock(&psp_devs_lock); @@ -107,6 +111,8 @@ void psp_dev_destroy(struct psp_dev *psd) */ void psp_dev_unregister(struct psp_dev *psd) { + struct psp_assoc *pas, *next; + mutex_lock(&psp_devs_lock); mutex_lock(&psd->lock); @@ -119,6 +125,9 @@ void psp_dev_unregister(struct psp_dev *psd) xa_store(&psp_devs, psd->id, NULL, GFP_KERNEL); mutex_unlock(&psp_devs_lock); + list_for_each_entry_safe(pas, next, &psd->active_assocs, assocs_list) + psp_dev_tx_key_del(psd, pas); + rcu_assign_pointer(psd->main_netdev->psp_dev, NULL); psd->ops = NULL; @@ -130,6 +139,21 @@ void psp_dev_unregister(struct psp_dev *psd) } EXPORT_SYMBOL(psp_dev_unregister); +unsigned int psp_key_size(u32 version) +{ + switch (version) { + case PSP_VERSION_HDR0_AES_GCM_128: + case PSP_VERSION_HDR0_AES_GMAC_128: + return 16; + case PSP_VERSION_HDR0_AES_GCM_256: + case PSP_VERSION_HDR0_AES_GMAC_256: + return 32; + default: + return 0; + } +} +EXPORT_SYMBOL(psp_key_size); + static int __init psp_init(void) { mutex_init(&psp_devs_lock); diff --git a/net/psp/psp_nl.c b/net/psp/psp_nl.c index 75f2702c1029..1b1d08fce637 100644 --- a/net/psp/psp_nl.c +++ b/net/psp/psp_nl.c @@ -79,9 +79,12 @@ void psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb, struct genl_info *info) { + struct socket *socket = info->user_ptr[1]; struct psp_dev *psd = info->user_ptr[0]; mutex_unlock(&psd->lock); + if (socket) + sockfd_put(socket); } static int @@ -261,3 +264,232 @@ err_free_rsp: nlmsg_free(rsp); return err; } + +/* Key etc. */ + +int psp_assoc_device_get_locked(const struct genl_split_ops *ops, + struct sk_buff *skb, struct genl_info *info) +{ + struct socket *socket; + struct psp_dev *psd; + struct nlattr *id; + int fd, err; + + if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_SOCK_FD)) + return -EINVAL; + + fd = nla_get_u32(info->attrs[PSP_A_ASSOC_SOCK_FD]); + socket = sockfd_lookup(fd, &err); + if (!socket) + return err; + + if (!sk_is_tcp(socket->sk)) { + NL_SET_ERR_MSG_ATTR(info->extack, + info->attrs[PSP_A_ASSOC_SOCK_FD], + "Unsupported socket family and type"); + err = -EOPNOTSUPP; + goto err_sock_put; + } + + psd = psp_dev_get_for_sock(socket->sk); + if (psd) { + err = psp_dev_check_access(psd, genl_info_net(info)); + if (err) { + psp_dev_put(psd); + psd = NULL; + } + } + + if (!psd && GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_DEV_ID)) { + err = -EINVAL; + goto err_sock_put; + } + + id = info->attrs[PSP_A_ASSOC_DEV_ID]; + if (psd) { + mutex_lock(&psd->lock); + if (id && psd->id != nla_get_u32(id)) { + mutex_unlock(&psd->lock); + NL_SET_ERR_MSG_ATTR(info->extack, id, + "Device id vs socket mismatch"); + err = -EINVAL; + goto err_psd_put; + } + + psp_dev_put(psd); + } else { + psd = psp_device_get_and_lock(genl_info_net(info), id); + if (IS_ERR(psd)) { + err = PTR_ERR(psd); + goto err_sock_put; + } + } + + info->user_ptr[0] = psd; + info->user_ptr[1] = socket; + + return 0; + +err_psd_put: + psp_dev_put(psd); +err_sock_put: + sockfd_put(socket); + return err; +} + +static int +psp_nl_parse_key(struct genl_info *info, u32 attr, struct psp_key_parsed *key, + unsigned int key_sz) +{ + struct nlattr *nest = info->attrs[attr]; + struct nlattr *tb[PSP_A_KEYS_SPI + 1]; + u32 spi; + int err; + + err = nla_parse_nested(tb, ARRAY_SIZE(tb) - 1, nest, + psp_keys_nl_policy, info->extack); + if (err) + return err; + + if (NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_KEY) || + NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_SPI)) + return -EINVAL; + + if (nla_len(tb[PSP_A_KEYS_KEY]) != key_sz) { + NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY], + "incorrect key length"); + return -EINVAL; + } + + spi = nla_get_u32(tb[PSP_A_KEYS_SPI]); + if (!(spi & PSP_SPI_KEY_ID)) { + NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY], + "invalid SPI: lower 31b must be non-zero"); + return -EINVAL; + } + + key->spi = cpu_to_be32(spi); + memcpy(key->key, nla_data(tb[PSP_A_KEYS_KEY]), key_sz); + + return 0; +} + +static int +psp_nl_put_key(struct sk_buff *skb, u32 attr, u32 version, + struct psp_key_parsed *key) +{ + int key_sz = psp_key_size(version); + void *nest; + + nest = nla_nest_start(skb, attr); + + if (nla_put_u32(skb, PSP_A_KEYS_SPI, be32_to_cpu(key->spi)) || + nla_put(skb, PSP_A_KEYS_KEY, key_sz, key->key)) { + nla_nest_cancel(skb, nest); + return -EMSGSIZE; + } + + nla_nest_end(skb, nest); + + return 0; +} + +int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct socket *socket = info->user_ptr[1]; + struct psp_dev *psd = info->user_ptr[0]; + struct psp_key_parsed key; + struct psp_assoc *pas; + struct sk_buff *rsp; + u32 version; + int err; + + if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION)) + return -EINVAL; + + version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]); + if (!(psd->caps->versions & (1 << version))) { + NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]); + return -EOPNOTSUPP; + } + + rsp = psp_nl_reply_new(info); + if (!rsp) + return -ENOMEM; + + pas = psp_assoc_create(psd); + if (!pas) { + err = -ENOMEM; + goto err_free_rsp; + } + pas->version = version; + + err = psd->ops->rx_spi_alloc(psd, version, &key, info->extack); + if (err) + goto err_free_pas; + + if (nla_put_u32(rsp, PSP_A_ASSOC_DEV_ID, psd->id) || + psp_nl_put_key(rsp, PSP_A_ASSOC_RX_KEY, version, &key)) { + err = -EMSGSIZE; + goto err_free_pas; + } + + err = psp_sock_assoc_set_rx(socket->sk, pas, &key, info->extack); + if (err) { + NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_SOCK_FD]); + goto err_free_pas; + } + psp_assoc_put(pas); + + return psp_nl_reply_send(rsp, info); + +err_free_pas: + psp_assoc_put(pas); +err_free_rsp: + nlmsg_free(rsp); + return err; +} + +int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct socket *socket = info->user_ptr[1]; + struct psp_dev *psd = info->user_ptr[0]; + struct psp_key_parsed key; + struct sk_buff *rsp; + unsigned int key_sz; + u32 version; + int err; + + if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION) || + GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_TX_KEY)) + return -EINVAL; + + version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]); + if (!(psd->caps->versions & (1 << version))) { + NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]); + return -EOPNOTSUPP; + } + + key_sz = psp_key_size(version); + if (!key_sz) + return -EINVAL; + + err = psp_nl_parse_key(info, PSP_A_ASSOC_TX_KEY, &key, key_sz); + if (err < 0) + return err; + + rsp = psp_nl_reply_new(info); + if (!rsp) + return -ENOMEM; + + err = psp_sock_assoc_set_tx(socket->sk, psd, version, &key, + info->extack); + if (err) + goto err_free_msg; + + return psp_nl_reply_send(rsp, info); + +err_free_msg: + nlmsg_free(rsp); + return err; +} diff --git a/net/psp/psp_sock.c b/net/psp/psp_sock.c new file mode 100644 index 000000000000..8ebccee94593 --- /dev/null +++ b/net/psp/psp_sock.c @@ -0,0 +1,274 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include <linux/file.h> +#include <linux/net.h> +#include <linux/rcupdate.h> +#include <linux/tcp.h> + +#include <net/ip.h> +#include <net/psp.h> +#include "psp.h" + +struct psp_dev *psp_dev_get_for_sock(struct sock *sk) +{ + struct dst_entry *dst; + struct psp_dev *psd; + + dst = sk_dst_get(sk); + if (!dst) + return NULL; + + rcu_read_lock(); + psd = rcu_dereference(dst->dev->psp_dev); + if (psd && !psp_dev_tryget(psd)) + psd = NULL; + rcu_read_unlock(); + + dst_release(dst); + + return psd; +} + +static struct sk_buff * +psp_validate_xmit(struct sock *sk, struct net_device *dev, struct sk_buff *skb) +{ + struct psp_assoc *pas; + bool good; + + rcu_read_lock(); + pas = psp_skb_get_assoc_rcu(skb); + good = !pas || rcu_access_pointer(dev->psp_dev) == pas->psd; + rcu_read_unlock(); + if (!good) { + kfree_skb_reason(skb, SKB_DROP_REASON_PSP_OUTPUT); + return NULL; + } + + return skb; +} + +struct psp_assoc *psp_assoc_create(struct psp_dev *psd) +{ + struct psp_assoc *pas; + + lockdep_assert_held(&psd->lock); + + pas = kzalloc(struct_size(pas, drv_data, psd->caps->assoc_drv_spc), + GFP_KERNEL_ACCOUNT); + if (!pas) + return NULL; + + pas->psd = psd; + pas->dev_id = psd->id; + psp_dev_get(psd); + refcount_set(&pas->refcnt, 1); + + list_add_tail(&pas->assocs_list, &psd->active_assocs); + + return pas; +} + +static struct psp_assoc *psp_assoc_dummy(struct psp_assoc *pas) +{ + struct psp_dev *psd = pas->psd; + size_t sz; + + lockdep_assert_held(&psd->lock); + + sz = struct_size(pas, drv_data, psd->caps->assoc_drv_spc); + return kmemdup(pas, sz, GFP_KERNEL); +} + +static int psp_dev_tx_key_add(struct psp_dev *psd, struct psp_assoc *pas, + struct netlink_ext_ack *extack) +{ + return psd->ops->tx_key_add(psd, pas, extack); +} + +void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas) +{ + if (pas->tx.spi) + psd->ops->tx_key_del(psd, pas); + list_del(&pas->assocs_list); +} + +static void psp_assoc_free(struct work_struct *work) +{ + struct psp_assoc *pas = container_of(work, struct psp_assoc, work); + struct psp_dev *psd = pas->psd; + + mutex_lock(&psd->lock); + if (psd->ops) + psp_dev_tx_key_del(psd, pas); + mutex_unlock(&psd->lock); + psp_dev_put(psd); + kfree(pas); +} + +static void psp_assoc_free_queue(struct rcu_head *head) +{ + struct psp_assoc *pas = container_of(head, struct psp_assoc, rcu); + + INIT_WORK(&pas->work, psp_assoc_free); + schedule_work(&pas->work); +} + +/** + * psp_assoc_put() - release a reference on a PSP association + * @pas: association to release + */ +void psp_assoc_put(struct psp_assoc *pas) +{ + if (pas && refcount_dec_and_test(&pas->refcnt)) + call_rcu(&pas->rcu, psp_assoc_free_queue); +} + +void psp_sk_assoc_free(struct sock *sk) +{ + struct psp_assoc *pas = rcu_dereference_protected(sk->psp_assoc, 1); + + rcu_assign_pointer(sk->psp_assoc, NULL); + psp_assoc_put(pas); +} + +int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas, + struct psp_key_parsed *key, + struct netlink_ext_ack *extack) +{ + int err; + + memcpy(&pas->rx, key, sizeof(*key)); + + lock_sock(sk); + + if (psp_sk_assoc(sk)) { + NL_SET_ERR_MSG(extack, "Socket already has PSP state"); + err = -EBUSY; + goto exit_unlock; + } + + refcount_inc(&pas->refcnt); + rcu_assign_pointer(sk->psp_assoc, pas); + err = 0; + +exit_unlock: + release_sock(sk); + + return err; +} + +static int psp_sock_recv_queue_check(struct sock *sk, struct psp_assoc *pas) +{ + struct psp_skb_ext *pse; + struct sk_buff *skb; + + skb_rbtree_walk(skb, &tcp_sk(sk)->out_of_order_queue) { + pse = skb_ext_find(skb, SKB_EXT_PSP); + if (!psp_pse_matches_pas(pse, pas)) + return -EBUSY; + } + + skb_queue_walk(&sk->sk_receive_queue, skb) { + pse = skb_ext_find(skb, SKB_EXT_PSP); + if (!psp_pse_matches_pas(pse, pas)) + return -EBUSY; + } + return 0; +} + +int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd, + u32 version, struct psp_key_parsed *key, + struct netlink_ext_ack *extack) +{ + struct psp_assoc *pas, *dummy; + int err; + + lock_sock(sk); + + pas = psp_sk_assoc(sk); + if (!pas) { + NL_SET_ERR_MSG(extack, "Socket has no Rx key"); + err = -EINVAL; + goto exit_unlock; + } + if (pas->psd != psd) { + NL_SET_ERR_MSG(extack, "Rx key from different device"); + err = -EINVAL; + goto exit_unlock; + } + if (pas->version != version) { + NL_SET_ERR_MSG(extack, + "PSP version mismatch with existing state"); + err = -EINVAL; + goto exit_unlock; + } + if (pas->tx.spi) { + NL_SET_ERR_MSG(extack, "Tx key already set"); + err = -EBUSY; + goto exit_unlock; + } + + err = psp_sock_recv_queue_check(sk, pas); + if (err) { + NL_SET_ERR_MSG(extack, "Socket has incompatible segments already in the recv queue"); + goto exit_unlock; + } + + /* Pass a fake association to drivers to make sure they don't + * try to store pointers to it. For re-keying we'll need to + * re-allocate the assoc structures. + */ + dummy = psp_assoc_dummy(pas); + if (!dummy) { + err = -ENOMEM; + goto exit_unlock; + } + + memcpy(&dummy->tx, key, sizeof(*key)); + err = psp_dev_tx_key_add(psd, dummy, extack); + if (err) + goto exit_free_dummy; + + memcpy(pas->drv_data, dummy->drv_data, psd->caps->assoc_drv_spc); + memcpy(&pas->tx, key, sizeof(*key)); + + WRITE_ONCE(sk->sk_validate_xmit_skb, psp_validate_xmit); + tcp_write_collapse_fence(sk); + pas->upgrade_seq = tcp_sk(sk)->rcv_nxt; + +exit_free_dummy: + kfree(dummy); +exit_unlock: + release_sock(sk); + return err; +} + +void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) +{ + struct psp_assoc *pas = psp_sk_assoc(sk); + + if (pas) + refcount_inc(&pas->refcnt); + rcu_assign_pointer(tw->psp_assoc, pas); + tw->tw_validate_xmit_skb = psp_validate_xmit; +} + +void psp_twsk_assoc_free(struct inet_timewait_sock *tw) +{ + struct psp_assoc *pas = rcu_dereference_protected(tw->psp_assoc, 1); + + rcu_assign_pointer(tw->psp_assoc, NULL); + psp_assoc_put(pas); +} + +void psp_reply_set_decrypted(struct sk_buff *skb) +{ + struct psp_assoc *pas; + + rcu_read_lock(); + pas = psp_sk_get_assoc_rcu(skb->sk); + if (pas && pas->tx.spi) + skb->decrypted = 1; + rcu_read_unlock(); +} +EXPORT_IPV6_MOD_GPL(psp_reply_set_decrypted); |