]> CyberLeo.Net >> Repos - FreeBSD/releng/10.3.git/blob - contrib/openbsm/bin/auditdistd/proto_tls.c
- Copy stable/10@296371 to releng/10.3 in preparation for 10.3-RC1
[FreeBSD/releng/10.3.git] / contrib / openbsm / bin / auditdistd / proto_tls.c
1 /*-
2  * Copyright (c) 2011 The FreeBSD Foundation
3  * All rights reserved.
4  *
5  * This software was developed by Pawel Jakub Dawidek under sponsorship from
6  * the FreeBSD Foundation.
7  *
8  * Redistribution and use in source and binary forms, with or without
9  * modification, are permitted provided that the following conditions
10  * are met:
11  * 1. Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in the
15  *    documentation and/or other materials provided with the distribution.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
18  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
21  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27  * SUCH DAMAGE.
28  */
29
30 #include <config/config.h>
31
32 #include <sys/param.h>  /* MAXHOSTNAMELEN */
33 #include <sys/socket.h>
34
35 #include <arpa/inet.h>
36
37 #include <netinet/in.h>
38 #include <netinet/tcp.h>
39
40 #include <errno.h>
41 #include <fcntl.h>
42 #include <netdb.h>
43 #include <signal.h>
44 #include <stdbool.h>
45 #include <stdint.h>
46 #include <stdio.h>
47 #include <string.h>
48 #include <unistd.h>
49
50 #include <openssl/err.h>
51 #include <openssl/ssl.h>
52
53 #include <compat/compat.h>
54 #ifndef HAVE_CLOSEFROM
55 #include <compat/closefrom.h>
56 #endif
57 #ifndef HAVE_STRLCPY
58 #include <compat/strlcpy.h>
59 #endif
60
61 #include "pjdlog.h"
62 #include "proto_impl.h"
63 #include "sandbox.h"
64 #include "subr.h"
65
66 #define TLS_CTX_MAGIC   0x715c7
67 struct tls_ctx {
68         int             tls_magic;
69         struct proto_conn *tls_sock;
70         struct proto_conn *tls_tcp;
71         char            tls_laddr[256];
72         char            tls_raddr[256];
73         int             tls_side;
74 #define TLS_SIDE_CLIENT         0
75 #define TLS_SIDE_SERVER_LISTEN  1
76 #define TLS_SIDE_SERVER_WORK    2
77         bool            tls_wait_called;
78 };
79
80 #define TLS_DEFAULT_TIMEOUT     30
81
82 static int tls_connect_wait(void *ctx, int timeout);
83 static void tls_close(void *ctx);
84
85 static void
86 block(int fd)
87 {
88         int flags;
89
90         flags = fcntl(fd, F_GETFL);
91         if (flags == -1)
92                 pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed");
93         flags &= ~O_NONBLOCK;
94         if (fcntl(fd, F_SETFL, flags) == -1)
95                 pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed");
96 }
97
98 static void
99 nonblock(int fd)
100 {
101         int flags;
102
103         flags = fcntl(fd, F_GETFL);
104         if (flags == -1)
105                 pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed");
106         flags |= O_NONBLOCK;
107         if (fcntl(fd, F_SETFL, flags) == -1)
108                 pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed");
109 }
110
111 static int
112 wait_for_fd(int fd, int timeout)
113 {
114         struct timeval tv;
115         fd_set fdset;
116         int error, ret;
117
118         error = 0;
119
120         for (;;) {
121                 FD_ZERO(&fdset);
122                 FD_SET(fd, &fdset);
123
124                 tv.tv_sec = timeout;
125                 tv.tv_usec = 0;
126
127                 ret = select(fd + 1, NULL, &fdset, NULL,
128                     timeout == -1 ? NULL : &tv);
129                 if (ret == 0) {
130                         error = ETIMEDOUT;
131                         break;
132                 } else if (ret == -1) {
133                         if (errno == EINTR)
134                                 continue;
135                         error = errno;
136                         break;
137                 }
138                 PJDLOG_ASSERT(ret > 0);
139                 PJDLOG_ASSERT(FD_ISSET(fd, &fdset));
140                 break;
141         }
142
143         return (error);
144 }
145
146 static void
147 ssl_log_errors(void)
148 {
149         unsigned long error;
150
151         while ((error = ERR_get_error()) != 0)
152                 pjdlog_error("SSL error: %s", ERR_error_string(error, NULL));
153 }
154
155 static int
156 ssl_check_error(SSL *ssl, int ret)
157 {
158         int error;
159
160         error = SSL_get_error(ssl, ret);
161
162         switch (error) {
163         case SSL_ERROR_NONE:
164                 return (0);
165         case SSL_ERROR_WANT_READ:
166                 pjdlog_debug(2, "SSL_ERROR_WANT_READ");
167                 return (-1);
168         case SSL_ERROR_WANT_WRITE:
169                 pjdlog_debug(2, "SSL_ERROR_WANT_WRITE");
170                 return (-1);
171         case SSL_ERROR_ZERO_RETURN:
172                 pjdlog_exitx(EX_OK, "Connection closed.");
173         case SSL_ERROR_SYSCALL:
174                 ssl_log_errors();
175                 pjdlog_exitx(EX_TEMPFAIL, "SSL I/O error.");
176         case SSL_ERROR_SSL:
177                 ssl_log_errors();
178                 pjdlog_exitx(EX_TEMPFAIL, "SSL protocol error.");
179         default:
180                 ssl_log_errors();
181                 pjdlog_exitx(EX_TEMPFAIL, "Unknown SSL error (%d).", error);
182         }
183 }
184
185 static void
186 tcp_recv_ssl_send(int recvfd, SSL *sendssl)
187 {
188         static unsigned char buf[65536];
189         ssize_t tcpdone;
190         int sendfd, ssldone;
191
192         sendfd = SSL_get_fd(sendssl);
193         PJDLOG_ASSERT(sendfd >= 0);
194         pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd);
195         for (;;) {
196                 tcpdone = recv(recvfd, buf, sizeof(buf), 0);
197                 pjdlog_debug(2, "%s: recv() returned %zd", __func__, tcpdone);
198                 if (tcpdone == 0) {
199                         pjdlog_debug(1, "Connection terminated.");
200                         exit(0);
201                 } else if (tcpdone == -1) {
202                         if (errno == EINTR)
203                                 continue;
204                         else if (errno == EAGAIN)
205                                 break;
206                         pjdlog_exit(EX_TEMPFAIL, "recv() failed");
207                 }
208                 for (;;) {
209                         ssldone = SSL_write(sendssl, buf, (int)tcpdone);
210                         pjdlog_debug(2, "%s: send() returned %d", __func__,
211                             ssldone);
212                         if (ssl_check_error(sendssl, ssldone) == -1) {
213                                 (void)wait_for_fd(sendfd, -1);
214                                 continue;
215                         }
216                         PJDLOG_ASSERT(ssldone == tcpdone);
217                         break;
218                 }
219         }
220         pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd);
221 }
222
223 static void
224 ssl_recv_tcp_send(SSL *recvssl, int sendfd)
225 {
226         static unsigned char buf[65536];
227         unsigned char *ptr;
228         ssize_t tcpdone;
229         size_t todo;
230         int recvfd, ssldone;
231
232         recvfd = SSL_get_fd(recvssl);
233         PJDLOG_ASSERT(recvfd >= 0);
234         pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd);
235         for (;;) {
236                 ssldone = SSL_read(recvssl, buf, sizeof(buf));
237                 pjdlog_debug(2, "%s: SSL_read() returned %d", __func__,
238                     ssldone);
239                 if (ssl_check_error(recvssl, ssldone) == -1)
240                         break;
241                 todo = (size_t)ssldone;
242                 ptr = buf;
243                 do {
244                         tcpdone = send(sendfd, ptr, todo, MSG_NOSIGNAL);
245                         pjdlog_debug(2, "%s: send() returned %zd", __func__,
246                             tcpdone);
247                         if (tcpdone == 0) {
248                                 pjdlog_debug(1, "Connection terminated.");
249                                 exit(0);
250                         } else if (tcpdone == -1) {
251                                 if (errno == EINTR || errno == ENOBUFS)
252                                         continue;
253                                 if (errno == EAGAIN) {
254                                         (void)wait_for_fd(sendfd, -1);
255                                         continue;
256                                 }
257                                 pjdlog_exit(EX_TEMPFAIL, "send() failed");
258                         }
259                         todo -= tcpdone;
260                         ptr += tcpdone;
261                 } while (todo > 0);
262         }
263         pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd);
264 }
265
266 static void
267 tls_loop(int sockfd, SSL *tcpssl)
268 {
269         fd_set fds;
270         int maxfd, tcpfd;
271
272         tcpfd = SSL_get_fd(tcpssl);
273         PJDLOG_ASSERT(tcpfd >= 0);
274
275         for (;;) {
276                 FD_ZERO(&fds);
277                 FD_SET(sockfd, &fds);
278                 FD_SET(tcpfd, &fds);
279                 maxfd = MAX(sockfd, tcpfd);
280
281                 PJDLOG_ASSERT(maxfd + 1 <= (int)FD_SETSIZE);
282                 if (select(maxfd + 1, &fds, NULL, NULL, NULL) == -1) {
283                         if (errno == EINTR)
284                                 continue;
285                         pjdlog_exit(EX_TEMPFAIL, "select() failed");
286                 }
287                 if (FD_ISSET(sockfd, &fds))
288                         tcp_recv_ssl_send(sockfd, tcpssl);
289                 if (FD_ISSET(tcpfd, &fds))
290                         ssl_recv_tcp_send(tcpssl, sockfd);
291         }
292 }
293
294 static void
295 tls_certificate_verify(SSL *ssl, const char *fingerprint)
296 {
297         unsigned char md[EVP_MAX_MD_SIZE];
298         char mdstr[sizeof("SHA256=") - 1 + EVP_MAX_MD_SIZE * 3];
299         char *mdstrp;
300         unsigned int i, mdsize;
301         X509 *cert;
302
303         if (fingerprint[0] == '\0') {
304                 pjdlog_debug(1, "No fingerprint verification requested.");
305                 return;
306         }
307
308         cert = SSL_get_peer_certificate(ssl);
309         if (cert == NULL)
310                 pjdlog_exitx(EX_TEMPFAIL, "No peer certificate received.");
311
312         if (X509_digest(cert, EVP_sha256(), md, &mdsize) != 1)
313                 pjdlog_exitx(EX_TEMPFAIL, "X509_digest() failed.");
314         PJDLOG_ASSERT(mdsize <= EVP_MAX_MD_SIZE);
315
316         X509_free(cert);
317
318         (void)strlcpy(mdstr, "SHA256=", sizeof(mdstr));
319         mdstrp = mdstr + strlen(mdstr);
320         for (i = 0; i < mdsize; i++) {
321                 PJDLOG_VERIFY(mdstrp + 3 <= mdstr + sizeof(mdstr));
322                 (void)sprintf(mdstrp, "%02hhX:", md[i]);
323                 mdstrp += 3;
324         }
325         /* Clear last colon. */
326         mdstrp[-1] = '\0';
327         if (strcasecmp(mdstr, fingerprint) != 0) {
328                 pjdlog_exitx(EX_NOPERM,
329                     "Finger print doesn't match. Received \"%s\", expected \"%s\"",
330                     mdstr, fingerprint);
331         }
332 }
333
334 static void
335 tls_exec_client(const char *user, int startfd, const char *srcaddr,
336     const char *dstaddr, const char *fingerprint, const char *defport,
337     int timeout, int debuglevel)
338 {
339         struct proto_conn *tcp;
340         char *saddr, *daddr;
341         SSL_CTX *sslctx;
342         SSL *ssl;
343         long ret;
344         int sockfd, tcpfd;
345         uint8_t connected;
346
347         pjdlog_debug_set(debuglevel);
348         pjdlog_prefix_set("[TLS sandbox] (client) ");
349 #ifdef HAVE_SETPROCTITLE
350         setproctitle("[TLS sandbox] (client) ");
351 #endif
352         proto_set("tcp:port", defport);
353
354         sockfd = startfd;
355
356         /* Change tls:// to tcp://. */
357         if (srcaddr == NULL) {
358                 saddr = NULL;
359         } else {
360                 saddr = strdup(srcaddr);
361                 if (saddr == NULL)
362                         pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory.");
363                 bcopy("tcp://", saddr, 6);
364         }
365         daddr = strdup(dstaddr);
366         if (daddr == NULL)
367                 pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory.");
368         bcopy("tcp://", daddr, 6);
369
370         /* Establish TCP connection. */
371         if (proto_connect(saddr, daddr, timeout, &tcp) == -1)
372                 exit(EX_TEMPFAIL);
373
374         SSL_load_error_strings();
375         SSL_library_init();
376
377         /*
378          * TODO: On FreeBSD we could move this below sandbox() once libc and
379          *       libcrypto use sysctl kern.arandom to obtain random data
380          *       instead of /dev/urandom and friends.
381          */
382         sslctx = SSL_CTX_new(TLSv1_client_method());
383         if (sslctx == NULL)
384                 pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed.");
385
386         if (sandbox(user, true, "proto_tls client: %s", dstaddr) != 0)
387                 pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS client.");
388         pjdlog_debug(1, "Privileges successfully dropped.");
389
390         SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
391
392         /* Load CA certs. */
393         /* TODO */
394         //SSL_CTX_load_verify_locations(sslctx, cacerts_file, NULL);
395
396         ssl = SSL_new(sslctx);
397         if (ssl == NULL)
398                 pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed.");
399
400         tcpfd = proto_descriptor(tcp);
401
402         block(tcpfd);
403
404         if (SSL_set_fd(ssl, tcpfd) != 1)
405                 pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed.");
406
407         ret = SSL_connect(ssl);
408         ssl_check_error(ssl, (int)ret);
409
410         nonblock(sockfd);
411         nonblock(tcpfd);
412
413         tls_certificate_verify(ssl, fingerprint);
414
415         /*
416          * The following byte is send to make proto_connect_wait() to work.
417          */
418         connected = 1;
419         for (;;) {
420                 switch (send(sockfd, &connected, sizeof(connected), 0)) {
421                 case -1:
422                         if (errno == EINTR || errno == ENOBUFS)
423                                 continue;
424                         if (errno == EAGAIN) {
425                                 (void)wait_for_fd(sockfd, -1);
426                                 continue;
427                         }
428                         pjdlog_exit(EX_TEMPFAIL, "send() failed");
429                 case 0:
430                         pjdlog_debug(1, "Connection terminated.");
431                         exit(0);
432                 case 1:
433                         break;
434                 }
435                 break;
436         }
437
438         tls_loop(sockfd, ssl);
439 }
440
441 static void
442 tls_call_exec_client(struct proto_conn *sock, const char *srcaddr,
443     const char *dstaddr, int timeout)
444 {
445         char *timeoutstr, *startfdstr, *debugstr;
446         int startfd;
447
448         /* Declare that we are receiver. */
449         proto_recv(sock, NULL, 0);
450
451         if (pjdlog_mode_get() == PJDLOG_MODE_STD)
452                 startfd = 3;
453         else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */
454                 startfd = 0;
455
456         if (proto_descriptor(sock) != startfd) {
457                 /* Move socketpair descriptor to descriptor number startfd. */
458                 if (dup2(proto_descriptor(sock), startfd) == -1)
459                         pjdlog_exit(EX_OSERR, "dup2() failed");
460                 proto_close(sock);
461         } else {
462                 /*
463                  * The FD_CLOEXEC is cleared by dup2(2), so when we not
464                  * call it, we have to clear it by hand in case it is set.
465                  */
466                 if (fcntl(startfd, F_SETFD, 0) == -1)
467                         pjdlog_exit(EX_OSERR, "fcntl() failed");
468         }
469
470         closefrom(startfd + 1);
471
472         if (asprintf(&startfdstr, "%d", startfd) == -1)
473                 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
474         if (timeout == -1)
475                 timeout = TLS_DEFAULT_TIMEOUT;
476         if (asprintf(&timeoutstr, "%d", timeout) == -1)
477                 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
478         if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1)
479                 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
480
481         execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls",
482             proto_get("user"), "client", startfdstr,
483             srcaddr == NULL ? "" : srcaddr, dstaddr,
484             proto_get("tls:fingerprint"), proto_get("tcp:port"), timeoutstr,
485             debugstr, NULL);
486         pjdlog_exit(EX_SOFTWARE, "execl() failed");
487 }
488
489 static int
490 tls_connect(const char *srcaddr, const char *dstaddr, int timeout, void **ctxp)
491 {
492         struct tls_ctx *tlsctx;
493         struct proto_conn *sock;
494         pid_t pid;
495         int error;
496
497         PJDLOG_ASSERT(srcaddr == NULL || srcaddr[0] != '\0');
498         PJDLOG_ASSERT(dstaddr != NULL);
499         PJDLOG_ASSERT(timeout >= -1);
500         PJDLOG_ASSERT(ctxp != NULL);
501
502         if (strncmp(dstaddr, "tls://", 6) != 0)
503                 return (-1);
504         if (srcaddr != NULL && strncmp(srcaddr, "tls://", 6) != 0)
505                 return (-1);
506
507         if (proto_connect(NULL, "socketpair://", -1, &sock) == -1)
508                 return (errno);
509
510 #if 0
511         /*
512          * We use rfork() with the following flags to disable SIGCHLD
513          * delivery upon the sandbox process exit.
514          */
515         pid = rfork(RFFDG | RFPROC | RFTSIGZMB | RFTSIGFLAGS(0));
516 #else
517         /*
518          * We don't use rfork() to be able to log information about sandbox
519          * process exiting.
520          */
521         pid = fork();
522 #endif
523         switch (pid) {
524         case -1:
525                 /* Failure. */
526                 error = errno;
527                 proto_close(sock);
528                 return (error);
529         case 0:
530                 /* Child. */
531                 pjdlog_prefix_set("[TLS sandbox] (client) ");
532 #ifdef HAVE_SETPROCTITLE
533                 setproctitle("[TLS sandbox] (client) ");
534 #endif
535                 tls_call_exec_client(sock, srcaddr, dstaddr, timeout);
536                 /* NOTREACHED */
537         default:
538                 /* Parent. */
539                 tlsctx = calloc(1, sizeof(*tlsctx));
540                 if (tlsctx == NULL) {
541                         error = errno;
542                         proto_close(sock);
543                         (void)kill(pid, SIGKILL);
544                         return (error);
545                 }
546                 proto_send(sock, NULL, 0);
547                 tlsctx->tls_sock = sock;
548                 tlsctx->tls_tcp = NULL;
549                 tlsctx->tls_side = TLS_SIDE_CLIENT;
550                 tlsctx->tls_wait_called = false;
551                 tlsctx->tls_magic = TLS_CTX_MAGIC;
552                 if (timeout >= 0) {
553                         error = tls_connect_wait(tlsctx, timeout);
554                         if (error != 0) {
555                                 (void)kill(pid, SIGKILL);
556                                 tls_close(tlsctx);
557                                 return (error);
558                         }
559                 }
560                 *ctxp = tlsctx;
561                 return (0);
562         }
563 }
564
565 static int
566 tls_connect_wait(void *ctx, int timeout)
567 {
568         struct tls_ctx *tlsctx = ctx;
569         int error, sockfd;
570         uint8_t connected;
571
572         PJDLOG_ASSERT(tlsctx != NULL);
573         PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
574         PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT);
575         PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
576         PJDLOG_ASSERT(!tlsctx->tls_wait_called);
577         PJDLOG_ASSERT(timeout >= 0);
578
579         sockfd = proto_descriptor(tlsctx->tls_sock);
580         error = wait_for_fd(sockfd, timeout);
581         if (error != 0)
582                 return (error);
583
584         for (;;) {
585                 switch (recv(sockfd, &connected, sizeof(connected),
586                     MSG_WAITALL)) {
587                 case -1:
588                         if (errno == EINTR || errno == ENOBUFS)
589                                 continue;
590                         error = errno;
591                         break;
592                 case 0:
593                         pjdlog_debug(1, "Connection terminated.");
594                         error = ENOTCONN;
595                         break;
596                 case 1:
597                         tlsctx->tls_wait_called = true;
598                         break;
599                 }
600                 break;
601         }
602
603         return (error);
604 }
605
606 static int
607 tls_server(const char *lstaddr, void **ctxp)
608 {
609         struct proto_conn *tcp;
610         struct tls_ctx *tlsctx;
611         char *laddr;
612         int error;
613
614         if (strncmp(lstaddr, "tls://", 6) != 0)
615                 return (-1);
616
617         tlsctx = malloc(sizeof(*tlsctx));
618         if (tlsctx == NULL) {
619                 pjdlog_warning("Unable to allocate memory.");
620                 return (ENOMEM);
621         }
622
623         laddr = strdup(lstaddr);
624         if (laddr == NULL) {
625                 free(tlsctx);
626                 pjdlog_warning("Unable to allocate memory.");
627                 return (ENOMEM);
628         }
629         bcopy("tcp://", laddr, 6);
630
631         if (proto_server(laddr, &tcp) == -1) {
632                 error = errno;
633                 free(tlsctx);
634                 free(laddr);
635                 return (error);
636         }
637         free(laddr);
638
639         tlsctx->tls_sock = NULL;
640         tlsctx->tls_tcp = tcp;
641         tlsctx->tls_side = TLS_SIDE_SERVER_LISTEN;
642         tlsctx->tls_wait_called = true;
643         tlsctx->tls_magic = TLS_CTX_MAGIC;
644         *ctxp = tlsctx;
645
646         return (0);
647 }
648
649 static void
650 tls_exec_server(const char *user, int startfd, const char *privkey,
651     const char *cert, int debuglevel)
652 {
653         SSL_CTX *sslctx;
654         SSL *ssl;
655         int sockfd, tcpfd, ret;
656
657         pjdlog_debug_set(debuglevel);
658         pjdlog_prefix_set("[TLS sandbox] (server) ");
659 #ifdef HAVE_SETPROCTITLE
660         setproctitle("[TLS sandbox] (server) ");
661 #endif
662
663         sockfd = startfd;
664         tcpfd = startfd + 1;
665
666         SSL_load_error_strings();
667         SSL_library_init();
668
669         sslctx = SSL_CTX_new(TLSv1_server_method());
670         if (sslctx == NULL)
671                 pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed.");
672
673         SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
674
675         ssl = SSL_new(sslctx);
676         if (ssl == NULL)
677                 pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed.");
678
679         if (SSL_use_RSAPrivateKey_file(ssl, privkey, SSL_FILETYPE_PEM) != 1) {
680                 ssl_log_errors();
681                 pjdlog_exitx(EX_CONFIG,
682                     "SSL_use_RSAPrivateKey_file(%s) failed.", privkey);
683         }
684
685         if (SSL_use_certificate_file(ssl, cert, SSL_FILETYPE_PEM) != 1) {
686                 ssl_log_errors();
687                 pjdlog_exitx(EX_CONFIG, "SSL_use_certificate_file(%s) failed.",
688                     cert);
689         }
690
691         if (sandbox(user, true, "proto_tls server") != 0)
692                 pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS server.");
693         pjdlog_debug(1, "Privileges successfully dropped.");
694
695         nonblock(sockfd);
696         nonblock(tcpfd);
697
698         if (SSL_set_fd(ssl, tcpfd) != 1)
699                 pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed.");
700
701         ret = SSL_accept(ssl);
702         ssl_check_error(ssl, ret);
703
704         tls_loop(sockfd, ssl);
705 }
706
707 static void
708 tls_call_exec_server(struct proto_conn *sock, struct proto_conn *tcp)
709 {
710         int startfd, sockfd, tcpfd, safefd;
711         char *startfdstr, *debugstr;
712
713         if (pjdlog_mode_get() == PJDLOG_MODE_STD)
714                 startfd = 3;
715         else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */
716                 startfd = 0;
717
718         /* Declare that we are receiver. */
719         proto_send(sock, NULL, 0);
720
721         sockfd = proto_descriptor(sock);
722         tcpfd = proto_descriptor(tcp);
723
724         safefd = MAX(sockfd, tcpfd);
725         safefd = MAX(safefd, startfd);
726         safefd++;
727
728         /* Move sockfd and tcpfd to safe numbers first. */
729         if (dup2(sockfd, safefd) == -1)
730                 pjdlog_exit(EX_OSERR, "dup2() failed");
731         proto_close(sock);
732         sockfd = safefd;
733         if (dup2(tcpfd, safefd + 1) == -1)
734                 pjdlog_exit(EX_OSERR, "dup2() failed");
735         proto_close(tcp);
736         tcpfd = safefd + 1;
737
738         /* Move socketpair descriptor to descriptor number startfd. */
739         if (dup2(sockfd, startfd) == -1)
740                 pjdlog_exit(EX_OSERR, "dup2() failed");
741         (void)close(sockfd);
742         /* Move tcp descriptor to descriptor number startfd + 1. */
743         if (dup2(tcpfd, startfd + 1) == -1)
744                 pjdlog_exit(EX_OSERR, "dup2() failed");
745         (void)close(tcpfd);
746
747         closefrom(startfd + 2);
748
749         /*
750          * Even if FD_CLOEXEC was set on descriptors before dup2(), it should
751          * have been cleared on dup2(), but better be safe than sorry.
752          */
753         if (fcntl(startfd, F_SETFD, 0) == -1)
754                 pjdlog_exit(EX_OSERR, "fcntl() failed");
755         if (fcntl(startfd + 1, F_SETFD, 0) == -1)
756                 pjdlog_exit(EX_OSERR, "fcntl() failed");
757
758         if (asprintf(&startfdstr, "%d", startfd) == -1)
759                 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
760         if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1)
761                 pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
762
763         execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls",
764             proto_get("user"), "server", startfdstr, proto_get("tls:keyfile"),
765             proto_get("tls:certfile"), debugstr, NULL);
766         pjdlog_exit(EX_SOFTWARE, "execl() failed");
767 }
768
769 static int
770 tls_accept(void *ctx, void **newctxp)
771 {
772         struct tls_ctx *tlsctx = ctx;
773         struct tls_ctx *newtlsctx;
774         struct proto_conn *sock, *tcp;
775         pid_t pid;
776         int error;
777
778         PJDLOG_ASSERT(tlsctx != NULL);
779         PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
780         PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_SERVER_LISTEN);
781
782         if (proto_connect(NULL, "socketpair://", -1, &sock) == -1)
783                 return (errno);
784
785         /* Accept TCP connection. */
786         if (proto_accept(tlsctx->tls_tcp, &tcp) == -1) {
787                 error = errno;
788                 proto_close(sock);
789                 return (error);
790         }
791
792         pid = fork();
793         switch (pid) {
794         case -1:
795                 /* Failure. */
796                 error = errno;
797                 proto_close(sock);
798                 return (error);
799         case 0:
800                 /* Child. */
801                 pjdlog_prefix_set("[TLS sandbox] (server) ");
802 #ifdef HAVE_SETPROCTITLE
803                 setproctitle("[TLS sandbox] (server) ");
804 #endif
805                 /* Close listen socket. */
806                 proto_close(tlsctx->tls_tcp);
807                 tls_call_exec_server(sock, tcp);
808                 /* NOTREACHED */
809                 PJDLOG_ABORT("Unreachable.");
810         default:
811                 /* Parent. */
812                 newtlsctx = calloc(1, sizeof(*tlsctx));
813                 if (newtlsctx == NULL) {
814                         error = errno;
815                         proto_close(sock);
816                         proto_close(tcp);
817                         (void)kill(pid, SIGKILL);
818                         return (error);
819                 }
820                 proto_local_address(tcp, newtlsctx->tls_laddr,
821                     sizeof(newtlsctx->tls_laddr));
822                 PJDLOG_ASSERT(strncmp(newtlsctx->tls_laddr, "tcp://", 6) == 0);
823                 bcopy("tls://", newtlsctx->tls_laddr, 6);
824                 *strrchr(newtlsctx->tls_laddr, ':') = '\0';
825                 proto_remote_address(tcp, newtlsctx->tls_raddr,
826                     sizeof(newtlsctx->tls_raddr));
827                 PJDLOG_ASSERT(strncmp(newtlsctx->tls_raddr, "tcp://", 6) == 0);
828                 bcopy("tls://", newtlsctx->tls_raddr, 6);
829                 *strrchr(newtlsctx->tls_raddr, ':') = '\0';
830                 proto_close(tcp);
831                 proto_recv(sock, NULL, 0);
832                 newtlsctx->tls_sock = sock;
833                 newtlsctx->tls_tcp = NULL;
834                 newtlsctx->tls_wait_called = true;
835                 newtlsctx->tls_side = TLS_SIDE_SERVER_WORK;
836                 newtlsctx->tls_magic = TLS_CTX_MAGIC;
837                 *newctxp = newtlsctx;
838                 return (0);
839         }
840 }
841
842 static int
843 tls_wrap(int fd, bool client, void **ctxp)
844 {
845         struct tls_ctx *tlsctx;
846         struct proto_conn *sock;
847         int error;
848
849         tlsctx = calloc(1, sizeof(*tlsctx));
850         if (tlsctx == NULL)
851                 return (errno);
852
853         if (proto_wrap("socketpair", client, fd, &sock) == -1) {
854                 error = errno;
855                 free(tlsctx);
856                 return (error);
857         }
858
859         tlsctx->tls_sock = sock;
860         tlsctx->tls_tcp = NULL;
861         tlsctx->tls_wait_called = (client ? false : true);
862         tlsctx->tls_side = (client ? TLS_SIDE_CLIENT : TLS_SIDE_SERVER_WORK);
863         tlsctx->tls_magic = TLS_CTX_MAGIC;
864         *ctxp = tlsctx;
865
866         return (0);
867 }
868
869 static int
870 tls_send(void *ctx, const unsigned char *data, size_t size, int fd)
871 {
872         struct tls_ctx *tlsctx = ctx;
873
874         PJDLOG_ASSERT(tlsctx != NULL);
875         PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
876         PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT ||
877             tlsctx->tls_side == TLS_SIDE_SERVER_WORK);
878         PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
879         PJDLOG_ASSERT(tlsctx->tls_wait_called);
880         PJDLOG_ASSERT(fd == -1);
881
882         if (proto_send(tlsctx->tls_sock, data, size) == -1)
883                 return (errno);
884
885         return (0);
886 }
887
888 static int
889 tls_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
890 {
891         struct tls_ctx *tlsctx = ctx;
892
893         PJDLOG_ASSERT(tlsctx != NULL);
894         PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
895         PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT ||
896             tlsctx->tls_side == TLS_SIDE_SERVER_WORK);
897         PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
898         PJDLOG_ASSERT(tlsctx->tls_wait_called);
899         PJDLOG_ASSERT(fdp == NULL);
900
901         if (proto_recv(tlsctx->tls_sock, data, size) == -1)
902                 return (errno);
903
904         return (0);
905 }
906
907 static int
908 tls_descriptor(const void *ctx)
909 {
910         const struct tls_ctx *tlsctx = ctx;
911
912         PJDLOG_ASSERT(tlsctx != NULL);
913         PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
914
915         switch (tlsctx->tls_side) {
916         case TLS_SIDE_CLIENT:
917         case TLS_SIDE_SERVER_WORK:
918                 PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
919
920                 return (proto_descriptor(tlsctx->tls_sock));
921         case TLS_SIDE_SERVER_LISTEN:
922                 PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
923
924                 return (proto_descriptor(tlsctx->tls_tcp));
925         default:
926                 PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
927         }
928 }
929
930 static bool
931 tcp_address_match(const void *ctx, const char *addr)
932 {
933         const struct tls_ctx *tlsctx = ctx;
934
935         PJDLOG_ASSERT(tlsctx != NULL);
936         PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
937
938         return (strcmp(tlsctx->tls_raddr, addr) == 0);
939 }
940
941 static void
942 tls_local_address(const void *ctx, char *addr, size_t size)
943 {
944         const struct tls_ctx *tlsctx = ctx;
945
946         PJDLOG_ASSERT(tlsctx != NULL);
947         PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
948         PJDLOG_ASSERT(tlsctx->tls_wait_called);
949
950         switch (tlsctx->tls_side) {
951         case TLS_SIDE_CLIENT:
952                 PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
953
954                 PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size);
955                 break;
956         case TLS_SIDE_SERVER_WORK:
957                 PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
958
959                 PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_laddr, size) < size);
960                 break;
961         case TLS_SIDE_SERVER_LISTEN:
962                 PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
963
964                 proto_local_address(tlsctx->tls_tcp, addr, size);
965                 PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0);
966                 /* Replace tcp:// prefix with tls:// */
967                 bcopy("tls://", addr, 6);
968                 break;
969         default:
970                 PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
971         }
972 }
973
974 static void
975 tls_remote_address(const void *ctx, char *addr, size_t size)
976 {
977         const struct tls_ctx *tlsctx = ctx;
978
979         PJDLOG_ASSERT(tlsctx != NULL);
980         PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
981         PJDLOG_ASSERT(tlsctx->tls_wait_called);
982
983         switch (tlsctx->tls_side) {
984         case TLS_SIDE_CLIENT:
985                 PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
986
987                 PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size);
988                 break;
989         case TLS_SIDE_SERVER_WORK:
990                 PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
991
992                 PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_raddr, size) < size);
993                 break;
994         case TLS_SIDE_SERVER_LISTEN:
995                 PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
996
997                 proto_remote_address(tlsctx->tls_tcp, addr, size);
998                 PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0);
999                 /* Replace tcp:// prefix with tls:// */
1000                 bcopy("tls://", addr, 6);
1001                 break;
1002         default:
1003                 PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
1004         }
1005 }
1006
1007 static void
1008 tls_close(void *ctx)
1009 {
1010         struct tls_ctx *tlsctx = ctx;
1011
1012         PJDLOG_ASSERT(tlsctx != NULL);
1013         PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
1014
1015         if (tlsctx->tls_sock != NULL) {
1016                 proto_close(tlsctx->tls_sock);
1017                 tlsctx->tls_sock = NULL;
1018         }
1019         if (tlsctx->tls_tcp != NULL) {
1020                 proto_close(tlsctx->tls_tcp);
1021                 tlsctx->tls_tcp = NULL;
1022         }
1023         tlsctx->tls_side = 0;
1024         tlsctx->tls_magic = 0;
1025         free(tlsctx);
1026 }
1027
1028 static int
1029 tls_exec(int argc, char *argv[])
1030 {
1031
1032         PJDLOG_ASSERT(argc > 3);
1033         PJDLOG_ASSERT(strcmp(argv[0], "tls") == 0);
1034
1035         pjdlog_init(atoi(argv[3]) == 0 ? PJDLOG_MODE_SYSLOG : PJDLOG_MODE_STD);
1036
1037         if (strcmp(argv[2], "client") == 0) {
1038                 if (argc != 10)
1039                         return (EINVAL);
1040                 tls_exec_client(argv[1], atoi(argv[3]),
1041                     argv[4][0] == '\0' ? NULL : argv[4], argv[5], argv[6],
1042                     argv[7], atoi(argv[8]), atoi(argv[9]));
1043         } else if (strcmp(argv[2], "server") == 0) {
1044                 if (argc != 7)
1045                         return (EINVAL);
1046                 tls_exec_server(argv[1], atoi(argv[3]), argv[4], argv[5],
1047                     atoi(argv[6]));
1048         }
1049         return (EINVAL);
1050 }
1051
1052 static struct proto tls_proto = {
1053         .prt_name = "tls",
1054         .prt_connect = tls_connect,
1055         .prt_connect_wait = tls_connect_wait,
1056         .prt_server = tls_server,
1057         .prt_accept = tls_accept,
1058         .prt_wrap = tls_wrap,
1059         .prt_send = tls_send,
1060         .prt_recv = tls_recv,
1061         .prt_descriptor = tls_descriptor,
1062         .prt_address_match = tcp_address_match,
1063         .prt_local_address = tls_local_address,
1064         .prt_remote_address = tls_remote_address,
1065         .prt_close = tls_close,
1066         .prt_exec = tls_exec
1067 };
1068
1069 static __constructor void
1070 tls_ctor(void)
1071 {
1072
1073         proto_register(&tls_proto, false);
1074 }