]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - tests/sys/net/routing/rtsock_common.h
Fix dst/netmask handling in routing socket code.
[FreeBSD/FreeBSD.git] / tests / sys / net / routing / rtsock_common.h
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3  *
4  * Copyright (c) 2019 Alexander V. Chernikov
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright
12  *    notice, this list of conditions and the following disclaimer in the
13  *    documentation and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18  * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
19  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
21  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
22  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
23  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
24  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
25  * SUCH DAMAGE.
26  *
27  * $FreeBSD$
28  */
29
30 #ifndef _NET_ROUTING_RTSOCK_COMMON_H_
31 #define _NET_ROUTING_RTSOCK_COMMON_H_
32
33 #include <stdio.h>
34 #include <stdlib.h>
35 #include <string.h>
36 #include <unistd.h>
37 #include <fcntl.h>
38 #include <stdbool.h>
39 #include <ctype.h>
40 #include <poll.h>
41
42 #include <sys/types.h>
43 #include <sys/time.h>
44 #include <sys/param.h>
45 #include <sys/socket.h>
46 #include <sys/ioctl.h>
47 #include <sys/jail.h>
48 #include <sys/linker.h>
49 #include <net/if.h>
50 #include <net/if_dl.h>
51 #include <net/route.h>
52
53 #include <arpa/inet.h>
54 #include <net/ethernet.h>
55
56 #include <netinet/in.h>
57 #include <netinet6/in6_var.h>
58 #include <netinet6/nd6.h>
59
60 #include <ifaddrs.h>
61
62 #include <errno.h>
63 #include <err.h>
64 #include <sysexits.h>
65
66 #include <atf-c.h>
67 #include "freebsd_test_suite/macros.h"
68
69 #include "rtsock_print.h"
70 #include "params.h"
71
72 void rtsock_update_rtm_len(struct rt_msghdr *rtm);
73 void rtsock_validate_message(char *buffer, ssize_t len);
74 void rtsock_add_rtm_sa(struct rt_msghdr *rtm, int addr_type, struct sockaddr *sa);
75
76 void file_append_line(char *fname, char *text);
77
78 static int _rtm_seq = 42;
79
80
81 /*
82  * Checks if the interface cloner module is present for @name.
83  */
84 static int
85 _check_cloner(char *name)
86 {
87         struct if_clonereq ifcr;
88         char *cp, *buf;
89         int idx;
90         int s;
91         int found = 0;
92
93         s = socket(AF_LOCAL, SOCK_DGRAM, 0);
94         if (s == -1)
95                 err(1, "socket(AF_LOCAL,SOCK_DGRAM)");
96
97         memset(&ifcr, 0, sizeof(ifcr));
98
99         if (ioctl(s, SIOCIFGCLONERS, &ifcr) < 0)
100                 err(1, "SIOCIFGCLONERS for count");
101
102         buf = malloc(ifcr.ifcr_total * IFNAMSIZ);
103         if (buf == NULL)
104                 err(1, "unable to allocate cloner name buffer");
105
106         ifcr.ifcr_count = ifcr.ifcr_total;
107         ifcr.ifcr_buffer = buf;
108
109         if (ioctl(s, SIOCIFGCLONERS, &ifcr) < 0)
110                 err(1, "SIOCIFGCLONERS for names");
111
112         /*
113          * In case some disappeared in the mean time, clamp it down.
114          */
115         if (ifcr.ifcr_count > ifcr.ifcr_total)
116                 ifcr.ifcr_count = ifcr.ifcr_total;
117
118         for (cp = buf, idx = 0; idx < ifcr.ifcr_count; idx++, cp += IFNAMSIZ) {
119                 if (!strcmp(cp, name)) {
120                         found = 1;
121                         break;
122                 }
123         }
124
125         free(buf);
126         close(s);
127
128         return (found);
129 }
130
131 static char *
132 iface_create(char *ifname_orig)
133 {
134         struct ifreq ifr;
135         int s;
136         char prefix[IFNAMSIZ], ifname[IFNAMSIZ], *result;
137
138         char *src, *dst;
139         for (src = ifname_orig, dst = prefix; *src && isalpha(*src); src++)
140                 *dst++ = *src;
141         *dst = '\0';
142
143         memset(&ifr, 0, sizeof(struct ifreq));
144
145         s = socket(AF_LOCAL, SOCK_DGRAM, 0);
146         strlcpy(ifr.ifr_name, ifname_orig, sizeof(ifr.ifr_name));
147
148         RLOG("creating iface %s %s", prefix, ifr.ifr_name);
149         if (ioctl(s, SIOCIFCREATE2, &ifr) < 0)
150                 err(1, "SIOCIFCREATE2");
151
152         strlcpy(ifname, ifr.ifr_name, IFNAMSIZ);
153         RLOG("created interface %s", ifname);
154
155         result = strdup(ifname);
156
157         file_append_line(IFACES_FNAME, ifname);
158         if (strstr(ifname, "epair") == ifname) {
159                 /* call returned epairXXXa, need to add epairXXXb */
160                 ifname[strlen(ifname) - 1] = 'b';
161                 file_append_line(IFACES_FNAME, ifname);
162         }
163
164         return (result);
165 }
166
167 static int
168 iface_destroy(char *ifname)
169 {
170         struct ifreq ifr;
171         int s;
172
173         s = socket(AF_LOCAL, SOCK_DGRAM, 0);
174         strlcpy(ifr.ifr_name, ifname, sizeof(ifr.ifr_name));
175
176         RLOG("destroying interface %s", ifname);
177         if (ioctl(s, SIOCIFDESTROY, &ifr) < 0)
178                 return (0);
179
180         return (1);
181 }
182
183 /*
184  * Open tunneling device such as tuntap and returns fd.
185  */
186 int
187 iface_open(char *ifname)
188 {
189         char path[256];
190
191         snprintf(path, sizeof(path), "/dev/%s", ifname);
192
193         RLOG("opening interface %s", ifname);
194         int fd = open(path, O_RDWR|O_EXCL);
195         if (fd == -1) {
196                 RLOG_ERRNO("unable to open interface %s", ifname);
197                 return (-1);
198         }
199
200         return (fd);
201 }
202
203 /*
204  * Sets primary IPv4 addr.
205  * Returns 0 on success.
206  */
207 static inline int
208 iface_setup_addr(char *ifname, char *addr, int plen)
209 {
210         char cmd[512];
211         char *af;
212
213         if (strchr(addr, ':'))
214                 af = "inet6";
215         else
216                 af = "inet";
217         RLOG("setting af_%s %s/%d on %s", af, addr, plen, ifname);
218         snprintf(cmd, sizeof(cmd), "/sbin/ifconfig %s %s %s/%d", ifname,
219                 af, addr, plen);
220
221         return system(cmd);
222 }
223
224 /*
225  * Removes primary IPv4 prefix.
226  * Returns 0 on success.
227  */
228 static inline int
229 iface_delete_addr(char *ifname, char *addr)
230 {
231         char cmd[512];
232
233         if (strchr(addr, ':')) {
234                 RLOG("removing IPv6 %s from %s", addr, ifname);
235                 snprintf(cmd, sizeof(cmd), "/sbin/ifconfig %s inet6 %s delete", ifname, addr);
236         } else {
237                 RLOG("removing IPv4 %s from %s", addr, ifname);
238                 snprintf(cmd, sizeof(cmd), "/sbin/ifconfig %s -alias %s", ifname, addr);
239         }
240
241         return system(cmd);
242 }
243
244 int
245 iface_turn_up(char *ifname)
246 {
247         struct ifreq ifr;
248         int s;
249
250         if ((s = socket(AF_INET6, SOCK_DGRAM, 0)) < 0) {
251                 RLOG_ERRNO("socket");
252                 return (-1);
253         }
254         memset(&ifr, 0, sizeof(struct ifreq));
255         strlcpy(ifr.ifr_name, ifname, sizeof(ifr.ifr_name));
256         if (ioctl(s, SIOCGIFFLAGS, (caddr_t)&ifr) < 0) {
257                 RLOG_ERRNO("ioctl(SIOCGIFFLAGS)");
258                 return (-1);
259         }
260         /* Update flags */
261         if ((ifr.ifr_flags & IFF_UP) == 0) {
262                 ifr.ifr_flags |= IFF_UP;
263                 if (ioctl(s, SIOCSIFFLAGS, (caddr_t)&ifr) < 0) {
264                         RLOG_ERRNO("ioctl(SIOSGIFFLAGS)");
265                         return (-1);
266                 }
267                 RLOG("turned interface %s up", ifname);
268         }
269
270         return (0);
271 }
272
273 /*
274  * Removes ND6_IFF_IFDISABLED from IPv6 interface flags.
275  * Returns 0 on success.
276  */
277 int
278 iface_enable_ipv6(char *ifname)
279 {
280         struct in6_ndireq nd;
281         int s;
282
283         if ((s = socket(AF_INET6, SOCK_DGRAM, 0)) < 0) {
284                 err(1, "socket");
285         }
286         memset(&nd, 0, sizeof(nd));
287         strlcpy(nd.ifname, ifname, sizeof(nd.ifname));
288         if (ioctl(s, SIOCGIFINFO_IN6, (caddr_t)&nd) < 0) {
289                 RLOG_ERRNO("ioctl(SIOCGIFINFO_IN6)");
290                 return (-1);
291         }
292         /* Update flags */
293         if ((nd.ndi.flags & ND6_IFF_IFDISABLED) != 0) {
294                 nd.ndi.flags &= ~ND6_IFF_IFDISABLED;
295                 if (ioctl(s, SIOCSIFINFO_IN6, (caddr_t)&nd) < 0) {
296                         RLOG_ERRNO("ioctl(SIOCSIFINFO_IN6)");
297                         return (-1);
298                 }
299                 RLOG("enabled IPv6 for %s", ifname);
300         }
301
302         return (0);
303 }
304
305 void
306 file_append_line(char *fname, char *text)
307 {
308         FILE *f;
309
310         f = fopen(fname, "a");
311         fputs(text, f);
312         fputs("\n", f);
313         fclose(f);
314 }
315
316 static int
317 vnet_wait_interface(char *vnet_name, char *ifname)
318 {
319         char buf[512], cmd[512], *line, *token;
320         FILE *fp;
321         int i;
322
323         snprintf(cmd, sizeof(cmd), "/usr/sbin/jexec %s /sbin/ifconfig -l", vnet_name);
324         for (int i = 0; i < 50; i++) {
325                 fp = popen(cmd, "r");
326                 line = fgets(buf, sizeof(buf), fp);
327                 /* cut last\n */
328                 if (line[0])
329                         line[strlen(line)-1] = '\0';
330                 while ((token = strsep(&line, " ")) != NULL) {
331                         if (strcmp(token, ifname) == 0)
332                                 return (1);
333                 }
334
335                 /* sleep 100ms */
336                 usleep(1000 * 100);
337         }
338
339         return (0);
340 }
341
342 void
343 vnet_switch(char *vnet_name, char **ifnames, int count)
344 {
345         char buf[512], cmd[512], *line;
346         FILE *fp;
347         int jid, len, ret;
348
349         RLOG("switching to vnet %s with interface(s) %s", vnet_name, ifnames[0]);
350         len = snprintf(cmd, sizeof(cmd),
351             "/usr/sbin/jail -i -c name=%s persist vnet", vnet_name);
352         for (int i = 0; i < count && len < sizeof(cmd); i++) {
353                 len += snprintf(&cmd[len], sizeof(cmd) - len,
354                     " vnet.interface=%s", ifnames[i]);
355         }
356         RLOG("jail cmd: \"%s\"\n", cmd);
357
358         fp = popen(cmd, "r");
359         if (fp == NULL)
360                 atf_tc_fail("jail creation failed");
361         line = fgets(buf, sizeof(buf), fp);
362         if (line == NULL)
363                 atf_tc_fail("empty output from jail(8)");
364         jid = strtol(line, NULL, 10);
365         if (jid <= 0) {
366                 atf_tc_fail("invalid jail output: %s", line);
367         }
368
369         RLOG("created jail jid=%d", jid);
370         file_append_line(JAILS_FNAME, vnet_name);
371
372         /* Wait while interface appearsh inside vnet */
373         for (int i = 0; i < count; i++) {
374                 if (vnet_wait_interface(vnet_name, ifnames[i]))
375                         continue;
376                 atf_tc_fail("unable to move interface %s to jail %s",
377                     ifnames[i], vnet_name);
378         }
379
380         if (jail_attach(jid) == -1) {
381                 RLOG_ERRNO("jail %s attach failed: ret=%d", vnet_name, errno);
382                 atf_tc_fail("jail attach failed");
383         }
384
385         RLOG("attached to the jail");
386 }
387
388 void
389 vnet_switch_one(char *vnet_name, char *ifname)
390 {
391         char *ifnames[1];
392
393         ifnames[0] = ifname;
394         vnet_switch(vnet_name, ifnames, 1);
395 }
396
397
398 #define SA_F_IGNORE_IFNAME      0x01
399 #define SA_F_IGNORE_IFTYPE      0x02
400 #define SA_F_IGNORE_MEMCMP      0x04
401 int
402 sa_equal_msg_flags(const struct sockaddr *a, const struct sockaddr *b, char *msg, size_t sz, int flags)
403 {
404         char a_s[64], b_s[64];
405         const struct sockaddr_in *a4, *b4;
406         const struct sockaddr_in6 *a6, *b6;
407         const struct sockaddr_dl *al, *bl;
408
409         if (a == NULL) {
410                 snprintf(msg, sz, "first sa is NULL");
411                 return 0;
412         }
413         if (b == NULL) {
414                 snprintf(msg, sz, "second sa is NULL");
415                 return 0;
416         }
417
418         if (a->sa_family != b->sa_family) {
419                 snprintf(msg, sz, "family: %d vs %d", a->sa_family, b->sa_family);
420                 return 0;
421         }
422         if (a->sa_len != b->sa_len) {
423                 snprintf(msg, sz, "len: %d vs %d", a->sa_len, b->sa_len);
424                 return 0;
425         }
426
427         switch (a->sa_family) {
428         case AF_INET:
429                 a4 = (const struct sockaddr_in *)a;
430                 b4 = (const struct sockaddr_in *)b;
431                 if (a4->sin_addr.s_addr != b4->sin_addr.s_addr) {
432                         inet_ntop(AF_INET, &a4->sin_addr, a_s, sizeof(a_s));
433                         inet_ntop(AF_INET, &b4->sin_addr, b_s, sizeof(b_s));
434                         snprintf(msg, sz, "addr diff: %s vs %s", a_s, b_s);
435                         return 0;
436                 }
437                 if (a4->sin_port != b4->sin_port) {
438                         snprintf(msg, sz, "port diff: %d vs %d",
439                                         ntohs(a4->sin_port), ntohs(b4->sin_port));
440                         //return 0;
441                 }
442                 const uint32_t *a32, *b32;
443                 a32 = (const uint32_t *)a4->sin_zero;
444                 b32 = (const uint32_t *)b4->sin_zero;
445                 if ((*a32 != *b32) || (*(a32 + 1) != *(b32 + 1))) {
446                         snprintf(msg, sz, "zero diff: 0x%08X%08X vs 0x%08X%08X",
447                                         ntohl(*a32), ntohl(*(a32 + 1)),
448                                         ntohl(*b32), ntohl(*(b32 + 1)));
449                         return 0;
450                 }
451                 return 1;
452         case AF_INET6:
453                 a6 = (const struct sockaddr_in6 *)a;
454                 b6 = (const struct sockaddr_in6 *)b;
455                 if (!IN6_ARE_ADDR_EQUAL(&a6->sin6_addr, &b6->sin6_addr)) {
456                         inet_ntop(AF_INET6, &a6->sin6_addr, a_s, sizeof(a_s));
457                         inet_ntop(AF_INET6, &b6->sin6_addr, a_s, sizeof(a_s));
458                         snprintf(msg, sz, "addr diff: %s vs %s", a_s, b_s);
459                         return 0;
460                 }
461                 if (a6->sin6_scope_id != b6->sin6_scope_id) {
462                         snprintf(msg, sz, "scope diff: %u vs %u", a6->sin6_scope_id, b6->sin6_scope_id);
463                         return 0;
464                 }
465                 break;
466         case AF_LINK:
467                 al = (const struct sockaddr_dl *)a;
468                 bl = (const struct sockaddr_dl *)b;
469
470                 if (al->sdl_index != bl->sdl_index) {
471                         snprintf(msg, sz, "sdl_index diff: %u vs %u", al->sdl_index, bl->sdl_index);
472                         return 0;
473                 }
474
475                 if ((al->sdl_alen != bl->sdl_alen) || (memcmp(LLADDR(al), LLADDR(bl), al->sdl_alen) != 0)) {
476                         char abuf[64], bbuf[64];
477                         sa_print_hd(abuf, sizeof(abuf), LLADDR(al), al->sdl_alen);
478                         sa_print_hd(bbuf, sizeof(bbuf), LLADDR(bl), bl->sdl_alen);
479                         snprintf(msg, sz, "sdl_alen diff: {%s} (%d) vs {%s} (%d)",
480                             abuf, al->sdl_alen, bbuf, bl->sdl_alen);
481                         return 0;
482                 }
483
484                 if (((flags & SA_F_IGNORE_IFTYPE) == 0) && (al->sdl_type != bl->sdl_type)) {
485                         snprintf(msg, sz, "sdl_type diff: %u vs %u", al->sdl_type, bl->sdl_type);
486                         return 0;
487                 }
488
489                 if (((flags & SA_F_IGNORE_IFNAME) == 0) && ((al->sdl_nlen != bl->sdl_nlen) ||
490                             (memcmp(al->sdl_data, bl->sdl_data, al->sdl_nlen) != 0))) {
491                         char abuf[64], bbuf[64];
492                         memcpy(abuf, al->sdl_data, al->sdl_nlen);
493                         abuf[al->sdl_nlen] = '\0';
494                         memcpy(bbuf, bl->sdl_data, bl->sdl_nlen);
495                         abuf[bl->sdl_nlen] = '\0';
496                         snprintf(msg, sz, "sdl_nlen diff: {%s} (%d) vs {%s} (%d)",
497                             abuf, al->sdl_nlen, bbuf, bl->sdl_nlen);
498                         return 0;
499                 }
500
501                 if (flags & SA_F_IGNORE_MEMCMP)
502                         return 1;
503                 break;
504         }
505
506         if (memcmp(a, b, a->sa_len)) {
507                 int i;
508                 for (i = 0; i < a->sa_len; i++)
509                         if (((const char *)a)[i] != ((const char *)b)[i])
510                                 break;
511
512                 sa_print(a, 1);
513                 sa_print(b, 1);
514
515                 snprintf(msg, sz, "overall memcmp() reports diff for af %d offset %d",
516                                 a->sa_family, i);
517                 return 0;
518         }
519         return 1;
520 }
521
522 int
523 sa_equal_msg(const struct sockaddr *a, const struct sockaddr *b, char *msg, size_t sz)
524 {
525
526         return sa_equal_msg_flags(a, b, msg, sz, 0);
527 }
528
529 void
530 sa_fill_mask4(struct sockaddr_in *sin, int plen)
531 {
532
533         memset(sin, 0, sizeof(struct sockaddr_in));
534         sin->sin_family = AF_INET;
535         sin->sin_len = sizeof(struct sockaddr_in);
536         sin->sin_addr.s_addr = htonl(plen ? ~((1 << (32 - plen)) - 1) : 0);
537 }
538
539 void
540 sa_fill_mask6(struct sockaddr_in6 *sin6, uint8_t mask)
541 {
542         uint32_t *cp;
543
544         memset(sin6, 0, sizeof(struct sockaddr_in6));
545         sin6->sin6_family = AF_INET6;
546         sin6->sin6_len = sizeof(struct sockaddr_in6);
547
548         for (cp = (uint32_t *)&sin6->sin6_addr; mask >= 32; mask -= 32)
549                 *cp++ = 0xFFFFFFFF;
550         if (mask > 0)
551                 *cp = htonl(mask ? ~((1 << (32 - mask)) - 1) : 0);
552 }
553
554 /* 52:54:00:14:e3:10 */
555 #define ETHER_MAC_MAX_LENGTH    17
556
557 int
558 sa_convert_str_to_sa(const char *_addr, struct sockaddr *sa)
559 {
560         int error;
561
562         int af = AF_UNSPEC;
563
564         char *addr = strdup(_addr);
565         int retcode = 0;
566
567         /* classify AF by str */
568         if (strchr(addr, ':')) {
569                 /* inet6 or ether */
570                 char *k;
571                 int delim_cnt = 0;
572                 for (k = addr; *k; k++)
573                         if (*k == ':')
574                                 delim_cnt++;
575                 af = AF_INET6;
576
577                 if (delim_cnt == 5) {
578                         k = strchr(addr, '%');
579                         if (k != NULL && (k - addr) <= ETHER_MAC_MAX_LENGTH)
580                                 af = AF_LINK;
581                 }
582         } else if (strchr(addr, '.'))
583                 af = AF_INET;
584
585         /* */
586         char *delimiter;
587         int ifindex = 0;
588         char *ifname = NULL;
589         if ((delimiter = strchr(addr, '%')) != NULL) {
590                 *delimiter = '\0';
591                 ifname = delimiter + 1;
592                 ifindex = if_nametoindex(ifname);
593                 if (ifindex == 0)
594                         RLOG("unable to find ifindex for '%s'", ifname);
595                 else
596                         RLOG("if %s mapped to %d", ifname, ifindex);
597         }
598
599         if (af == AF_INET6) {
600                 struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)sa;
601                 memset(sin6, 0, sizeof(struct sockaddr_in6));
602                 sin6->sin6_family = AF_INET6;
603                 sin6->sin6_len = sizeof(struct sockaddr_in6);
604                 sin6->sin6_scope_id = ifindex;
605                 error = inet_pton(AF_INET6, addr, &sin6->sin6_addr);
606                 if (error != 1)
607                         RLOG_ERRNO("inet_ntop() failed: ret=%d", error);
608                 else
609                         retcode = 1;
610         } else if (af == AF_INET) {
611                 struct sockaddr_in *sin = (struct sockaddr_in *)sa;
612                 memset(sin, 0, sizeof(struct sockaddr_in));
613                 sin->sin_family = AF_INET;
614                 sin->sin_len = sizeof(struct sockaddr_in);
615                 error = inet_pton(AF_INET, addr, &sin->sin_addr);
616                 if (error != 1)
617                         RLOG("inet_ntop() failed: ret=%d", error);
618                 else
619                         retcode = 1;
620         } else if (af == AF_LINK) {
621                 struct sockaddr_dl *sdl = (struct sockaddr_dl *)sa;
622                 memset(sdl, 0, sizeof(struct sockaddr_dl));
623                 sdl->sdl_family = AF_LINK;
624                 sdl->sdl_len = sizeof(struct sockaddr_dl);
625                 sdl->sdl_index = ifindex;
626                 sdl->sdl_alen = 6;
627                 struct ether_addr *ea = (struct ether_addr *)LLADDR(sdl);
628                 if (ether_aton_r(addr, ea) == NULL)
629                         RLOG("ether_aton() failed");
630                 else
631                         retcode = 1;
632         }
633
634         return (retcode);
635 }
636
637
638 int
639 rtsock_setup_socket()
640 {
641         int fd;
642         int af = AF_UNSPEC; /* 0 to capture messages from all AFs */
643         fd = socket(PF_ROUTE, SOCK_RAW, af);
644
645         ATF_REQUIRE_MSG(fd != -1, "rtsock open failed: %s", strerror(errno));
646
647         /* Listen for our messages */
648         int on = 1;
649         if (setsockopt(fd, SOL_SOCKET,SO_USELOOPBACK, &on, sizeof(on)) < 0)
650                 RLOG_ERRNO("setsockopt failed");
651
652         return (fd);
653 }
654
655 ssize_t
656 rtsock_send_rtm(int fd, struct rt_msghdr *rtm)
657 {
658         int my_errno;
659         ssize_t len;
660
661         rtsock_update_rtm_len(rtm);
662
663         len = write(fd, rtm, rtm->rtm_msglen);
664         my_errno = errno;
665         RTSOCK_ATF_REQUIRE_MSG(rtm, len == rtm->rtm_msglen,
666             "rtsock write failed: want %d got %zd (%s)",
667             rtm->rtm_msglen, len, strerror(my_errno));
668
669         return (len);
670 }
671
672 struct rt_msghdr *
673 rtsock_read_rtm(int fd, char *buffer, size_t buflen)
674 {
675         ssize_t len;
676         struct pollfd pfd;
677         int poll_delay = 5 * 1000; /* 5 seconds */
678
679         /* Check for the data available to read first */
680         memset(&pfd, 0, sizeof(pfd));
681         pfd.fd = fd;
682         pfd.events = POLLIN;
683
684         if (poll(&pfd, 1, poll_delay) == 0)
685                 ATF_REQUIRE_MSG(1 == 0, "rtsock read timed out (%d seconds passed)",
686                     poll_delay / 1000);
687
688         len = read(fd, buffer, buflen);
689         int my_errno = errno;
690         ATF_REQUIRE_MSG(len > 0, "rtsock read failed: %s", strerror(my_errno));
691
692         rtsock_validate_message(buffer, len);
693         return ((struct rt_msghdr *)buffer);
694 }
695
696 struct rt_msghdr *
697 rtsock_read_rtm_reply(int fd, char *buffer, size_t buflen, int seq)
698 {
699         struct rt_msghdr *rtm;
700         int found = 0;
701
702         while (true) {
703                 rtm = rtsock_read_rtm(fd, buffer, buflen);
704                 if (rtm->rtm_pid == getpid() && rtm->rtm_seq == seq)
705                         found = 1;
706                 if (found)
707                         RLOG("--- MATCHED RTSOCK MESSAGE ---");
708                 else
709                         RLOG("--- SKIPPED RTSOCK MESSAGE ---");
710                 rtsock_print_rtm(rtm);
711                 if (found)
712                         return (rtm);
713         }
714
715         /* NOTREACHED */
716 }
717
718 void
719 rtsock_prepare_route_message_base(struct rt_msghdr *rtm, int cmd)
720 {
721
722         memset(rtm, 0, sizeof(struct rt_msghdr));
723         rtm->rtm_type = cmd;
724         rtm->rtm_version = RTM_VERSION;
725         rtm->rtm_seq = _rtm_seq++;
726 }
727
728 void
729 rtsock_prepare_route_message(struct rt_msghdr *rtm, int cmd, struct sockaddr *dst,
730   struct sockaddr *mask, struct sockaddr *gw)
731 {
732
733         rtsock_prepare_route_message_base(rtm, cmd);
734         if (dst != NULL)
735                 rtsock_add_rtm_sa(rtm, RTA_DST, dst);
736
737         if (gw != NULL) {
738                 rtsock_add_rtm_sa(rtm, RTA_GATEWAY, gw);
739                 rtm->rtm_flags |= RTF_GATEWAY;
740         }
741
742         if (mask != NULL)
743                 rtsock_add_rtm_sa(rtm, RTA_NETMASK, mask);
744 }
745
746 void
747 rtsock_add_rtm_sa(struct rt_msghdr *rtm, int addr_type, struct sockaddr *sa)
748 {
749         char *ptr = (char *)(rtm + 1);
750         for (int i = 0; i < RTAX_MAX; i++) {
751                 if (rtm->rtm_addrs & (1 << i)) {
752                         /* add */
753                         ptr += ALIGN(((struct sockaddr *)ptr)->sa_len);
754                 }
755         }
756
757         rtm->rtm_addrs |= addr_type;
758         memcpy(ptr, sa, sa->sa_len);
759 }
760
761 struct sockaddr *
762 rtsock_find_rtm_sa(struct rt_msghdr *rtm, int addr_type)
763 {
764         char *ptr = (char *)(rtm + 1);
765         for (int i = 0; i < RTAX_MAX; i++) {
766                 if (rtm->rtm_addrs & (1 << i)) {
767                         if (addr_type == (1 << i))
768                                 return ((struct sockaddr *)ptr);
769                         /* add */
770                         ptr += ALIGN(((struct sockaddr *)ptr)->sa_len);
771                 }
772         }
773
774         return (NULL);
775 }
776
777 size_t
778 rtsock_calc_rtm_len(struct rt_msghdr *rtm)
779 {
780         size_t len = sizeof(struct rt_msghdr);
781
782         char *ptr = (char *)(rtm + 1);
783         for (int i = 0; i < RTAX_MAX; i++) {
784                 if (rtm->rtm_addrs & (1 << i)) {
785                         /* add */
786                         int sa_len = ALIGN(((struct sockaddr *)ptr)->sa_len);
787                         len += sa_len;
788                         ptr += sa_len;
789                 }
790         }
791
792         return len;
793 }
794
795 void
796 rtsock_update_rtm_len(struct rt_msghdr *rtm)
797 {
798
799         rtm->rtm_msglen = rtsock_calc_rtm_len(rtm);
800 }
801
802 static void
803 _validate_message_sockaddrs(char *buffer, int rtm_len, size_t offset, int rtm_addrs)
804 {
805         struct sockaddr *sa;
806         size_t parsed_len = offset;
807
808         /* Offset denotes initial header size */
809         sa = (struct sockaddr *)(buffer + offset);
810
811         for (int i = 0; i < RTAX_MAX; i++) {
812                 if ((rtm_addrs & (1 << i)) == 0)
813                         continue;
814                 parsed_len += SA_SIZE(sa);
815                 RTSOCK_ATF_REQUIRE_MSG((struct rt_msghdr *)buffer, parsed_len <= rtm_len,
816                     "SA %d: len %d exceeds msg size %d", i, (int)sa->sa_len, rtm_len);
817                 if (sa->sa_family == AF_LINK) {
818                         struct sockaddr_dl *sdl = (struct sockaddr_dl *)sa;
819                         int data_len = sdl->sdl_nlen + sdl->sdl_alen;
820                         data_len += offsetof(struct sockaddr_dl, sdl_data);
821
822                         RTSOCK_ATF_REQUIRE_MSG((struct rt_msghdr *)buffer,
823                             data_len <= rtm_len,
824                             "AF_LINK data size exceeds total len: %u vs %u, nlen=%d alen=%d",
825                             data_len, rtm_len, sdl->sdl_nlen, sdl->sdl_alen);
826                 }
827                 sa = (struct sockaddr *)((char *)sa + SA_SIZE(sa));
828         }
829 }
830
831 /*
832  * Raises error if base syntax checks fails.
833  */
834 void
835 rtsock_validate_message(char *buffer, ssize_t len)
836 {
837         struct rt_msghdr *rtm;
838
839         ATF_REQUIRE_MSG(len > 0, "read() return %zd, error: %s", len, strerror(errno));
840
841         rtm = (struct rt_msghdr *)buffer;
842         ATF_REQUIRE_MSG(rtm->rtm_version == RTM_VERSION, "unknown RTM_VERSION: expected %d got %d",
843                         RTM_VERSION, rtm->rtm_version);
844         ATF_REQUIRE_MSG(rtm->rtm_msglen <= len, "wrong message length: expected %d got %d",
845                         (int)len, (int)rtm->rtm_msglen);
846
847         switch (rtm->rtm_type) {
848         case RTM_GET:
849         case RTM_ADD:
850         case RTM_DELETE:
851         case RTM_CHANGE:
852                 _validate_message_sockaddrs(buffer, rtm->rtm_msglen,
853                     sizeof(struct rt_msghdr), rtm->rtm_addrs);
854                 break;
855         case RTM_DELADDR:
856         case RTM_NEWADDR:
857                 _validate_message_sockaddrs(buffer, rtm->rtm_msglen,
858                     sizeof(struct ifa_msghdr), ((struct ifa_msghdr *)buffer)->ifam_addrs);
859                 break;
860         }
861 }
862
863 void
864 rtsock_validate_pid_ours(struct rt_msghdr *rtm)
865 {
866         RTSOCK_ATF_REQUIRE_MSG(rtm, rtm->rtm_pid == getpid(), "expected pid %d, got %d",
867             getpid(), rtm->rtm_pid);
868 }
869
870 void
871 rtsock_validate_pid_user(struct rt_msghdr *rtm)
872 {
873         RTSOCK_ATF_REQUIRE_MSG(rtm, rtm->rtm_pid > 0, "expected non-zero pid, got %d",
874             rtm->rtm_pid);
875 }
876
877 void
878 rtsock_validate_pid_kernel(struct rt_msghdr *rtm)
879 {
880         RTSOCK_ATF_REQUIRE_MSG(rtm, rtm->rtm_pid == 0, "expected zero pid, got %d",
881             rtm->rtm_pid);
882 }
883
884 #endif