]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - contrib/ofed/librdmacm/preload.c
MFV r362143:
[FreeBSD/FreeBSD.git] / contrib / ofed / librdmacm / preload.c
1 /*
2  * Copyright (c) 2011-2012 Intel Corporation.  All rights reserved.
3  *
4  * This software is available to you under a choice of one of two
5  * licenses.  You may choose to be licensed under the terms of the GNU
6  * General Public License (GPL) Version 2, available from the file
7  * COPYING in the main directory of this source tree, or the
8  * OpenIB.org BSD license below:
9  *
10  *     Redistribution and use in source and binary forms, with or
11  *     without modification, are permitted provided that the following
12  *     conditions are met:
13  *
14  *      - Redistributions of source code must retain the above
15  *        copyright notice, this list of conditions and the following
16  *        disclaimer.
17  *
18  *      - Redistributions in binary form must reproduce the above
19  *        copyright notice, this list of conditions and the following
20  *        disclaimer in the documentation and/or other materials
21  *        provided with the distribution.
22  *
23  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
27  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
28  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
29  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30  * SOFTWARE.
31  *
32  */
33 #define _GNU_SOURCE
34 #include <config.h>
35
36 #include <sys/types.h>
37 #include <sys/socket.h>
38 #include <sys/uio.h>
39 #include <sys/stat.h>
40 #include <sys/mman.h>
41 #include <stdarg.h>
42 #include <dlfcn.h>
43 #include <netdb.h>
44 #include <unistd.h>
45 #include <fcntl.h>
46 #include <string.h>
47 #include <netinet/tcp.h>
48 #include <unistd.h>
49 #include <semaphore.h>
50 #include <ctype.h>
51 #include <stdlib.h>
52 #include <stdio.h>
53
54 #include <rdma/rdma_cma.h>
55 #include <rdma/rdma_verbs.h>
56 #include <rdma/rsocket.h>
57 #include "cma.h"
58 #include "indexer.h"
59
60 struct socket_calls {
61         int (*socket)(int domain, int type, int protocol);
62         int (*bind)(int socket, const struct sockaddr *addr, socklen_t addrlen);
63         int (*listen)(int socket, int backlog);
64         int (*accept)(int socket, struct sockaddr *addr, socklen_t *addrlen);
65         int (*connect)(int socket, const struct sockaddr *addr, socklen_t addrlen);
66         ssize_t (*recv)(int socket, void *buf, size_t len, int flags);
67         ssize_t (*recvfrom)(int socket, void *buf, size_t len, int flags,
68                             struct sockaddr *src_addr, socklen_t *addrlen);
69         ssize_t (*recvmsg)(int socket, struct msghdr *msg, int flags);
70         ssize_t (*read)(int socket, void *buf, size_t count);
71         ssize_t (*readv)(int socket, const struct iovec *iov, int iovcnt);
72         ssize_t (*send)(int socket, const void *buf, size_t len, int flags);
73         ssize_t (*sendto)(int socket, const void *buf, size_t len, int flags,
74                           const struct sockaddr *dest_addr, socklen_t addrlen);
75         ssize_t (*sendmsg)(int socket, const struct msghdr *msg, int flags);
76         ssize_t (*write)(int socket, const void *buf, size_t count);
77         ssize_t (*writev)(int socket, const struct iovec *iov, int iovcnt);
78         int (*poll)(struct pollfd *fds, nfds_t nfds, int timeout);
79         int (*shutdown)(int socket, int how);
80         int (*close)(int socket);
81         int (*getpeername)(int socket, struct sockaddr *addr, socklen_t *addrlen);
82         int (*getsockname)(int socket, struct sockaddr *addr, socklen_t *addrlen);
83         int (*setsockopt)(int socket, int level, int optname,
84                           const void *optval, socklen_t optlen);
85         int (*getsockopt)(int socket, int level, int optname,
86                           void *optval, socklen_t *optlen);
87         int (*fcntl)(int socket, int cmd, ... /* arg */);
88         int (*dup2)(int oldfd, int newfd);
89         ssize_t (*sendfile)(int out_fd, int in_fd, off_t *offset, size_t count);
90         int (*fxstat)(int ver, int fd, struct stat *buf);
91 };
92
93 static struct socket_calls real;
94 static struct socket_calls rs;
95
96 static struct index_map idm;
97 static pthread_mutex_t mut = PTHREAD_MUTEX_INITIALIZER;
98
99 static int sq_size;
100 static int rq_size;
101 static int sq_inline;
102 static int fork_support;
103
104 enum fd_type {
105         fd_normal,
106         fd_rsocket
107 };
108
109 enum fd_fork_state {
110         fd_ready,
111         fd_fork,
112         fd_fork_listen,
113         fd_fork_active,
114         fd_fork_passive
115 };
116
117 struct fd_info {
118         enum fd_type type;
119         enum fd_fork_state state;
120         int fd;
121         int dupfd;
122         _Atomic(int) refcnt;
123 };
124
125 struct config_entry {
126         char *name;
127         int domain;
128         int type;
129         int protocol;
130 };
131
132 static struct config_entry *config;
133 static int config_cnt;
134
135 static void free_config(void)
136 {
137         while (config_cnt)
138                 free(config[--config_cnt].name);
139
140         free(config);
141 }
142
143 /*
144  * Config file format:
145  * # Starting '#' indicates comment
146  * # wild card values are supported using '*'
147  * # domain - *, INET, INET6, IB
148  * # type - *, STREAM, DGRAM
149  * # protocol - *, TCP, UDP
150  * program_name domain type protocol
151  */
152 static void scan_config(void)
153 {
154         struct config_entry *new_config;
155         FILE *fp;
156         char line[120], prog[64], dom[16], type[16], proto[16];
157
158         fp = fopen(RS_CONF_DIR "/preload_config", "r");
159         if (!fp)
160                 return;
161
162         while (fgets(line, sizeof(line), fp)) {
163                 if (line[0] == '#')
164                         continue;
165
166                 if (sscanf(line, "%64s%16s%16s%16s", prog, dom, type, proto) != 4)
167                         continue;
168
169                 new_config = realloc(config, (config_cnt + 1) *
170                                              sizeof(struct config_entry));
171                 if (!new_config)
172                         break;
173
174                 config = new_config;
175                 memset(&config[config_cnt], 0, sizeof(struct config_entry));
176
177                 if (!strcasecmp(dom, "INET") ||
178                     !strcasecmp(dom, "AF_INET") ||
179                     !strcasecmp(dom, "PF_INET")) {
180                         config[config_cnt].domain = AF_INET;
181                 } else if (!strcasecmp(dom, "INET6") ||
182                            !strcasecmp(dom, "AF_INET6") ||
183                            !strcasecmp(dom, "PF_INET6")) {
184                         config[config_cnt].domain = AF_INET6;
185                 } else if (!strcasecmp(dom, "IB") ||
186                            !strcasecmp(dom, "AF_IB") ||
187                            !strcasecmp(dom, "PF_IB")) {
188                         config[config_cnt].domain = AF_IB;
189                 } else if (strcmp(dom, "*")) {
190                         continue;
191                 }
192
193                 if (!strcasecmp(type, "STREAM") ||
194                     !strcasecmp(type, "SOCK_STREAM")) {
195                         config[config_cnt].type = SOCK_STREAM;
196                 } else if (!strcasecmp(type, "DGRAM") ||
197                            !strcasecmp(type, "SOCK_DGRAM")) {
198                         config[config_cnt].type = SOCK_DGRAM;
199                 } else if (strcmp(type, "*")) {
200                         continue;
201                 }
202
203                 if (!strcasecmp(proto, "TCP") ||
204                     !strcasecmp(proto, "IPPROTO_TCP")) {
205                         config[config_cnt].protocol = IPPROTO_TCP;
206                 } else if (!strcasecmp(proto, "UDP") ||
207                            !strcasecmp(proto, "IPPROTO_UDP")) {
208                         config[config_cnt].protocol = IPPROTO_UDP;
209                 } else if (strcmp(proto, "*")) {
210                         continue;
211                 }
212
213                 if (strcmp(prog, "*")) {
214                     if (!(config[config_cnt].name = strdup(prog)))
215                             continue;
216                 }
217
218                 config_cnt++;
219         }
220
221         fclose(fp);
222         if (config_cnt)
223                 atexit(free_config);
224 }
225
226 static int intercept_socket(int domain, int type, int protocol)
227 {
228         int i;
229
230         if (!config_cnt)
231                 return 1;
232
233         if (!protocol) {
234                 if (type == SOCK_STREAM)
235                         protocol = IPPROTO_TCP;
236                 else if (type == SOCK_DGRAM)
237                         protocol = IPPROTO_UDP;
238         }
239
240         for (i = 0; i < config_cnt; i++) {
241                 if ((!config[i].name ||
242                      !strncasecmp(config[i].name, program_invocation_short_name,
243                                   strlen(config[i].name))) &&
244                     (!config[i].domain || config[i].domain == domain) &&
245                     (!config[i].type || config[i].type == type) &&
246                     (!config[i].protocol || config[i].protocol == protocol))
247                         return 1;
248         }
249
250         return 0;
251 }
252
253 static int fd_open(void)
254 {
255         struct fd_info *fdi;
256         int ret, index;
257
258         fdi = calloc(1, sizeof(*fdi));
259         if (!fdi)
260                 return ERR(ENOMEM);
261
262         index = open("/dev/null", O_RDONLY);
263         if (index < 0) {
264                 ret = index;
265                 goto err1;
266         }
267
268         fdi->dupfd = -1;
269         atomic_store(&fdi->refcnt, 1);
270         pthread_mutex_lock(&mut);
271         ret = idm_set(&idm, index, fdi);
272         pthread_mutex_unlock(&mut);
273         if (ret < 0)
274                 goto err2;
275
276         return index;
277
278 err2:
279         real.close(index);
280 err1:
281         free(fdi);
282         return ret;
283 }
284
285 static void fd_store(int index, int fd, enum fd_type type, enum fd_fork_state state)
286 {
287         struct fd_info *fdi;
288
289         fdi = idm_at(&idm, index);
290         fdi->fd = fd;
291         fdi->type = type;
292         fdi->state = state;
293 }
294
295 static inline enum fd_type fd_get(int index, int *fd)
296 {
297         struct fd_info *fdi;
298
299         fdi = idm_lookup(&idm, index);
300         if (fdi) {
301                 *fd = fdi->fd;
302                 return fdi->type;
303
304         } else {
305                 *fd = index;
306                 return fd_normal;
307         }
308 }
309
310 static inline int fd_getd(int index)
311 {
312         struct fd_info *fdi;
313
314         fdi = idm_lookup(&idm, index);
315         return fdi ? fdi->fd : index;
316 }
317
318 static inline enum fd_fork_state fd_gets(int index)
319 {
320         struct fd_info *fdi;
321
322         fdi = idm_lookup(&idm, index);
323         return fdi ? fdi->state : fd_ready;
324 }
325
326 static inline enum fd_type fd_gett(int index)
327 {
328         struct fd_info *fdi;
329
330         fdi = idm_lookup(&idm, index);
331         return fdi ? fdi->type : fd_normal;
332 }
333
334 static enum fd_type fd_close(int index, int *fd)
335 {
336         struct fd_info *fdi;
337         enum fd_type type;
338
339         fdi = idm_lookup(&idm, index);
340         if (fdi) {
341                 idm_clear(&idm, index);
342                 *fd = fdi->fd;
343                 type = fdi->type;
344                 real.close(index);
345                 free(fdi);
346         } else {
347                 *fd = index;
348                 type = fd_normal;
349         }
350         return type;
351 }
352
353 static void getenv_options(void)
354 {
355         char *var;
356
357         var = getenv("RS_SQ_SIZE");
358         if (var)
359                 sq_size = atoi(var);
360
361         var = getenv("RS_RQ_SIZE");
362         if (var)
363                 rq_size = atoi(var);
364
365         var = getenv("RS_INLINE");
366         if (var)
367                 sq_inline = atoi(var);
368
369         var = getenv("RDMAV_FORK_SAFE");
370         if (var)
371                 fork_support = atoi(var);
372 }
373
374 static void init_preload(void)
375 {
376         static int init;
377
378         /* Quick check without lock */
379         if (init)
380                 return;
381
382         pthread_mutex_lock(&mut);
383         if (init)
384                 goto out;
385
386         real.socket = dlsym(RTLD_NEXT, "socket");
387         real.bind = dlsym(RTLD_NEXT, "bind");
388         real.listen = dlsym(RTLD_NEXT, "listen");
389         real.accept = dlsym(RTLD_NEXT, "accept");
390         real.connect = dlsym(RTLD_NEXT, "connect");
391         real.recv = dlsym(RTLD_NEXT, "recv");
392         real.recvfrom = dlsym(RTLD_NEXT, "recvfrom");
393         real.recvmsg = dlsym(RTLD_NEXT, "recvmsg");
394         real.read = dlsym(RTLD_NEXT, "read");
395         real.readv = dlsym(RTLD_NEXT, "readv");
396         real.send = dlsym(RTLD_NEXT, "send");
397         real.sendto = dlsym(RTLD_NEXT, "sendto");
398         real.sendmsg = dlsym(RTLD_NEXT, "sendmsg");
399         real.write = dlsym(RTLD_NEXT, "write");
400         real.writev = dlsym(RTLD_NEXT, "writev");
401         real.poll = dlsym(RTLD_NEXT, "poll");
402         real.shutdown = dlsym(RTLD_NEXT, "shutdown");
403         real.close = dlsym(RTLD_NEXT, "close");
404         real.getpeername = dlsym(RTLD_NEXT, "getpeername");
405         real.getsockname = dlsym(RTLD_NEXT, "getsockname");
406         real.setsockopt = dlsym(RTLD_NEXT, "setsockopt");
407         real.getsockopt = dlsym(RTLD_NEXT, "getsockopt");
408         real.fcntl = dlsym(RTLD_NEXT, "fcntl");
409         real.dup2 = dlsym(RTLD_NEXT, "dup2");
410         real.sendfile = dlsym(RTLD_NEXT, "sendfile");
411         real.fxstat = dlsym(RTLD_NEXT, "__fxstat");
412
413         rs.socket = dlsym(RTLD_DEFAULT, "rsocket");
414         rs.bind = dlsym(RTLD_DEFAULT, "rbind");
415         rs.listen = dlsym(RTLD_DEFAULT, "rlisten");
416         rs.accept = dlsym(RTLD_DEFAULT, "raccept");
417         rs.connect = dlsym(RTLD_DEFAULT, "rconnect");
418         rs.recv = dlsym(RTLD_DEFAULT, "rrecv");
419         rs.recvfrom = dlsym(RTLD_DEFAULT, "rrecvfrom");
420         rs.recvmsg = dlsym(RTLD_DEFAULT, "rrecvmsg");
421         rs.read = dlsym(RTLD_DEFAULT, "rread");
422         rs.readv = dlsym(RTLD_DEFAULT, "rreadv");
423         rs.send = dlsym(RTLD_DEFAULT, "rsend");
424         rs.sendto = dlsym(RTLD_DEFAULT, "rsendto");
425         rs.sendmsg = dlsym(RTLD_DEFAULT, "rsendmsg");
426         rs.write = dlsym(RTLD_DEFAULT, "rwrite");
427         rs.writev = dlsym(RTLD_DEFAULT, "rwritev");
428         rs.poll = dlsym(RTLD_DEFAULT, "rpoll");
429         rs.shutdown = dlsym(RTLD_DEFAULT, "rshutdown");
430         rs.close = dlsym(RTLD_DEFAULT, "rclose");
431         rs.getpeername = dlsym(RTLD_DEFAULT, "rgetpeername");
432         rs.getsockname = dlsym(RTLD_DEFAULT, "rgetsockname");
433         rs.setsockopt = dlsym(RTLD_DEFAULT, "rsetsockopt");
434         rs.getsockopt = dlsym(RTLD_DEFAULT, "rgetsockopt");
435         rs.fcntl = dlsym(RTLD_DEFAULT, "rfcntl");
436
437         getenv_options();
438         scan_config();
439         init = 1;
440 out:
441         pthread_mutex_unlock(&mut);
442 }
443
444 /*
445  * We currently only handle copying a few common values.
446  */
447 static int copysockopts(int dfd, int sfd, struct socket_calls *dapi,
448                         struct socket_calls *sapi)
449 {
450         socklen_t len;
451         int param, ret;
452
453         ret = sapi->fcntl(sfd, F_GETFL);
454         if (ret > 0)
455                 ret = dapi->fcntl(dfd, F_SETFL, ret);
456         if (ret)
457                 return ret;
458
459         len = sizeof param;
460         ret = sapi->getsockopt(sfd, SOL_SOCKET, SO_REUSEADDR, &param, &len);
461         if (param && !ret)
462                 ret = dapi->setsockopt(dfd, SOL_SOCKET, SO_REUSEADDR, &param, len);
463         if (ret)
464                 return ret;
465
466         len = sizeof param;
467         ret = sapi->getsockopt(sfd, IPPROTO_TCP, TCP_NODELAY, &param, &len);
468         if (param && !ret)
469                 ret = dapi->setsockopt(dfd, IPPROTO_TCP, TCP_NODELAY, &param, len);
470         if (ret)
471                 return ret;
472
473         return 0;
474 }
475
476 /*
477  * Convert between an rsocket and a normal socket.
478  */
479 static int transpose_socket(int socket, enum fd_type new_type)
480 {
481         socklen_t len = 0;
482         int sfd, dfd, param, ret;
483         struct socket_calls *sapi, *dapi;
484
485         sfd = fd_getd(socket);
486         if (new_type == fd_rsocket) {
487                 dapi = &rs;
488                 sapi = &real;
489         } else {
490                 dapi = &real;
491                 sapi = &rs;
492         }
493
494         ret = sapi->getsockname(sfd, NULL, &len);
495         if (ret)
496                 return ret;
497
498         param = (len == sizeof(struct sockaddr_in6)) ? PF_INET6 : PF_INET;
499         dfd = dapi->socket(param, SOCK_STREAM, 0);
500         if (dfd < 0)
501                 return dfd;
502
503         ret = copysockopts(dfd, sfd, dapi, sapi);
504         if (ret)
505                 goto err;
506
507         fd_store(socket, dfd, new_type, fd_ready);
508         return dfd;
509
510 err:
511         dapi->close(dfd);
512         return ret;
513 }
514
515 /*
516  * Use defaults on failure.
517  */
518 static void set_rsocket_options(int rsocket)
519 {
520         if (sq_size)
521                 rsetsockopt(rsocket, SOL_RDMA, RDMA_SQSIZE, &sq_size, sizeof sq_size);
522
523         if (rq_size)
524                 rsetsockopt(rsocket, SOL_RDMA, RDMA_RQSIZE, &rq_size, sizeof rq_size);
525
526         if (sq_inline)
527                 rsetsockopt(rsocket, SOL_RDMA, RDMA_INLINE, &sq_inline, sizeof sq_inline);
528 }
529
530 int socket(int domain, int type, int protocol)
531 {
532         static __thread int recursive;
533         int index, ret;
534
535         init_preload();
536
537         if (recursive || !intercept_socket(domain, type, protocol))
538                 goto real;
539
540         index = fd_open();
541         if (index < 0)
542                 return index;
543
544         if (fork_support && (domain == PF_INET || domain == PF_INET6) &&
545             (type == SOCK_STREAM) && (!protocol || protocol == IPPROTO_TCP)) {
546                 ret = real.socket(domain, type, protocol);
547                 if (ret < 0)
548                         return ret;
549                 fd_store(index, ret, fd_normal, fd_fork);
550                 return index;
551         }
552
553         recursive = 1;
554         ret = rsocket(domain, type, protocol);
555         recursive = 0;
556         if (ret >= 0) {
557                 fd_store(index, ret, fd_rsocket, fd_ready);
558                 set_rsocket_options(ret);
559                 return index;
560         }
561         fd_close(index, &ret);
562 real:
563         return real.socket(domain, type, protocol);
564 }
565
566 int bind(int socket, const struct sockaddr *addr, socklen_t addrlen)
567 {
568         int fd;
569         return (fd_get(socket, &fd) == fd_rsocket) ?
570                 rbind(fd, addr, addrlen) : real.bind(fd, addr, addrlen);
571 }
572
573 int listen(int socket, int backlog)
574 {
575         int fd, ret;
576         if (fd_get(socket, &fd) == fd_rsocket) {
577                 ret = rlisten(fd, backlog);
578         } else {
579                 ret = real.listen(fd, backlog);
580                 if (!ret && fd_gets(socket) == fd_fork)
581                         fd_store(socket, fd, fd_normal, fd_fork_listen);
582         }
583         return ret;
584 }
585
586 int accept(int socket, struct sockaddr *addr, socklen_t *addrlen)
587 {
588         int fd, index, ret;
589
590         if (fd_get(socket, &fd) == fd_rsocket) {
591                 index = fd_open();
592                 if (index < 0)
593                         return index;
594
595                 ret = raccept(fd, addr, addrlen);
596                 if (ret < 0) {
597                         fd_close(index, &fd);
598                         return ret;
599                 }
600
601                 fd_store(index, ret, fd_rsocket, fd_ready);
602                 return index;
603         } else if (fd_gets(socket) == fd_fork_listen) {
604                 index = fd_open();
605                 if (index < 0)
606                         return index;
607
608                 ret = real.accept(fd, addr, addrlen);
609                 if (ret < 0) {
610                         fd_close(index, &fd);
611                         return ret;
612                 }
613
614                 fd_store(index, ret, fd_normal, fd_fork_passive);
615                 return index;
616         } else {
617                 return real.accept(fd, addr, addrlen);
618         }
619 }
620
621 /*
622  * We can't fork RDMA connections and pass them from the parent to the child
623  * process.  Instead, we need to establish the RDMA connection after calling
624  * fork.  To do this, we delay establishing the RDMA connection until we try
625  * to send/receive on the server side.
626  */
627 static void fork_active(int socket)
628 {
629         struct sockaddr_storage addr;
630         int sfd, dfd, ret;
631         socklen_t len;
632         uint32_t msg;
633         long flags;
634
635         sfd = fd_getd(socket);
636
637         flags = real.fcntl(sfd, F_GETFL);
638         real.fcntl(sfd, F_SETFL, 0);
639         ret = real.recv(sfd, &msg, sizeof msg, MSG_PEEK);
640         real.fcntl(sfd, F_SETFL, flags);
641         if ((ret != sizeof msg) || msg)
642                 goto err1;
643
644         len = sizeof addr;
645         ret = real.getpeername(sfd, (struct sockaddr *) &addr, &len);
646         if (ret)
647                 goto err1;
648
649         dfd = rsocket(addr.ss_family, SOCK_STREAM, 0);
650         if (dfd < 0)
651                 goto err1;
652
653         ret = rconnect(dfd, (struct sockaddr *) &addr, len);
654         if (ret)
655                 goto err2;
656
657         set_rsocket_options(dfd);
658         copysockopts(dfd, sfd, &rs, &real);
659         real.shutdown(sfd, SHUT_RDWR);
660         real.close(sfd);
661         fd_store(socket, dfd, fd_rsocket, fd_ready);
662         return;
663
664 err2:
665         rclose(dfd);
666 err1:
667         fd_store(socket, sfd, fd_normal, fd_ready);
668 }
669
670 /*
671  * The server will start listening for the new connection, then send a
672  * message to the active side when the listen is ready.  This does leave
673  * fork unsupported in the following case: the server is nonblocking and
674  * calls select/poll waiting to receive data from the client.
675  */
676 static void fork_passive(int socket)
677 {
678         struct sockaddr_in6 sin6;
679         sem_t *sem;
680         int lfd, sfd, dfd, ret, param;
681         socklen_t len;
682         uint32_t msg;
683
684         sfd = fd_getd(socket);
685
686         len = sizeof sin6;
687         ret = real.getsockname(sfd, (struct sockaddr *) &sin6, &len);
688         if (ret)
689                 goto out;
690         sin6.sin6_flowinfo = 0;
691         sin6.sin6_scope_id = 0;
692         memset(&sin6.sin6_addr, 0, sizeof sin6.sin6_addr);
693
694         sem = sem_open("/rsocket_fork", O_CREAT | O_RDWR,
695                        S_IRWXU | S_IRWXG, 1);
696         if (sem == SEM_FAILED) {
697                 ret = -1;
698                 goto out;
699         }
700
701         lfd = rsocket(sin6.sin6_family, SOCK_STREAM, 0);
702         if (lfd < 0) {
703                 ret = lfd;
704                 goto sclose;
705         }
706
707         param = 1;
708         rsetsockopt(lfd, SOL_SOCKET, SO_REUSEADDR, &param, sizeof param);
709
710         sem_wait(sem);
711         ret = rbind(lfd, (struct sockaddr *) &sin6, sizeof sin6);
712         if (ret)
713                 goto lclose;
714
715         ret = rlisten(lfd, 1);
716         if (ret)
717                 goto lclose;
718
719         msg = 0;
720         len = real.write(sfd, &msg, sizeof msg);
721         if (len != sizeof msg)
722                 goto lclose;
723
724         dfd = raccept(lfd, NULL, NULL);
725         if (dfd < 0) {
726                 ret  = dfd;
727                 goto lclose;
728         }
729
730         set_rsocket_options(dfd);
731         copysockopts(dfd, sfd, &rs, &real);
732         real.shutdown(sfd, SHUT_RDWR);
733         real.close(sfd);
734         fd_store(socket, dfd, fd_rsocket, fd_ready);
735
736 lclose:
737         rclose(lfd);
738         sem_post(sem);
739 sclose:
740         sem_close(sem);
741 out:
742         if (ret)
743                 fd_store(socket, sfd, fd_normal, fd_ready);
744 }
745
746 static inline enum fd_type fd_fork_get(int index, int *fd)
747 {
748         struct fd_info *fdi;
749
750         fdi = idm_lookup(&idm, index);
751         if (fdi) {
752                 if (fdi->state == fd_fork_passive)
753                         fork_passive(index);
754                 else if (fdi->state == fd_fork_active)
755                         fork_active(index);
756                 *fd = fdi->fd;
757                 return fdi->type;
758
759         } else {
760                 *fd = index;
761                 return fd_normal;
762         }
763 }
764
765 int connect(int socket, const struct sockaddr *addr, socklen_t addrlen)
766 {
767         int fd, ret;
768
769         if (fd_get(socket, &fd) == fd_rsocket) {
770                 ret = rconnect(fd, addr, addrlen);
771                 if (!ret || errno == EINPROGRESS)
772                         return ret;
773
774                 ret = transpose_socket(socket, fd_normal);
775                 if (ret < 0)
776                         return ret;
777
778                 rclose(fd);
779                 fd = ret;
780         } else if (fd_gets(socket) == fd_fork) {
781                 fd_store(socket, fd, fd_normal, fd_fork_active);
782         }
783
784         return real.connect(fd, addr, addrlen);
785 }
786
787 ssize_t recv(int socket, void *buf, size_t len, int flags)
788 {
789         int fd;
790         return (fd_fork_get(socket, &fd) == fd_rsocket) ?
791                 rrecv(fd, buf, len, flags) : real.recv(fd, buf, len, flags);
792 }
793
794 ssize_t recvfrom(int socket, void *buf, size_t len, int flags,
795                  struct sockaddr *src_addr, socklen_t *addrlen)
796 {
797         int fd;
798         return (fd_fork_get(socket, &fd) == fd_rsocket) ?
799                 rrecvfrom(fd, buf, len, flags, src_addr, addrlen) :
800                 real.recvfrom(fd, buf, len, flags, src_addr, addrlen);
801 }
802
803 ssize_t recvmsg(int socket, struct msghdr *msg, int flags)
804 {
805         int fd;
806         return (fd_fork_get(socket, &fd) == fd_rsocket) ?
807                 rrecvmsg(fd, msg, flags) : real.recvmsg(fd, msg, flags);
808 }
809
810 ssize_t read(int socket, void *buf, size_t count)
811 {
812         int fd;
813         init_preload();
814         return (fd_fork_get(socket, &fd) == fd_rsocket) ?
815                 rread(fd, buf, count) : real.read(fd, buf, count);
816 }
817
818 ssize_t readv(int socket, const struct iovec *iov, int iovcnt)
819 {
820         int fd;
821         init_preload();
822         return (fd_fork_get(socket, &fd) == fd_rsocket) ?
823                 rreadv(fd, iov, iovcnt) : real.readv(fd, iov, iovcnt);
824 }
825
826 ssize_t send(int socket, const void *buf, size_t len, int flags)
827 {
828         int fd;
829         return (fd_fork_get(socket, &fd) == fd_rsocket) ?
830                 rsend(fd, buf, len, flags) : real.send(fd, buf, len, flags);
831 }
832
833 ssize_t sendto(int socket, const void *buf, size_t len, int flags,
834                 const struct sockaddr *dest_addr, socklen_t addrlen)
835 {
836         int fd;
837         return (fd_fork_get(socket, &fd) == fd_rsocket) ?
838                 rsendto(fd, buf, len, flags, dest_addr, addrlen) :
839                 real.sendto(fd, buf, len, flags, dest_addr, addrlen);
840 }
841
842 ssize_t sendmsg(int socket, const struct msghdr *msg, int flags)
843 {
844         int fd;
845         return (fd_fork_get(socket, &fd) == fd_rsocket) ?
846                 rsendmsg(fd, msg, flags) : real.sendmsg(fd, msg, flags);
847 }
848
849 ssize_t write(int socket, const void *buf, size_t count)
850 {
851         int fd;
852         init_preload();
853         return (fd_fork_get(socket, &fd) == fd_rsocket) ?
854                 rwrite(fd, buf, count) : real.write(fd, buf, count);
855 }
856
857 ssize_t writev(int socket, const struct iovec *iov, int iovcnt)
858 {
859         int fd;
860         init_preload();
861         return (fd_fork_get(socket, &fd) == fd_rsocket) ?
862                 rwritev(fd, iov, iovcnt) : real.writev(fd, iov, iovcnt);
863 }
864
865 static struct pollfd *fds_alloc(nfds_t nfds)
866 {
867         static __thread struct pollfd *rfds;
868         static __thread nfds_t rnfds;
869
870         if (nfds > rnfds) {
871                 if (rfds)
872                         free(rfds);
873
874                 rfds = malloc(sizeof(*rfds) * nfds);
875                 rnfds = rfds ? nfds : 0;
876         }
877
878         return rfds;
879 }
880
881 int poll(struct pollfd *fds, nfds_t nfds, int timeout)
882 {
883         struct pollfd *rfds;
884         int i, ret;
885
886         init_preload();
887         for (i = 0; i < nfds; i++) {
888                 if (fd_gett(fds[i].fd) == fd_rsocket)
889                         goto use_rpoll;
890         }
891
892         return real.poll(fds, nfds, timeout);
893
894 use_rpoll:
895         rfds = fds_alloc(nfds);
896         if (!rfds)
897                 return ERR(ENOMEM);
898
899         for (i = 0; i < nfds; i++) {
900                 rfds[i].fd = fd_getd(fds[i].fd);
901                 rfds[i].events = fds[i].events;
902                 rfds[i].revents = 0;
903         }
904
905         ret = rpoll(rfds, nfds, timeout);
906
907         for (i = 0; i < nfds; i++)
908                 fds[i].revents = rfds[i].revents;
909
910         return ret;
911 }
912
913 static void select_to_rpoll(struct pollfd *fds, int *nfds,
914                             fd_set *readfds, fd_set *writefds, fd_set *exceptfds)
915 {
916         int fd, events, i = 0;
917
918         for (fd = 0; fd < *nfds; fd++) {
919                 events = (readfds && FD_ISSET(fd, readfds)) ? POLLIN : 0;
920                 if (writefds && FD_ISSET(fd, writefds))
921                         events |= POLLOUT;
922
923                 if (events || (exceptfds && FD_ISSET(fd, exceptfds))) {
924                         fds[i].fd = fd_getd(fd);
925                         fds[i++].events = events;
926                 }
927         }
928
929         *nfds = i;
930 }
931
932 static int rpoll_to_select(struct pollfd *fds, int nfds,
933                            fd_set *readfds, fd_set *writefds, fd_set *exceptfds)
934 {
935         int fd, rfd, i, cnt = 0;
936
937         for (i = 0, fd = 0; i < nfds; fd++) {
938                 rfd = fd_getd(fd);
939                 if (rfd != fds[i].fd)
940                         continue;
941
942                 if (readfds && (fds[i].revents & POLLIN)) {
943                         FD_SET(fd, readfds);
944                         cnt++;
945                 }
946
947                 if (writefds && (fds[i].revents & POLLOUT)) {
948                         FD_SET(fd, writefds);
949                         cnt++;
950                 }
951
952                 if (exceptfds && (fds[i].revents & ~(POLLIN | POLLOUT))) {
953                         FD_SET(fd, exceptfds);
954                         cnt++;
955                 }
956                 i++;
957         }
958
959         return cnt;
960 }
961
962 static int rs_convert_timeout(struct timeval *timeout)
963 {
964         return !timeout ? -1 : timeout->tv_sec * 1000 + timeout->tv_usec / 1000;
965 }
966
967 int select(int nfds, fd_set *readfds, fd_set *writefds,
968            fd_set *exceptfds, struct timeval *timeout)
969 {
970         struct pollfd *fds;
971         int ret;
972
973         fds = fds_alloc(nfds);
974         if (!fds)
975                 return ERR(ENOMEM);
976
977         select_to_rpoll(fds, &nfds, readfds, writefds, exceptfds);
978         ret = rpoll(fds, nfds, rs_convert_timeout(timeout));
979
980         if (readfds)
981                 FD_ZERO(readfds);
982         if (writefds)
983                 FD_ZERO(writefds);
984         if (exceptfds)
985                 FD_ZERO(exceptfds);
986
987         if (ret > 0)
988                 ret = rpoll_to_select(fds, nfds, readfds, writefds, exceptfds);
989
990         return ret;
991 }
992
993 int shutdown(int socket, int how)
994 {
995         int fd;
996         return (fd_get(socket, &fd) == fd_rsocket) ?
997                 rshutdown(fd, how) : real.shutdown(fd, how);
998 }
999
1000 int close(int socket)
1001 {
1002         struct fd_info *fdi;
1003         int ret;
1004
1005         init_preload();
1006         fdi = idm_lookup(&idm, socket);
1007         if (!fdi)
1008                 return real.close(socket);
1009
1010         if (fdi->dupfd != -1) {
1011                 ret = close(fdi->dupfd);
1012                 if (ret)
1013                         return ret;
1014         }
1015
1016         if (atomic_fetch_sub(&fdi->refcnt, 1) != 1)
1017                 return 0;
1018
1019         idm_clear(&idm, socket);
1020         real.close(socket);
1021         ret = (fdi->type == fd_rsocket) ? rclose(fdi->fd) : real.close(fdi->fd);
1022         free(fdi);
1023         return ret;
1024 }
1025
1026 int getpeername(int socket, struct sockaddr *addr, socklen_t *addrlen)
1027 {
1028         int fd;
1029         return (fd_get(socket, &fd) == fd_rsocket) ?
1030                 rgetpeername(fd, addr, addrlen) :
1031                 real.getpeername(fd, addr, addrlen);
1032 }
1033
1034 int getsockname(int socket, struct sockaddr *addr, socklen_t *addrlen)
1035 {
1036         int fd;
1037         init_preload();
1038         return (fd_get(socket, &fd) == fd_rsocket) ?
1039                 rgetsockname(fd, addr, addrlen) :
1040                 real.getsockname(fd, addr, addrlen);
1041 }
1042
1043 int setsockopt(int socket, int level, int optname,
1044                 const void *optval, socklen_t optlen)
1045 {
1046         int fd;
1047         return (fd_get(socket, &fd) == fd_rsocket) ?
1048                 rsetsockopt(fd, level, optname, optval, optlen) :
1049                 real.setsockopt(fd, level, optname, optval, optlen);
1050 }
1051
1052 int getsockopt(int socket, int level, int optname,
1053                 void *optval, socklen_t *optlen)
1054 {
1055         int fd;
1056         return (fd_get(socket, &fd) == fd_rsocket) ?
1057                 rgetsockopt(fd, level, optname, optval, optlen) :
1058                 real.getsockopt(fd, level, optname, optval, optlen);
1059 }
1060
1061 int fcntl(int socket, int cmd, ... /* arg */)
1062 {
1063         va_list args;
1064         long lparam;
1065         void *pparam;
1066         int fd, ret;
1067
1068         init_preload();
1069         va_start(args, cmd);
1070         switch (cmd) {
1071         case F_GETFD:
1072         case F_GETFL:
1073         case F_GETOWN:
1074         case F_GETSIG:
1075         case F_GETLEASE:
1076                 ret = (fd_get(socket, &fd) == fd_rsocket) ?
1077                         rfcntl(fd, cmd) : real.fcntl(fd, cmd);
1078                 break;
1079         case F_DUPFD:
1080         /*case F_DUPFD_CLOEXEC:*/
1081         case F_SETFD:
1082         case F_SETFL:
1083         case F_SETOWN:
1084         case F_SETSIG:
1085         case F_SETLEASE:
1086         case F_NOTIFY:
1087                 lparam = va_arg(args, long);
1088                 ret = (fd_get(socket, &fd) == fd_rsocket) ?
1089                         rfcntl(fd, cmd, lparam) : real.fcntl(fd, cmd, lparam);
1090                 break;
1091         default:
1092                 pparam = va_arg(args, void *);
1093                 ret = (fd_get(socket, &fd) == fd_rsocket) ?
1094                         rfcntl(fd, cmd, pparam) : real.fcntl(fd, cmd, pparam);
1095                 break;
1096         }
1097         va_end(args);
1098         return ret;
1099 }
1100
1101 /*
1102  * dup2 is not thread safe
1103  */
1104 int dup2(int oldfd, int newfd)
1105 {
1106         struct fd_info *oldfdi, *newfdi;
1107         int ret;
1108
1109         init_preload();
1110         oldfdi = idm_lookup(&idm, oldfd);
1111         if (oldfdi) {
1112                 if (oldfdi->state == fd_fork_passive)
1113                         fork_passive(oldfd);
1114                 else if (oldfdi->state == fd_fork_active)
1115                         fork_active(oldfd);
1116         }
1117
1118         newfdi = idm_lookup(&idm, newfd);
1119         if (newfdi) {
1120                  /* newfd cannot have been dup'ed directly */
1121                 if (atomic_load(&newfdi->refcnt) > 1)
1122                         return ERR(EBUSY);
1123                 close(newfd);
1124         }
1125
1126         ret = real.dup2(oldfd, newfd);
1127         if (!oldfdi || ret != newfd)
1128                 return ret;
1129
1130         newfdi = calloc(1, sizeof(*newfdi));
1131         if (!newfdi) {
1132                 close(newfd);
1133                 return ERR(ENOMEM);
1134         }
1135
1136         pthread_mutex_lock(&mut);
1137         idm_set(&idm, newfd, newfdi);
1138         pthread_mutex_unlock(&mut);
1139
1140         newfdi->fd = oldfdi->fd;
1141         newfdi->type = oldfdi->type;
1142         if (oldfdi->dupfd != -1) {
1143                 newfdi->dupfd = oldfdi->dupfd;
1144                 oldfdi = idm_lookup(&idm, oldfdi->dupfd);
1145         } else {
1146                 newfdi->dupfd = oldfd;
1147         }
1148         atomic_store(&newfdi->refcnt, 1);
1149         atomic_fetch_add(&oldfdi->refcnt, 1);
1150         return newfd;
1151 }
1152
1153 ssize_t sendfile(int out_fd, int in_fd, off_t *offset, size_t count)
1154 {
1155         void *file_addr;
1156         int fd;
1157         size_t ret;
1158
1159         if (fd_get(out_fd, &fd) != fd_rsocket)
1160                 return real.sendfile(fd, in_fd, offset, count);
1161
1162         file_addr = mmap(NULL, count, PROT_READ, 0, in_fd, offset ? *offset : 0);
1163         if (file_addr == (void *) -1)
1164                 return -1;
1165
1166         ret = rwrite(fd, file_addr, count);
1167         if ((ret > 0) && offset)
1168                 lseek(in_fd, ret, SEEK_CUR);
1169         munmap(file_addr, count);
1170         return ret;
1171 }
1172
1173 int __fxstat(int ver, int socket, struct stat *buf)
1174 {
1175         int fd, ret;
1176
1177         init_preload();
1178         if (fd_get(socket, &fd) == fd_rsocket) {
1179                 ret = real.fxstat(ver, socket, buf);
1180                 if (!ret)
1181                         buf->st_mode = (buf->st_mode & ~S_IFMT) | __S_IFSOCK;
1182         } else {
1183                 ret = real.fxstat(ver, fd, buf);
1184         }
1185         return ret;
1186 }