]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - tools/tools/switch_tls/switch_tls.c
MFV r354382,r354385: 10601 10757 Pool allocation classes
[FreeBSD/FreeBSD.git] / tools / tools / switch_tls / switch_tls.c
1 /* $OpenBSD: tcpdrop.c,v 1.4 2004/05/22 23:55:22 deraadt Exp $ */
2
3 /*-
4  * Copyright (c) 2009 Juli Mallett <jmallett@FreeBSD.org>
5  * Copyright (c) 2004 Markus Friedl <markus@openbsd.org>
6  *
7  * Permission to use, copy, modify, and distribute this software for any
8  * purpose with or without fee is hereby granted, provided that the above
9  * copyright notice and this permission notice appear in all copies.
10  *
11  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
12  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
13  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
14  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
15  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
16  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
17  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
18  */
19
20 #include <sys/cdefs.h>
21 __FBSDID("$FreeBSD$");
22
23 #include <sys/param.h>
24 #include <sys/types.h>
25 #include <sys/socket.h>
26 #include <sys/socketvar.h>
27 #include <sys/sysctl.h>
28
29 #include <netinet/in.h>
30 #include <netinet/in_pcb.h>
31 #define TCPSTATES
32 #include <netinet/tcp_fsm.h>
33 #include <netinet/tcp_var.h>
34
35 #include <err.h>
36 #include <netdb.h>
37 #include <stdbool.h>
38 #include <stdio.h>
39 #include <stdlib.h>
40 #include <string.h>
41 #include <unistd.h>
42
43 #define TCPDROP_FOREIGN         0
44 #define TCPDROP_LOCAL           1
45
46 #define SW_TLS                  0
47 #define IFNET_TLS               1
48
49 struct host_service {
50         char hs_host[NI_MAXHOST];
51         char hs_service[NI_MAXSERV];
52 };
53
54 static bool tcpswitch_list_commands = false;
55
56 static char *findport(const char *);
57 static struct xinpgen *getxpcblist(const char *);
58 static void sockinfo(const struct sockaddr *, struct host_service *);
59 static bool tcpswitch(const struct sockaddr *, const struct sockaddr *, int);
60 static bool tcpswitchall(const char *, int);
61 static bool tcpswitchbyname(const char *, const char *, const char *,
62     const char *, int);
63 static bool tcpswitchconn(const struct in_conninfo *, int);
64 static void usage(void);
65
66 /*
67  * Switch a tcp connection.
68  */
69 int
70 main(int argc, char *argv[])
71 {
72         char stack[TCP_FUNCTION_NAME_LEN_MAX];
73         char *lport, *fport;
74         bool switchall, switchallstack;
75         int ch, mode;
76
77         switchall = false;
78         switchallstack = false;
79         stack[0] = '\0';
80         mode = SW_TLS;
81
82         while ((ch = getopt(argc, argv, "ailS:s")) != -1) {
83                 switch (ch) {
84                 case 'a':
85                         switchall = true;
86                         break;
87                 case 'i':
88                         mode = IFNET_TLS;
89                         break;
90                 case 'l':
91                         tcpswitch_list_commands = true;
92                         break;
93                 case 'S':
94                         switchallstack = true;
95                         strlcpy(stack, optarg, sizeof(stack));
96                         break;
97                 case 's':
98                         mode = SW_TLS;
99                         break;
100                 default:
101                         usage();
102                 }
103         }
104         argc -= optind;
105         argv += optind;
106
107         if (switchall && switchallstack)
108                 usage();
109         if (switchall || switchallstack) {
110                 if (argc != 0)
111                         usage();
112                 if (!tcpswitchall(stack, mode))
113                         exit(1);
114                 exit(0);
115         }
116
117         if ((argc != 2 && argc != 4) || tcpswitch_list_commands)
118                 usage();
119
120         if (argc == 2) {
121                 lport = findport(argv[0]);
122                 fport = findport(argv[1]);
123                 if (lport == NULL || lport[1] == '\0' || fport == NULL ||
124                     fport[1] == '\0')
125                         usage();
126                 *lport++ = '\0';
127                 *fport++ = '\0';
128                 if (!tcpswitchbyname(argv[0], lport, argv[1], fport, mode))
129                         exit(1);
130         } else if (!tcpswitchbyname(argv[0], argv[1], argv[2], argv[3], mode))
131                 exit(1);
132
133         exit(0);
134 }
135
136 static char *
137 findport(const char *arg)
138 {
139         char *dot, *colon;
140
141         /* A strrspn() or strrpbrk() would be nice. */
142         dot = strrchr(arg, '.');
143         colon = strrchr(arg, ':');
144         if (dot == NULL)
145                 return (colon);
146         if (colon == NULL)
147                 return (dot);
148         if (dot < colon)
149                 return (colon);
150         else
151                 return (dot);
152 }
153
154 static struct xinpgen *
155 getxpcblist(const char *name)
156 {
157         struct xinpgen *xinp;
158         size_t len;
159         int rv;
160
161         len = 0;
162         rv = sysctlbyname(name, NULL, &len, NULL, 0);
163         if (rv == -1)
164                 err(1, "sysctlbyname %s", name);
165
166         if (len == 0)
167                 errx(1, "%s is empty", name);
168
169         xinp = malloc(len);
170         if (xinp == NULL)
171                 errx(1, "malloc failed");
172
173         rv = sysctlbyname(name, xinp, &len, NULL, 0);
174         if (rv == -1)
175                 err(1, "sysctlbyname %s", name);
176
177         return (xinp);
178 }
179
180 static void
181 sockinfo(const struct sockaddr *sa, struct host_service *hs)
182 {
183         static const int flags = NI_NUMERICHOST | NI_NUMERICSERV;
184         int rv;
185
186         rv = getnameinfo(sa, sa->sa_len, hs->hs_host, sizeof hs->hs_host,
187             hs->hs_service, sizeof hs->hs_service, flags);
188         if (rv == -1)
189                 err(1, "getnameinfo");
190 }
191
192 static bool
193 tcpswitch(const struct sockaddr *lsa, const struct sockaddr *fsa, int mode)
194 {
195         struct host_service local, foreign;
196         struct sockaddr_storage addrs[2];
197         int rv;
198
199         memcpy(&addrs[TCPDROP_FOREIGN], fsa, fsa->sa_len);
200         memcpy(&addrs[TCPDROP_LOCAL], lsa, lsa->sa_len);
201
202         sockinfo(lsa, &local);
203         sockinfo(fsa, &foreign);
204
205         if (tcpswitch_list_commands) {
206                 printf("switch_tls %s %s %s %s %s\n",
207                     mode == SW_TLS ? "-s" : "-i",
208                     local.hs_host, local.hs_service,
209                     foreign.hs_host, foreign.hs_service);
210                 return (true);
211         }
212
213         rv = sysctlbyname(mode == SW_TLS ? "net.inet.tcp.switch_to_sw_tls" :
214             "net.inet.tcp.switch_to_ifnet_tls", NULL, NULL, &addrs,
215             sizeof addrs);
216         if (rv == -1) {
217                 warn("%s %s %s %s", local.hs_host, local.hs_service,
218                     foreign.hs_host, foreign.hs_service);
219                 return (false);
220         }
221         printf("%s %s %s %s: switched\n", local.hs_host, local.hs_service,
222             foreign.hs_host, foreign.hs_service);
223         return (true);
224 }
225
226 static bool
227 tcpswitchall(const char *stack, int mode)
228 {
229         struct xinpgen *head, *xinp;
230         struct xtcpcb *xtp;
231         struct xinpcb *xip;
232         bool ok;
233
234         ok = true;
235
236         head = getxpcblist("net.inet.tcp.pcblist");
237
238 #define XINP_NEXT(xinp)                                                 \
239         ((struct xinpgen *)(uintptr_t)((uintptr_t)(xinp) + (xinp)->xig_len))
240
241         for (xinp = XINP_NEXT(head); xinp->xig_len > sizeof *xinp;
242             xinp = XINP_NEXT(xinp)) {
243                 xtp = (struct xtcpcb *)xinp;
244                 xip = &xtp->xt_inp;
245
246                 /*
247                  * XXX
248                  * Check protocol, support just v4 or v6, etc.
249                  */
250
251                 /* Ignore PCBs which were freed during copyout.  */
252                 if (xip->inp_gencnt > head->xig_gen)
253                         continue;
254
255                 /* Skip listening sockets.  */
256                 if (xtp->t_state == TCPS_LISTEN)
257                         continue;
258
259                 /* If requested, skip sockets not having the requested stack. */
260                 if (stack[0] != '\0' &&
261                     strncmp(xtp->xt_stack, stack, TCP_FUNCTION_NAME_LEN_MAX))
262                         continue;
263
264                 if (!tcpswitchconn(&xip->inp_inc, mode))
265                         ok = false;
266         }
267         free(head);
268
269         return (ok);
270 }
271
272 static bool
273 tcpswitchbyname(const char *lhost, const char *lport, const char *fhost,
274     const char *fport, int mode)
275 {
276         static const struct addrinfo hints = {
277                 /*
278                  * Look for streams in all domains.
279                  */
280                 .ai_family = AF_UNSPEC,
281                 .ai_socktype = SOCK_STREAM,
282         };
283         struct addrinfo *ail, *local, *aif, *foreign;
284         int error;
285         bool ok, infamily;
286
287         error = getaddrinfo(lhost, lport, &hints, &local);
288         if (error != 0)
289                 errx(1, "getaddrinfo: %s port %s: %s", lhost, lport,
290                     gai_strerror(error));
291
292         error = getaddrinfo(fhost, fport, &hints, &foreign);
293         if (error != 0) {
294                 freeaddrinfo(local); /* XXX gratuitous */
295                 errx(1, "getaddrinfo: %s port %s: %s", fhost, fport,
296                     gai_strerror(error));
297         }
298
299         ok = true;
300         infamily = false;
301
302         /*
303          * Try every combination of local and foreign address pairs.
304          */
305         for (ail = local; ail != NULL; ail = ail->ai_next) {
306                 for (aif = foreign; aif != NULL; aif = aif->ai_next) {
307                         if (ail->ai_family != aif->ai_family)
308                                 continue;
309                         infamily = true;
310                         if (!tcpswitch(ail->ai_addr, aif->ai_addr, mode))
311                                 ok = false;
312                 }
313         }
314
315         if (!infamily) {
316                 warnx("%s %s %s %s: different address families", lhost, lport,
317                     fhost, fport);
318                 ok = false;
319         }
320
321         freeaddrinfo(local);
322         freeaddrinfo(foreign);
323
324         return (ok);
325 }
326
327 static bool
328 tcpswitchconn(const struct in_conninfo *inc, int mode)
329 {
330         struct sockaddr *local, *foreign;
331         struct sockaddr_in6 sin6[2];
332         struct sockaddr_in sin4[2];
333
334         if ((inc->inc_flags & INC_ISIPV6) != 0) {
335                 memset(sin6, 0, sizeof sin6);
336
337                 sin6[TCPDROP_LOCAL].sin6_len = sizeof sin6[TCPDROP_LOCAL];
338                 sin6[TCPDROP_LOCAL].sin6_family = AF_INET6;
339                 sin6[TCPDROP_LOCAL].sin6_port = inc->inc_lport;
340                 memcpy(&sin6[TCPDROP_LOCAL].sin6_addr, &inc->inc6_laddr,
341                     sizeof inc->inc6_laddr);
342                 local = (struct sockaddr *)&sin6[TCPDROP_LOCAL];
343
344                 sin6[TCPDROP_FOREIGN].sin6_len = sizeof sin6[TCPDROP_FOREIGN];
345                 sin6[TCPDROP_FOREIGN].sin6_family = AF_INET6;
346                 sin6[TCPDROP_FOREIGN].sin6_port = inc->inc_fport;
347                 memcpy(&sin6[TCPDROP_FOREIGN].sin6_addr, &inc->inc6_faddr,
348                     sizeof inc->inc6_faddr);
349                 foreign = (struct sockaddr *)&sin6[TCPDROP_FOREIGN];
350         } else {
351                 memset(sin4, 0, sizeof sin4);
352
353                 sin4[TCPDROP_LOCAL].sin_len = sizeof sin4[TCPDROP_LOCAL];
354                 sin4[TCPDROP_LOCAL].sin_family = AF_INET;
355                 sin4[TCPDROP_LOCAL].sin_port = inc->inc_lport;
356                 memcpy(&sin4[TCPDROP_LOCAL].sin_addr, &inc->inc_laddr,
357                     sizeof inc->inc_laddr);
358                 local = (struct sockaddr *)&sin4[TCPDROP_LOCAL];
359
360                 sin4[TCPDROP_FOREIGN].sin_len = sizeof sin4[TCPDROP_FOREIGN];
361                 sin4[TCPDROP_FOREIGN].sin_family = AF_INET;
362                 sin4[TCPDROP_FOREIGN].sin_port = inc->inc_fport;
363                 memcpy(&sin4[TCPDROP_FOREIGN].sin_addr, &inc->inc_faddr,
364                     sizeof inc->inc_faddr);
365                 foreign = (struct sockaddr *)&sin4[TCPDROP_FOREIGN];
366         }
367
368         return (tcpswitch(local, foreign, mode));
369 }
370
371 static void
372 usage(void)
373 {
374         fprintf(stderr,
375 "usage: switch_tls [-i | -s] local-address local-port foreign-address foreign-port\n"
376 "       switch_tls [-i | -s] local-address:local-port foreign-address:foreign-port\n"
377 "       switch_tls [-i | -s] local-address.local-port foreign-address.foreign-port\n"
378 "       switch_tls [-l | -i | -s] -a\n"
379 "       switch_tls [-l | -i | -s] -S stack\n");
380         exit(1);
381 }