]> CyberLeo.Net >> Repos - FreeBSD/releng/8.1.git/blob - sbin/hastd/proto_tcp4.c
Copy stable/8 to releng/8.1 in preparation for 8.1-RC1.
[FreeBSD/releng/8.1.git] / sbin / hastd / proto_tcp4.c
1 /*-
2  * Copyright (c) 2009-2010 The FreeBSD Foundation
3  * All rights reserved.
4  *
5  * This software was developed by Pawel Jakub Dawidek under sponsorship from
6  * the FreeBSD Foundation.
7  *
8  * Redistribution and use in source and binary forms, with or without
9  * modification, are permitted provided that the following conditions
10  * are met:
11  * 1. Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in the
15  *    documentation and/or other materials provided with the distribution.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
18  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
21  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27  * SUCH DAMAGE.
28  */
29
30 #include <sys/cdefs.h>
31 __FBSDID("$FreeBSD$");
32
33 #include <sys/param.h>  /* MAXHOSTNAMELEN */
34
35 #include <netinet/in.h>
36 #include <netinet/tcp.h>
37
38 #include <assert.h>
39 #include <errno.h>
40 #include <fcntl.h>
41 #include <netdb.h>
42 #include <stdbool.h>
43 #include <stdint.h>
44 #include <stdio.h>
45 #include <string.h>
46 #include <unistd.h>
47
48 #include "hast.h"
49 #include "pjdlog.h"
50 #include "proto_impl.h"
51 #include "subr.h"
52
53 #define TCP4_CTX_MAGIC  0x7c441c
54 struct tcp4_ctx {
55         int                     tc_magic;
56         struct sockaddr_in      tc_sin;
57         int                     tc_fd;
58         int                     tc_side;
59 #define TCP4_SIDE_CLIENT        0
60 #define TCP4_SIDE_SERVER_LISTEN 1
61 #define TCP4_SIDE_SERVER_WORK   2
62 };
63
64 static void tcp4_close(void *ctx);
65
66 static in_addr_t
67 str2ip(const char *str)
68 {
69         struct hostent *hp;
70         in_addr_t ip;
71
72         ip = inet_addr(str);
73         if (ip != INADDR_NONE) {
74                 /* It is a valid IP address. */
75                 return (ip);
76         }
77         /* Check if it is a valid host name. */
78         hp = gethostbyname(str);
79         if (hp == NULL)
80                 return (INADDR_NONE);
81         return (((struct in_addr *)(void *)hp->h_addr)->s_addr);
82 }
83
84 /*
85  * Function converts the given string to unsigned number.
86  */
87 static int
88 numfromstr(const char *str, intmax_t minnum, intmax_t maxnum, intmax_t *nump)
89 {
90         intmax_t digit, num;
91
92         if (str[0] == '\0')
93                 goto invalid;   /* Empty string. */
94         num = 0;
95         for (; *str != '\0'; str++) {
96                 if (*str < '0' || *str > '9')
97                         goto invalid;   /* Non-digit character. */
98                 digit = *str - '0';
99                 if (num > num * 10 + digit)
100                         goto invalid;   /* Overflow. */
101                 num = num * 10 + digit;
102                 if (num > maxnum)
103                         goto invalid;   /* Too big. */
104         }
105         if (num < minnum)
106                 goto invalid;   /* Too small. */
107         *nump = num;
108         return (0);
109 invalid:
110         errno = EINVAL;
111         return (-1);
112 }
113
114 static int
115 tcp4_addr(const char *addr, struct sockaddr_in *sinp)
116 {
117         char iporhost[MAXHOSTNAMELEN];
118         const char *pp;
119         size_t size;
120         in_addr_t ip;
121
122         if (addr == NULL)
123                 return (-1);
124
125         if (strncasecmp(addr, "tcp4://", 7) == 0)
126                 addr += 7;
127         else if (strncasecmp(addr, "tcp://", 6) == 0)
128                 addr += 6;
129         else if (addr[0] != '/' &&      /* If this is not path... */
130             strstr(addr, "://") == NULL)/* ...and has no prefix... */
131                 ;                       /* ...tcp4 is the default. */
132         else
133                 return (-1);
134
135         sinp->sin_family = AF_INET;
136         sinp->sin_len = sizeof(*sinp);
137         /* Extract optional port. */
138         pp = strrchr(addr, ':');
139         if (pp == NULL) {
140                 /* Port not given, use the default. */
141                 sinp->sin_port = htons(HASTD_PORT);
142         } else {
143                 intmax_t port;
144
145                 if (numfromstr(pp + 1, 1, 65535, &port) < 0)
146                         return (errno);
147                 sinp->sin_port = htons(port);
148         }
149         /* Extract host name or IP address. */
150         if (pp == NULL) {
151                 size = sizeof(iporhost);
152                 if (strlcpy(iporhost, addr, size) >= size)
153                         return (ENAMETOOLONG);
154         } else {
155                 size = (size_t)(pp - addr + 1);
156                 if (size > sizeof(iporhost))
157                         return (ENAMETOOLONG);
158                 strlcpy(iporhost, addr, size);
159         }
160         /* Convert string (IP address or host name) to in_addr_t. */
161         ip = str2ip(iporhost);
162         if (ip == INADDR_NONE)
163                 return (EINVAL);
164         sinp->sin_addr.s_addr = ip;
165
166         return (0);
167 }
168
169 static int
170 tcp4_common_setup(const char *addr, void **ctxp, int side)
171 {
172         struct tcp4_ctx *tctx;
173         int ret, val;
174
175         tctx = malloc(sizeof(*tctx));
176         if (tctx == NULL)
177                 return (errno);
178
179         /* Parse given address. */
180         if ((ret = tcp4_addr(addr, &tctx->tc_sin)) != 0) {
181                 free(tctx);
182                 return (ret);
183         }
184
185         tctx->tc_fd = socket(AF_INET, SOCK_STREAM, 0);
186         if (tctx->tc_fd == -1) {
187                 ret = errno;
188                 free(tctx);
189                 return (ret);
190         }
191
192         /* Socket settings. */
193         val = 1;
194         if (setsockopt(tctx->tc_fd, IPPROTO_TCP, TCP_NODELAY, &val,
195             sizeof(val)) == -1) {
196                 pjdlog_warning("Unable to set TCP_NOELAY on %s", addr);
197         }
198         val = 131072;
199         if (setsockopt(tctx->tc_fd, SOL_SOCKET, SO_SNDBUF, &val,
200             sizeof(val)) == -1) {
201                 pjdlog_warning("Unable to set send buffer size on %s", addr);
202         }
203         val = 131072;
204         if (setsockopt(tctx->tc_fd, SOL_SOCKET, SO_RCVBUF, &val,
205             sizeof(val)) == -1) {
206                 pjdlog_warning("Unable to set receive buffer size on %s", addr);
207         }
208
209         tctx->tc_side = side;
210         tctx->tc_magic = TCP4_CTX_MAGIC;
211         *ctxp = tctx;
212
213         return (0);
214 }
215
216 static int
217 tcp4_client(const char *addr, void **ctxp)
218 {
219
220         return (tcp4_common_setup(addr, ctxp, TCP4_SIDE_CLIENT));
221 }
222
223 static int
224 tcp4_connect(void *ctx)
225 {
226         struct tcp4_ctx *tctx = ctx;
227         struct timeval tv;
228         fd_set fdset;
229         socklen_t esize;
230         int error, flags, ret;
231
232         assert(tctx != NULL);
233         assert(tctx->tc_magic == TCP4_CTX_MAGIC);
234         assert(tctx->tc_side == TCP4_SIDE_CLIENT);
235         assert(tctx->tc_fd >= 0);
236
237         flags = fcntl(tctx->tc_fd, F_GETFL);
238         if (flags == -1) {
239                 KEEP_ERRNO(pjdlog_common(LOG_DEBUG, 1, errno,
240                     "fcntl(F_GETFL) failed"));
241                 return (errno);
242         }
243         /*
244          * We make socket non-blocking so we have decided about connection
245          * timeout.
246          */
247         flags |= O_NONBLOCK;
248         if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
249                 KEEP_ERRNO(pjdlog_common(LOG_DEBUG, 1, errno,
250                     "fcntl(F_SETFL, O_NONBLOCK) failed"));
251                 return (errno);
252         }
253
254         if (connect(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sin,
255             sizeof(tctx->tc_sin)) == 0) {
256                 error = 0;
257                 goto done;
258         }
259         if (errno != EINPROGRESS) {
260                 error = errno;
261                 pjdlog_common(LOG_DEBUG, 1, errno, "connect() failed");
262                 goto done;
263         }
264         /*
265          * Connection can't be established immediately, let's wait
266          * for HAST_TIMEOUT seconds.
267          */
268         tv.tv_sec = HAST_TIMEOUT;
269         tv.tv_usec = 0;
270 again:
271         FD_ZERO(&fdset);
272         FD_SET(tctx->tc_fd, &fdset); 
273         ret = select(tctx->tc_fd + 1, NULL, &fdset, NULL, &tv);
274         if (ret == 0) {
275                 error = ETIMEDOUT;
276                 goto done;
277         } else if (ret == -1) {
278                 if (errno == EINTR)
279                         goto again;
280                 error = errno;
281                 pjdlog_common(LOG_DEBUG, 1, errno, "select() failed");
282                 goto done;
283         }
284         assert(ret > 0);
285         assert(FD_ISSET(tctx->tc_fd, &fdset));
286         esize = sizeof(error);
287         if (getsockopt(tctx->tc_fd, SOL_SOCKET, SO_ERROR, &error,
288             &esize) == -1) {
289                 error = errno;
290                 pjdlog_common(LOG_DEBUG, 1, errno,
291                     "getsockopt(SO_ERROR) failed");
292                 goto done;
293         }
294         if (error != 0) {
295                 pjdlog_common(LOG_DEBUG, 1, error,
296                     "getsockopt(SO_ERROR) returned error");
297                 goto done;
298         }
299         error = 0;
300 done:
301         flags &= ~O_NONBLOCK;
302         if (fcntl(tctx->tc_fd, F_SETFL, flags) == -1) {
303                 if (error == 0)
304                         error = errno;
305                 pjdlog_common(LOG_DEBUG, 1, errno,
306                     "fcntl(F_SETFL, ~O_NONBLOCK) failed");
307         }
308         return (error);
309 }
310
311 static int
312 tcp4_server(const char *addr, void **ctxp)
313 {
314         struct tcp4_ctx *tctx;
315         int ret, val;
316
317         ret = tcp4_common_setup(addr, ctxp, TCP4_SIDE_SERVER_LISTEN);
318         if (ret != 0)
319                 return (ret);
320
321         tctx = *ctxp;
322
323         val = 1;
324         /* Ignore failure. */
325         (void)setsockopt(tctx->tc_fd, SOL_SOCKET, SO_REUSEADDR, &val,
326            sizeof(val));
327
328         if (bind(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sin,
329             sizeof(tctx->tc_sin)) < 0) {
330                 ret = errno;
331                 tcp4_close(tctx);
332                 return (ret);
333         }
334         if (listen(tctx->tc_fd, 8) < 0) {
335                 ret = errno;
336                 tcp4_close(tctx);
337                 return (ret);
338         }
339
340         return (0);
341 }
342
343 static int
344 tcp4_accept(void *ctx, void **newctxp)
345 {
346         struct tcp4_ctx *tctx = ctx;
347         struct tcp4_ctx *newtctx;
348         socklen_t fromlen;
349         int ret;
350
351         assert(tctx != NULL);
352         assert(tctx->tc_magic == TCP4_CTX_MAGIC);
353         assert(tctx->tc_side == TCP4_SIDE_SERVER_LISTEN);
354         assert(tctx->tc_fd >= 0);
355
356         newtctx = malloc(sizeof(*newtctx));
357         if (newtctx == NULL)
358                 return (errno);
359
360         fromlen = sizeof(tctx->tc_sin);
361         newtctx->tc_fd = accept(tctx->tc_fd, (struct sockaddr *)&tctx->tc_sin,
362             &fromlen);
363         if (newtctx->tc_fd < 0) {
364                 ret = errno;
365                 free(newtctx);
366                 return (ret);
367         }
368
369         newtctx->tc_side = TCP4_SIDE_SERVER_WORK;
370         newtctx->tc_magic = TCP4_CTX_MAGIC;
371         *newctxp = newtctx;
372
373         return (0);
374 }
375
376 static int
377 tcp4_send(void *ctx, const unsigned char *data, size_t size)
378 {
379         struct tcp4_ctx *tctx = ctx;
380
381         assert(tctx != NULL);
382         assert(tctx->tc_magic == TCP4_CTX_MAGIC);
383         assert(tctx->tc_fd >= 0);
384
385         return (proto_common_send(tctx->tc_fd, data, size));
386 }
387
388 static int
389 tcp4_recv(void *ctx, unsigned char *data, size_t size)
390 {
391         struct tcp4_ctx *tctx = ctx;
392
393         assert(tctx != NULL);
394         assert(tctx->tc_magic == TCP4_CTX_MAGIC);
395         assert(tctx->tc_fd >= 0);
396
397         return (proto_common_recv(tctx->tc_fd, data, size));
398 }
399
400 static int
401 tcp4_descriptor(const void *ctx)
402 {
403         const struct tcp4_ctx *tctx = ctx;
404
405         assert(tctx != NULL);
406         assert(tctx->tc_magic == TCP4_CTX_MAGIC);
407
408         return (tctx->tc_fd);
409 }
410
411 static void
412 sin2str(struct sockaddr_in *sinp, char *addr, size_t size)
413 {
414         in_addr_t ip;
415         unsigned int port;
416
417         assert(addr != NULL);
418         assert(sinp->sin_family == AF_INET);
419
420         ip = ntohl(sinp->sin_addr.s_addr);
421         port = ntohs(sinp->sin_port);
422         snprintf(addr, size, "tcp4://%u.%u.%u.%u:%u", ((ip >> 24) & 0xff),
423             ((ip >> 16) & 0xff), ((ip >> 8) & 0xff), (ip & 0xff), port);
424 }
425
426 static bool
427 tcp4_address_match(const void *ctx, const char *addr)
428 {
429         const struct tcp4_ctx *tctx = ctx;
430         struct sockaddr_in sin;
431         socklen_t sinlen;
432         in_addr_t ip1, ip2;
433
434         assert(tctx != NULL);
435         assert(tctx->tc_magic == TCP4_CTX_MAGIC);
436
437         if (tcp4_addr(addr, &sin) != 0)
438                 return (false);
439         ip1 = sin.sin_addr.s_addr;
440
441         sinlen = sizeof(sin);
442         if (getpeername(tctx->tc_fd, (struct sockaddr *)&sin, &sinlen) < 0)
443                 return (false);
444         ip2 = sin.sin_addr.s_addr;
445
446         return (ip1 == ip2);
447 }
448
449 static void
450 tcp4_local_address(const void *ctx, char *addr, size_t size)
451 {
452         const struct tcp4_ctx *tctx = ctx;
453         struct sockaddr_in sin;
454         socklen_t sinlen;
455
456         assert(tctx != NULL);
457         assert(tctx->tc_magic == TCP4_CTX_MAGIC);
458
459         sinlen = sizeof(sin);
460         if (getsockname(tctx->tc_fd, (struct sockaddr *)&sin, &sinlen) < 0) {
461                 strlcpy(addr, "N/A", size);
462                 return;
463         }
464         sin2str(&sin, addr, size);
465 }
466
467 static void
468 tcp4_remote_address(const void *ctx, char *addr, size_t size)
469 {
470         const struct tcp4_ctx *tctx = ctx;
471         struct sockaddr_in sin;
472         socklen_t sinlen;
473
474         assert(tctx != NULL);
475         assert(tctx->tc_magic == TCP4_CTX_MAGIC);
476
477         sinlen = sizeof(sin);
478         if (getpeername(tctx->tc_fd, (struct sockaddr *)&sin, &sinlen) < 0) {
479                 strlcpy(addr, "N/A", size);
480                 return;
481         }
482         sin2str(&sin, addr, size);
483 }
484
485 static void
486 tcp4_close(void *ctx)
487 {
488         struct tcp4_ctx *tctx = ctx;
489
490         assert(tctx != NULL);
491         assert(tctx->tc_magic == TCP4_CTX_MAGIC);
492
493         if (tctx->tc_fd >= 0)
494                 close(tctx->tc_fd);
495         tctx->tc_magic = 0;
496         free(tctx);
497 }
498
499 static struct hast_proto tcp4_proto = {
500         .hp_name = "tcp4",
501         .hp_client = tcp4_client,
502         .hp_connect = tcp4_connect,
503         .hp_server = tcp4_server,
504         .hp_accept = tcp4_accept,
505         .hp_send = tcp4_send,
506         .hp_recv = tcp4_recv,
507         .hp_descriptor = tcp4_descriptor,
508         .hp_address_match = tcp4_address_match,
509         .hp_local_address = tcp4_local_address,
510         .hp_remote_address = tcp4_remote_address,
511         .hp_close = tcp4_close
512 };
513
514 static __constructor void
515 tcp4_ctor(void)
516 {
517
518         proto_register(&tcp4_proto);
519 }