]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - lib/libnv/msgio.c
MFV r345495:
[FreeBSD/FreeBSD.git] / lib / libnv / msgio.c
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3  *
4  * Copyright (c) 2013 The FreeBSD Foundation
5  * Copyright (c) 2013 Mariusz Zaborski <oshogbo@FreeBSD.org>
6  * All rights reserved.
7  *
8  * This software was developed by Pawel Jakub Dawidek under sponsorship from
9  * the FreeBSD Foundation.
10  *
11  * Redistribution and use in source and binary forms, with or without
12  * modification, are permitted provided that the following conditions
13  * are met:
14  * 1. Redistributions of source code must retain the above copyright
15  *    notice, this list of conditions and the following disclaimer.
16  * 2. Redistributions in binary form must reproduce the above copyright
17  *    notice, this list of conditions and the following disclaimer in the
18  *    documentation and/or other materials provided with the distribution.
19  *
20  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
21  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
24  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
26  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
27  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
28  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
29  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
30  * SUCH DAMAGE.
31  */
32
33 #include <sys/cdefs.h>
34 __FBSDID("$FreeBSD$");
35
36 #include <sys/param.h>
37 #include <sys/socket.h>
38
39 #include <errno.h>
40 #include <fcntl.h>
41 #include <stdbool.h>
42 #include <stdint.h>
43 #include <stdlib.h>
44 #include <string.h>
45 #include <unistd.h>
46
47 #ifdef HAVE_PJDLOG
48 #include <pjdlog.h>
49 #endif
50
51 #include "common_impl.h"
52 #include "msgio.h"
53
54 #ifndef HAVE_PJDLOG
55 #include <assert.h>
56 #define PJDLOG_ASSERT(...)              assert(__VA_ARGS__)
57 #define PJDLOG_RASSERT(expr, ...)       assert(expr)
58 #define PJDLOG_ABORT(...)               abort()
59 #endif
60
61 #define PKG_MAX_SIZE    (MCLBYTES / CMSG_SPACE(sizeof(int)) - 1)
62
63 static int
64 msghdr_add_fd(struct cmsghdr *cmsg, int fd)
65 {
66
67         PJDLOG_ASSERT(fd >= 0);
68
69         cmsg->cmsg_level = SOL_SOCKET;
70         cmsg->cmsg_type = SCM_RIGHTS;
71         cmsg->cmsg_len = CMSG_LEN(sizeof(fd));
72         bcopy(&fd, CMSG_DATA(cmsg), sizeof(fd));
73
74         return (0);
75 }
76
77 static int
78 msghdr_get_fd(struct cmsghdr *cmsg)
79 {
80         int fd;
81
82         if (cmsg == NULL || cmsg->cmsg_level != SOL_SOCKET ||
83             cmsg->cmsg_type != SCM_RIGHTS ||
84             cmsg->cmsg_len != CMSG_LEN(sizeof(fd))) {
85                 errno = EINVAL;
86                 return (-1);
87         }
88
89         bcopy(CMSG_DATA(cmsg), &fd, sizeof(fd));
90 #ifndef MSG_CMSG_CLOEXEC
91         /*
92          * If the MSG_CMSG_CLOEXEC flag is not available we cannot set the
93          * close-on-exec flag atomically, but we still want to set it for
94          * consistency.
95          */
96         (void) fcntl(fd, F_SETFD, FD_CLOEXEC);
97 #endif
98
99         return (fd);
100 }
101
102 static void
103 fd_wait(int fd, bool doread)
104 {
105         fd_set fds;
106
107         PJDLOG_ASSERT(fd >= 0);
108
109         FD_ZERO(&fds);
110         FD_SET(fd, &fds);
111         (void)select(fd + 1, doread ? &fds : NULL, doread ? NULL : &fds,
112             NULL, NULL);
113 }
114
115 static int
116 msg_recv(int sock, struct msghdr *msg)
117 {
118         int flags;
119
120         PJDLOG_ASSERT(sock >= 0);
121
122 #ifdef MSG_CMSG_CLOEXEC
123         flags = MSG_CMSG_CLOEXEC;
124 #else
125         flags = 0;
126 #endif
127
128         for (;;) {
129                 fd_wait(sock, true);
130                 if (recvmsg(sock, msg, flags) == -1) {
131                         if (errno == EINTR)
132                                 continue;
133                         return (-1);
134                 }
135                 break;
136         }
137
138         return (0);
139 }
140
141 static int
142 msg_send(int sock, const struct msghdr *msg)
143 {
144
145         PJDLOG_ASSERT(sock >= 0);
146
147         for (;;) {
148                 fd_wait(sock, false);
149                 if (sendmsg(sock, msg, 0) == -1) {
150                         if (errno == EINTR)
151                                 continue;
152                         return (-1);
153                 }
154                 break;
155         }
156
157         return (0);
158 }
159
160 /*
161  * MacOS/Linux do not define struct cmsgcred but we need to bootstrap libnv
162  * when building on non-FreeBSD systems. Since they are not used during
163  * bootstrap we can just omit these two functions there.
164  */
165 #ifndef __FreeBSD__
166 #warning "cred_send() not supported on non-FreeBSD systems"
167 #else
168 int
169 cred_send(int sock)
170 {
171         unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
172         struct msghdr msg;
173         struct cmsghdr *cmsg;
174         struct iovec iov;
175         uint8_t dummy;
176
177         bzero(credbuf, sizeof(credbuf));
178         bzero(&msg, sizeof(msg));
179         bzero(&iov, sizeof(iov));
180
181         /*
182          * XXX: We send one byte along with the control message, because
183          *      setting msg_iov to NULL only works if this is the first
184          *      packet send over the socket. Once we send some data we
185          *      won't be able to send credentials anymore. This is most
186          *      likely a kernel bug.
187          */
188         dummy = 0;
189         iov.iov_base = &dummy;
190         iov.iov_len = sizeof(dummy);
191
192         msg.msg_iov = &iov;
193         msg.msg_iovlen = 1;
194         msg.msg_control = credbuf;
195         msg.msg_controllen = sizeof(credbuf);
196
197         cmsg = CMSG_FIRSTHDR(&msg);
198         cmsg->cmsg_len = CMSG_LEN(sizeof(struct cmsgcred));
199         cmsg->cmsg_level = SOL_SOCKET;
200         cmsg->cmsg_type = SCM_CREDS;
201
202         if (msg_send(sock, &msg) == -1)
203                 return (-1);
204
205         return (0);
206 }
207
208 int
209 cred_recv(int sock, struct cmsgcred *cred)
210 {
211         unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
212         struct msghdr msg;
213         struct cmsghdr *cmsg;
214         struct iovec iov;
215         uint8_t dummy;
216
217         bzero(credbuf, sizeof(credbuf));
218         bzero(&msg, sizeof(msg));
219         bzero(&iov, sizeof(iov));
220
221         iov.iov_base = &dummy;
222         iov.iov_len = sizeof(dummy);
223
224         msg.msg_iov = &iov;
225         msg.msg_iovlen = 1;
226         msg.msg_control = credbuf;
227         msg.msg_controllen = sizeof(credbuf);
228
229         if (msg_recv(sock, &msg) == -1)
230                 return (-1);
231
232         cmsg = CMSG_FIRSTHDR(&msg);
233         if (cmsg == NULL ||
234             cmsg->cmsg_len != CMSG_LEN(sizeof(struct cmsgcred)) ||
235             cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDS) {
236                 errno = EINVAL;
237                 return (-1);
238         }
239         bcopy(CMSG_DATA(cmsg), cred, sizeof(*cred));
240
241         return (0);
242 }
243 #endif
244
245 static int
246 fd_package_send(int sock, const int *fds, size_t nfds)
247 {
248         struct msghdr msg;
249         struct cmsghdr *cmsg;
250         struct iovec iov;
251         unsigned int i;
252         int serrno, ret;
253         uint8_t dummy;
254
255         PJDLOG_ASSERT(sock >= 0);
256         PJDLOG_ASSERT(fds != NULL);
257         PJDLOG_ASSERT(nfds > 0);
258
259         bzero(&msg, sizeof(msg));
260
261         /*
262          * XXX: Look into cred_send function for more details.
263          */
264         dummy = 0;
265         iov.iov_base = &dummy;
266         iov.iov_len = sizeof(dummy);
267
268         msg.msg_iov = &iov;
269         msg.msg_iovlen = 1;
270         msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
271         msg.msg_control = calloc(1, msg.msg_controllen);
272         if (msg.msg_control == NULL)
273                 return (-1);
274
275         ret = -1;
276
277         for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
278             i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
279                 if (msghdr_add_fd(cmsg, fds[i]) == -1)
280                         goto end;
281         }
282
283         if (msg_send(sock, &msg) == -1)
284                 goto end;
285
286         ret = 0;
287 end:
288         serrno = errno;
289         free(msg.msg_control);
290         errno = serrno;
291         return (ret);
292 }
293
294 static int
295 fd_package_recv(int sock, int *fds, size_t nfds)
296 {
297         struct msghdr msg;
298         struct cmsghdr *cmsg;
299         unsigned int i;
300         int serrno, ret;
301         struct iovec iov;
302         uint8_t dummy;
303
304         PJDLOG_ASSERT(sock >= 0);
305         PJDLOG_ASSERT(nfds > 0);
306         PJDLOG_ASSERT(fds != NULL);
307
308         bzero(&msg, sizeof(msg));
309         bzero(&iov, sizeof(iov));
310
311         /*
312          * XXX: Look into cred_send function for more details.
313          */
314         iov.iov_base = &dummy;
315         iov.iov_len = sizeof(dummy);
316
317         msg.msg_iov = &iov;
318         msg.msg_iovlen = 1;
319         msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
320         msg.msg_control = calloc(1, msg.msg_controllen);
321         if (msg.msg_control == NULL)
322                 return (-1);
323
324         ret = -1;
325
326         if (msg_recv(sock, &msg) == -1)
327                 goto end;
328
329         for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
330             i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
331                 fds[i] = msghdr_get_fd(cmsg);
332                 if (fds[i] < 0)
333                         break;
334         }
335
336         if (cmsg != NULL || i < nfds) {
337                 int fd;
338
339                 /*
340                  * We need to close all received descriptors, even if we have
341                  * different control message (eg. SCM_CREDS) in between.
342                  */
343                 for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL;
344                     cmsg = CMSG_NXTHDR(&msg, cmsg)) {
345                         fd = msghdr_get_fd(cmsg);
346                         if (fd >= 0)
347                                 close(fd);
348                 }
349                 errno = EINVAL;
350                 goto end;
351         }
352
353         ret = 0;
354 end:
355         serrno = errno;
356         free(msg.msg_control);
357         errno = serrno;
358         return (ret);
359 }
360
361 int
362 fd_recv(int sock, int *fds, size_t nfds)
363 {
364         unsigned int i, step, j;
365         int ret, serrno;
366
367         if (nfds == 0 || fds == NULL) {
368                 errno = EINVAL;
369                 return (-1);
370         }
371
372         ret = i = step = 0;
373         while (i < nfds) {
374                 if (PKG_MAX_SIZE < nfds - i)
375                         step = PKG_MAX_SIZE;
376                 else
377                         step = nfds - i;
378                 ret = fd_package_recv(sock, fds + i, step);
379                 if (ret != 0) {
380                         /* Close all received descriptors. */
381                         serrno = errno;
382                         for (j = 0; j < i; j++)
383                                 close(fds[j]);
384                         errno = serrno;
385                         break;
386                 }
387                 i += step;
388         }
389
390         return (ret);
391 }
392
393 int
394 fd_send(int sock, const int *fds, size_t nfds)
395 {
396         unsigned int i, step;
397         int ret;
398
399         if (nfds == 0 || fds == NULL) {
400                 errno = EINVAL;
401                 return (-1);
402         }
403
404         ret = i = step = 0;
405         while (i < nfds) {
406                 if (PKG_MAX_SIZE < nfds - i)
407                         step = PKG_MAX_SIZE;
408                 else
409                         step = nfds - i;
410                 ret = fd_package_send(sock, fds + i, step);
411                 if (ret != 0)
412                         break;
413                 i += step;
414         }
415
416         return (ret);
417 }
418
419 int
420 buf_send(int sock, void *buf, size_t size)
421 {
422         ssize_t done;
423         unsigned char *ptr;
424
425         PJDLOG_ASSERT(sock >= 0);
426         PJDLOG_ASSERT(size > 0);
427         PJDLOG_ASSERT(buf != NULL);
428
429         ptr = buf;
430         do {
431                 fd_wait(sock, false);
432                 done = send(sock, ptr, size, 0);
433                 if (done == -1) {
434                         if (errno == EINTR)
435                                 continue;
436                         return (-1);
437                 } else if (done == 0) {
438                         errno = ENOTCONN;
439                         return (-1);
440                 }
441                 size -= done;
442                 ptr += done;
443         } while (size > 0);
444
445         return (0);
446 }
447
448 int
449 buf_recv(int sock, void *buf, size_t size)
450 {
451         ssize_t done;
452         unsigned char *ptr;
453
454         PJDLOG_ASSERT(sock >= 0);
455         PJDLOG_ASSERT(buf != NULL);
456
457         ptr = buf;
458         while (size > 0) {
459                 fd_wait(sock, true);
460                 done = recv(sock, ptr, size, 0);
461                 if (done == -1) {
462                         if (errno == EINTR)
463                                 continue;
464                         return (-1);
465                 } else if (done == 0) {
466                         errno = ENOTCONN;
467                         return (-1);
468                 }
469                 size -= done;
470                 ptr += done;
471         }
472
473         return (0);
474 }