]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - sys/dev/if_wg/module/module.c
Import kernel WireGuard support
[FreeBSD/FreeBSD.git] / sys / dev / if_wg / module / module.c
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3  *
4  * Copyright (c) 2019-2020 Rubicon Communications, LLC (Netgate)
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  *   1. Redistributions of source code must retain the above copyright
10  *      notice, this list of conditions and the following disclaimer.
11  *   2. Redistributions in binary form must reproduce the above copyright
12  *      notice, this list of conditions and the following disclaimer in the
13  *      documentation and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
19  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
21  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
22  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
23  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
24  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
25  * SUCH DAMAGE.
26  */
27
28 #include <sys/cdefs.h>
29 __FBSDID("$FreeBSD$");
30
31 #include "opt_inet.h"
32 #include "opt_inet6.h"
33 #include <sys/param.h>
34 #include <sys/types.h>
35 #include <sys/systm.h>
36 #include <sys/kernel.h>
37 #include <sys/lock.h>
38 #include <sys/priv.h>
39 #include <sys/mutex.h>
40 #include <sys/mbuf.h>
41 #include <sys/module.h>
42 #include <sys/proc.h>
43 #include <sys/socket.h>
44 #include <sys/sockio.h>
45 #include <sys/queue.h>
46 #include <sys/smp.h>
47
48 #include <net/if.h>
49 #include <net/ethernet.h>
50 #include <net/if_var.h>
51 #include <net/iflib.h>
52 #include <net/if_clone.h>
53 #include <net/radix.h>
54 #include <net/bpf.h>
55 #include <net/mp_ring.h>
56
57 #include "ifdi_if.h"
58
59 #include <sys/wg_module.h>
60 #include <crypto/zinc.h>
61 #include <sys/wg_noise.h>
62 #include <sys/if_wg_session_vars.h>
63 #include <sys/if_wg_session.h>
64
65 MALLOC_DEFINE(M_WG, "WG", "wireguard");
66
67 #define WG_CAPS         IFCAP_LINKSTATE
68 #define ph_family       PH_loc.eight[5]
69
70 TASKQGROUP_DECLARE(if_io_tqg);
71
72 static int clone_count;
73 uma_zone_t ratelimit_zone;
74
75 void
76 wg_encrypt_dispatch(struct wg_softc *sc)
77 {
78         for (int i = 0; i < mp_ncpus; i++) {
79                 if (sc->sc_encrypt[i].gt_task.ta_flags & TASK_ENQUEUED)
80                         continue;
81                 GROUPTASK_ENQUEUE(&sc->sc_encrypt[i]);
82         }
83 }
84
85 void
86 wg_decrypt_dispatch(struct wg_softc *sc)
87 {
88         for (int i = 0; i < mp_ncpus; i++) {
89                 if (sc->sc_decrypt[i].gt_task.ta_flags & TASK_ENQUEUED)
90                         continue;
91                 GROUPTASK_ENQUEUE(&sc->sc_decrypt[i]);
92         }
93 }
94
95 static void
96 crypto_taskq_setup(struct wg_softc *sc)
97 {
98         device_t dev = iflib_get_dev(sc->wg_ctx);
99
100         sc->sc_encrypt = malloc(sizeof(struct grouptask)*mp_ncpus, M_WG, M_WAITOK);
101         sc->sc_decrypt = malloc(sizeof(struct grouptask)*mp_ncpus, M_WG, M_WAITOK);
102
103         for (int i = 0; i < mp_ncpus; i++) {
104                 GROUPTASK_INIT(&sc->sc_encrypt[i], 0,
105                      (gtask_fn_t *)wg_softc_encrypt, sc);
106                 taskqgroup_attach_cpu(qgroup_if_io_tqg, &sc->sc_encrypt[i], sc,  i, dev, NULL, "wg encrypt");
107                 GROUPTASK_INIT(&sc->sc_decrypt[i], 0,
108                     (gtask_fn_t *)wg_softc_decrypt, sc);
109                 taskqgroup_attach_cpu(qgroup_if_io_tqg, &sc->sc_decrypt[i], sc, i, dev, NULL, "wg decrypt");
110         }
111 }
112
113 static void
114 crypto_taskq_destroy(struct wg_softc *sc)
115 {
116         for (int i = 0; i < mp_ncpus; i++) {
117                 taskqgroup_detach(qgroup_if_io_tqg, &sc->sc_encrypt[i]);
118                 taskqgroup_detach(qgroup_if_io_tqg, &sc->sc_decrypt[i]);
119         }
120         free(sc->sc_encrypt, M_WG);
121         free(sc->sc_decrypt, M_WG);
122 }
123
124 static int
125 wg_cloneattach(if_ctx_t ctx, struct if_clone *ifc, const char *name, caddr_t params)
126 {
127         struct wg_softc *sc = iflib_get_softc(ctx);
128         if_softc_ctx_t scctx;
129         device_t dev;
130         struct iovec iov;
131         nvlist_t *nvl;
132         void *packed;
133         struct noise_local *local;
134         uint8_t                  public[WG_KEY_SIZE];
135         struct noise_upcall      noise_upcall;
136         int err;
137         uint16_t listen_port;
138         const void *key;
139         size_t size;
140
141         err = 0;
142         dev = iflib_get_dev(ctx);
143         if (params == NULL) {
144                 key = NULL;
145                 listen_port = 0;
146                 nvl = NULL;
147                 packed = NULL;
148                 goto unpacked;
149         }
150         if (copyin(params, &iov, sizeof(iov)))
151                 return (EFAULT);
152         /* check that this is reasonable */
153         size = iov.iov_len;
154         packed = malloc(size, M_TEMP, M_WAITOK);
155         if (copyin(iov.iov_base, packed, size)) {
156                 err = EFAULT;
157                 goto out;
158         }
159         nvl = nvlist_unpack(packed, size, 0);
160         if (nvl == NULL) {
161                 device_printf(dev, "%s nvlist_unpack failed\n", __func__);
162                 err = EBADMSG;
163                 goto out;
164         }
165         if (!nvlist_exists_number(nvl, "listen-port")) {
166                 device_printf(dev, "%s listen-port not set\n", __func__);
167                 err = EBADMSG;
168                 goto nvl_out;
169         }
170         listen_port = nvlist_get_number(nvl, "listen-port");
171
172         if (!nvlist_exists_binary(nvl, "private-key")) {
173                 device_printf(dev, "%s private-key not set\n", __func__);
174                 err = EBADMSG;
175                 goto nvl_out;
176         }
177         key = nvlist_get_binary(nvl, "private-key", &size);
178         if (size != CURVE25519_KEY_SIZE) {
179                 device_printf(dev, "%s bad length for private-key %zu\n", __func__, size);
180                 err = EBADMSG;
181                 goto nvl_out;
182         }
183 unpacked:
184         local = &sc->sc_local;
185         noise_upcall.u_arg = sc;
186         noise_upcall.u_remote_get =
187                 (struct noise_remote *(*)(void *, uint8_t *))wg_remote_get;
188         noise_upcall.u_index_set =
189                 (uint32_t (*)(void *, struct noise_remote *))wg_index_set;
190         noise_upcall.u_index_drop =
191                 (void (*)(void *, uint32_t))wg_index_drop;
192         noise_local_init(local, &noise_upcall);
193         cookie_checker_init(&sc->sc_cookie, ratelimit_zone);
194
195         sc->sc_socket.so_port = listen_port;
196
197         if (key != NULL) {
198                 noise_local_set_private(local, __DECONST(uint8_t *, key));
199                 noise_local_keys(local, public, NULL);
200                 cookie_checker_update(&sc->sc_cookie, public);
201         }
202         atomic_add_int(&clone_count, 1);
203         scctx = sc->shared = iflib_get_softc_ctx(ctx);
204         scctx->isc_capenable = WG_CAPS;
205         scctx->isc_tx_csum_flags = CSUM_TCP | CSUM_UDP | CSUM_TSO | CSUM_IP6_TCP \
206                 | CSUM_IP6_UDP | CSUM_IP6_TCP;
207         sc->wg_ctx = ctx;
208         sc->sc_ifp = iflib_get_ifp(ctx);
209
210         mbufq_init(&sc->sc_handshake_queue, MAX_QUEUED_INCOMING_HANDSHAKES);
211         mtx_init(&sc->sc_mtx, NULL, "wg softc lock",  MTX_DEF);
212         rw_init(&sc->sc_index_lock, "wg index lock");
213         sc->sc_encap_ring = buf_ring_alloc(MAX_QUEUED_PACKETS, M_WG, M_WAITOK, &sc->sc_mtx);
214         sc->sc_decap_ring = buf_ring_alloc(MAX_QUEUED_PACKETS, M_WG, M_WAITOK, &sc->sc_mtx);
215         GROUPTASK_INIT(&sc->sc_handshake, 0,
216             (gtask_fn_t *)wg_softc_handshake_receive, sc);
217         taskqgroup_attach(qgroup_if_io_tqg, &sc->sc_handshake, sc, dev, NULL, "wg tx initiation");
218         crypto_taskq_setup(sc);
219  nvl_out:
220         if (nvl != NULL)
221                 nvlist_destroy(nvl);
222 out:
223         free(packed, M_TEMP);
224         return (err);
225 }
226
227 static int
228 wg_transmit(struct ifnet *ifp, struct mbuf *m)
229 {
230         struct wg_softc *sc;
231         sa_family_t family;
232         struct epoch_tracker et;
233         struct wg_peer *peer;
234         struct wg_tag *t;
235         uint32_t af;
236         int rc;
237
238
239         /*
240          * Work around lifetime issue in the ipv6 mld code.
241          */
242         if (__predict_false(ifp->if_flags & IFF_DYING))
243                 return (ENXIO);
244
245         rc = 0;
246         sc = iflib_get_softc(ifp->if_softc);
247         if ((t = wg_tag_get(m)) == NULL) {
248                 rc = ENOBUFS;
249                 goto early_out;
250         }
251         af = m->m_pkthdr.ph_family;
252         BPF_MTAP2(ifp, &af, sizeof(af), m);
253
254         NET_EPOCH_ENTER(et);
255         peer = wg_route_lookup(&sc->sc_routes, m, OUT);
256         if (__predict_false(peer == NULL)) {
257                 rc = ENOKEY;
258                 printf("peer not found - dropping %p\n", m);
259                 /* XXX log */
260                 goto err;
261         }
262
263         family = atomic_load_acq(peer->p_endpoint.e_remote.r_sa.sa_family);
264         if (__predict_false(family != AF_INET && family != AF_INET6)) {
265                 rc = EHOSTUNREACH;
266                 /* XXX log */
267                 goto err;
268         }
269         t->t_peer = peer;
270         t->t_mbuf = NULL;
271         t->t_done = 0;
272         t->t_mtu = ifp->if_mtu;
273
274         rc = wg_queue_out(peer, m);
275         if (rc == 0)
276                 wg_encrypt_dispatch(peer->p_sc);
277         NET_EPOCH_EXIT(et);
278         return (rc); 
279 err:
280         NET_EPOCH_EXIT(et);
281 early_out:
282         if_inc_counter(sc->sc_ifp, IFCOUNTER_OERRORS, 1);
283         /* XXX send ICMP unreachable */
284         m_free(m);
285         return (rc);
286 }
287
288 static int
289 wg_output(struct ifnet *ifp, struct mbuf *m, const struct sockaddr *sa, struct route *rt)
290 {
291         m->m_pkthdr.ph_family =  sa->sa_family;
292         return (wg_transmit(ifp, m));
293 }
294
295 static int
296 wg_attach_post(if_ctx_t ctx)
297 {
298         struct ifnet *ifp;
299         struct wg_softc *sc;
300
301         sc = iflib_get_softc(ctx);
302         ifp = iflib_get_ifp(ctx);
303         if_setmtu(ifp, ETHERMTU - 80);
304
305         if_setflagbits(ifp, IFF_NOARP, IFF_POINTOPOINT);
306         ifp->if_transmit = wg_transmit;
307         ifp->if_output = wg_output;
308
309         wg_hashtable_init(&sc->sc_hashtable);
310         sc->sc_index = hashinit(HASHTABLE_INDEX_SIZE, M_DEVBUF, &sc->sc_index_mask);
311         wg_route_init(&sc->sc_routes);
312
313         return (0);
314 }
315
316 static int
317 wg_mtu_set(if_ctx_t ctx, uint32_t mtu)
318 {
319
320         return (0);
321 }
322
323 static int
324 wg_set_promisc(if_ctx_t ctx, int flags)
325 {
326
327         return (0);
328 }
329
330 static int
331 wg_detach(if_ctx_t ctx)
332 {
333         struct wg_softc *sc;
334
335         sc = iflib_get_softc(ctx);
336         if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
337         NET_EPOCH_WAIT();
338         wg_socket_reinit(sc, NULL, NULL);
339         taskqgroup_drain_all(qgroup_if_io_tqg);
340         pause("link_down", hz/4);
341         wg_peer_remove_all(sc);
342         pause("link_down", hz);
343         mtx_destroy(&sc->sc_mtx);
344         rw_destroy(&sc->sc_index_lock);
345         taskqgroup_detach(qgroup_if_io_tqg, &sc->sc_handshake);
346         crypto_taskq_destroy(sc);
347         buf_ring_free(sc->sc_encap_ring, M_WG);
348         buf_ring_free(sc->sc_decap_ring, M_WG);
349
350         wg_route_destroy(&sc->sc_routes);
351         wg_hashtable_destroy(&sc->sc_hashtable);
352         atomic_add_int(&clone_count, -1);
353         return (0);
354 }
355
356 static void
357 wg_init(if_ctx_t ctx)
358 {
359         struct ifnet *ifp;
360         struct wg_softc *sc;
361         int rc;
362
363         sc = iflib_get_softc(ctx);
364         ifp = iflib_get_ifp(ctx);
365         rc = wg_socket_init(sc);
366         if (rc)
367                 return;
368         if_link_state_change(ifp, LINK_STATE_UP);
369 }
370
371 static void
372 wg_stop(if_ctx_t ctx)
373 {
374         struct wg_softc *sc;
375         struct ifnet *ifp;
376
377         sc  = iflib_get_softc(ctx);
378         ifp = iflib_get_ifp(ctx);
379         if_link_state_change(ifp, LINK_STATE_DOWN);
380 }
381
382 static nvlist_t *
383 wg_peer_to_nvl(struct wg_peer *peer)
384 {
385         struct wg_route *rt;
386         int i, count;
387         nvlist_t *nvl;
388         caddr_t key;
389         struct wg_allowedip *aip;
390
391         if ((nvl = nvlist_create(0)) == NULL)
392                 return (NULL);
393         key = peer->p_remote.r_public;
394         nvlist_add_binary(nvl, "public-key", key, WG_KEY_SIZE);
395         nvlist_add_binary(nvl, "endpoint", &peer->p_endpoint.e_remote, sizeof(struct sockaddr));
396         i = count = 0;
397         CK_LIST_FOREACH(rt, &peer->p_routes, r_entry) {
398                 count++;
399         }
400         aip = malloc(count*sizeof(*aip), M_TEMP, M_WAITOK);
401         CK_LIST_FOREACH(rt, &peer->p_routes, r_entry) {
402                 memcpy(&aip[i++], &rt->r_cidr, sizeof(*aip));
403         }
404         nvlist_add_binary(nvl, "allowed-ips", aip, count*sizeof(*aip));
405         free(aip, M_TEMP);
406         return (nvl);
407 }
408
409 static int
410 wg_marshal_peers(struct wg_softc *sc, nvlist_t **nvlp, nvlist_t ***nvl_arrayp, int *peer_countp)
411 {
412         struct wg_peer *peer;
413         int err, i, peer_count;
414         nvlist_t *nvl, **nvl_array;
415         struct epoch_tracker et;
416 #ifdef INVARIANTS
417         void *packed;
418         size_t size;
419 #endif
420         nvl = NULL;
421         nvl_array = NULL;
422         if (nvl_arrayp)
423                 *nvl_arrayp = NULL;
424         if (nvlp)
425                 *nvlp = NULL;
426         if (peer_countp)
427                 *peer_countp = 0;
428         peer_count = sc->sc_hashtable.h_num_peers;
429         if (peer_count == 0) {
430                 printf("no peers found\n");
431                 return (ENOENT);
432         }
433
434         if (nvlp && (nvl = nvlist_create(0)) == NULL)
435                 return (ENOMEM);
436         err = i = 0;
437         nvl_array = malloc(peer_count*sizeof(void*), M_TEMP, M_WAITOK);
438         NET_EPOCH_ENTER(et);
439         CK_LIST_FOREACH(peer, &sc->sc_hashtable.h_peers_list, p_entry) {
440                 nvl_array[i] = wg_peer_to_nvl(peer);
441                 if (nvl_array[i] == NULL) {
442                         printf("wg_peer_to_nvl failed on %d peer\n", i);
443                         break;
444                 }
445 #ifdef INVARIANTS
446                 packed = nvlist_pack(nvl_array[i], &size);
447                 if (packed == NULL) {
448                         printf("nvlist_pack(%p, %p) => %d",
449                                    nvl_array[i], &size, nvlist_error(nvl));
450                 }
451                 free(packed, M_NVLIST);
452 #endif  
453                 i++;
454                 if (i == peer_count)
455                         break;
456         }
457         NET_EPOCH_EXIT(et);
458         *peer_countp = peer_count = i;
459         if (peer_count == 0) {
460                 printf("no peers found in list\n");
461                 err = ENOENT;
462                 goto out;
463         }
464         if (nvl) {
465                 nvlist_add_nvlist_array(nvl, "peer-list",
466                     (const nvlist_t * const *)nvl_array, peer_count);
467                 if ((err = nvlist_error(nvl))) {
468                         printf("nvlist_add_nvlist_array(%p, \"peer-list\", %p, %d) => %d\n",
469                             nvl, nvl_array, peer_count, err);
470                         goto out;
471                 }
472                 *nvlp = nvl;
473         }
474         *nvl_arrayp = nvl_array;
475         return (0);
476  out:
477         return (err);
478 }
479
480 static int
481 wgc_get(struct wg_softc *sc, struct ifdrv *ifd)
482 {
483         nvlist_t *nvl, **nvl_array;
484         void *packed;
485         size_t size;
486         int peer_count, err;
487
488         nvl = nvlist_create(0);
489         if (nvl == NULL)
490                 return (ENOMEM);
491
492         err = 0;
493         packed = NULL;
494         if (sc->sc_socket.so_port != 0)
495                 nvlist_add_number(nvl, "listen-port", sc->sc_socket.so_port);
496         if (sc->sc_local.l_has_identity) {
497                 nvlist_add_binary(nvl, "public-key", sc->sc_local.l_public, WG_KEY_SIZE);
498                 if (curthread->td_ucred->cr_uid == 0)
499                         nvlist_add_binary(nvl, "private-key", sc->sc_local.l_private, WG_KEY_SIZE);
500         }
501         if (sc->sc_hashtable.h_num_peers > 0) {
502                 err = wg_marshal_peers(sc, NULL, &nvl_array, &peer_count);
503                 if (err)
504                         goto out;
505                 nvlist_add_nvlist_array(nvl, "peer-list",
506                     (const nvlist_t * const *)nvl_array, peer_count);
507         }
508         packed = nvlist_pack(nvl, &size);
509         if (packed == NULL)
510                 return (ENOMEM);
511         if (ifd->ifd_len == 0) {
512                 ifd->ifd_len = size;
513                 goto out;
514         }
515         if (ifd->ifd_len < size) {
516                 err = ENOSPC;
517                 goto out;
518         }
519         if (ifd->ifd_data == NULL) {
520                 err = EFAULT;
521                 goto out;
522         }
523         err = copyout(packed, ifd->ifd_data, size);
524         ifd->ifd_len = size;
525  out:
526         nvlist_destroy(nvl);
527         free(packed, M_NVLIST);
528         return (err);
529 }
530
531 static bool
532 wg_allowedip_valid(const struct wg_allowedip *wip)
533 {
534
535         return (true);
536 }
537
538 static int
539 wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl)
540 {
541         uint8_t                  public[WG_KEY_SIZE];
542         const void *pub_key;
543         const struct sockaddr *endpoint;
544         int i, err, allowedip_count;
545         device_t dev;
546         size_t size;
547         struct wg_peer *peer = NULL;
548         bool need_insert = false;
549         dev = iflib_get_dev(sc->wg_ctx);
550
551         if (!nvlist_exists_binary(nvl, "public-key")) {
552                 device_printf(dev, "peer has no public-key\n");
553                 return (EINVAL);
554         }
555         pub_key = nvlist_get_binary(nvl, "public-key", &size);
556         if (size != CURVE25519_KEY_SIZE) {
557                 device_printf(dev, "%s bad length for public-key %zu\n", __func__, size);
558                 return (EINVAL);
559         }
560         if (noise_local_keys(&sc->sc_local, public, NULL) == 0 &&
561             bcmp(public, pub_key, WG_KEY_SIZE) == 0) {
562                 device_printf(dev, "public-key for peer already in use by host\n");
563                 return (EINVAL);
564         }
565         peer = wg_peer_lookup(sc, pub_key);
566         if (nvlist_exists_bool(nvl, "peer-remove") &&
567                 nvlist_get_bool(nvl, "peer-remove")) {
568                 if (peer != NULL) {
569                         wg_hashtable_peer_remove(&sc->sc_hashtable, peer);
570                         wg_peer_destroy(peer);
571                         /* XXX free */
572                         printf("peer removed\n");
573                 }
574                 return (0);
575         }
576         if (nvlist_exists_bool(nvl, "replace-allowedips") &&
577                 nvlist_get_bool(nvl, "replace-allowedips") &&
578             peer != NULL) {
579
580                 wg_route_delete(&peer->p_sc->sc_routes, peer);
581         }
582         if (peer == NULL) {
583                 need_insert = true;
584                 peer = wg_peer_alloc(sc);
585                 noise_remote_init(&peer->p_remote, pub_key, &sc->sc_local);
586                 cookie_maker_init(&peer->p_cookie, pub_key);
587         }
588         if (nvlist_exists_binary(nvl, "endpoint")) {
589                 endpoint = nvlist_get_binary(nvl, "endpoint", &size);
590                 if (size != sizeof(*endpoint)) {
591                         device_printf(dev, "%s bad length for endpoint %zu\n", __func__, size);
592                         err = EBADMSG;
593                         goto out;
594                 }
595                 memcpy(&peer->p_endpoint.e_remote, endpoint,
596                     sizeof(peer->p_endpoint.e_remote));
597         }
598         if (nvlist_exists_binary(nvl, "pre-shared-key")) {
599                 const void *key;
600
601                 key = nvlist_get_binary(nvl, "pre-shared-key", &size);
602                 noise_remote_set_psk(&peer->p_remote, key);
603         }
604         if (nvlist_exists_number(nvl, "persistent-keepalive-interval")) {
605                 uint16_t pki;
606
607                 pki = nvlist_get_number(nvl, "persistent-keepalive-interval");
608                 wg_timers_set_persistent_keepalive(&peer->p_timers, pki);
609         }
610         if (nvlist_exists_binary(nvl, "allowed-ips")) {
611                 const struct wg_allowedip *aip, *aip_base;
612
613                 aip = aip_base = nvlist_get_binary(nvl, "allowed-ips", &size);
614                 if (size % sizeof(struct wg_allowedip) != 0) {
615                         device_printf(dev, "%s bad length for allowed-ips %zu not integer multiple of struct size\n", __func__, size);
616                         err = EBADMSG;
617                         goto out;
618                 }
619                 allowedip_count = size/sizeof(struct wg_allowedip);
620                 for (i = 0; i < allowedip_count; i++) {
621                         if (!wg_allowedip_valid(&aip_base[i])) {
622                                 device_printf(dev, "%s allowedip %d not valid\n", __func__, i);
623                                 err = EBADMSG;
624                                 goto out;
625                         }
626                 }
627                 for (int i = 0; i < allowedip_count; i++, aip++) {
628                         if ((err = wg_route_add(&sc->sc_routes, peer, aip)) != 0) {
629                                 printf("route add %d failed -> %d\n", i, err);
630                         }
631                 }
632         }
633         if (need_insert)
634                 wg_hashtable_peer_insert(&sc->sc_hashtable, peer);
635         return (0);
636
637 out:
638         wg_peer_destroy(peer);
639         return (err);
640 }
641
642 static int
643 wgc_set(struct wg_softc *sc, struct ifdrv *ifd)
644 {
645         uint8_t                  public[WG_KEY_SIZE];
646         void *nvlpacked;
647         nvlist_t *nvl;
648         device_t dev;
649         ssize_t size;
650         int err;
651
652         if (ifd->ifd_len == 0 || ifd->ifd_data == NULL)
653                 return (EFAULT);
654
655         dev = iflib_get_dev(sc->wg_ctx);
656         nvlpacked = malloc(ifd->ifd_len, M_TEMP, M_WAITOK);
657         err = copyin(ifd->ifd_data, nvlpacked, ifd->ifd_len);
658         if (err)
659                 goto out;
660         nvl = nvlist_unpack(nvlpacked, ifd->ifd_len, 0);
661         if (nvl == NULL) {
662                 device_printf(dev, "%s nvlist_unpack failed\n", __func__);
663                 err = EBADMSG;
664                 goto out;
665         }
666         if (nvlist_exists_bool(nvl, "replace-peers") &&
667                 nvlist_get_bool(nvl, "replace-peers"))
668                 wg_peer_remove_all(sc);
669         if (nvlist_exists_number(nvl, "listen-port")) {
670                 int listen_port __unused = nvlist_get_number(nvl, "listen-port");
671                         /*
672                          * Set listen port
673                          */
674                 if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
675                 pause("link_down", hz/4);
676                 wg_socket_reinit(sc, NULL, NULL);
677                 sc->sc_socket.so_port = listen_port;
678                 if ((err = wg_socket_init(sc)) != 0)
679                         goto out;
680            if_link_state_change(sc->sc_ifp, LINK_STATE_UP);
681         }
682         if (nvlist_exists_binary(nvl, "private-key")) {
683                 struct noise_local *local;
684                 const void *key = nvlist_get_binary(nvl, "private-key", &size);
685
686                 if (size != CURVE25519_KEY_SIZE) {
687                         device_printf(dev, "%s bad length for private-key %zu\n", __func__, size);
688                         err = EBADMSG;
689                         goto nvl_out;
690                 }
691                 /*
692                  * set private key
693                  */
694                 local = &sc->sc_local;
695                 noise_local_set_private(local, __DECONST(uint8_t *, key));
696                 noise_local_keys(local, public, NULL);
697                 cookie_checker_update(&sc->sc_cookie, public);
698         }
699         if (nvlist_exists_number(nvl, "user-cookie")) {
700                 sc->sc_user_cookie = nvlist_get_number(nvl, "user-cookie");
701                 /*
702                  * setsockopt
703                  */
704         }
705         if (nvlist_exists_nvlist_array(nvl, "peer-list")) {
706                 size_t peercount;
707                 const nvlist_t * const*nvl_peers;
708
709                 nvl_peers = nvlist_get_nvlist_array(nvl, "peer-list", &peercount);
710                 for (int i = 0; i < peercount; i++) {
711                         wg_peer_add(sc, nvl_peers[i]);
712                 }
713         }
714 nvl_out:
715         nvlist_destroy(nvl);
716 out:
717         free(nvlpacked, M_TEMP);
718         return (err);
719 }
720
721 static int
722 wg_priv_ioctl(if_ctx_t ctx, u_long command, caddr_t data)
723 {
724         struct wg_softc *sc = iflib_get_softc(ctx);
725         struct ifdrv *ifd = (struct ifdrv *)data;
726         int ifd_cmd;
727
728         switch (command) {
729                 case SIOCGDRVSPEC:
730                 case SIOCSDRVSPEC:
731                         ifd_cmd = ifd->ifd_cmd;
732                         break;
733                 default:
734                         return (EINVAL);
735         }
736         switch (ifd_cmd) {
737                 case WGC_GET:
738                         return (wgc_get(sc, ifd));
739                         break;
740                 case WGC_SET:
741                         if (priv_check(curthread, PRIV_NET_HWIOCTL))
742                                 return (EPERM);
743                         return (wgc_set(sc, ifd));
744                         break;
745         }
746         return (ENOTSUP);
747 }
748
749 static device_method_t wg_if_methods[] = {
750         DEVMETHOD(ifdi_cloneattach, wg_cloneattach),
751         DEVMETHOD(ifdi_attach_post, wg_attach_post),
752         DEVMETHOD(ifdi_detach, wg_detach),
753         DEVMETHOD(ifdi_init, wg_init),
754         DEVMETHOD(ifdi_stop, wg_stop),
755         DEVMETHOD(ifdi_priv_ioctl, wg_priv_ioctl),
756         DEVMETHOD(ifdi_mtu_set, wg_mtu_set),
757         DEVMETHOD(ifdi_promisc_set, wg_set_promisc),
758         DEVMETHOD_END
759 };
760
761 static driver_t wg_iflib_driver = {
762         "wg", wg_if_methods, sizeof(struct wg_softc)
763 };
764
765 char wg_driver_version[] = "0.0.1";
766
767 static struct if_shared_ctx wg_sctx_init = {
768         .isc_magic = IFLIB_MAGIC,
769         .isc_driver_version = wg_driver_version,
770         .isc_driver = &wg_iflib_driver,
771         .isc_flags = IFLIB_PSEUDO,
772         .isc_name = "wg",
773 };
774
775 if_shared_ctx_t wg_sctx = &wg_sctx_init;
776 static if_pseudo_t wg_pseudo;
777
778
779 int
780 wg_ctx_init(void)
781 {
782         ratelimit_zone = uma_zcreate("wg ratelimit", sizeof(struct ratelimit),
783              NULL, NULL, NULL, NULL, 0, 0);
784         return (0);
785 }
786
787 void
788 wg_ctx_uninit(void)
789 {
790         uma_zdestroy(ratelimit_zone);
791 }
792
793 static int
794 wg_module_init(void)
795 {
796         int rc;
797
798         if ((rc = wg_ctx_init()))
799                 return (rc);
800
801         wg_pseudo = iflib_clone_register(wg_sctx);
802         if (wg_pseudo == NULL)
803                 return (ENXIO);
804
805         return (0);
806 }
807
808 static void
809 wg_module_deinit(void)
810 {
811         wg_ctx_uninit();
812         iflib_clone_deregister(wg_pseudo);
813 }
814
815 static int
816 wg_module_event_handler(module_t mod, int what, void *arg)
817 {
818         int err;
819
820         switch (what) {
821                 case MOD_LOAD:
822                         if ((err = wg_module_init()) != 0)
823                                 return (err);
824                         break;
825                 case MOD_UNLOAD:
826                         if (clone_count == 0)
827                                 wg_module_deinit();
828                         else
829                                 return (EBUSY);
830                         break;
831                 default:
832                         return (EOPNOTSUPP);
833         }
834         return (0);
835 }
836
837 static moduledata_t wg_moduledata = {
838         "wg",
839         wg_module_event_handler,
840         NULL
841 };
842
843 DECLARE_MODULE(wg, wg_moduledata, SI_SUB_PSEUDO, SI_ORDER_ANY);
844 MODULE_VERSION(wg, 1);
845 MODULE_DEPEND(wg, iflib, 1, 1, 1);
846 MODULE_DEPEND(wg, blake2, 1, 1, 1);
847 MODULE_DEPEND(wg, crypto, 1, 1, 1);