]> CyberLeo.Net >> Repos - FreeBSD/releng/9.0.git/blob - sbin/hastd/proto_tcp.c
Copy stable/9 to releng/9.0 as part of the FreeBSD 9.0-RELEASE release
[FreeBSD/releng/9.0.git] / sbin / hastd / proto_tcp.c
1 /*-
2  * Copyright (c) 2009-2010 The FreeBSD Foundation
3  * Copyright (c) 2011 Pawel Jakub Dawidek <pawel@dawidek.net>
4  * All rights reserved.
5  *
6  * This software was developed by Pawel Jakub Dawidek under sponsorship from
7  * the FreeBSD Foundation.
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  * 1. Redistributions of source code must retain the above copyright
13  *    notice, this list of conditions and the following disclaimer.
14  * 2. Redistributions in binary form must reproduce the above copyright
15  *    notice, this list of conditions and the following disclaimer in the
16  *    documentation and/or other materials provided with the distribution.
17  *
18  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
19  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
22  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
24  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
25  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
26  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
27  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
28  * SUCH DAMAGE.
29  */
30
31 #include <sys/cdefs.h>
32 __FBSDID("$FreeBSD$");
33
34 #include <sys/param.h>  /* MAXHOSTNAMELEN */
35 #include <sys/socket.h>
36
37 #include <arpa/inet.h>
38
39 #include <netinet/in.h>
40 #include <netinet/tcp.h>
41
42 #include <errno.h>
43 #include <fcntl.h>
44 #include <netdb.h>
45 #include <stdbool.h>
46 #include <stdint.h>
47 #include <stdio.h>
48 #include <string.h>
49 #include <unistd.h>
50
51 #include "pjdlog.h"
52 #include "proto_impl.h"
53 #include "subr.h"
54
55 #define TCP_CTX_MAGIC   0x7c41c
56 struct tcp_ctx {
57         int                     tc_magic;
58         struct sockaddr_storage tc_sa;
59         int                     tc_fd;
60         int                     tc_side;
61 #define TCP_SIDE_CLIENT         0
62 #define TCP_SIDE_SERVER_LISTEN  1
63 #define TCP_SIDE_SERVER_WORK    2
64 };
65
66 static int tcp_connect_wait(void *ctx, int timeout);
67 static void tcp_close(void *ctx);
68
69 /*
70  * Function converts the given string to unsigned number.
71  */
72 static int
73 numfromstr(const char *str, intmax_t minnum, intmax_t maxnum, intmax_t *nump)
74 {
75         intmax_t digit, num;
76
77         if (str[0] == '\0')
78                 goto invalid;   /* Empty string. */
79         num = 0;
80         for (; *str != '\0'; str++) {
81                 if (*str < '0' || *str > '9')
82                         goto invalid;   /* Non-digit character. */
83                 digit = *str - '0';
84                 if (num > num * 10 + digit)
85                         goto invalid;   /* Overflow. */
86                 num = num * 10 + digit;
87                 if (num > maxnum)
88                         goto invalid;   /* Too big. */
89         }
90         if (num < minnum)
91                 goto invalid;   /* Too small. */
92         *nump = num;
93         return (0);
94 invalid:
95         errno = EINVAL;
96         return (-1);
97 }
98
99 static int
100 tcp_addr(const char *addr, int defport, struct sockaddr_storage *sap)
101 {
102         char iporhost[MAXHOSTNAMELEN], portstr[6];
103         struct addrinfo hints;
104         struct addrinfo *res;
105         const char *pp;
106         intmax_t port;
107         size_t size;
108         int error;
109
110         if (addr == NULL)
111                 return (-1);
112
113         bzero(&hints, sizeof(hints));
114         hints.ai_flags = AI_ADDRCONFIG | AI_NUMERICSERV;
115         hints.ai_family = PF_UNSPEC;
116         hints.ai_socktype = SOCK_STREAM;
117         hints.ai_protocol = IPPROTO_TCP;
118
119         if (strncasecmp(addr, "tcp4://", 7) == 0) {
120                 addr += 7;
121                 hints.ai_family = PF_INET;
122         } else if (strncasecmp(addr, "tcp6://", 7) == 0) {
123                 addr += 7;
124                 hints.ai_family = PF_INET6;
125         } else if (strncasecmp(addr, "tcp://", 6) == 0) {
126                 addr += 6;
127         } else {
128                 /*
129                  * Because TCP is the default assume IP or host is given without
130                  * prefix.
131                  */
132         }
133
134         /*
135          * Extract optional port.
136          * There are three cases to consider.
137          * 1. hostname with port, eg. freefall.freebsd.org:8457
138          * 2. IPv4 address with port, eg. 192.168.0.101:8457
139          * 3. IPv6 address with port, eg. [fe80::1]:8457
140          * We discover IPv6 address by checking for two colons and if port is
141          * given, the address has to start with [.
142          */
143         pp = NULL;
144         if (strchr(addr, ':') != strrchr(addr, ':')) {
145                 if (addr[0] == '[')
146                         pp = strrchr(addr, ':');
147         } else {
148                 pp = strrchr(addr, ':');
149         }
150         if (pp == NULL) {
151                 /* Port not given, use the default. */
152                 port = defport;
153         } else {
154                 if (numfromstr(pp + 1, 1, 65535, &port) < 0)
155                         return (errno);
156         }
157         (void)snprintf(portstr, sizeof(portstr), "%jd", (intmax_t)port);
158         /* Extract host name or IP address. */
159         if (pp == NULL) {
160                 size = sizeof(iporhost);
161                 if (strlcpy(iporhost, addr, size) >= size)
162                         return (ENAMETOOLONG);
163         } else if (addr[0] == '[' && pp[-1] == ']') {
164                 size = (size_t)(pp - addr - 2 + 1);
165                 if (size > sizeof(iporhost))
166                         return (ENAMETOOLONG);
167                 (void)strlcpy(iporhost, addr + 1, size);
168         } else {
169                 size = (size_t)(pp - addr + 1);
170                 if (size > sizeof(iporhost))
171                         return (ENAMETOOLONG);
172                 (void)strlcpy(iporhost, addr, size);
173         }
174
175         error = getaddrinfo(iporhost, portstr, &hints, &res);
176         if (error != 0) {
177                 pjdlog_debug(1, "getaddrinfo(%s, %s) failed: %s.", iporhost,
178                     portstr, gai_strerror(error));
179                 return (EINVAL);
180         }
181         if (res == NULL)
182                 return (ENOENT);
183
184         memcpy(sap, res->ai_addr, res->ai_addrlen);
185
186         freeaddrinfo(res);
187
188         return (0);
189 }
190
191 static int
192 tcp_setup_new(const char *addr, int side, void **ctxp)
193 {
194         struct tcp_ctx *tctx;
195         int ret, nodelay;
196
197         PJDLOG_ASSERT(addr != NULL);
198         PJDLOG_ASSERT(side == TCP_SIDE_CLIENT ||
199             side == TCP_SIDE_SERVER_LISTEN);
200         PJDLOG_ASSERT(ctxp != NULL);
201
202         tctx = malloc(sizeof(*tctx));
203         if (tctx == NULL)
204                 return (errno);
205
206         /* Parse given address. */
207         if ((ret = tcp_addr(addr, PROTO_TCP_DEFAULT_PORT, &tctx->tc_sa)) != 0) {
208                 free(tctx);
209                 return (ret);
210         }
211
212         PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
213
214         tctx->tc_fd = socket(tctx->tc_sa.ss_family, SOCK_STREAM, 0);
215         if (tctx->tc_fd == -1) {
216                 ret = errno;
217                 free(tctx);
218                 return (ret);
219         }
220
221         PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
222
223         /* Socket settings. */
224         nodelay = 1;
225         if (setsockopt(tctx->tc_fd, IPPROTO_TCP, TCP_NODELAY, &nodelay,
226             sizeof(nodelay)) == -1) {
227                 pjdlog_errno(LOG_WARNING, "Unable to set TCP_NOELAY");
228         }
229
230         tctx->tc_side = side;
231         tctx->tc_magic = TCP_CTX_MAGIC;
232         *ctxp = tctx;
233
234         return (0);
235 }
236
237 static int
238 tcp_setup_wrap(int fd, int side, void **ctxp)
239 {
240         struct tcp_ctx *tctx;
241
242         PJDLOG_ASSERT(fd >= 0);
243         PJDLOG_ASSERT(side == TCP_SIDE_CLIENT ||
244             side == TCP_SIDE_SERVER_WORK);
245         PJDLOG_ASSERT(ctxp != NULL);
246
247         tctx = malloc(sizeof(*tctx));
248         if (tctx == NULL)
249                 return (errno);
250
251         tctx->tc_fd = fd;
252         tctx->tc_sa.ss_family = AF_UNSPEC;
253         tctx->tc_side = side;
254         tctx->tc_magic = TCP_CTX_MAGIC;
255         *ctxp = tctx;
256
257         return (0);
258 }
259
260 static int
261 tcp_client(const char *srcaddr, const char *dstaddr, void **ctxp)
262 {
263         struct tcp_ctx *tctx;
264         struct sockaddr_storage sa;
265         int ret;
266
267         ret = tcp_setup_new(dstaddr, TCP_SIDE_CLIENT, ctxp);
268         if (ret != 0)
269                 return (ret);
270         tctx = *ctxp;
271         if (srcaddr == NULL)
272                 return (0);
273         ret = tcp_addr(srcaddr, 0, &sa);
274         if (ret != 0) {
275                 tcp_close(tctx);
276                 return (ret);
277         }
278         if (bind(tctx->tc_fd, (struct sockaddr *)&sa, sa.ss_len) < 0) {
279                 ret = errno;
280                 tcp_close(tctx);
281                 return (ret);
282         }
283         return (0);
284 }
285
286 static int
287 tcp_connect(void *ctx, int timeout)
288 {
289         struct tcp_ctx *tctx = ctx;
290         int error, flags;
291
292         PJDLOG_ASSERT(tctx != NULL);
293         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
294         PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_CLIENT);
295         PJDLOG_ASSERT(tctx->tc_fd >= 0);
296         PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
297         PJDLOG_ASSERT(timeout >= -1);
298
299         flags = fcntl(tctx->tc_fd, F_GETFL);
300         if (flags == -1) {
301                 KEEP_ERRNO(pjdlog_common(LOG_DEBUG, 1, errno,
302                     "fcntl(F_GETFL) failed"));
303                 return (errno);
304         }
305         /*
306          * We make socket non-blocking so we can handle connection timeout
307          * manually.
308          */
309         flags |= O_NONBLOCK;
310         if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
311                 KEEP_ERRNO(pjdlog_common(LOG_DEBUG, 1, errno,
312                     "fcntl(F_SETFL, O_NONBLOCK) failed"));
313                 return (errno);
314         }
315
316         if (connect(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
317             tctx->tc_sa.ss_len) == 0) {
318                 if (timeout == -1)
319                         return (0);
320                 error = 0;
321                 goto done;
322         }
323         if (errno != EINPROGRESS) {
324                 error = errno;
325                 pjdlog_common(LOG_DEBUG, 1, errno, "connect() failed");
326                 goto done;
327         }
328         if (timeout == -1)
329                 return (0);
330         return (tcp_connect_wait(ctx, timeout));
331 done:
332         flags &= ~O_NONBLOCK;
333         if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
334                 if (error == 0)
335                         error = errno;
336                 pjdlog_common(LOG_DEBUG, 1, errno,
337                     "fcntl(F_SETFL, ~O_NONBLOCK) failed");
338         }
339         return (error);
340 }
341
342 static int
343 tcp_connect_wait(void *ctx, int timeout)
344 {
345         struct tcp_ctx *tctx = ctx;
346         struct timeval tv;
347         fd_set fdset;
348         socklen_t esize;
349         int error, flags, ret;
350
351         PJDLOG_ASSERT(tctx != NULL);
352         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
353         PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_CLIENT);
354         PJDLOG_ASSERT(tctx->tc_fd >= 0);
355         PJDLOG_ASSERT(timeout >= 0);
356
357         tv.tv_sec = timeout;
358         tv.tv_usec = 0;
359 again:
360         FD_ZERO(&fdset);
361         FD_SET(tctx->tc_fd, &fdset);
362         ret = select(tctx->tc_fd + 1, NULL, &fdset, NULL, &tv);
363         if (ret == 0) {
364                 error = ETIMEDOUT;
365                 goto done;
366         } else if (ret == -1) {
367                 if (errno == EINTR)
368                         goto again;
369                 error = errno;
370                 pjdlog_common(LOG_DEBUG, 1, errno, "select() failed");
371                 goto done;
372         }
373         PJDLOG_ASSERT(ret > 0);
374         PJDLOG_ASSERT(FD_ISSET(tctx->tc_fd, &fdset));
375         esize = sizeof(error);
376         if (getsockopt(tctx->tc_fd, SOL_SOCKET, SO_ERROR, &error,
377             &esize) == -1) {
378                 error = errno;
379                 pjdlog_common(LOG_DEBUG, 1, errno,
380                     "getsockopt(SO_ERROR) failed");
381                 goto done;
382         }
383         if (error != 0) {
384                 pjdlog_common(LOG_DEBUG, 1, error,
385                     "getsockopt(SO_ERROR) returned error");
386                 goto done;
387         }
388         error = 0;
389 done:
390         flags = fcntl(tctx->tc_fd, F_GETFL);
391         if (flags == -1) {
392                 if (error == 0)
393                         error = errno;
394                 pjdlog_common(LOG_DEBUG, 1, errno, "fcntl(F_GETFL) failed");
395                 return (error);
396         }
397         flags &= ~O_NONBLOCK;
398         if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
399                 if (error == 0)
400                         error = errno;
401                 pjdlog_common(LOG_DEBUG, 1, errno,
402                     "fcntl(F_SETFL, ~O_NONBLOCK) failed");
403         }
404         return (error);
405 }
406
407 static int
408 tcp_server(const char *addr, void **ctxp)
409 {
410         struct tcp_ctx *tctx;
411         int ret, val;
412
413         ret = tcp_setup_new(addr, TCP_SIDE_SERVER_LISTEN, ctxp);
414         if (ret != 0)
415                 return (ret);
416
417         tctx = *ctxp;
418
419         val = 1;
420         /* Ignore failure. */
421         (void)setsockopt(tctx->tc_fd, SOL_SOCKET, SO_REUSEADDR, &val,
422            sizeof(val));
423
424         PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
425
426         if (bind(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
427             tctx->tc_sa.ss_len) < 0) {
428                 ret = errno;
429                 tcp_close(tctx);
430                 return (ret);
431         }
432         if (listen(tctx->tc_fd, 8) < 0) {
433                 ret = errno;
434                 tcp_close(tctx);
435                 return (ret);
436         }
437
438         return (0);
439 }
440
441 static int
442 tcp_accept(void *ctx, void **newctxp)
443 {
444         struct tcp_ctx *tctx = ctx;
445         struct tcp_ctx *newtctx;
446         socklen_t fromlen;
447         int ret;
448
449         PJDLOG_ASSERT(tctx != NULL);
450         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
451         PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_SERVER_LISTEN);
452         PJDLOG_ASSERT(tctx->tc_fd >= 0);
453         PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
454
455         newtctx = malloc(sizeof(*newtctx));
456         if (newtctx == NULL)
457                 return (errno);
458
459         fromlen = tctx->tc_sa.ss_len;
460         newtctx->tc_fd = accept(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
461             &fromlen);
462         if (newtctx->tc_fd < 0) {
463                 ret = errno;
464                 free(newtctx);
465                 return (ret);
466         }
467
468         newtctx->tc_side = TCP_SIDE_SERVER_WORK;
469         newtctx->tc_magic = TCP_CTX_MAGIC;
470         *newctxp = newtctx;
471
472         return (0);
473 }
474
475 static int
476 tcp_wrap(int fd, bool client, void **ctxp)
477 {
478
479         return (tcp_setup_wrap(fd,
480             client ? TCP_SIDE_CLIENT : TCP_SIDE_SERVER_WORK, ctxp));
481 }
482
483 static int
484 tcp_send(void *ctx, const unsigned char *data, size_t size, int fd)
485 {
486         struct tcp_ctx *tctx = ctx;
487
488         PJDLOG_ASSERT(tctx != NULL);
489         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
490         PJDLOG_ASSERT(tctx->tc_fd >= 0);
491         PJDLOG_ASSERT(fd == -1);
492
493         return (proto_common_send(tctx->tc_fd, data, size, -1));
494 }
495
496 static int
497 tcp_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
498 {
499         struct tcp_ctx *tctx = ctx;
500
501         PJDLOG_ASSERT(tctx != NULL);
502         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
503         PJDLOG_ASSERT(tctx->tc_fd >= 0);
504         PJDLOG_ASSERT(fdp == NULL);
505
506         return (proto_common_recv(tctx->tc_fd, data, size, NULL));
507 }
508
509 static int
510 tcp_descriptor(const void *ctx)
511 {
512         const struct tcp_ctx *tctx = ctx;
513
514         PJDLOG_ASSERT(tctx != NULL);
515         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
516
517         return (tctx->tc_fd);
518 }
519
520 static bool
521 tcp_address_match(const void *ctx, const char *addr)
522 {
523         const struct tcp_ctx *tctx = ctx;
524         struct sockaddr_storage sa1, sa2;
525         socklen_t salen;
526
527         PJDLOG_ASSERT(tctx != NULL);
528         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
529
530         if (tcp_addr(addr, PROTO_TCP_DEFAULT_PORT, &sa1) != 0)
531                 return (false);
532
533         salen = sizeof(sa2);
534         if (getpeername(tctx->tc_fd, (struct sockaddr *)&sa2, &salen) < 0)
535                 return (false);
536
537         if (sa1.ss_family != sa2.ss_family || sa1.ss_len != sa2.ss_len)
538                 return (false);
539
540         switch (sa1.ss_family) {
541         case AF_INET:
542             {
543                 struct sockaddr_in *sin1, *sin2;
544
545                 sin1 = (struct sockaddr_in *)&sa1;
546                 sin2 = (struct sockaddr_in *)&sa2;
547
548                 return (memcmp(&sin1->sin_addr, &sin2->sin_addr,
549                     sizeof(sin1->sin_addr)) == 0);
550             }
551         case AF_INET6:
552             {
553                 struct sockaddr_in6 *sin1, *sin2;
554
555                 sin1 = (struct sockaddr_in6 *)&sa1;
556                 sin2 = (struct sockaddr_in6 *)&sa2;
557
558                 return (memcmp(&sin1->sin6_addr, &sin2->sin6_addr,
559                     sizeof(sin1->sin6_addr)) == 0);
560             }
561         default:
562                 return (false);
563         }
564 }
565
566 static void
567 tcp_local_address(const void *ctx, char *addr, size_t size)
568 {
569         const struct tcp_ctx *tctx = ctx;
570         struct sockaddr_storage sa;
571         socklen_t salen;
572
573         PJDLOG_ASSERT(tctx != NULL);
574         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
575
576         salen = sizeof(sa);
577         if (getsockname(tctx->tc_fd, (struct sockaddr *)&sa, &salen) < 0) {
578                 PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
579                 return;
580         }
581         PJDLOG_VERIFY(snprintf(addr, size, "tcp://%S", &sa) < (ssize_t)size);
582 }
583
584 static void
585 tcp_remote_address(const void *ctx, char *addr, size_t size)
586 {
587         const struct tcp_ctx *tctx = ctx;
588         struct sockaddr_storage sa;
589         socklen_t salen;
590
591         PJDLOG_ASSERT(tctx != NULL);
592         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
593
594         salen = sizeof(sa);
595         if (getpeername(tctx->tc_fd, (struct sockaddr *)&sa, &salen) < 0) {
596                 PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
597                 return;
598         }
599         PJDLOG_VERIFY(snprintf(addr, size, "tcp://%S", &sa) < (ssize_t)size);
600 }
601
602 static void
603 tcp_close(void *ctx)
604 {
605         struct tcp_ctx *tctx = ctx;
606
607         PJDLOG_ASSERT(tctx != NULL);
608         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
609
610         if (tctx->tc_fd >= 0)
611                 close(tctx->tc_fd);
612         tctx->tc_magic = 0;
613         free(tctx);
614 }
615
616 static struct proto tcp_proto = {
617         .prt_name = "tcp",
618         .prt_client = tcp_client,
619         .prt_connect = tcp_connect,
620         .prt_connect_wait = tcp_connect_wait,
621         .prt_server = tcp_server,
622         .prt_accept = tcp_accept,
623         .prt_wrap = tcp_wrap,
624         .prt_send = tcp_send,
625         .prt_recv = tcp_recv,
626         .prt_descriptor = tcp_descriptor,
627         .prt_address_match = tcp_address_match,
628         .prt_local_address = tcp_local_address,
629         .prt_remote_address = tcp_remote_address,
630         .prt_close = tcp_close
631 };
632
633 static __constructor void
634 tcp_ctor(void)
635 {
636
637         proto_register(&tcp_proto, true);
638 }