]> CyberLeo.Net >> Repos - FreeBSD/releng/10.0.git/blob - contrib/openbsm/bin/auditdistd/proto_tls.c
- Copy stable/10 (r259064) to releng/10.0 as part of the
[FreeBSD/releng/10.0.git] / contrib / openbsm / bin / auditdistd / proto_tls.c
1 /*-
2  * Copyright (c) 2011 The FreeBSD Foundation
3  * All rights reserved.
4  *
5  * This software was developed by Pawel Jakub Dawidek under sponsorship from
6  * the FreeBSD Foundation.
7  *
8  * Redistribution and use in source and binary forms, with or without
9  * modification, are permitted provided that the following conditions
10  * are met:
11  * 1. Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in the
15  *    documentation and/or other materials provided with the distribution.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
18  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
21  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27  * SUCH DAMAGE.
28  *
29  * $P4: //depot/projects/trustedbsd/openbsm/bin/auditdistd/proto_tls.c#2 $
30  */
31
32 #include <config/config.h>
33
34 #include <sys/param.h>  /* MAXHOSTNAMELEN */
35 #include <sys/socket.h>
36
37 #include <arpa/inet.h>
38
39 #include <netinet/in.h>
40 #include <netinet/tcp.h>
41
42 #include <errno.h>
43 #include <fcntl.h>
44 #include <netdb.h>
45 #include <signal.h>
46 #include <stdbool.h>
47 #include <stdint.h>
48 #include <stdio.h>
49 #include <string.h>
50 #include <unistd.h>
51
52 #include <openssl/err.h>
53 #include <openssl/ssl.h>
54
55 #include <compat/compat.h>
56 #ifndef HAVE_CLOSEFROM
57 #include <compat/closefrom.h>
58 #endif
59 #ifndef HAVE_STRLCPY
60 #include <compat/strlcpy.h>
61 #endif
62
63 #include "pjdlog.h"
64 #include "proto_impl.h"
65 #include "sandbox.h"
66 #include "subr.h"
67
68 #define TLS_CTX_MAGIC   0x715c7
69 struct tls_ctx {
70         int             tls_magic;
71         struct proto_conn *tls_sock;
72         struct proto_conn *tls_tcp;
73         char            tls_laddr[256];
74         char            tls_raddr[256];
75         int             tls_side;
76 #define TLS_SIDE_CLIENT         0
77 #define TLS_SIDE_SERVER_LISTEN  1
78 #define TLS_SIDE_SERVER_WORK    2
79         bool            tls_wait_called;
80 };
81
82 #define TLS_DEFAULT_TIMEOUT     30
83
84 static int tls_connect_wait(void *ctx, int timeout);
85 static void tls_close(void *ctx);
86
87 static void
88 block(int fd)
89 {
90         int flags;
91
92         flags = fcntl(fd, F_GETFL);
93         if (flags == -1)
94                 pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed");
95         flags &= ~O_NONBLOCK;
96         if (fcntl(fd, F_SETFL, flags) == -1)
97                 pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed");
98 }
99
100 static void
101 nonblock(int fd)
102 {
103         int flags;
104
105         flags = fcntl(fd, F_GETFL);
106         if (flags == -1)
107                 pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed");
108         flags |= O_NONBLOCK;
109         if (fcntl(fd, F_SETFL, flags) == -1)
110                 pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed");
111 }
112
113 static int
114 wait_for_fd(int fd, int timeout)
115 {
116         struct timeval tv;
117         fd_set fdset;
118         int error, ret;
119
120         error = 0;
121
122         for (;;) {
123                 FD_ZERO(&fdset);
124                 FD_SET(fd, &fdset);
125
126                 tv.tv_sec = timeout;
127                 tv.tv_usec = 0;
128
129                 ret = select(fd + 1, NULL, &fdset, NULL,
130                     timeout == -1 ? NULL : &tv);
131                 if (ret == 0) {
132                         error = ETIMEDOUT;
133                         break;
134                 } else if (ret == -1) {
135                         if (errno == EINTR)
136                                 continue;
137                         error = errno;
138                         break;
139                 }
140                 PJDLOG_ASSERT(ret > 0);
141                 PJDLOG_ASSERT(FD_ISSET(fd, &fdset));
142                 break;
143         }
144
145         return (error);
146 }
147
148 static void
149 ssl_log_errors(void)
150 {
151         unsigned long error;
152
153         while ((error = ERR_get_error()) != 0)
154                 pjdlog_error("SSL error: %s", ERR_error_string(error, NULL));
155 }
156
157 static int
158 ssl_check_error(SSL *ssl, int ret)
159 {
160         int error;
161
162         error = SSL_get_error(ssl, ret);
163
164         switch (error) {
165         case SSL_ERROR_NONE:
166                 return (0);
167         case SSL_ERROR_WANT_READ:
168                 pjdlog_debug(2, "SSL_ERROR_WANT_READ");
169                 return (-1);
170         case SSL_ERROR_WANT_WRITE:
171                 pjdlog_debug(2, "SSL_ERROR_WANT_WRITE");
172                 return (-1);
173         case SSL_ERROR_ZERO_RETURN:
174                 pjdlog_exitx(EX_OK, "Connection closed.");
175         case SSL_ERROR_SYSCALL:
176                 ssl_log_errors();
177                 pjdlog_exitx(EX_TEMPFAIL, "SSL I/O error.");
178         case SSL_ERROR_SSL:
179                 ssl_log_errors();
180                 pjdlog_exitx(EX_TEMPFAIL, "SSL protocol error.");
181         default:
182                 ssl_log_errors();
183                 pjdlog_exitx(EX_TEMPFAIL, "Unknown SSL error (%d).", error);
184         }
185 }
186
187 static void
188 tcp_recv_ssl_send(int recvfd, SSL *sendssl)
189 {
190         static unsigned char buf[65536];
191         ssize_t tcpdone;
192         int sendfd, ssldone;
193
194         sendfd = SSL_get_fd(sendssl);
195         PJDLOG_ASSERT(sendfd >= 0);
196         pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd);
197         for (;;) {
198                 tcpdone = recv(recvfd, buf, sizeof(buf), 0);
199                 pjdlog_debug(2, "%s: recv() returned %zd", __func__, tcpdone);
200                 if (tcpdone == 0) {
201                         pjdlog_debug(1, "Connection terminated.");
202                         exit(0);
203                 } else if (tcpdone == -1) {
204                         if (errno == EINTR)
205                                 continue;
206                         else if (errno == EAGAIN)
207                                 break;
208                         pjdlog_exit(EX_TEMPFAIL, "recv() failed");
209                 }
210                 for (;;) {
211                         ssldone = SSL_write(sendssl, buf, (int)tcpdone);
212                         pjdlog_debug(2, "%s: send() returned %d", __func__,
213                             ssldone);
214                         if (ssl_check_error(sendssl, ssldone) == -1) {
215                                 (void)wait_for_fd(sendfd, -1);
216                                 continue;
217                         }
218                         PJDLOG_ASSERT(ssldone == tcpdone);
219                         break;
220                 }
221         }
222         pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd);
223 }
224
225 static void
226 ssl_recv_tcp_send(SSL *recvssl, int sendfd)
227 {
228         static unsigned char buf[65536];
229         unsigned char *ptr;
230         ssize_t tcpdone;
231         size_t todo;
232         int recvfd, ssldone;
233
234         recvfd = SSL_get_fd(recvssl);
235         PJDLOG_ASSERT(recvfd >= 0);
236         pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd);
237         for (;;) {
238                 ssldone = SSL_read(recvssl, buf, sizeof(buf));
239                 pjdlog_debug(2, "%s: SSL_read() returned %d", __func__,
240                     ssldone);
241                 if (ssl_check_error(recvssl, ssldone) == -1)
242                         break;
243                 todo = (size_t)ssldone;
244                 ptr = buf;
245                 do {
246                         tcpdone = send(sendfd, ptr, todo, MSG_NOSIGNAL);
247                         pjdlog_debug(2, "%s: send() returned %zd", __func__,
248                             tcpdone);
249                         if (tcpdone == 0) {
250                                 pjdlog_debug(1, "Connection terminated.");
251                                 exit(0);
252                         } else if (tcpdone == -1) {
253                                 if (errno == EINTR || errno == ENOBUFS)
254                                         continue;
255                                 if (errno == EAGAIN) {
256                                         (void)wait_for_fd(sendfd, -1);
257                                         continue;
258                                 }
259                                 pjdlog_exit(EX_TEMPFAIL, "send() failed");
260                         }
261                         todo -= tcpdone;
262                         ptr += tcpdone;
263                 } while (todo > 0);
264         }
265         pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd);
266 }
267
268 static void
269 tls_loop(int sockfd, SSL *tcpssl)
270 {
271         fd_set fds;
272         int maxfd, tcpfd;
273
274         tcpfd = SSL_get_fd(tcpssl);
275         PJDLOG_ASSERT(tcpfd >= 0);
276
277         for (;;) {
278                 FD_ZERO(&fds);
279                 FD_SET(sockfd, &fds);
280                 FD_SET(tcpfd, &fds);
281                 maxfd = MAX(sockfd, tcpfd);
282
283                 PJDLOG_ASSERT(maxfd + 1 <= (int)FD_SETSIZE);
284                 if (select(maxfd + 1, &fds, NULL, NULL, NULL) == -1) {
285                         if (errno == EINTR)
286                                 continue;
287                         pjdlog_exit(EX_TEMPFAIL, "select() failed");
288                 }
289                 if (FD_ISSET(sockfd, &fds))
290                         tcp_recv_ssl_send(sockfd, tcpssl);
291                 if (FD_ISSET(tcpfd, &fds))
292                         ssl_recv_tcp_send(tcpssl, sockfd);
293         }
294 }
295
296 static void
297 tls_certificate_verify(SSL *ssl, const char *fingerprint)
298 {
299         unsigned char md[EVP_MAX_MD_SIZE];
300         char mdstr[sizeof("SHA256=") - 1 + EVP_MAX_MD_SIZE * 3];
301         char *mdstrp;
302         unsigned int i, mdsize;
303         X509 *cert;
304
305         if (fingerprint[0] == '\0') {
306                 pjdlog_debug(1, "No fingerprint verification requested.");
307                 return;
308         }
309
310         cert = SSL_get_peer_certificate(ssl);
311         if (cert == NULL)
312                 pjdlog_exitx(EX_TEMPFAIL, "No peer certificate received.");
313
314         if (X509_digest(cert, EVP_sha256(), md, &mdsize) != 1)
315                 pjdlog_exitx(EX_TEMPFAIL, "X509_digest() failed.");
316         PJDLOG_ASSERT(mdsize <= EVP_MAX_MD_SIZE);
317
318         X509_free(cert);
319
320         (void)strlcpy(mdstr, "SHA256=", sizeof(mdstr));
321         mdstrp = mdstr + strlen(mdstr);
322         for (i = 0; i < mdsize; i++) {
323                 PJDLOG_VERIFY(mdstrp + 3 <= mdstr + sizeof(mdstr));
324                 (void)sprintf(mdstrp, "%02hhX:", md[i]);
325                 mdstrp += 3;
326         }
327         /* Clear last colon. */
328         mdstrp[-1] = '\0';
329         if (strcasecmp(mdstr, fingerprint) != 0) {
330                 pjdlog_exitx(EX_NOPERM,
331                     "Finger print doesn't match. Received \"%s\", expected \"%s\"",
332                     mdstr, fingerprint);
333         }
334 }
335
336 static void
337 tls_exec_client(const char *user, int startfd, const char *srcaddr,
338     const char *dstaddr, const char *fingerprint, const char *defport,
339     int timeout, int debuglevel)
340 {
341         struct proto_conn *tcp;
342         char *saddr, *daddr;
343         SSL_CTX *sslctx;
344         SSL *ssl;
345         long ret;
346         int sockfd, tcpfd;
347         uint8_t connected;
348
349         pjdlog_debug_set(debuglevel);
350         pjdlog_prefix_set("[TLS sandbox] (client) ");
351 #ifdef HAVE_SETPROCTITLE
352         setproctitle("[TLS sandbox] (client) ");
353 #endif
354         proto_set("tcp:port", defport);
355
356         sockfd = startfd;
357
358         /* Change tls:// to tcp://. */
359         if (srcaddr == NULL) {
360                 saddr = NULL;
361         } else {
362                 saddr = strdup(srcaddr);
363                 if (saddr == NULL)
364                         pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory.");
365                 bcopy("tcp://", saddr, 6);
366         }
367         daddr = strdup(dstaddr);
368         if (daddr == NULL)
369                 pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory.");
370         bcopy("tcp://", daddr, 6);
371
372         /* Establish TCP connection. */
373         if (proto_connect(saddr, daddr, timeout, &tcp) == -1)
374                 exit(EX_TEMPFAIL);
375
376         SSL_load_error_strings();
377         SSL_library_init();
378
379         /*
380          * TODO: On FreeBSD we could move this below sandbox() once libc and
381          *       libcrypto use sysctl kern.arandom to obtain random data
382          *       instead of /dev/urandom and friends.
383          */
384         sslctx = SSL_CTX_new(TLSv1_client_method());
385         if (sslctx == NULL)
386                 pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed.");
387
388         if (sandbox(user, true, "proto_tls client: %s", dstaddr) != 0)
389                 pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS client.");
390         pjdlog_debug(1, "Privileges successfully dropped.");
391
392         SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
393
394         /* Load CA certs. */
395         /* TODO */
396         //SSL_CTX_load_verify_locations(sslctx, cacerts_file, NULL);
397
398         ssl = SSL_new(sslctx);
399         if (ssl == NULL)
400                 pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed.");
401
402         tcpfd = proto_descriptor(tcp);
403
404         block(tcpfd);
405
406         if (SSL_set_fd(ssl, tcpfd) != 1)
407                 pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed.");
408
409         ret = SSL_connect(ssl);
410         ssl_check_error(ssl, (int)ret);
411
412         nonblock(sockfd);
413         nonblock(tcpfd);
414
415         tls_certificate_verify(ssl, fingerprint);
416
417         /*
418          * The following byte is send to make proto_connect_wait() to work.
419          */
420         connected = 1;
421         for (;;) {
422                 switch (send(sockfd, &connected, sizeof(connected), 0)) {
423                 case -1:
424                         if (errno == EINTR || errno == ENOBUFS)
425                                 continue;
426                         if (errno == EAGAIN) {
427                                 (void)wait_for_fd(sockfd, -1);
428                                 continue;
429                         }
430                         pjdlog_exit(EX_TEMPFAIL, "send() failed");
431                 case 0:
432                         pjdlog_debug(1, "Connection terminated.");
433                         exit(0);
434                 case 1:
435                         break;
436                 }
437                 break;
438         }
439
440         tls_loop(sockfd, ssl);
441 }
442
443 static void
444 tls_call_exec_client(struct proto_conn *sock, const char *srcaddr,
445     const char *dstaddr, int timeout)
446 {
447         char *timeoutstr, *startfdstr, *debugstr;
448         int startfd;
449
450         /* Declare that we are receiver. */
451         proto_recv(sock, NULL, 0);
452
453         if (pjdlog_mode_get() == PJDLOG_MODE_STD)
454                 startfd = 3;
455         else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */
456                 startfd = 0;
457
458         if (proto_descriptor(sock) != startfd) {
459                 /* Move socketpair descriptor to descriptor number startfd. */
460                 if (dup2(proto_descriptor(sock), startfd) == -1)
461                         pjdlog_exit(EX_OSERR, "dup2() failed");
462                 proto_close(sock);
463         } else {
464                 /*
465                  * The FD_CLOEXEC is cleared by dup2(2), so when we not
466                  * call it, we have to clear it by hand in case it is set.
467                  */
468                 if (fcntl(startfd, F_SETFD, 0) == -1)
469                         pjdlog_exit(EX_OSERR, "fcntl() failed");
470         }
471
472         closefrom(startfd + 1);
473
474         if (asprintf(&startfdstr, "%d", startfd) == -1)
475                 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
476         if (timeout == -1)
477                 timeout = TLS_DEFAULT_TIMEOUT;
478         if (asprintf(&timeoutstr, "%d", timeout) == -1)
479                 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
480         if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1)
481                 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
482
483         execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls",
484             proto_get("user"), "client", startfdstr,
485             srcaddr == NULL ? "" : srcaddr, dstaddr,
486             proto_get("tls:fingerprint"), proto_get("tcp:port"), timeoutstr,
487             debugstr, NULL);
488         pjdlog_exit(EX_SOFTWARE, "execl() failed");
489 }
490
491 static int
492 tls_connect(const char *srcaddr, const char *dstaddr, int timeout, void **ctxp)
493 {
494         struct tls_ctx *tlsctx;
495         struct proto_conn *sock;
496         pid_t pid;
497         int error;
498
499         PJDLOG_ASSERT(srcaddr == NULL || srcaddr[0] != '\0');
500         PJDLOG_ASSERT(dstaddr != NULL);
501         PJDLOG_ASSERT(timeout >= -1);
502         PJDLOG_ASSERT(ctxp != NULL);
503
504         if (strncmp(dstaddr, "tls://", 6) != 0)
505                 return (-1);
506         if (srcaddr != NULL && strncmp(srcaddr, "tls://", 6) != 0)
507                 return (-1);
508
509         if (proto_connect(NULL, "socketpair://", -1, &sock) == -1)
510                 return (errno);
511
512 #if 0
513         /*
514          * We use rfork() with the following flags to disable SIGCHLD
515          * delivery upon the sandbox process exit.
516          */
517         pid = rfork(RFFDG | RFPROC | RFTSIGZMB | RFTSIGFLAGS(0));
518 #else
519         /*
520          * We don't use rfork() to be able to log information about sandbox
521          * process exiting.
522          */
523         pid = fork();
524 #endif
525         switch (pid) {
526         case -1:
527                 /* Failure. */
528                 error = errno;
529                 proto_close(sock);
530                 return (error);
531         case 0:
532                 /* Child. */
533                 pjdlog_prefix_set("[TLS sandbox] (client) ");
534 #ifdef HAVE_SETPROCTITLE
535                 setproctitle("[TLS sandbox] (client) ");
536 #endif
537                 tls_call_exec_client(sock, srcaddr, dstaddr, timeout);
538                 /* NOTREACHED */
539         default:
540                 /* Parent. */
541                 tlsctx = calloc(1, sizeof(*tlsctx));
542                 if (tlsctx == NULL) {
543                         error = errno;
544                         proto_close(sock);
545                         (void)kill(pid, SIGKILL);
546                         return (error);
547                 }
548                 proto_send(sock, NULL, 0);
549                 tlsctx->tls_sock = sock;
550                 tlsctx->tls_tcp = NULL;
551                 tlsctx->tls_side = TLS_SIDE_CLIENT;
552                 tlsctx->tls_wait_called = false;
553                 tlsctx->tls_magic = TLS_CTX_MAGIC;
554                 if (timeout >= 0) {
555                         error = tls_connect_wait(tlsctx, timeout);
556                         if (error != 0) {
557                                 (void)kill(pid, SIGKILL);
558                                 tls_close(tlsctx);
559                                 return (error);
560                         }
561                 }
562                 *ctxp = tlsctx;
563                 return (0);
564         }
565 }
566
567 static int
568 tls_connect_wait(void *ctx, int timeout)
569 {
570         struct tls_ctx *tlsctx = ctx;
571         int error, sockfd;
572         uint8_t connected;
573
574         PJDLOG_ASSERT(tlsctx != NULL);
575         PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
576         PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT);
577         PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
578         PJDLOG_ASSERT(!tlsctx->tls_wait_called);
579         PJDLOG_ASSERT(timeout >= 0);
580
581         sockfd = proto_descriptor(tlsctx->tls_sock);
582         error = wait_for_fd(sockfd, timeout);
583         if (error != 0)
584                 return (error);
585
586         for (;;) {
587                 switch (recv(sockfd, &connected, sizeof(connected),
588                     MSG_WAITALL)) {
589                 case -1:
590                         if (errno == EINTR || errno == ENOBUFS)
591                                 continue;
592                         error = errno;
593                         break;
594                 case 0:
595                         pjdlog_debug(1, "Connection terminated.");
596                         error = ENOTCONN;
597                         break;
598                 case 1:
599                         tlsctx->tls_wait_called = true;
600                         break;
601                 }
602                 break;
603         }
604
605         return (error);
606 }
607
608 static int
609 tls_server(const char *lstaddr, void **ctxp)
610 {
611         struct proto_conn *tcp;
612         struct tls_ctx *tlsctx;
613         char *laddr;
614         int error;
615
616         if (strncmp(lstaddr, "tls://", 6) != 0)
617                 return (-1);
618
619         tlsctx = malloc(sizeof(*tlsctx));
620         if (tlsctx == NULL) {
621                 pjdlog_warning("Unable to allocate memory.");
622                 return (ENOMEM);
623         }
624
625         laddr = strdup(lstaddr);
626         if (laddr == NULL) {
627                 free(tlsctx);
628                 pjdlog_warning("Unable to allocate memory.");
629                 return (ENOMEM);
630         }
631         bcopy("tcp://", laddr, 6);
632
633         if (proto_server(laddr, &tcp) == -1) {
634                 error = errno;
635                 free(tlsctx);
636                 free(laddr);
637                 return (error);
638         }
639         free(laddr);
640
641         tlsctx->tls_sock = NULL;
642         tlsctx->tls_tcp = tcp;
643         tlsctx->tls_side = TLS_SIDE_SERVER_LISTEN;
644         tlsctx->tls_wait_called = true;
645         tlsctx->tls_magic = TLS_CTX_MAGIC;
646         *ctxp = tlsctx;
647
648         return (0);
649 }
650
651 static void
652 tls_exec_server(const char *user, int startfd, const char *privkey,
653     const char *cert, int debuglevel)
654 {
655         SSL_CTX *sslctx;
656         SSL *ssl;
657         int sockfd, tcpfd, ret;
658
659         pjdlog_debug_set(debuglevel);
660         pjdlog_prefix_set("[TLS sandbox] (server) ");
661 #ifdef HAVE_SETPROCTITLE
662         setproctitle("[TLS sandbox] (server) ");
663 #endif
664
665         sockfd = startfd;
666         tcpfd = startfd + 1;
667
668         SSL_load_error_strings();
669         SSL_library_init();
670
671         sslctx = SSL_CTX_new(TLSv1_server_method());
672         if (sslctx == NULL)
673                 pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed.");
674
675         SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
676
677         ssl = SSL_new(sslctx);
678         if (ssl == NULL)
679                 pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed.");
680
681         if (SSL_use_RSAPrivateKey_file(ssl, privkey, SSL_FILETYPE_PEM) != 1) {
682                 ssl_log_errors();
683                 pjdlog_exitx(EX_CONFIG,
684                     "SSL_use_RSAPrivateKey_file(%s) failed.", privkey);
685         }
686
687         if (SSL_use_certificate_file(ssl, cert, SSL_FILETYPE_PEM) != 1) {
688                 ssl_log_errors();
689                 pjdlog_exitx(EX_CONFIG, "SSL_use_certificate_file(%s) failed.",
690                     cert);
691         }
692
693         if (sandbox(user, true, "proto_tls server") != 0)
694                 pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS server.");
695         pjdlog_debug(1, "Privileges successfully dropped.");
696
697         nonblock(sockfd);
698         nonblock(tcpfd);
699
700         if (SSL_set_fd(ssl, tcpfd) != 1)
701                 pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed.");
702
703         ret = SSL_accept(ssl);
704         ssl_check_error(ssl, ret);
705
706         tls_loop(sockfd, ssl);
707 }
708
709 static void
710 tls_call_exec_server(struct proto_conn *sock, struct proto_conn *tcp)
711 {
712         int startfd, sockfd, tcpfd, safefd;
713         char *startfdstr, *debugstr;
714
715         if (pjdlog_mode_get() == PJDLOG_MODE_STD)
716                 startfd = 3;
717         else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */
718                 startfd = 0;
719
720         /* Declare that we are receiver. */
721         proto_send(sock, NULL, 0);
722
723         sockfd = proto_descriptor(sock);
724         tcpfd = proto_descriptor(tcp);
725
726         safefd = MAX(sockfd, tcpfd);
727         safefd = MAX(safefd, startfd);
728         safefd++;
729
730         /* Move sockfd and tcpfd to safe numbers first. */
731         if (dup2(sockfd, safefd) == -1)
732                 pjdlog_exit(EX_OSERR, "dup2() failed");
733         proto_close(sock);
734         sockfd = safefd;
735         if (dup2(tcpfd, safefd + 1) == -1)
736                 pjdlog_exit(EX_OSERR, "dup2() failed");
737         proto_close(tcp);
738         tcpfd = safefd + 1;
739
740         /* Move socketpair descriptor to descriptor number startfd. */
741         if (dup2(sockfd, startfd) == -1)
742                 pjdlog_exit(EX_OSERR, "dup2() failed");
743         (void)close(sockfd);
744         /* Move tcp descriptor to descriptor number startfd + 1. */
745         if (dup2(tcpfd, startfd + 1) == -1)
746                 pjdlog_exit(EX_OSERR, "dup2() failed");
747         (void)close(tcpfd);
748
749         closefrom(startfd + 2);
750
751         /*
752          * Even if FD_CLOEXEC was set on descriptors before dup2(), it should
753          * have been cleared on dup2(), but better be safe than sorry.
754          */
755         if (fcntl(startfd, F_SETFD, 0) == -1)
756                 pjdlog_exit(EX_OSERR, "fcntl() failed");
757         if (fcntl(startfd + 1, F_SETFD, 0) == -1)
758                 pjdlog_exit(EX_OSERR, "fcntl() failed");
759
760         if (asprintf(&startfdstr, "%d", startfd) == -1)
761                 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
762         if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1)
763                 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
764
765         execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls",
766             proto_get("user"), "server", startfdstr, proto_get("tls:keyfile"),
767             proto_get("tls:certfile"), debugstr, NULL);
768         pjdlog_exit(EX_SOFTWARE, "execl() failed");
769 }
770
771 static int
772 tls_accept(void *ctx, void **newctxp)
773 {
774         struct tls_ctx *tlsctx = ctx;
775         struct tls_ctx *newtlsctx;
776         struct proto_conn *sock, *tcp;
777         pid_t pid;
778         int error;
779
780         PJDLOG_ASSERT(tlsctx != NULL);
781         PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
782         PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_SERVER_LISTEN);
783
784         if (proto_connect(NULL, "socketpair://", -1, &sock) == -1)
785                 return (errno);
786
787         /* Accept TCP connection. */
788         if (proto_accept(tlsctx->tls_tcp, &tcp) == -1) {
789                 error = errno;
790                 proto_close(sock);
791                 return (error);
792         }
793
794         pid = fork();
795         switch (pid) {
796         case -1:
797                 /* Failure. */
798                 error = errno;
799                 proto_close(sock);
800                 return (error);
801         case 0:
802                 /* Child. */
803                 pjdlog_prefix_set("[TLS sandbox] (server) ");
804 #ifdef HAVE_SETPROCTITLE
805                 setproctitle("[TLS sandbox] (server) ");
806 #endif
807                 /* Close listen socket. */
808                 proto_close(tlsctx->tls_tcp);
809                 tls_call_exec_server(sock, tcp);
810                 /* NOTREACHED */
811                 PJDLOG_ABORT("Unreachable.");
812         default:
813                 /* Parent. */
814                 newtlsctx = calloc(1, sizeof(*tlsctx));
815                 if (newtlsctx == NULL) {
816                         error = errno;
817                         proto_close(sock);
818                         proto_close(tcp);
819                         (void)kill(pid, SIGKILL);
820                         return (error);
821                 }
822                 proto_local_address(tcp, newtlsctx->tls_laddr,
823                     sizeof(newtlsctx->tls_laddr));
824                 PJDLOG_ASSERT(strncmp(newtlsctx->tls_laddr, "tcp://", 6) == 0);
825                 bcopy("tls://", newtlsctx->tls_laddr, 6);
826                 *strrchr(newtlsctx->tls_laddr, ':') = '\0';
827                 proto_remote_address(tcp, newtlsctx->tls_raddr,
828                     sizeof(newtlsctx->tls_raddr));
829                 PJDLOG_ASSERT(strncmp(newtlsctx->tls_raddr, "tcp://", 6) == 0);
830                 bcopy("tls://", newtlsctx->tls_raddr, 6);
831                 *strrchr(newtlsctx->tls_raddr, ':') = '\0';
832                 proto_close(tcp);
833                 proto_recv(sock, NULL, 0);
834                 newtlsctx->tls_sock = sock;
835                 newtlsctx->tls_tcp = NULL;
836                 newtlsctx->tls_wait_called = true;
837                 newtlsctx->tls_side = TLS_SIDE_SERVER_WORK;
838                 newtlsctx->tls_magic = TLS_CTX_MAGIC;
839                 *newctxp = newtlsctx;
840                 return (0);
841         }
842 }
843
844 static int
845 tls_wrap(int fd, bool client, void **ctxp)
846 {
847         struct tls_ctx *tlsctx;
848         struct proto_conn *sock;
849         int error;
850
851         tlsctx = calloc(1, sizeof(*tlsctx));
852         if (tlsctx == NULL)
853                 return (errno);
854
855         if (proto_wrap("socketpair", client, fd, &sock) == -1) {
856                 error = errno;
857                 free(tlsctx);
858                 return (error);
859         }
860
861         tlsctx->tls_sock = sock;
862         tlsctx->tls_tcp = NULL;
863         tlsctx->tls_wait_called = (client ? false : true);
864         tlsctx->tls_side = (client ? TLS_SIDE_CLIENT : TLS_SIDE_SERVER_WORK);
865         tlsctx->tls_magic = TLS_CTX_MAGIC;
866         *ctxp = tlsctx;
867
868         return (0);
869 }
870
871 static int
872 tls_send(void *ctx, const unsigned char *data, size_t size, int fd)
873 {
874         struct tls_ctx *tlsctx = ctx;
875
876         PJDLOG_ASSERT(tlsctx != NULL);
877         PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
878         PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT ||
879             tlsctx->tls_side == TLS_SIDE_SERVER_WORK);
880         PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
881         PJDLOG_ASSERT(tlsctx->tls_wait_called);
882         PJDLOG_ASSERT(fd == -1);
883
884         if (proto_send(tlsctx->tls_sock, data, size) == -1)
885                 return (errno);
886
887         return (0);
888 }
889
890 static int
891 tls_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
892 {
893         struct tls_ctx *tlsctx = ctx;
894
895         PJDLOG_ASSERT(tlsctx != NULL);
896         PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
897         PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT ||
898             tlsctx->tls_side == TLS_SIDE_SERVER_WORK);
899         PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
900         PJDLOG_ASSERT(tlsctx->tls_wait_called);
901         PJDLOG_ASSERT(fdp == NULL);
902
903         if (proto_recv(tlsctx->tls_sock, data, size) == -1)
904                 return (errno);
905
906         return (0);
907 }
908
909 static int
910 tls_descriptor(const void *ctx)
911 {
912         const struct tls_ctx *tlsctx = ctx;
913
914         PJDLOG_ASSERT(tlsctx != NULL);
915         PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
916
917         switch (tlsctx->tls_side) {
918         case TLS_SIDE_CLIENT:
919         case TLS_SIDE_SERVER_WORK:
920                 PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
921
922                 return (proto_descriptor(tlsctx->tls_sock));
923         case TLS_SIDE_SERVER_LISTEN:
924                 PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
925
926                 return (proto_descriptor(tlsctx->tls_tcp));
927         default:
928                 PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
929         }
930 }
931
932 static bool
933 tcp_address_match(const void *ctx, const char *addr)
934 {
935         const struct tls_ctx *tlsctx = ctx;
936
937         PJDLOG_ASSERT(tlsctx != NULL);
938         PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
939
940         return (strcmp(tlsctx->tls_raddr, addr) == 0);
941 }
942
943 static void
944 tls_local_address(const void *ctx, char *addr, size_t size)
945 {
946         const struct tls_ctx *tlsctx = ctx;
947
948         PJDLOG_ASSERT(tlsctx != NULL);
949         PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
950         PJDLOG_ASSERT(tlsctx->tls_wait_called);
951
952         switch (tlsctx->tls_side) {
953         case TLS_SIDE_CLIENT:
954                 PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
955
956                 PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size);
957                 break;
958         case TLS_SIDE_SERVER_WORK:
959                 PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
960
961                 PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_laddr, size) < size);
962                 break;
963         case TLS_SIDE_SERVER_LISTEN:
964                 PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
965
966                 proto_local_address(tlsctx->tls_tcp, addr, size);
967                 PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0);
968                 /* Replace tcp:// prefix with tls:// */
969                 bcopy("tls://", addr, 6);
970                 break;
971         default:
972                 PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
973         }
974 }
975
976 static void
977 tls_remote_address(const void *ctx, char *addr, size_t size)
978 {
979         const struct tls_ctx *tlsctx = ctx;
980
981         PJDLOG_ASSERT(tlsctx != NULL);
982         PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
983         PJDLOG_ASSERT(tlsctx->tls_wait_called);
984
985         switch (tlsctx->tls_side) {
986         case TLS_SIDE_CLIENT:
987                 PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
988
989                 PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size);
990                 break;
991         case TLS_SIDE_SERVER_WORK:
992                 PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
993
994                 PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_raddr, size) < size);
995                 break;
996         case TLS_SIDE_SERVER_LISTEN:
997                 PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
998
999                 proto_remote_address(tlsctx->tls_tcp, addr, size);
1000                 PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0);
1001                 /* Replace tcp:// prefix with tls:// */
1002                 bcopy("tls://", addr, 6);
1003                 break;
1004         default:
1005                 PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
1006         }
1007 }
1008
1009 static void
1010 tls_close(void *ctx)
1011 {
1012         struct tls_ctx *tlsctx = ctx;
1013
1014         PJDLOG_ASSERT(tlsctx != NULL);
1015         PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
1016
1017         if (tlsctx->tls_sock != NULL) {
1018                 proto_close(tlsctx->tls_sock);
1019                 tlsctx->tls_sock = NULL;
1020         }
1021         if (tlsctx->tls_tcp != NULL) {
1022                 proto_close(tlsctx->tls_tcp);
1023                 tlsctx->tls_tcp = NULL;
1024         }
1025         tlsctx->tls_side = 0;
1026         tlsctx->tls_magic = 0;
1027         free(tlsctx);
1028 }
1029
1030 static int
1031 tls_exec(int argc, char *argv[])
1032 {
1033
1034         PJDLOG_ASSERT(argc > 3);
1035         PJDLOG_ASSERT(strcmp(argv[0], "tls") == 0);
1036
1037         pjdlog_init(atoi(argv[3]) == 0 ? PJDLOG_MODE_SYSLOG : PJDLOG_MODE_STD);
1038
1039         if (strcmp(argv[2], "client") == 0) {
1040                 if (argc != 10)
1041                         return (EINVAL);
1042                 tls_exec_client(argv[1], atoi(argv[3]),
1043                     argv[4][0] == '\0' ? NULL : argv[4], argv[5], argv[6],
1044                     argv[7], atoi(argv[8]), atoi(argv[9]));
1045         } else if (strcmp(argv[2], "server") == 0) {
1046                 if (argc != 7)
1047                         return (EINVAL);
1048                 tls_exec_server(argv[1], atoi(argv[3]), argv[4], argv[5],
1049                     atoi(argv[6]));
1050         }
1051         return (EINVAL);
1052 }
1053
1054 static struct proto tls_proto = {
1055         .prt_name = "tls",
1056         .prt_connect = tls_connect,
1057         .prt_connect_wait = tls_connect_wait,
1058         .prt_server = tls_server,
1059         .prt_accept = tls_accept,
1060         .prt_wrap = tls_wrap,
1061         .prt_send = tls_send,
1062         .prt_recv = tls_recv,
1063         .prt_descriptor = tls_descriptor,
1064         .prt_address_match = tcp_address_match,
1065         .prt_local_address = tls_local_address,
1066         .prt_remote_address = tls_remote_address,
1067         .prt_close = tls_close,
1068         .prt_exec = tls_exec
1069 };
1070
1071 static __constructor void
1072 tls_ctor(void)
1073 {
1074
1075         proto_register(&tls_proto, false);
1076 }