2 * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
4 * Copyright (c) 2019-2020 Rubicon Communications, LLC (Netgate)
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
9 * 1. Redistributions of source code must retain the above copyright
10 * notice, this list of conditions and the following disclaimer.
11 * 2. Redistributions in binary form must reproduce the above copyright
12 * notice, this list of conditions and the following disclaimer in the
13 * documentation and/or other materials provided with the distribution.
15 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
16 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
19 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
21 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
22 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
23 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
24 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
28 #include <sys/cdefs.h>
29 __FBSDID("$FreeBSD$");
32 #include "opt_inet6.h"
33 #include <sys/param.h>
34 #include <sys/types.h>
35 #include <sys/systm.h>
36 #include <sys/kernel.h>
39 #include <sys/mutex.h>
41 #include <sys/module.h>
43 #include <sys/socket.h>
44 #include <sys/sockio.h>
45 #include <sys/queue.h>
49 #include <net/ethernet.h>
50 #include <net/if_var.h>
51 #include <net/iflib.h>
52 #include <net/if_clone.h>
53 #include <net/radix.h>
55 #include <net/mp_ring.h>
59 #include <sys/wg_module.h>
60 #include <crypto/zinc.h>
61 #include <sys/wg_noise.h>
62 #include <sys/if_wg_session_vars.h>
63 #include <sys/if_wg_session.h>
65 MALLOC_DEFINE(M_WG, "WG", "wireguard");
67 #define WG_CAPS IFCAP_LINKSTATE
68 #define ph_family PH_loc.eight[5]
70 TASKQGROUP_DECLARE(if_io_tqg);
72 static int clone_count;
73 uma_zone_t ratelimit_zone;
76 wg_encrypt_dispatch(struct wg_softc *sc)
78 for (int i = 0; i < mp_ncpus; i++) {
79 if (sc->sc_encrypt[i].gt_task.ta_flags & TASK_ENQUEUED)
81 GROUPTASK_ENQUEUE(&sc->sc_encrypt[i]);
86 wg_decrypt_dispatch(struct wg_softc *sc)
88 for (int i = 0; i < mp_ncpus; i++) {
89 if (sc->sc_decrypt[i].gt_task.ta_flags & TASK_ENQUEUED)
91 GROUPTASK_ENQUEUE(&sc->sc_decrypt[i]);
96 crypto_taskq_setup(struct wg_softc *sc)
98 device_t dev = iflib_get_dev(sc->wg_ctx);
100 sc->sc_encrypt = malloc(sizeof(struct grouptask)*mp_ncpus, M_WG, M_WAITOK);
101 sc->sc_decrypt = malloc(sizeof(struct grouptask)*mp_ncpus, M_WG, M_WAITOK);
103 for (int i = 0; i < mp_ncpus; i++) {
104 GROUPTASK_INIT(&sc->sc_encrypt[i], 0,
105 (gtask_fn_t *)wg_softc_encrypt, sc);
106 taskqgroup_attach_cpu(qgroup_if_io_tqg, &sc->sc_encrypt[i], sc, i, dev, NULL, "wg encrypt");
107 GROUPTASK_INIT(&sc->sc_decrypt[i], 0,
108 (gtask_fn_t *)wg_softc_decrypt, sc);
109 taskqgroup_attach_cpu(qgroup_if_io_tqg, &sc->sc_decrypt[i], sc, i, dev, NULL, "wg decrypt");
114 crypto_taskq_destroy(struct wg_softc *sc)
116 for (int i = 0; i < mp_ncpus; i++) {
117 taskqgroup_detach(qgroup_if_io_tqg, &sc->sc_encrypt[i]);
118 taskqgroup_detach(qgroup_if_io_tqg, &sc->sc_decrypt[i]);
120 free(sc->sc_encrypt, M_WG);
121 free(sc->sc_decrypt, M_WG);
125 wg_cloneattach(if_ctx_t ctx, struct if_clone *ifc, const char *name, caddr_t params)
127 struct wg_softc *sc = iflib_get_softc(ctx);
128 if_softc_ctx_t scctx;
133 struct noise_local *local;
134 uint8_t public[WG_KEY_SIZE];
135 struct noise_upcall noise_upcall;
137 uint16_t listen_port;
142 dev = iflib_get_dev(ctx);
143 if (params == NULL) {
150 if (copyin(params, &iov, sizeof(iov)))
152 /* check that this is reasonable */
154 packed = malloc(size, M_TEMP, M_WAITOK);
155 if (copyin(iov.iov_base, packed, size)) {
159 nvl = nvlist_unpack(packed, size, 0);
161 device_printf(dev, "%s nvlist_unpack failed\n", __func__);
165 if (!nvlist_exists_number(nvl, "listen-port")) {
166 device_printf(dev, "%s listen-port not set\n", __func__);
170 listen_port = nvlist_get_number(nvl, "listen-port");
172 if (!nvlist_exists_binary(nvl, "private-key")) {
173 device_printf(dev, "%s private-key not set\n", __func__);
177 key = nvlist_get_binary(nvl, "private-key", &size);
178 if (size != CURVE25519_KEY_SIZE) {
179 device_printf(dev, "%s bad length for private-key %zu\n", __func__, size);
184 local = &sc->sc_local;
185 noise_upcall.u_arg = sc;
186 noise_upcall.u_remote_get =
187 (struct noise_remote *(*)(void *, uint8_t *))wg_remote_get;
188 noise_upcall.u_index_set =
189 (uint32_t (*)(void *, struct noise_remote *))wg_index_set;
190 noise_upcall.u_index_drop =
191 (void (*)(void *, uint32_t))wg_index_drop;
192 noise_local_init(local, &noise_upcall);
193 cookie_checker_init(&sc->sc_cookie, ratelimit_zone);
195 sc->sc_socket.so_port = listen_port;
198 noise_local_set_private(local, __DECONST(uint8_t *, key));
199 noise_local_keys(local, public, NULL);
200 cookie_checker_update(&sc->sc_cookie, public);
202 atomic_add_int(&clone_count, 1);
203 scctx = sc->shared = iflib_get_softc_ctx(ctx);
204 scctx->isc_capenable = WG_CAPS;
205 scctx->isc_tx_csum_flags = CSUM_TCP | CSUM_UDP | CSUM_TSO | CSUM_IP6_TCP \
206 | CSUM_IP6_UDP | CSUM_IP6_TCP;
208 sc->sc_ifp = iflib_get_ifp(ctx);
210 mbufq_init(&sc->sc_handshake_queue, MAX_QUEUED_INCOMING_HANDSHAKES);
211 mtx_init(&sc->sc_mtx, NULL, "wg softc lock", MTX_DEF);
212 rw_init(&sc->sc_index_lock, "wg index lock");
213 sc->sc_encap_ring = buf_ring_alloc(MAX_QUEUED_PACKETS, M_WG, M_WAITOK, &sc->sc_mtx);
214 sc->sc_decap_ring = buf_ring_alloc(MAX_QUEUED_PACKETS, M_WG, M_WAITOK, &sc->sc_mtx);
215 GROUPTASK_INIT(&sc->sc_handshake, 0,
216 (gtask_fn_t *)wg_softc_handshake_receive, sc);
217 taskqgroup_attach(qgroup_if_io_tqg, &sc->sc_handshake, sc, dev, NULL, "wg tx initiation");
218 crypto_taskq_setup(sc);
223 free(packed, M_TEMP);
228 wg_transmit(struct ifnet *ifp, struct mbuf *m)
232 struct epoch_tracker et;
233 struct wg_peer *peer;
240 * Work around lifetime issue in the ipv6 mld code.
242 if (__predict_false(ifp->if_flags & IFF_DYING))
246 sc = iflib_get_softc(ifp->if_softc);
247 if ((t = wg_tag_get(m)) == NULL) {
251 af = m->m_pkthdr.ph_family;
252 BPF_MTAP2(ifp, &af, sizeof(af), m);
255 peer = wg_route_lookup(&sc->sc_routes, m, OUT);
256 if (__predict_false(peer == NULL)) {
258 printf("peer not found - dropping %p\n", m);
263 family = atomic_load_acq(peer->p_endpoint.e_remote.r_sa.sa_family);
264 if (__predict_false(family != AF_INET && family != AF_INET6)) {
272 t->t_mtu = ifp->if_mtu;
274 rc = wg_queue_out(peer, m);
276 wg_encrypt_dispatch(peer->p_sc);
282 if_inc_counter(sc->sc_ifp, IFCOUNTER_OERRORS, 1);
283 /* XXX send ICMP unreachable */
289 wg_output(struct ifnet *ifp, struct mbuf *m, const struct sockaddr *sa, struct route *rt)
291 m->m_pkthdr.ph_family = sa->sa_family;
292 return (wg_transmit(ifp, m));
296 wg_attach_post(if_ctx_t ctx)
301 sc = iflib_get_softc(ctx);
302 ifp = iflib_get_ifp(ctx);
303 if_setmtu(ifp, ETHERMTU - 80);
305 if_setflagbits(ifp, IFF_NOARP, IFF_POINTOPOINT);
306 ifp->if_transmit = wg_transmit;
307 ifp->if_output = wg_output;
309 wg_hashtable_init(&sc->sc_hashtable);
310 sc->sc_index = hashinit(HASHTABLE_INDEX_SIZE, M_DEVBUF, &sc->sc_index_mask);
311 wg_route_init(&sc->sc_routes);
317 wg_mtu_set(if_ctx_t ctx, uint32_t mtu)
324 wg_set_promisc(if_ctx_t ctx, int flags)
331 wg_detach(if_ctx_t ctx)
335 sc = iflib_get_softc(ctx);
336 if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
338 wg_socket_reinit(sc, NULL, NULL);
339 taskqgroup_drain_all(qgroup_if_io_tqg);
340 pause("link_down", hz/4);
341 wg_peer_remove_all(sc);
342 pause("link_down", hz);
343 mtx_destroy(&sc->sc_mtx);
344 rw_destroy(&sc->sc_index_lock);
345 taskqgroup_detach(qgroup_if_io_tqg, &sc->sc_handshake);
346 crypto_taskq_destroy(sc);
347 buf_ring_free(sc->sc_encap_ring, M_WG);
348 buf_ring_free(sc->sc_decap_ring, M_WG);
350 wg_route_destroy(&sc->sc_routes);
351 wg_hashtable_destroy(&sc->sc_hashtable);
352 atomic_add_int(&clone_count, -1);
357 wg_init(if_ctx_t ctx)
363 sc = iflib_get_softc(ctx);
364 ifp = iflib_get_ifp(ctx);
365 rc = wg_socket_init(sc);
368 if_link_state_change(ifp, LINK_STATE_UP);
372 wg_stop(if_ctx_t ctx)
377 sc = iflib_get_softc(ctx);
378 ifp = iflib_get_ifp(ctx);
379 if_link_state_change(ifp, LINK_STATE_DOWN);
383 wg_peer_to_nvl(struct wg_peer *peer)
389 struct wg_allowedip *aip;
391 if ((nvl = nvlist_create(0)) == NULL)
393 key = peer->p_remote.r_public;
394 nvlist_add_binary(nvl, "public-key", key, WG_KEY_SIZE);
395 nvlist_add_binary(nvl, "endpoint", &peer->p_endpoint.e_remote, sizeof(struct sockaddr));
397 CK_LIST_FOREACH(rt, &peer->p_routes, r_entry) {
400 aip = malloc(count*sizeof(*aip), M_TEMP, M_WAITOK);
401 CK_LIST_FOREACH(rt, &peer->p_routes, r_entry) {
402 memcpy(&aip[i++], &rt->r_cidr, sizeof(*aip));
404 nvlist_add_binary(nvl, "allowed-ips", aip, count*sizeof(*aip));
410 wg_marshal_peers(struct wg_softc *sc, nvlist_t **nvlp, nvlist_t ***nvl_arrayp, int *peer_countp)
412 struct wg_peer *peer;
413 int err, i, peer_count;
414 nvlist_t *nvl, **nvl_array;
415 struct epoch_tracker et;
428 peer_count = sc->sc_hashtable.h_num_peers;
429 if (peer_count == 0) {
430 printf("no peers found\n");
434 if (nvlp && (nvl = nvlist_create(0)) == NULL)
437 nvl_array = malloc(peer_count*sizeof(void*), M_TEMP, M_WAITOK);
439 CK_LIST_FOREACH(peer, &sc->sc_hashtable.h_peers_list, p_entry) {
440 nvl_array[i] = wg_peer_to_nvl(peer);
441 if (nvl_array[i] == NULL) {
442 printf("wg_peer_to_nvl failed on %d peer\n", i);
446 packed = nvlist_pack(nvl_array[i], &size);
447 if (packed == NULL) {
448 printf("nvlist_pack(%p, %p) => %d",
449 nvl_array[i], &size, nvlist_error(nvl));
451 free(packed, M_NVLIST);
458 *peer_countp = peer_count = i;
459 if (peer_count == 0) {
460 printf("no peers found in list\n");
465 nvlist_add_nvlist_array(nvl, "peer-list",
466 (const nvlist_t * const *)nvl_array, peer_count);
467 if ((err = nvlist_error(nvl))) {
468 printf("nvlist_add_nvlist_array(%p, \"peer-list\", %p, %d) => %d\n",
469 nvl, nvl_array, peer_count, err);
474 *nvl_arrayp = nvl_array;
481 wgc_get(struct wg_softc *sc, struct ifdrv *ifd)
483 nvlist_t *nvl, **nvl_array;
488 nvl = nvlist_create(0);
494 if (sc->sc_socket.so_port != 0)
495 nvlist_add_number(nvl, "listen-port", sc->sc_socket.so_port);
496 if (sc->sc_local.l_has_identity) {
497 nvlist_add_binary(nvl, "public-key", sc->sc_local.l_public, WG_KEY_SIZE);
498 if (curthread->td_ucred->cr_uid == 0)
499 nvlist_add_binary(nvl, "private-key", sc->sc_local.l_private, WG_KEY_SIZE);
501 if (sc->sc_hashtable.h_num_peers > 0) {
502 err = wg_marshal_peers(sc, NULL, &nvl_array, &peer_count);
505 nvlist_add_nvlist_array(nvl, "peer-list",
506 (const nvlist_t * const *)nvl_array, peer_count);
508 packed = nvlist_pack(nvl, &size);
511 if (ifd->ifd_len == 0) {
515 if (ifd->ifd_len < size) {
519 if (ifd->ifd_data == NULL) {
523 err = copyout(packed, ifd->ifd_data, size);
527 free(packed, M_NVLIST);
532 wg_allowedip_valid(const struct wg_allowedip *wip)
539 wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl)
541 uint8_t public[WG_KEY_SIZE];
543 const struct sockaddr *endpoint;
544 int i, err, allowedip_count;
547 struct wg_peer *peer = NULL;
548 bool need_insert = false;
549 dev = iflib_get_dev(sc->wg_ctx);
551 if (!nvlist_exists_binary(nvl, "public-key")) {
552 device_printf(dev, "peer has no public-key\n");
555 pub_key = nvlist_get_binary(nvl, "public-key", &size);
556 if (size != CURVE25519_KEY_SIZE) {
557 device_printf(dev, "%s bad length for public-key %zu\n", __func__, size);
560 if (noise_local_keys(&sc->sc_local, public, NULL) == 0 &&
561 bcmp(public, pub_key, WG_KEY_SIZE) == 0) {
562 device_printf(dev, "public-key for peer already in use by host\n");
565 peer = wg_peer_lookup(sc, pub_key);
566 if (nvlist_exists_bool(nvl, "peer-remove") &&
567 nvlist_get_bool(nvl, "peer-remove")) {
569 wg_hashtable_peer_remove(&sc->sc_hashtable, peer);
570 wg_peer_destroy(peer);
572 printf("peer removed\n");
576 if (nvlist_exists_bool(nvl, "replace-allowedips") &&
577 nvlist_get_bool(nvl, "replace-allowedips") &&
580 wg_route_delete(&peer->p_sc->sc_routes, peer);
584 peer = wg_peer_alloc(sc);
585 noise_remote_init(&peer->p_remote, pub_key, &sc->sc_local);
586 cookie_maker_init(&peer->p_cookie, pub_key);
588 if (nvlist_exists_binary(nvl, "endpoint")) {
589 endpoint = nvlist_get_binary(nvl, "endpoint", &size);
590 if (size != sizeof(*endpoint)) {
591 device_printf(dev, "%s bad length for endpoint %zu\n", __func__, size);
595 memcpy(&peer->p_endpoint.e_remote, endpoint,
596 sizeof(peer->p_endpoint.e_remote));
598 if (nvlist_exists_binary(nvl, "pre-shared-key")) {
601 key = nvlist_get_binary(nvl, "pre-shared-key", &size);
602 noise_remote_set_psk(&peer->p_remote, key);
604 if (nvlist_exists_number(nvl, "persistent-keepalive-interval")) {
607 pki = nvlist_get_number(nvl, "persistent-keepalive-interval");
608 wg_timers_set_persistent_keepalive(&peer->p_timers, pki);
610 if (nvlist_exists_binary(nvl, "allowed-ips")) {
611 const struct wg_allowedip *aip, *aip_base;
613 aip = aip_base = nvlist_get_binary(nvl, "allowed-ips", &size);
614 if (size % sizeof(struct wg_allowedip) != 0) {
615 device_printf(dev, "%s bad length for allowed-ips %zu not integer multiple of struct size\n", __func__, size);
619 allowedip_count = size/sizeof(struct wg_allowedip);
620 for (i = 0; i < allowedip_count; i++) {
621 if (!wg_allowedip_valid(&aip_base[i])) {
622 device_printf(dev, "%s allowedip %d not valid\n", __func__, i);
627 for (int i = 0; i < allowedip_count; i++, aip++) {
628 if ((err = wg_route_add(&sc->sc_routes, peer, aip)) != 0) {
629 printf("route add %d failed -> %d\n", i, err);
634 wg_hashtable_peer_insert(&sc->sc_hashtable, peer);
638 wg_peer_destroy(peer);
643 wgc_set(struct wg_softc *sc, struct ifdrv *ifd)
645 uint8_t public[WG_KEY_SIZE];
652 if (ifd->ifd_len == 0 || ifd->ifd_data == NULL)
655 dev = iflib_get_dev(sc->wg_ctx);
656 nvlpacked = malloc(ifd->ifd_len, M_TEMP, M_WAITOK);
657 err = copyin(ifd->ifd_data, nvlpacked, ifd->ifd_len);
660 nvl = nvlist_unpack(nvlpacked, ifd->ifd_len, 0);
662 device_printf(dev, "%s nvlist_unpack failed\n", __func__);
666 if (nvlist_exists_bool(nvl, "replace-peers") &&
667 nvlist_get_bool(nvl, "replace-peers"))
668 wg_peer_remove_all(sc);
669 if (nvlist_exists_number(nvl, "listen-port")) {
670 int listen_port __unused = nvlist_get_number(nvl, "listen-port");
674 if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
675 pause("link_down", hz/4);
676 wg_socket_reinit(sc, NULL, NULL);
677 sc->sc_socket.so_port = listen_port;
678 if ((err = wg_socket_init(sc)) != 0)
680 if_link_state_change(sc->sc_ifp, LINK_STATE_UP);
682 if (nvlist_exists_binary(nvl, "private-key")) {
683 struct noise_local *local;
684 const void *key = nvlist_get_binary(nvl, "private-key", &size);
686 if (size != CURVE25519_KEY_SIZE) {
687 device_printf(dev, "%s bad length for private-key %zu\n", __func__, size);
694 local = &sc->sc_local;
695 noise_local_set_private(local, __DECONST(uint8_t *, key));
696 noise_local_keys(local, public, NULL);
697 cookie_checker_update(&sc->sc_cookie, public);
699 if (nvlist_exists_number(nvl, "user-cookie")) {
700 sc->sc_user_cookie = nvlist_get_number(nvl, "user-cookie");
705 if (nvlist_exists_nvlist_array(nvl, "peer-list")) {
707 const nvlist_t * const*nvl_peers;
709 nvl_peers = nvlist_get_nvlist_array(nvl, "peer-list", &peercount);
710 for (int i = 0; i < peercount; i++) {
711 wg_peer_add(sc, nvl_peers[i]);
717 free(nvlpacked, M_TEMP);
722 wg_priv_ioctl(if_ctx_t ctx, u_long command, caddr_t data)
724 struct wg_softc *sc = iflib_get_softc(ctx);
725 struct ifdrv *ifd = (struct ifdrv *)data;
731 ifd_cmd = ifd->ifd_cmd;
738 return (wgc_get(sc, ifd));
741 if (priv_check(curthread, PRIV_NET_HWIOCTL))
743 return (wgc_set(sc, ifd));
749 static device_method_t wg_if_methods[] = {
750 DEVMETHOD(ifdi_cloneattach, wg_cloneattach),
751 DEVMETHOD(ifdi_attach_post, wg_attach_post),
752 DEVMETHOD(ifdi_detach, wg_detach),
753 DEVMETHOD(ifdi_init, wg_init),
754 DEVMETHOD(ifdi_stop, wg_stop),
755 DEVMETHOD(ifdi_priv_ioctl, wg_priv_ioctl),
756 DEVMETHOD(ifdi_mtu_set, wg_mtu_set),
757 DEVMETHOD(ifdi_promisc_set, wg_set_promisc),
761 static driver_t wg_iflib_driver = {
762 "wg", wg_if_methods, sizeof(struct wg_softc)
765 char wg_driver_version[] = "0.0.1";
767 static struct if_shared_ctx wg_sctx_init = {
768 .isc_magic = IFLIB_MAGIC,
769 .isc_driver_version = wg_driver_version,
770 .isc_driver = &wg_iflib_driver,
771 .isc_flags = IFLIB_PSEUDO,
775 if_shared_ctx_t wg_sctx = &wg_sctx_init;
776 static if_pseudo_t wg_pseudo;
782 ratelimit_zone = uma_zcreate("wg ratelimit", sizeof(struct ratelimit),
783 NULL, NULL, NULL, NULL, 0, 0);
790 uma_zdestroy(ratelimit_zone);
798 if ((rc = wg_ctx_init()))
801 wg_pseudo = iflib_clone_register(wg_sctx);
802 if (wg_pseudo == NULL)
809 wg_module_deinit(void)
812 iflib_clone_deregister(wg_pseudo);
816 wg_module_event_handler(module_t mod, int what, void *arg)
822 if ((err = wg_module_init()) != 0)
826 if (clone_count == 0)
837 static moduledata_t wg_moduledata = {
839 wg_module_event_handler,
843 DECLARE_MODULE(wg, wg_moduledata, SI_SUB_PSEUDO, SI_ORDER_ANY);
844 MODULE_VERSION(wg, 1);
845 MODULE_DEPEND(wg, iflib, 1, 1, 1);
846 MODULE_DEPEND(wg, blake2, 1, 1, 1);
847 MODULE_DEPEND(wg, crypto, 1, 1, 1);