]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - sys/dev/if_wg/module/module.c
Import wireguard fixes from pfSense 2.5
[FreeBSD/FreeBSD.git] / sys / dev / if_wg / module / module.c
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3  *
4  * Copyright (c) 2019-2020 Rubicon Communications, LLC (Netgate)
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  *   1. Redistributions of source code must retain the above copyright
10  *      notice, this list of conditions and the following disclaimer.
11  *   2. Redistributions in binary form must reproduce the above copyright
12  *      notice, this list of conditions and the following disclaimer in the
13  *      documentation and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
19  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
21  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
22  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
23  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
24  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
25  * SUCH DAMAGE.
26  */
27
28 #include <sys/cdefs.h>
29 __FBSDID("$FreeBSD$");
30
31 #include "opt_inet.h"
32 #include "opt_inet6.h"
33 #include <sys/param.h>
34 #include <sys/types.h>
35 #include <sys/systm.h>
36 #include <sys/kernel.h>
37 #include <sys/lock.h>
38 #include <sys/priv.h>
39 #include <sys/mutex.h>
40 #include <sys/mbuf.h>
41 #include <sys/module.h>
42 #include <sys/proc.h>
43 #include <sys/socket.h>
44 #include <sys/sockio.h>
45 #include <sys/queue.h>
46 #include <sys/smp.h>
47
48 #include <net/if.h>
49 #include <net/ethernet.h>
50 #include <net/if_var.h>
51 #include <net/iflib.h>
52 #include <net/if_clone.h>
53 #include <net/radix.h>
54 #include <net/bpf.h>
55 #include <net/mp_ring.h>
56
57 #include "ifdi_if.h"
58
59 #include <sys/wg_module.h>
60 #include <crypto/zinc.h>
61 #include <sys/wg_noise.h>
62 #include <sys/if_wg_session_vars.h>
63 #include <sys/if_wg_session.h>
64
65 MALLOC_DEFINE(M_WG, "WG", "wireguard");
66
67 #define WG_CAPS         IFCAP_LINKSTATE
68 #define ph_family       PH_loc.eight[5]
69
70 TASKQGROUP_DECLARE(if_io_tqg);
71
72 static int clone_count;
73 uma_zone_t ratelimit_zone;
74
75 void
76 wg_encrypt_dispatch(struct wg_softc *sc)
77 {
78         for (int i = 0; i < mp_ncpus; i++) {
79                 if (sc->sc_encrypt[i].gt_task.ta_flags & TASK_ENQUEUED)
80                         continue;
81                 GROUPTASK_ENQUEUE(&sc->sc_encrypt[i]);
82         }
83 }
84
85 void
86 wg_decrypt_dispatch(struct wg_softc *sc)
87 {
88         for (int i = 0; i < mp_ncpus; i++) {
89                 if (sc->sc_decrypt[i].gt_task.ta_flags & TASK_ENQUEUED)
90                         continue;
91                 GROUPTASK_ENQUEUE(&sc->sc_decrypt[i]);
92         }
93 }
94
95 static void
96 crypto_taskq_setup(struct wg_softc *sc)
97 {
98         device_t dev = iflib_get_dev(sc->wg_ctx);
99
100         sc->sc_encrypt = malloc(sizeof(struct grouptask)*mp_ncpus, M_WG, M_WAITOK);
101         sc->sc_decrypt = malloc(sizeof(struct grouptask)*mp_ncpus, M_WG, M_WAITOK);
102
103         for (int i = 0; i < mp_ncpus; i++) {
104                 GROUPTASK_INIT(&sc->sc_encrypt[i], 0,
105                      (gtask_fn_t *)wg_softc_encrypt, sc);
106                 taskqgroup_attach_cpu(qgroup_if_io_tqg, &sc->sc_encrypt[i], sc,  i, dev, NULL, "wg encrypt");
107                 GROUPTASK_INIT(&sc->sc_decrypt[i], 0,
108                     (gtask_fn_t *)wg_softc_decrypt, sc);
109                 taskqgroup_attach_cpu(qgroup_if_io_tqg, &sc->sc_decrypt[i], sc, i, dev, NULL, "wg decrypt");
110         }
111 }
112
113 static void
114 crypto_taskq_destroy(struct wg_softc *sc)
115 {
116         for (int i = 0; i < mp_ncpus; i++) {
117                 taskqgroup_detach(qgroup_if_io_tqg, &sc->sc_encrypt[i]);
118                 taskqgroup_detach(qgroup_if_io_tqg, &sc->sc_decrypt[i]);
119         }
120         free(sc->sc_encrypt, M_WG);
121         free(sc->sc_decrypt, M_WG);
122 }
123
124 static int
125 wg_cloneattach(if_ctx_t ctx, struct if_clone *ifc, const char *name, caddr_t params)
126 {
127         struct wg_softc *sc = iflib_get_softc(ctx);
128         if_softc_ctx_t scctx;
129         device_t dev;
130         struct iovec iov;
131         nvlist_t *nvl;
132         void *packed;
133         struct noise_local *local;
134         uint8_t                  public[WG_KEY_SIZE];
135         struct noise_upcall      noise_upcall;
136         int err;
137         uint16_t listen_port;
138         const void *key;
139         size_t size;
140
141         err = 0;
142         dev = iflib_get_dev(ctx);
143         if (params == NULL) {
144                 key = NULL;
145                 listen_port = 0;
146                 nvl = NULL;
147                 packed = NULL;
148                 goto unpacked;
149         }
150         if (copyin(params, &iov, sizeof(iov)))
151                 return (EFAULT);
152         /* check that this is reasonable */
153         size = iov.iov_len;
154         packed = malloc(size, M_TEMP, M_WAITOK);
155         if (copyin(iov.iov_base, packed, size)) {
156                 err = EFAULT;
157                 goto out;
158         }
159         nvl = nvlist_unpack(packed, size, 0);
160         if (nvl == NULL) {
161                 device_printf(dev, "%s nvlist_unpack failed\n", __func__);
162                 err = EBADMSG;
163                 goto out;
164         }
165         if (!nvlist_exists_number(nvl, "listen-port")) {
166                 device_printf(dev, "%s listen-port not set\n", __func__);
167                 err = EBADMSG;
168                 goto nvl_out;
169         }
170         listen_port = nvlist_get_number(nvl, "listen-port");
171
172         if (!nvlist_exists_binary(nvl, "private-key")) {
173                 device_printf(dev, "%s private-key not set\n", __func__);
174                 err = EBADMSG;
175                 goto nvl_out;
176         }
177         key = nvlist_get_binary(nvl, "private-key", &size);
178         if (size != CURVE25519_KEY_SIZE) {
179                 device_printf(dev, "%s bad length for private-key %zu\n", __func__, size);
180                 err = EBADMSG;
181                 goto nvl_out;
182         }
183 unpacked:
184         local = &sc->sc_local;
185         noise_upcall.u_arg = sc;
186         noise_upcall.u_remote_get =
187                 (struct noise_remote *(*)(void *, uint8_t *))wg_remote_get;
188         noise_upcall.u_index_set =
189                 (uint32_t (*)(void *, struct noise_remote *))wg_index_set;
190         noise_upcall.u_index_drop =
191                 (void (*)(void *, uint32_t))wg_index_drop;
192         noise_local_init(local, &noise_upcall);
193         cookie_checker_init(&sc->sc_cookie, ratelimit_zone);
194
195         sc->sc_socket.so_port = listen_port;
196
197         if (key != NULL) {
198                 noise_local_set_private(local, __DECONST(uint8_t *, key));
199                 noise_local_keys(local, public, NULL);
200                 cookie_checker_update(&sc->sc_cookie, public);
201         }
202         atomic_add_int(&clone_count, 1);
203         scctx = sc->shared = iflib_get_softc_ctx(ctx);
204         scctx->isc_capenable = WG_CAPS;
205         scctx->isc_tx_csum_flags = CSUM_TCP | CSUM_UDP | CSUM_TSO | CSUM_IP6_TCP \
206                 | CSUM_IP6_UDP | CSUM_IP6_TCP;
207         sc->wg_ctx = ctx;
208         sc->sc_ifp = iflib_get_ifp(ctx);
209
210         mbufq_init(&sc->sc_handshake_queue, MAX_QUEUED_INCOMING_HANDSHAKES);
211         mtx_init(&sc->sc_mtx, NULL, "wg softc lock",  MTX_DEF);
212         rw_init(&sc->sc_index_lock, "wg index lock");
213         sc->sc_encap_ring = buf_ring_alloc(MAX_QUEUED_PACKETS, M_WG, M_WAITOK, &sc->sc_mtx);
214         sc->sc_decap_ring = buf_ring_alloc(MAX_QUEUED_PACKETS, M_WG, M_WAITOK, &sc->sc_mtx);
215         GROUPTASK_INIT(&sc->sc_handshake, 0,
216             (gtask_fn_t *)wg_softc_handshake_receive, sc);
217         taskqgroup_attach(qgroup_if_io_tqg, &sc->sc_handshake, sc, dev, NULL, "wg tx initiation");
218         crypto_taskq_setup(sc);
219  nvl_out:
220         if (nvl != NULL)
221                 nvlist_destroy(nvl);
222 out:
223         free(packed, M_TEMP);
224         return (err);
225 }
226
227 static int
228 wg_transmit(struct ifnet *ifp, struct mbuf *m)
229 {
230         struct wg_softc *sc;
231         sa_family_t family;
232         struct epoch_tracker et;
233         struct wg_peer *peer;
234         struct wg_tag *t;
235         uint32_t af;
236         int rc;
237
238
239         /*
240          * Work around lifetime issue in the ipv6 mld code.
241          */
242         if (__predict_false(ifp->if_flags & IFF_DYING))
243                 return (ENXIO);
244
245         rc = 0;
246         sc = iflib_get_softc(ifp->if_softc);
247         if ((t = wg_tag_get(m)) == NULL) {
248                 rc = ENOBUFS;
249                 goto early_out;
250         }
251         af = m->m_pkthdr.ph_family;
252         BPF_MTAP2(ifp, &af, sizeof(af), m);
253
254         NET_EPOCH_ENTER(et);
255         peer = wg_route_lookup(&sc->sc_routes, m, OUT);
256         if (__predict_false(peer == NULL)) {
257                 rc = ENOKEY;
258                 /* XXX log */
259                 goto err;
260         }
261
262         family = atomic_load_acq(peer->p_endpoint.e_remote.r_sa.sa_family);
263         if (__predict_false(family != AF_INET && family != AF_INET6)) {
264                 rc = EHOSTUNREACH;
265                 /* XXX log */
266                 goto err;
267         }
268         t->t_peer = peer;
269         t->t_mbuf = NULL;
270         t->t_done = 0;
271         t->t_mtu = ifp->if_mtu;
272
273         rc = wg_queue_out(peer, m);
274         if (rc == 0)
275                 wg_encrypt_dispatch(peer->p_sc);
276         NET_EPOCH_EXIT(et);
277         return (rc); 
278 err:
279         NET_EPOCH_EXIT(et);
280 early_out:
281         if_inc_counter(sc->sc_ifp, IFCOUNTER_OERRORS, 1);
282         /* XXX send ICMP unreachable */
283         m_free(m);
284         return (rc);
285 }
286
287 static int
288 wg_output(struct ifnet *ifp, struct mbuf *m, const struct sockaddr *sa, struct route *rt)
289 {
290         m->m_pkthdr.ph_family =  sa->sa_family;
291         return (wg_transmit(ifp, m));
292 }
293
294 static int
295 wg_attach_post(if_ctx_t ctx)
296 {
297         struct ifnet *ifp;
298         struct wg_softc *sc;
299
300         sc = iflib_get_softc(ctx);
301         ifp = iflib_get_ifp(ctx);
302         if_setmtu(ifp, ETHERMTU - 80);
303
304         if_setflagbits(ifp, IFF_NOARP, IFF_POINTOPOINT);
305         ifp->if_transmit = wg_transmit;
306         ifp->if_output = wg_output;
307
308         wg_hashtable_init(&sc->sc_hashtable);
309         sc->sc_index = hashinit(HASHTABLE_INDEX_SIZE, M_DEVBUF, &sc->sc_index_mask);
310         wg_route_init(&sc->sc_routes);
311
312         return (0);
313 }
314
315 static int
316 wg_mtu_set(if_ctx_t ctx, uint32_t mtu)
317 {
318
319         return (0);
320 }
321
322 static int
323 wg_set_promisc(if_ctx_t ctx, int flags)
324 {
325
326         return (0);
327 }
328
329 static int
330 wg_detach(if_ctx_t ctx)
331 {
332         struct wg_softc *sc;
333
334         sc = iflib_get_softc(ctx);
335         if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
336         NET_EPOCH_WAIT();
337         wg_socket_reinit(sc, NULL, NULL);
338         taskqgroup_drain_all(qgroup_if_io_tqg);
339         pause("link_down", hz/4);
340         wg_peer_remove_all(sc);
341         pause("link_down", hz);
342         mtx_destroy(&sc->sc_mtx);
343         rw_destroy(&sc->sc_index_lock);
344         taskqgroup_detach(qgroup_if_io_tqg, &sc->sc_handshake);
345         crypto_taskq_destroy(sc);
346         buf_ring_free(sc->sc_encap_ring, M_WG);
347         buf_ring_free(sc->sc_decap_ring, M_WG);
348
349         wg_route_destroy(&sc->sc_routes);
350         wg_hashtable_destroy(&sc->sc_hashtable);
351         atomic_add_int(&clone_count, -1);
352         return (0);
353 }
354
355 static void
356 wg_init(if_ctx_t ctx)
357 {
358         struct ifnet *ifp;
359         struct wg_softc *sc;
360         int rc;
361
362         if (iflib_in_detach(ctx))
363                 return;
364
365         sc = iflib_get_softc(ctx);
366         ifp = iflib_get_ifp(ctx);
367         if (sc->sc_socket.so_so4 != NULL)
368                 printf("XXX wg_init, socket non-NULL %p\n",
369                     sc->sc_socket.so_so4);
370         wg_socket_reinit(sc, NULL, NULL);
371         rc = wg_socket_init(sc);
372         if (rc)
373                 return;
374         if_link_state_change(ifp, LINK_STATE_UP);
375 }
376
377 static void
378 wg_stop(if_ctx_t ctx)
379 {
380         struct wg_softc *sc;
381         struct ifnet *ifp;
382
383         sc  = iflib_get_softc(ctx);
384         ifp = iflib_get_ifp(ctx);
385         if_link_state_change(ifp, LINK_STATE_DOWN);
386         wg_socket_reinit(sc, NULL, NULL);
387 }
388
389 static nvlist_t *
390 wg_peer_to_nvl(struct wg_peer *peer)
391 {
392         struct wg_route *rt;
393         int i, count;
394         nvlist_t *nvl;
395         caddr_t key;
396         size_t sa_sz;
397         struct wg_allowedip *aip;
398         struct wg_endpoint *ep;
399
400         if ((nvl = nvlist_create(0)) == NULL)
401                 return (NULL);
402         key = peer->p_remote.r_public;
403         nvlist_add_binary(nvl, "public-key", key, WG_KEY_SIZE);
404         ep = &peer->p_endpoint;
405         if (ep->e_remote.r_sa.sa_family != 0) {
406                 sa_sz = (ep->e_remote.r_sa.sa_family == AF_INET) ?
407                         sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
408                 nvlist_add_binary(nvl, "endpoint", &ep->e_remote, sa_sz);
409         }
410         i = count = 0;
411         CK_LIST_FOREACH(rt, &peer->p_routes, r_entry) {
412                 count++;
413         }
414         aip = malloc(count*sizeof(*aip), M_TEMP, M_WAITOK);
415         CK_LIST_FOREACH(rt, &peer->p_routes, r_entry) {
416                 memcpy(&aip[i++], &rt->r_cidr, sizeof(*aip));
417         }
418         nvlist_add_binary(nvl, "allowed-ips", aip, count*sizeof(*aip));
419         free(aip, M_TEMP);
420         return (nvl);
421 }
422
423 static int
424 wg_marshal_peers(struct wg_softc *sc, nvlist_t **nvlp, nvlist_t ***nvl_arrayp, int *peer_countp)
425 {
426         struct wg_peer *peer;
427         int err, i, peer_count;
428         nvlist_t *nvl, **nvl_array;
429         struct epoch_tracker et;
430 #ifdef INVARIANTS
431         void *packed;
432         size_t size;
433 #endif
434         nvl = NULL;
435         nvl_array = NULL;
436         if (nvl_arrayp)
437                 *nvl_arrayp = NULL;
438         if (nvlp)
439                 *nvlp = NULL;
440         if (peer_countp)
441                 *peer_countp = 0;
442         peer_count = sc->sc_hashtable.h_num_peers;
443         if (peer_count == 0) {
444                 printf("no peers found\n");
445                 return (ENOENT);
446         }
447
448         if (nvlp && (nvl = nvlist_create(0)) == NULL)
449                 return (ENOMEM);
450         err = i = 0;
451         nvl_array = malloc(peer_count*sizeof(void*), M_TEMP, M_WAITOK);
452         NET_EPOCH_ENTER(et);
453         CK_LIST_FOREACH(peer, &sc->sc_hashtable.h_peers_list, p_entry) {
454                 nvl_array[i] = wg_peer_to_nvl(peer);
455                 if (nvl_array[i] == NULL) {
456                         printf("wg_peer_to_nvl failed on %d peer\n", i);
457                         break;
458                 }
459 #ifdef INVARIANTS
460                 packed = nvlist_pack(nvl_array[i], &size);
461                 if (packed == NULL) {
462                         printf("nvlist_pack(%p, %p) => %d",
463                                    nvl_array[i], &size, nvlist_error(nvl));
464                 }
465                 free(packed, M_NVLIST);
466 #endif  
467                 i++;
468                 if (i == peer_count)
469                         break;
470         }
471         NET_EPOCH_EXIT(et);
472         *peer_countp = peer_count = i;
473         if (peer_count == 0) {
474                 printf("no peers found in list\n");
475                 err = ENOENT;
476                 goto out;
477         }
478         if (nvl) {
479                 nvlist_add_nvlist_array(nvl, "peer-list",
480                     (const nvlist_t * const *)nvl_array, peer_count);
481                 if ((err = nvlist_error(nvl))) {
482                         printf("nvlist_add_nvlist_array(%p, \"peer-list\", %p, %d) => %d\n",
483                             nvl, nvl_array, peer_count, err);
484                         goto out;
485                 }
486                 *nvlp = nvl;
487         }
488         *nvl_arrayp = nvl_array;
489         return (0);
490  out:
491         return (err);
492 }
493
494 static int
495 wgc_get(struct wg_softc *sc, struct ifdrv *ifd)
496 {
497         nvlist_t *nvl, **nvl_array;
498         void *packed;
499         size_t size;
500         int peer_count, err;
501
502         nvl = nvlist_create(0);
503         if (nvl == NULL)
504                 return (ENOMEM);
505
506         err = 0;
507         packed = NULL;
508         if (sc->sc_socket.so_port != 0)
509                 nvlist_add_number(nvl, "listen-port", sc->sc_socket.so_port);
510         if (sc->sc_local.l_has_identity) {
511                 nvlist_add_binary(nvl, "public-key", sc->sc_local.l_public, WG_KEY_SIZE);
512                 if (curthread->td_ucred->cr_uid == 0)
513                         nvlist_add_binary(nvl, "private-key", sc->sc_local.l_private, WG_KEY_SIZE);
514         }
515         if (sc->sc_hashtable.h_num_peers > 0) {
516                 err = wg_marshal_peers(sc, NULL, &nvl_array, &peer_count);
517                 if (err)
518                         goto out;
519                 nvlist_add_nvlist_array(nvl, "peer-list",
520                     (const nvlist_t * const *)nvl_array, peer_count);
521         }
522         packed = nvlist_pack(nvl, &size);
523         if (packed == NULL)
524                 return (ENOMEM);
525         if (ifd->ifd_len == 0) {
526                 ifd->ifd_len = size;
527                 goto out;
528         }
529         if (ifd->ifd_len < size) {
530                 err = ENOSPC;
531                 goto out;
532         }
533         if (ifd->ifd_data == NULL) {
534                 err = EFAULT;
535                 goto out;
536         }
537         err = copyout(packed, ifd->ifd_data, size);
538         ifd->ifd_len = size;
539  out:
540         nvlist_destroy(nvl);
541         free(packed, M_NVLIST);
542         return (err);
543 }
544
545 static bool
546 wg_allowedip_valid(const struct wg_allowedip *wip)
547 {
548
549         return (true);
550 }
551
552 static int
553 wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl)
554 {
555         uint8_t                  public[WG_KEY_SIZE];
556         const void *pub_key;
557         const struct sockaddr *endpoint;
558         int i, err, allowedip_count;
559         device_t dev;
560         size_t size;
561         struct wg_peer *peer = NULL;
562         bool need_insert = false;
563         dev = iflib_get_dev(sc->wg_ctx);
564
565         if (!nvlist_exists_binary(nvl, "public-key")) {
566                 device_printf(dev, "peer has no public-key\n");
567                 return (EINVAL);
568         }
569         pub_key = nvlist_get_binary(nvl, "public-key", &size);
570         if (size != CURVE25519_KEY_SIZE) {
571                 device_printf(dev, "%s bad length for public-key %zu\n", __func__, size);
572                 return (EINVAL);
573         }
574         if (noise_local_keys(&sc->sc_local, public, NULL) == 0 &&
575             bcmp(public, pub_key, WG_KEY_SIZE) == 0) {
576                 device_printf(dev, "public-key for peer already in use by host\n");
577                 return (EINVAL);
578         }
579         peer = wg_peer_lookup(sc, pub_key);
580         if (nvlist_exists_bool(nvl, "peer-remove") &&
581                 nvlist_get_bool(nvl, "peer-remove")) {
582                 if (peer != NULL) {
583                         wg_hashtable_peer_remove(&sc->sc_hashtable, peer);
584                         wg_peer_destroy(peer);
585                         /* XXX free */
586                         printf("peer removed\n");
587                 }
588                 return (0);
589         }
590         if (nvlist_exists_bool(nvl, "replace-allowedips") &&
591                 nvlist_get_bool(nvl, "replace-allowedips") &&
592             peer != NULL) {
593
594                 wg_route_delete(&peer->p_sc->sc_routes, peer);
595         }
596         if (peer == NULL) {
597                 need_insert = true;
598                 peer = wg_peer_alloc(sc);
599                 noise_remote_init(&peer->p_remote, pub_key, &sc->sc_local);
600                 cookie_maker_init(&peer->p_cookie, pub_key);
601         }
602         if (nvlist_exists_binary(nvl, "endpoint")) {
603                 endpoint = nvlist_get_binary(nvl, "endpoint", &size);
604                 if (size > sizeof(peer->p_endpoint.e_remote)) {
605                         device_printf(dev, "%s bad length for endpoint %zu\n", __func__, size);
606                         err = EBADMSG;
607                         goto out;
608                 }
609                 memcpy(&peer->p_endpoint.e_remote, endpoint, size);
610         }
611         if (nvlist_exists_binary(nvl, "pre-shared-key")) {
612                 const void *key;
613
614                 key = nvlist_get_binary(nvl, "pre-shared-key", &size);
615                 noise_remote_set_psk(&peer->p_remote, key);
616         }
617         if (nvlist_exists_number(nvl, "persistent-keepalive-interval")) {
618                 uint16_t pki;
619
620                 pki = nvlist_get_number(nvl, "persistent-keepalive-interval");
621                 wg_timers_set_persistent_keepalive(&peer->p_timers, pki);
622         }
623         if (nvlist_exists_binary(nvl, "allowed-ips")) {
624                 const struct wg_allowedip *aip, *aip_base;
625
626                 aip = aip_base = nvlist_get_binary(nvl, "allowed-ips", &size);
627                 if (size % sizeof(struct wg_allowedip) != 0) {
628                         device_printf(dev, "%s bad length for allowed-ips %zu not integer multiple of struct size\n", __func__, size);
629                         err = EBADMSG;
630                         goto out;
631                 }
632                 allowedip_count = size/sizeof(struct wg_allowedip);
633                 for (i = 0; i < allowedip_count; i++) {
634                         if (!wg_allowedip_valid(&aip_base[i])) {
635                                 device_printf(dev, "%s allowedip %d not valid\n", __func__, i);
636                                 err = EBADMSG;
637                                 goto out;
638                         }
639                 }
640                 for (int i = 0; i < allowedip_count; i++, aip++) {
641                         if ((err = wg_route_add(&sc->sc_routes, peer, aip)) != 0) {
642                                 printf("route add %d failed -> %d\n", i, err);
643                         }
644                 }
645         }
646         if (need_insert)
647                 wg_hashtable_peer_insert(&sc->sc_hashtable, peer);
648         return (0);
649
650 out:
651         wg_peer_destroy(peer);
652         return (err);
653 }
654
655 static int
656 wgc_set(struct wg_softc *sc, struct ifdrv *ifd)
657 {
658         uint8_t                  public[WG_KEY_SIZE];
659         void *nvlpacked;
660         nvlist_t *nvl;
661         device_t dev;
662         ssize_t size;
663         int err;
664
665         if (ifd->ifd_len == 0 || ifd->ifd_data == NULL)
666                 return (EFAULT);
667
668         dev = iflib_get_dev(sc->wg_ctx);
669         nvlpacked = malloc(ifd->ifd_len, M_TEMP, M_WAITOK);
670         err = copyin(ifd->ifd_data, nvlpacked, ifd->ifd_len);
671         if (err)
672                 goto out;
673         nvl = nvlist_unpack(nvlpacked, ifd->ifd_len, 0);
674         if (nvl == NULL) {
675                 device_printf(dev, "%s nvlist_unpack failed\n", __func__);
676                 err = EBADMSG;
677                 goto out;
678         }
679         if (nvlist_exists_bool(nvl, "replace-peers") &&
680                 nvlist_get_bool(nvl, "replace-peers"))
681                 wg_peer_remove_all(sc);
682         if (nvlist_exists_number(nvl, "listen-port")) {
683                 int listen_port __unused = nvlist_get_number(nvl, "listen-port");
684                         /*
685                          * Set listen port
686                          */
687                 if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
688                 pause("link_down", hz/4);
689                 wg_socket_reinit(sc, NULL, NULL);
690                 sc->sc_socket.so_port = listen_port;
691                 if ((err = wg_socket_init(sc)) != 0)
692                         goto out;
693            if_link_state_change(sc->sc_ifp, LINK_STATE_UP);
694         }
695         if (nvlist_exists_binary(nvl, "private-key")) {
696                 struct noise_local *local;
697                 const void *key = nvlist_get_binary(nvl, "private-key", &size);
698
699                 if (size != CURVE25519_KEY_SIZE) {
700                         device_printf(dev, "%s bad length for private-key %zu\n", __func__, size);
701                         err = EBADMSG;
702                         goto nvl_out;
703                 }
704                 /*
705                  * set private key
706                  */
707                 local = &sc->sc_local;
708                 noise_local_set_private(local, __DECONST(uint8_t *, key));
709                 noise_local_keys(local, public, NULL);
710                 cookie_checker_update(&sc->sc_cookie, public);
711         }
712         if (nvlist_exists_number(nvl, "user-cookie")) {
713                 sc->sc_user_cookie = nvlist_get_number(nvl, "user-cookie");
714                 /*
715                  * setsockopt
716                  */
717         }
718         if (nvlist_exists_nvlist_array(nvl, "peer-list")) {
719                 size_t peercount;
720                 const nvlist_t * const*nvl_peers;
721
722                 nvl_peers = nvlist_get_nvlist_array(nvl, "peer-list", &peercount);
723                 for (int i = 0; i < peercount; i++) {
724                         wg_peer_add(sc, nvl_peers[i]);
725                 }
726         }
727 nvl_out:
728         nvlist_destroy(nvl);
729 out:
730         free(nvlpacked, M_TEMP);
731         return (err);
732 }
733
734 static int
735 wg_priv_ioctl(if_ctx_t ctx, u_long command, caddr_t data)
736 {
737         struct wg_softc *sc = iflib_get_softc(ctx);
738         struct ifdrv *ifd = (struct ifdrv *)data;
739         int ifd_cmd;
740
741         switch (command) {
742                 case SIOCGDRVSPEC:
743                 case SIOCSDRVSPEC:
744                         ifd_cmd = ifd->ifd_cmd;
745                         break;
746                 default:
747                         return (EINVAL);
748         }
749         switch (ifd_cmd) {
750                 case WGC_GET:
751                         return (wgc_get(sc, ifd));
752                         break;
753                 case WGC_SET:
754                         if (priv_check(curthread, PRIV_NET_HWIOCTL))
755                                 return (EPERM);
756                         return (wgc_set(sc, ifd));
757                         break;
758         }
759         return (ENOTSUP);
760 }
761
762 static device_method_t wg_if_methods[] = {
763         DEVMETHOD(ifdi_cloneattach, wg_cloneattach),
764         DEVMETHOD(ifdi_attach_post, wg_attach_post),
765         DEVMETHOD(ifdi_detach, wg_detach),
766         DEVMETHOD(ifdi_init, wg_init),
767         DEVMETHOD(ifdi_stop, wg_stop),
768         DEVMETHOD(ifdi_priv_ioctl, wg_priv_ioctl),
769         DEVMETHOD(ifdi_mtu_set, wg_mtu_set),
770         DEVMETHOD(ifdi_promisc_set, wg_set_promisc),
771         DEVMETHOD_END
772 };
773
774 static driver_t wg_iflib_driver = {
775         "wg", wg_if_methods, sizeof(struct wg_softc)
776 };
777
778 char wg_driver_version[] = "0.0.1";
779
780 static struct if_shared_ctx wg_sctx_init = {
781         .isc_magic = IFLIB_MAGIC,
782         .isc_driver_version = wg_driver_version,
783         .isc_driver = &wg_iflib_driver,
784         .isc_flags = IFLIB_PSEUDO,
785         .isc_name = "wg",
786 };
787
788 if_shared_ctx_t wg_sctx = &wg_sctx_init;
789 static if_pseudo_t wg_pseudo;
790
791
792 int
793 wg_ctx_init(void)
794 {
795         ratelimit_zone = uma_zcreate("wg ratelimit", sizeof(struct ratelimit),
796              NULL, NULL, NULL, NULL, 0, 0);
797         return (0);
798 }
799
800 void
801 wg_ctx_uninit(void)
802 {
803         uma_zdestroy(ratelimit_zone);
804 }
805
806 static int
807 wg_module_init(void)
808 {
809         int rc;
810
811         if ((rc = wg_ctx_init()))
812                 return (rc);
813
814         wg_pseudo = iflib_clone_register(wg_sctx);
815         if (wg_pseudo == NULL)
816                 return (ENXIO);
817
818         return (0);
819 }
820
821 static void
822 wg_module_deinit(void)
823 {
824         wg_ctx_uninit();
825         iflib_clone_deregister(wg_pseudo);
826 }
827
828 static int
829 wg_module_event_handler(module_t mod, int what, void *arg)
830 {
831         int err;
832
833         switch (what) {
834                 case MOD_LOAD:
835                         if ((err = wg_module_init()) != 0)
836                                 return (err);
837                         break;
838                 case MOD_UNLOAD:
839                         if (clone_count == 0)
840                                 wg_module_deinit();
841                         else
842                                 return (EBUSY);
843                         break;
844                 default:
845                         return (EOPNOTSUPP);
846         }
847         return (0);
848 }
849
850 static moduledata_t wg_moduledata = {
851         "wg",
852         wg_module_event_handler,
853         NULL
854 };
855
856 DECLARE_MODULE(wg, wg_moduledata, SI_SUB_PSEUDO, SI_ORDER_ANY);
857 MODULE_VERSION(wg, 1);
858 MODULE_DEPEND(wg, iflib, 1, 1, 1);
859 #if defined(__amd64__) || defined(__i386__)
860 /* Optimized blake2 implementations are only available on x86. */
861 MODULE_DEPEND(wg, blake2, 1, 1, 1);
862 #endif
863 MODULE_DEPEND(wg, crypto, 1, 1, 1);