]> CyberLeo.Net >> Repos - FreeBSD/releng/9.2.git/blob - sbin/hastd/proto_tcp.c
- Copy stable/9 to releng/9.2 as part of the 9.2-RELEASE cycle.
[FreeBSD/releng/9.2.git] / sbin / hastd / proto_tcp.c
1 /*-
2  * Copyright (c) 2009-2010 The FreeBSD Foundation
3  * Copyright (c) 2011 Pawel Jakub Dawidek <pawel@dawidek.net>
4  * All rights reserved.
5  *
6  * This software was developed by Pawel Jakub Dawidek under sponsorship from
7  * the FreeBSD Foundation.
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  * 1. Redistributions of source code must retain the above copyright
13  *    notice, this list of conditions and the following disclaimer.
14  * 2. Redistributions in binary form must reproduce the above copyright
15  *    notice, this list of conditions and the following disclaimer in the
16  *    documentation and/or other materials provided with the distribution.
17  *
18  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
19  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
22  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
24  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
25  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
26  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
27  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
28  * SUCH DAMAGE.
29  */
30
31 #include <sys/cdefs.h>
32 __FBSDID("$FreeBSD$");
33
34 #include <sys/param.h>  /* MAXHOSTNAMELEN */
35 #include <sys/socket.h>
36
37 #include <arpa/inet.h>
38
39 #include <netinet/in.h>
40 #include <netinet/tcp.h>
41
42 #include <errno.h>
43 #include <fcntl.h>
44 #include <netdb.h>
45 #include <stdbool.h>
46 #include <stdint.h>
47 #include <stdio.h>
48 #include <string.h>
49 #include <unistd.h>
50
51 #include "pjdlog.h"
52 #include "proto_impl.h"
53 #include "subr.h"
54
55 #define TCP_CTX_MAGIC   0x7c41c
56 struct tcp_ctx {
57         int                     tc_magic;
58         struct sockaddr_storage tc_sa;
59         int                     tc_fd;
60         int                     tc_side;
61 #define TCP_SIDE_CLIENT         0
62 #define TCP_SIDE_SERVER_LISTEN  1
63 #define TCP_SIDE_SERVER_WORK    2
64 };
65
66 static int tcp_connect_wait(void *ctx, int timeout);
67 static void tcp_close(void *ctx);
68
69 /*
70  * Function converts the given string to unsigned number.
71  */
72 static int
73 numfromstr(const char *str, intmax_t minnum, intmax_t maxnum, intmax_t *nump)
74 {
75         intmax_t digit, num;
76
77         if (str[0] == '\0')
78                 goto invalid;   /* Empty string. */
79         num = 0;
80         for (; *str != '\0'; str++) {
81                 if (*str < '0' || *str > '9')
82                         goto invalid;   /* Non-digit character. */
83                 digit = *str - '0';
84                 if (num > num * 10 + digit)
85                         goto invalid;   /* Overflow. */
86                 num = num * 10 + digit;
87                 if (num > maxnum)
88                         goto invalid;   /* Too big. */
89         }
90         if (num < minnum)
91                 goto invalid;   /* Too small. */
92         *nump = num;
93         return (0);
94 invalid:
95         errno = EINVAL;
96         return (-1);
97 }
98
99 static int
100 tcp_addr(const char *addr, int defport, struct sockaddr_storage *sap)
101 {
102         char iporhost[MAXHOSTNAMELEN], portstr[6];
103         struct addrinfo hints;
104         struct addrinfo *res;
105         const char *pp;
106         intmax_t port;
107         size_t size;
108         int error;
109
110         if (addr == NULL)
111                 return (-1);
112
113         bzero(&hints, sizeof(hints));
114         hints.ai_flags = AI_ADDRCONFIG | AI_NUMERICSERV;
115         hints.ai_family = PF_UNSPEC;
116         hints.ai_socktype = SOCK_STREAM;
117         hints.ai_protocol = IPPROTO_TCP;
118
119         if (strncasecmp(addr, "tcp4://", 7) == 0) {
120                 addr += 7;
121                 hints.ai_family = PF_INET;
122         } else if (strncasecmp(addr, "tcp6://", 7) == 0) {
123                 addr += 7;
124                 hints.ai_family = PF_INET6;
125         } else if (strncasecmp(addr, "tcp://", 6) == 0) {
126                 addr += 6;
127         } else {
128                 /*
129                  * Because TCP is the default assume IP or host is given without
130                  * prefix.
131                  */
132         }
133
134         /*
135          * Extract optional port.
136          * There are three cases to consider.
137          * 1. hostname with port, eg. freefall.freebsd.org:8457
138          * 2. IPv4 address with port, eg. 192.168.0.101:8457
139          * 3. IPv6 address with port, eg. [fe80::1]:8457
140          * We discover IPv6 address by checking for two colons and if port is
141          * given, the address has to start with [.
142          */
143         pp = NULL;
144         if (strchr(addr, ':') != strrchr(addr, ':')) {
145                 if (addr[0] == '[')
146                         pp = strrchr(addr, ':');
147         } else {
148                 pp = strrchr(addr, ':');
149         }
150         if (pp == NULL) {
151                 /* Port not given, use the default. */
152                 port = defport;
153         } else {
154                 if (numfromstr(pp + 1, 1, 65535, &port) == -1)
155                         return (errno);
156         }
157         (void)snprintf(portstr, sizeof(portstr), "%jd", (intmax_t)port);
158         /* Extract host name or IP address. */
159         if (pp == NULL) {
160                 size = sizeof(iporhost);
161                 if (strlcpy(iporhost, addr, size) >= size)
162                         return (ENAMETOOLONG);
163         } else if (addr[0] == '[' && pp[-1] == ']') {
164                 size = (size_t)(pp - addr - 2 + 1);
165                 if (size > sizeof(iporhost))
166                         return (ENAMETOOLONG);
167                 (void)strlcpy(iporhost, addr + 1, size);
168         } else {
169                 size = (size_t)(pp - addr + 1);
170                 if (size > sizeof(iporhost))
171                         return (ENAMETOOLONG);
172                 (void)strlcpy(iporhost, addr, size);
173         }
174
175         error = getaddrinfo(iporhost, portstr, &hints, &res);
176         if (error != 0) {
177                 pjdlog_debug(1, "getaddrinfo(%s, %s) failed: %s.", iporhost,
178                     portstr, gai_strerror(error));
179                 return (EINVAL);
180         }
181         if (res == NULL)
182                 return (ENOENT);
183
184         memcpy(sap, res->ai_addr, res->ai_addrlen);
185
186         freeaddrinfo(res);
187
188         return (0);
189 }
190
191 static int
192 tcp_setup_new(const char *addr, int side, void **ctxp)
193 {
194         struct tcp_ctx *tctx;
195         int ret, nodelay;
196
197         PJDLOG_ASSERT(addr != NULL);
198         PJDLOG_ASSERT(side == TCP_SIDE_CLIENT ||
199             side == TCP_SIDE_SERVER_LISTEN);
200         PJDLOG_ASSERT(ctxp != NULL);
201
202         tctx = malloc(sizeof(*tctx));
203         if (tctx == NULL)
204                 return (errno);
205
206         /* Parse given address. */
207         if ((ret = tcp_addr(addr, PROTO_TCP_DEFAULT_PORT, &tctx->tc_sa)) != 0) {
208                 free(tctx);
209                 return (ret);
210         }
211
212         PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
213
214         tctx->tc_fd = socket(tctx->tc_sa.ss_family, SOCK_STREAM, 0);
215         if (tctx->tc_fd == -1) {
216                 ret = errno;
217                 free(tctx);
218                 return (ret);
219         }
220
221         PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
222
223         /* Socket settings. */
224         nodelay = 1;
225         if (setsockopt(tctx->tc_fd, IPPROTO_TCP, TCP_NODELAY, &nodelay,
226             sizeof(nodelay)) == -1) {
227                 pjdlog_errno(LOG_WARNING, "Unable to set TCP_NOELAY");
228         }
229
230         tctx->tc_side = side;
231         tctx->tc_magic = TCP_CTX_MAGIC;
232         *ctxp = tctx;
233
234         return (0);
235 }
236
237 static int
238 tcp_setup_wrap(int fd, int side, void **ctxp)
239 {
240         struct tcp_ctx *tctx;
241
242         PJDLOG_ASSERT(fd >= 0);
243         PJDLOG_ASSERT(side == TCP_SIDE_CLIENT ||
244             side == TCP_SIDE_SERVER_WORK);
245         PJDLOG_ASSERT(ctxp != NULL);
246
247         tctx = malloc(sizeof(*tctx));
248         if (tctx == NULL)
249                 return (errno);
250
251         tctx->tc_fd = fd;
252         tctx->tc_sa.ss_family = AF_UNSPEC;
253         tctx->tc_side = side;
254         tctx->tc_magic = TCP_CTX_MAGIC;
255         *ctxp = tctx;
256
257         return (0);
258 }
259
260 static int
261 tcp_client(const char *srcaddr, const char *dstaddr, void **ctxp)
262 {
263         struct tcp_ctx *tctx;
264         struct sockaddr_storage sa;
265         int ret;
266
267         ret = tcp_setup_new(dstaddr, TCP_SIDE_CLIENT, ctxp);
268         if (ret != 0)
269                 return (ret);
270         tctx = *ctxp;
271         if (srcaddr == NULL)
272                 return (0);
273         ret = tcp_addr(srcaddr, 0, &sa);
274         if (ret != 0) {
275                 tcp_close(tctx);
276                 return (ret);
277         }
278         if (bind(tctx->tc_fd, (struct sockaddr *)&sa, sa.ss_len) == -1) {
279                 ret = errno;
280                 tcp_close(tctx);
281                 return (ret);
282         }
283         return (0);
284 }
285
286 static int
287 tcp_connect(void *ctx, int timeout)
288 {
289         struct tcp_ctx *tctx = ctx;
290         int error, flags;
291
292         PJDLOG_ASSERT(tctx != NULL);
293         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
294         PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_CLIENT);
295         PJDLOG_ASSERT(tctx->tc_fd >= 0);
296         PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
297         PJDLOG_ASSERT(timeout >= -1);
298
299         flags = fcntl(tctx->tc_fd, F_GETFL);
300         if (flags == -1) {
301                 pjdlog_common(LOG_DEBUG, 1, errno, "fcntl(F_GETFL) failed");
302                 return (errno);
303         }
304         /*
305          * We make socket non-blocking so we can handle connection timeout
306          * manually.
307          */
308         flags |= O_NONBLOCK;
309         if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
310                 pjdlog_common(LOG_DEBUG, 1, errno,
311                     "fcntl(F_SETFL, O_NONBLOCK) failed");
312                 return (errno);
313         }
314
315         if (connect(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
316             tctx->tc_sa.ss_len) == 0) {
317                 if (timeout == -1)
318                         return (0);
319                 error = 0;
320                 goto done;
321         }
322         if (errno != EINPROGRESS) {
323                 error = errno;
324                 pjdlog_common(LOG_DEBUG, 1, errno, "connect() failed");
325                 goto done;
326         }
327         if (timeout == -1)
328                 return (0);
329         return (tcp_connect_wait(ctx, timeout));
330 done:
331         flags &= ~O_NONBLOCK;
332         if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
333                 if (error == 0)
334                         error = errno;
335                 pjdlog_common(LOG_DEBUG, 1, errno,
336                     "fcntl(F_SETFL, ~O_NONBLOCK) failed");
337         }
338         return (error);
339 }
340
341 static int
342 tcp_connect_wait(void *ctx, int timeout)
343 {
344         struct tcp_ctx *tctx = ctx;
345         struct timeval tv;
346         fd_set fdset;
347         socklen_t esize;
348         int error, flags, ret;
349
350         PJDLOG_ASSERT(tctx != NULL);
351         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
352         PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_CLIENT);
353         PJDLOG_ASSERT(tctx->tc_fd >= 0);
354         PJDLOG_ASSERT(timeout >= 0);
355
356         tv.tv_sec = timeout;
357         tv.tv_usec = 0;
358 again:
359         FD_ZERO(&fdset);
360         FD_SET(tctx->tc_fd, &fdset);
361         ret = select(tctx->tc_fd + 1, NULL, &fdset, NULL, &tv);
362         if (ret == 0) {
363                 error = ETIMEDOUT;
364                 goto done;
365         } else if (ret == -1) {
366                 if (errno == EINTR)
367                         goto again;
368                 error = errno;
369                 pjdlog_common(LOG_DEBUG, 1, errno, "select() failed");
370                 goto done;
371         }
372         PJDLOG_ASSERT(ret > 0);
373         PJDLOG_ASSERT(FD_ISSET(tctx->tc_fd, &fdset));
374         esize = sizeof(error);
375         if (getsockopt(tctx->tc_fd, SOL_SOCKET, SO_ERROR, &error,
376             &esize) == -1) {
377                 error = errno;
378                 pjdlog_common(LOG_DEBUG, 1, errno,
379                     "getsockopt(SO_ERROR) failed");
380                 goto done;
381         }
382         if (error != 0) {
383                 pjdlog_common(LOG_DEBUG, 1, error,
384                     "getsockopt(SO_ERROR) returned error");
385                 goto done;
386         }
387         error = 0;
388 done:
389         flags = fcntl(tctx->tc_fd, F_GETFL);
390         if (flags == -1) {
391                 if (error == 0)
392                         error = errno;
393                 pjdlog_common(LOG_DEBUG, 1, errno, "fcntl(F_GETFL) failed");
394                 return (error);
395         }
396         flags &= ~O_NONBLOCK;
397         if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
398                 if (error == 0)
399                         error = errno;
400                 pjdlog_common(LOG_DEBUG, 1, errno,
401                     "fcntl(F_SETFL, ~O_NONBLOCK) failed");
402         }
403         return (error);
404 }
405
406 static int
407 tcp_server(const char *addr, void **ctxp)
408 {
409         struct tcp_ctx *tctx;
410         int ret, val;
411
412         ret = tcp_setup_new(addr, TCP_SIDE_SERVER_LISTEN, ctxp);
413         if (ret != 0)
414                 return (ret);
415
416         tctx = *ctxp;
417
418         val = 1;
419         /* Ignore failure. */
420         (void)setsockopt(tctx->tc_fd, SOL_SOCKET, SO_REUSEADDR, &val,
421            sizeof(val));
422
423         PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
424
425         if (bind(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
426             tctx->tc_sa.ss_len) == -1) {
427                 ret = errno;
428                 tcp_close(tctx);
429                 return (ret);
430         }
431         if (listen(tctx->tc_fd, 8) == -1) {
432                 ret = errno;
433                 tcp_close(tctx);
434                 return (ret);
435         }
436
437         return (0);
438 }
439
440 static int
441 tcp_accept(void *ctx, void **newctxp)
442 {
443         struct tcp_ctx *tctx = ctx;
444         struct tcp_ctx *newtctx;
445         socklen_t fromlen;
446         int ret;
447
448         PJDLOG_ASSERT(tctx != NULL);
449         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
450         PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_SERVER_LISTEN);
451         PJDLOG_ASSERT(tctx->tc_fd >= 0);
452         PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
453
454         newtctx = malloc(sizeof(*newtctx));
455         if (newtctx == NULL)
456                 return (errno);
457
458         fromlen = tctx->tc_sa.ss_len;
459         newtctx->tc_fd = accept(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
460             &fromlen);
461         if (newtctx->tc_fd == -1) {
462                 ret = errno;
463                 free(newtctx);
464                 return (ret);
465         }
466
467         newtctx->tc_side = TCP_SIDE_SERVER_WORK;
468         newtctx->tc_magic = TCP_CTX_MAGIC;
469         *newctxp = newtctx;
470
471         return (0);
472 }
473
474 static int
475 tcp_wrap(int fd, bool client, void **ctxp)
476 {
477
478         return (tcp_setup_wrap(fd,
479             client ? TCP_SIDE_CLIENT : TCP_SIDE_SERVER_WORK, ctxp));
480 }
481
482 static int
483 tcp_send(void *ctx, const unsigned char *data, size_t size, int fd)
484 {
485         struct tcp_ctx *tctx = ctx;
486
487         PJDLOG_ASSERT(tctx != NULL);
488         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
489         PJDLOG_ASSERT(tctx->tc_fd >= 0);
490         PJDLOG_ASSERT(fd == -1);
491
492         return (proto_common_send(tctx->tc_fd, data, size, -1));
493 }
494
495 static int
496 tcp_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
497 {
498         struct tcp_ctx *tctx = ctx;
499
500         PJDLOG_ASSERT(tctx != NULL);
501         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
502         PJDLOG_ASSERT(tctx->tc_fd >= 0);
503         PJDLOG_ASSERT(fdp == NULL);
504
505         return (proto_common_recv(tctx->tc_fd, data, size, NULL));
506 }
507
508 static int
509 tcp_descriptor(const void *ctx)
510 {
511         const struct tcp_ctx *tctx = ctx;
512
513         PJDLOG_ASSERT(tctx != NULL);
514         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
515
516         return (tctx->tc_fd);
517 }
518
519 static bool
520 tcp_address_match(const void *ctx, const char *addr)
521 {
522         const struct tcp_ctx *tctx = ctx;
523         struct sockaddr_storage sa1, sa2;
524         socklen_t salen;
525
526         PJDLOG_ASSERT(tctx != NULL);
527         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
528
529         if (tcp_addr(addr, PROTO_TCP_DEFAULT_PORT, &sa1) != 0)
530                 return (false);
531
532         salen = sizeof(sa2);
533         if (getpeername(tctx->tc_fd, (struct sockaddr *)&sa2, &salen) == -1)
534                 return (false);
535
536         if (sa1.ss_family != sa2.ss_family || sa1.ss_len != sa2.ss_len)
537                 return (false);
538
539         switch (sa1.ss_family) {
540         case AF_INET:
541             {
542                 struct sockaddr_in *sin1, *sin2;
543
544                 sin1 = (struct sockaddr_in *)&sa1;
545                 sin2 = (struct sockaddr_in *)&sa2;
546
547                 return (memcmp(&sin1->sin_addr, &sin2->sin_addr,
548                     sizeof(sin1->sin_addr)) == 0);
549             }
550         case AF_INET6:
551             {
552                 struct sockaddr_in6 *sin1, *sin2;
553
554                 sin1 = (struct sockaddr_in6 *)&sa1;
555                 sin2 = (struct sockaddr_in6 *)&sa2;
556
557                 return (memcmp(&sin1->sin6_addr, &sin2->sin6_addr,
558                     sizeof(sin1->sin6_addr)) == 0);
559             }
560         default:
561                 return (false);
562         }
563 }
564
565 static void
566 tcp_local_address(const void *ctx, char *addr, size_t size)
567 {
568         const struct tcp_ctx *tctx = ctx;
569         struct sockaddr_storage sa;
570         socklen_t salen;
571
572         PJDLOG_ASSERT(tctx != NULL);
573         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
574
575         salen = sizeof(sa);
576         if (getsockname(tctx->tc_fd, (struct sockaddr *)&sa, &salen) == -1) {
577                 PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
578                 return;
579         }
580         PJDLOG_VERIFY(snprintf(addr, size, "tcp://%S", &sa) < (ssize_t)size);
581 }
582
583 static void
584 tcp_remote_address(const void *ctx, char *addr, size_t size)
585 {
586         const struct tcp_ctx *tctx = ctx;
587         struct sockaddr_storage sa;
588         socklen_t salen;
589
590         PJDLOG_ASSERT(tctx != NULL);
591         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
592
593         salen = sizeof(sa);
594         if (getpeername(tctx->tc_fd, (struct sockaddr *)&sa, &salen) == -1) {
595                 PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
596                 return;
597         }
598         PJDLOG_VERIFY(snprintf(addr, size, "tcp://%S", &sa) < (ssize_t)size);
599 }
600
601 static void
602 tcp_close(void *ctx)
603 {
604         struct tcp_ctx *tctx = ctx;
605
606         PJDLOG_ASSERT(tctx != NULL);
607         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
608
609         if (tctx->tc_fd >= 0)
610                 close(tctx->tc_fd);
611         tctx->tc_magic = 0;
612         free(tctx);
613 }
614
615 static struct proto tcp_proto = {
616         .prt_name = "tcp",
617         .prt_client = tcp_client,
618         .prt_connect = tcp_connect,
619         .prt_connect_wait = tcp_connect_wait,
620         .prt_server = tcp_server,
621         .prt_accept = tcp_accept,
622         .prt_wrap = tcp_wrap,
623         .prt_send = tcp_send,
624         .prt_recv = tcp_recv,
625         .prt_descriptor = tcp_descriptor,
626         .prt_address_match = tcp_address_match,
627         .prt_local_address = tcp_local_address,
628         .prt_remote_address = tcp_remote_address,
629         .prt_close = tcp_close
630 };
631
632 static __constructor void
633 tcp_ctor(void)
634 {
635
636         proto_register(&tcp_proto, true);
637 }