]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - sbin/hastd/proto_tcp.c
zfs: merge openzfs/zfs@e0bd8118d
[FreeBSD/FreeBSD.git] / sbin / hastd / proto_tcp.c
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause
3  *
4  * Copyright (c) 2009-2010 The FreeBSD Foundation
5  * Copyright (c) 2011 Pawel Jakub Dawidek <pawel@dawidek.net>
6  * All rights reserved.
7  *
8  * This software was developed by Pawel Jakub Dawidek under sponsorship from
9  * the FreeBSD Foundation.
10  *
11  * Redistribution and use in source and binary forms, with or without
12  * modification, are permitted provided that the following conditions
13  * are met:
14  * 1. Redistributions of source code must retain the above copyright
15  *    notice, this list of conditions and the following disclaimer.
16  * 2. Redistributions in binary form must reproduce the above copyright
17  *    notice, this list of conditions and the following disclaimer in the
18  *    documentation and/or other materials provided with the distribution.
19  *
20  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
21  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
24  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
26  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
27  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
28  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
29  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
30  * SUCH DAMAGE.
31  */
32
33 #include <sys/param.h>  /* MAXHOSTNAMELEN */
34 #include <sys/socket.h>
35
36 #include <arpa/inet.h>
37
38 #include <netinet/in.h>
39 #include <netinet/tcp.h>
40
41 #include <errno.h>
42 #include <fcntl.h>
43 #include <netdb.h>
44 #include <stdbool.h>
45 #include <stdint.h>
46 #include <stdio.h>
47 #include <string.h>
48 #include <unistd.h>
49
50 #include "pjdlog.h"
51 #include "proto_impl.h"
52 #include "subr.h"
53
54 #define TCP_CTX_MAGIC   0x7c41c
55 struct tcp_ctx {
56         int                     tc_magic;
57         struct sockaddr_storage tc_sa;
58         int                     tc_fd;
59         int                     tc_side;
60 #define TCP_SIDE_CLIENT         0
61 #define TCP_SIDE_SERVER_LISTEN  1
62 #define TCP_SIDE_SERVER_WORK    2
63 };
64
65 static int tcp_connect_wait(void *ctx, int timeout);
66 static void tcp_close(void *ctx);
67
68 /*
69  * Function converts the given string to unsigned number.
70  */
71 static int
72 numfromstr(const char *str, intmax_t minnum, intmax_t maxnum, intmax_t *nump)
73 {
74         intmax_t digit, num;
75
76         if (str[0] == '\0')
77                 goto invalid;   /* Empty string. */
78         num = 0;
79         for (; *str != '\0'; str++) {
80                 if (*str < '0' || *str > '9')
81                         goto invalid;   /* Non-digit character. */
82                 digit = *str - '0';
83                 if (num > num * 10 + digit)
84                         goto invalid;   /* Overflow. */
85                 num = num * 10 + digit;
86                 if (num > maxnum)
87                         goto invalid;   /* Too big. */
88         }
89         if (num < minnum)
90                 goto invalid;   /* Too small. */
91         *nump = num;
92         return (0);
93 invalid:
94         errno = EINVAL;
95         return (-1);
96 }
97
98 static int
99 tcp_addr(const char *addr, int defport, struct sockaddr_storage *sap)
100 {
101         char iporhost[MAXHOSTNAMELEN], portstr[6];
102         struct addrinfo hints;
103         struct addrinfo *res;
104         const char *pp;
105         intmax_t port;
106         size_t size;
107         int error;
108
109         if (addr == NULL)
110                 return (-1);
111
112         bzero(&hints, sizeof(hints));
113         hints.ai_flags = AI_ADDRCONFIG | AI_NUMERICSERV;
114         hints.ai_family = PF_UNSPEC;
115         hints.ai_socktype = SOCK_STREAM;
116         hints.ai_protocol = IPPROTO_TCP;
117
118         if (strncasecmp(addr, "tcp4://", 7) == 0) {
119                 addr += 7;
120                 hints.ai_family = PF_INET;
121         } else if (strncasecmp(addr, "tcp6://", 7) == 0) {
122                 addr += 7;
123                 hints.ai_family = PF_INET6;
124         } else if (strncasecmp(addr, "tcp://", 6) == 0) {
125                 addr += 6;
126         } else {
127                 /*
128                  * Because TCP is the default assume IP or host is given without
129                  * prefix.
130                  */
131         }
132
133         /*
134          * Extract optional port.
135          * There are three cases to consider.
136          * 1. hostname with port, eg. freefall.freebsd.org:8457
137          * 2. IPv4 address with port, eg. 192.168.0.101:8457
138          * 3. IPv6 address with port, eg. [fe80::1]:8457
139          * We discover IPv6 address by checking for two colons and if port is
140          * given, the address has to start with [.
141          */
142         pp = NULL;
143         if (strchr(addr, ':') != strrchr(addr, ':')) {
144                 if (addr[0] == '[')
145                         pp = strrchr(addr, ':');
146         } else {
147                 pp = strrchr(addr, ':');
148         }
149         if (pp == NULL) {
150                 /* Port not given, use the default. */
151                 port = defport;
152         } else {
153                 if (numfromstr(pp + 1, 1, 65535, &port) == -1)
154                         return (errno);
155         }
156         (void)snprintf(portstr, sizeof(portstr), "%jd", (intmax_t)port);
157         /* Extract host name or IP address. */
158         if (pp == NULL) {
159                 size = sizeof(iporhost);
160                 if (strlcpy(iporhost, addr, size) >= size)
161                         return (ENAMETOOLONG);
162         } else if (addr[0] == '[' && pp[-1] == ']') {
163                 size = (size_t)(pp - addr - 2 + 1);
164                 if (size > sizeof(iporhost))
165                         return (ENAMETOOLONG);
166                 (void)strlcpy(iporhost, addr + 1, size);
167         } else {
168                 size = (size_t)(pp - addr + 1);
169                 if (size > sizeof(iporhost))
170                         return (ENAMETOOLONG);
171                 (void)strlcpy(iporhost, addr, size);
172         }
173
174         error = getaddrinfo(iporhost, portstr, &hints, &res);
175         if (error != 0) {
176                 pjdlog_debug(1, "getaddrinfo(%s, %s) failed: %s.", iporhost,
177                     portstr, gai_strerror(error));
178                 return (EINVAL);
179         }
180         if (res == NULL)
181                 return (ENOENT);
182
183         memcpy(sap, res->ai_addr, res->ai_addrlen);
184
185         freeaddrinfo(res);
186
187         return (0);
188 }
189
190 static int
191 tcp_setup_new(const char *addr, int side, void **ctxp)
192 {
193         struct tcp_ctx *tctx;
194         int ret, nodelay;
195
196         PJDLOG_ASSERT(addr != NULL);
197         PJDLOG_ASSERT(side == TCP_SIDE_CLIENT ||
198             side == TCP_SIDE_SERVER_LISTEN);
199         PJDLOG_ASSERT(ctxp != NULL);
200
201         tctx = malloc(sizeof(*tctx));
202         if (tctx == NULL)
203                 return (errno);
204
205         /* Parse given address. */
206         if ((ret = tcp_addr(addr, PROTO_TCP_DEFAULT_PORT, &tctx->tc_sa)) != 0) {
207                 free(tctx);
208                 return (ret);
209         }
210
211         PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
212
213         tctx->tc_fd = socket(tctx->tc_sa.ss_family, SOCK_STREAM, 0);
214         if (tctx->tc_fd == -1) {
215                 ret = errno;
216                 free(tctx);
217                 return (ret);
218         }
219
220         PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
221
222         /* Socket settings. */
223         nodelay = 1;
224         if (setsockopt(tctx->tc_fd, IPPROTO_TCP, TCP_NODELAY, &nodelay,
225             sizeof(nodelay)) == -1) {
226                 pjdlog_errno(LOG_WARNING, "Unable to set TCP_NOELAY");
227         }
228
229         tctx->tc_side = side;
230         tctx->tc_magic = TCP_CTX_MAGIC;
231         *ctxp = tctx;
232
233         return (0);
234 }
235
236 static int
237 tcp_setup_wrap(int fd, int side, void **ctxp)
238 {
239         struct tcp_ctx *tctx;
240
241         PJDLOG_ASSERT(fd >= 0);
242         PJDLOG_ASSERT(side == TCP_SIDE_CLIENT ||
243             side == TCP_SIDE_SERVER_WORK);
244         PJDLOG_ASSERT(ctxp != NULL);
245
246         tctx = malloc(sizeof(*tctx));
247         if (tctx == NULL)
248                 return (errno);
249
250         tctx->tc_fd = fd;
251         tctx->tc_sa.ss_family = AF_UNSPEC;
252         tctx->tc_side = side;
253         tctx->tc_magic = TCP_CTX_MAGIC;
254         *ctxp = tctx;
255
256         return (0);
257 }
258
259 static int
260 tcp_client(const char *srcaddr, const char *dstaddr, void **ctxp)
261 {
262         struct tcp_ctx *tctx;
263         struct sockaddr_storage sa;
264         int ret;
265
266         ret = tcp_setup_new(dstaddr, TCP_SIDE_CLIENT, ctxp);
267         if (ret != 0)
268                 return (ret);
269         tctx = *ctxp;
270         if (srcaddr == NULL)
271                 return (0);
272         ret = tcp_addr(srcaddr, 0, &sa);
273         if (ret != 0) {
274                 tcp_close(tctx);
275                 return (ret);
276         }
277         if (bind(tctx->tc_fd, (struct sockaddr *)&sa, sa.ss_len) == -1) {
278                 ret = errno;
279                 tcp_close(tctx);
280                 return (ret);
281         }
282         return (0);
283 }
284
285 static int
286 tcp_connect(void *ctx, int timeout)
287 {
288         struct tcp_ctx *tctx = ctx;
289         int error, flags;
290
291         PJDLOG_ASSERT(tctx != NULL);
292         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
293         PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_CLIENT);
294         PJDLOG_ASSERT(tctx->tc_fd >= 0);
295         PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
296         PJDLOG_ASSERT(timeout >= -1);
297
298         flags = fcntl(tctx->tc_fd, F_GETFL);
299         if (flags == -1) {
300                 pjdlog_common(LOG_DEBUG, 1, errno, "fcntl(F_GETFL) failed");
301                 return (errno);
302         }
303         /*
304          * We make socket non-blocking so we can handle connection timeout
305          * manually.
306          */
307         flags |= O_NONBLOCK;
308         if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
309                 pjdlog_common(LOG_DEBUG, 1, errno,
310                     "fcntl(F_SETFL, O_NONBLOCK) failed");
311                 return (errno);
312         }
313
314         if (connect(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
315             tctx->tc_sa.ss_len) == 0) {
316                 if (timeout == -1)
317                         return (0);
318                 error = 0;
319                 goto done;
320         }
321         if (errno != EINPROGRESS) {
322                 error = errno;
323                 pjdlog_common(LOG_DEBUG, 1, errno, "connect() failed");
324                 goto done;
325         }
326         if (timeout == -1)
327                 return (0);
328         return (tcp_connect_wait(ctx, timeout));
329 done:
330         flags &= ~O_NONBLOCK;
331         if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
332                 if (error == 0)
333                         error = errno;
334                 pjdlog_common(LOG_DEBUG, 1, errno,
335                     "fcntl(F_SETFL, ~O_NONBLOCK) failed");
336         }
337         return (error);
338 }
339
340 static int
341 tcp_connect_wait(void *ctx, int timeout)
342 {
343         struct tcp_ctx *tctx = ctx;
344         struct timeval tv;
345         fd_set fdset;
346         socklen_t esize;
347         int error, flags, ret;
348
349         PJDLOG_ASSERT(tctx != NULL);
350         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
351         PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_CLIENT);
352         PJDLOG_ASSERT(tctx->tc_fd >= 0);
353         PJDLOG_ASSERT(timeout >= 0);
354
355         tv.tv_sec = timeout;
356         tv.tv_usec = 0;
357 again:
358         FD_ZERO(&fdset);
359         FD_SET(tctx->tc_fd, &fdset);
360         ret = select(tctx->tc_fd + 1, NULL, &fdset, NULL, &tv);
361         if (ret == 0) {
362                 error = ETIMEDOUT;
363                 goto done;
364         } else if (ret == -1) {
365                 if (errno == EINTR)
366                         goto again;
367                 error = errno;
368                 pjdlog_common(LOG_DEBUG, 1, errno, "select() failed");
369                 goto done;
370         }
371         PJDLOG_ASSERT(ret > 0);
372         PJDLOG_ASSERT(FD_ISSET(tctx->tc_fd, &fdset));
373         esize = sizeof(error);
374         if (getsockopt(tctx->tc_fd, SOL_SOCKET, SO_ERROR, &error,
375             &esize) == -1) {
376                 error = errno;
377                 pjdlog_common(LOG_DEBUG, 1, errno,
378                     "getsockopt(SO_ERROR) failed");
379                 goto done;
380         }
381         if (error != 0) {
382                 pjdlog_common(LOG_DEBUG, 1, error,
383                     "getsockopt(SO_ERROR) returned error");
384                 goto done;
385         }
386         error = 0;
387 done:
388         flags = fcntl(tctx->tc_fd, F_GETFL);
389         if (flags == -1) {
390                 if (error == 0)
391                         error = errno;
392                 pjdlog_common(LOG_DEBUG, 1, errno, "fcntl(F_GETFL) failed");
393                 return (error);
394         }
395         flags &= ~O_NONBLOCK;
396         if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
397                 if (error == 0)
398                         error = errno;
399                 pjdlog_common(LOG_DEBUG, 1, errno,
400                     "fcntl(F_SETFL, ~O_NONBLOCK) failed");
401         }
402         return (error);
403 }
404
405 static int
406 tcp_server(const char *addr, void **ctxp)
407 {
408         struct tcp_ctx *tctx;
409         int ret, val;
410
411         ret = tcp_setup_new(addr, TCP_SIDE_SERVER_LISTEN, ctxp);
412         if (ret != 0)
413                 return (ret);
414
415         tctx = *ctxp;
416
417         val = 1;
418         /* Ignore failure. */
419         (void)setsockopt(tctx->tc_fd, SOL_SOCKET, SO_REUSEADDR, &val,
420            sizeof(val));
421
422         PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
423
424         if (bind(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
425             tctx->tc_sa.ss_len) == -1) {
426                 ret = errno;
427                 tcp_close(tctx);
428                 return (ret);
429         }
430         if (listen(tctx->tc_fd, 8) == -1) {
431                 ret = errno;
432                 tcp_close(tctx);
433                 return (ret);
434         }
435
436         return (0);
437 }
438
439 static int
440 tcp_accept(void *ctx, void **newctxp)
441 {
442         struct tcp_ctx *tctx = ctx;
443         struct tcp_ctx *newtctx;
444         socklen_t fromlen;
445         int ret;
446
447         PJDLOG_ASSERT(tctx != NULL);
448         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
449         PJDLOG_ASSERT(tctx->tc_side == TCP_SIDE_SERVER_LISTEN);
450         PJDLOG_ASSERT(tctx->tc_fd >= 0);
451         PJDLOG_ASSERT(tctx->tc_sa.ss_family != AF_UNSPEC);
452
453         newtctx = malloc(sizeof(*newtctx));
454         if (newtctx == NULL)
455                 return (errno);
456
457         fromlen = tctx->tc_sa.ss_len;
458         newtctx->tc_fd = accept(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sa,
459             &fromlen);
460         if (newtctx->tc_fd == -1) {
461                 ret = errno;
462                 free(newtctx);
463                 return (ret);
464         }
465
466         newtctx->tc_side = TCP_SIDE_SERVER_WORK;
467         newtctx->tc_magic = TCP_CTX_MAGIC;
468         *newctxp = newtctx;
469
470         return (0);
471 }
472
473 static int
474 tcp_wrap(int fd, bool client, void **ctxp)
475 {
476
477         return (tcp_setup_wrap(fd,
478             client ? TCP_SIDE_CLIENT : TCP_SIDE_SERVER_WORK, ctxp));
479 }
480
481 static int
482 tcp_send(void *ctx, const unsigned char *data, size_t size, int fd)
483 {
484         struct tcp_ctx *tctx = ctx;
485
486         PJDLOG_ASSERT(tctx != NULL);
487         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
488         PJDLOG_ASSERT(tctx->tc_fd >= 0);
489         PJDLOG_ASSERT(fd == -1);
490
491         return (proto_common_send(tctx->tc_fd, data, size, -1));
492 }
493
494 static int
495 tcp_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
496 {
497         struct tcp_ctx *tctx = ctx;
498
499         PJDLOG_ASSERT(tctx != NULL);
500         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
501         PJDLOG_ASSERT(tctx->tc_fd >= 0);
502         PJDLOG_ASSERT(fdp == NULL);
503
504         return (proto_common_recv(tctx->tc_fd, data, size, NULL));
505 }
506
507 static int
508 tcp_descriptor(const void *ctx)
509 {
510         const struct tcp_ctx *tctx = ctx;
511
512         PJDLOG_ASSERT(tctx != NULL);
513         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
514
515         return (tctx->tc_fd);
516 }
517
518 static bool
519 tcp_address_match(const void *ctx, const char *addr)
520 {
521         const struct tcp_ctx *tctx = ctx;
522         struct sockaddr_storage sa1, sa2;
523         socklen_t salen;
524
525         PJDLOG_ASSERT(tctx != NULL);
526         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
527
528         if (tcp_addr(addr, PROTO_TCP_DEFAULT_PORT, &sa1) != 0)
529                 return (false);
530
531         salen = sizeof(sa2);
532         if (getpeername(tctx->tc_fd, (struct sockaddr *)&sa2, &salen) == -1)
533                 return (false);
534
535         if (sa1.ss_family != sa2.ss_family || sa1.ss_len != sa2.ss_len)
536                 return (false);
537
538         switch (sa1.ss_family) {
539         case AF_INET:
540             {
541                 struct sockaddr_in *sin1, *sin2;
542
543                 sin1 = (struct sockaddr_in *)&sa1;
544                 sin2 = (struct sockaddr_in *)&sa2;
545
546                 return (memcmp(&sin1->sin_addr, &sin2->sin_addr,
547                     sizeof(sin1->sin_addr)) == 0);
548             }
549         case AF_INET6:
550             {
551                 struct sockaddr_in6 *sin1, *sin2;
552
553                 sin1 = (struct sockaddr_in6 *)&sa1;
554                 sin2 = (struct sockaddr_in6 *)&sa2;
555
556                 return (memcmp(&sin1->sin6_addr, &sin2->sin6_addr,
557                     sizeof(sin1->sin6_addr)) == 0);
558             }
559         default:
560                 return (false);
561         }
562 }
563
564 static void
565 tcp_local_address(const void *ctx, char *addr, size_t size)
566 {
567         const struct tcp_ctx *tctx = ctx;
568         struct sockaddr_storage sa;
569         socklen_t salen;
570
571         PJDLOG_ASSERT(tctx != NULL);
572         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
573
574         salen = sizeof(sa);
575         if (getsockname(tctx->tc_fd, (struct sockaddr *)&sa, &salen) == -1) {
576                 PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
577                 return;
578         }
579         PJDLOG_VERIFY(snprintf(addr, size, "tcp://%S", &sa) < (ssize_t)size);
580 }
581
582 static void
583 tcp_remote_address(const void *ctx, char *addr, size_t size)
584 {
585         const struct tcp_ctx *tctx = ctx;
586         struct sockaddr_storage sa;
587         socklen_t salen;
588
589         PJDLOG_ASSERT(tctx != NULL);
590         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
591
592         salen = sizeof(sa);
593         if (getpeername(tctx->tc_fd, (struct sockaddr *)&sa, &salen) == -1) {
594                 PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
595                 return;
596         }
597         PJDLOG_VERIFY(snprintf(addr, size, "tcp://%S", &sa) < (ssize_t)size);
598 }
599
600 static void
601 tcp_close(void *ctx)
602 {
603         struct tcp_ctx *tctx = ctx;
604
605         PJDLOG_ASSERT(tctx != NULL);
606         PJDLOG_ASSERT(tctx->tc_magic == TCP_CTX_MAGIC);
607
608         if (tctx->tc_fd >= 0)
609                 close(tctx->tc_fd);
610         tctx->tc_magic = 0;
611         free(tctx);
612 }
613
614 static struct proto tcp_proto = {
615         .prt_name = "tcp",
616         .prt_client = tcp_client,
617         .prt_connect = tcp_connect,
618         .prt_connect_wait = tcp_connect_wait,
619         .prt_server = tcp_server,
620         .prt_accept = tcp_accept,
621         .prt_wrap = tcp_wrap,
622         .prt_send = tcp_send,
623         .prt_recv = tcp_recv,
624         .prt_descriptor = tcp_descriptor,
625         .prt_address_match = tcp_address_match,
626         .prt_local_address = tcp_local_address,
627         .prt_remote_address = tcp_remote_address,
628         .prt_close = tcp_close
629 };
630
631 static __constructor void
632 tcp_ctor(void)
633 {
634
635         proto_register(&tcp_proto, true);
636 }