]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - lib/libnv/msgio.c
Ensure that libnv can be used when kern.trap_enotcap=1.
[FreeBSD/FreeBSD.git] / lib / libnv / msgio.c
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3  *
4  * Copyright (c) 2013 The FreeBSD Foundation
5  * Copyright (c) 2013 Mariusz Zaborski <oshogbo@FreeBSD.org>
6  * All rights reserved.
7  *
8  * This software was developed by Pawel Jakub Dawidek under sponsorship from
9  * the FreeBSD Foundation.
10  *
11  * Redistribution and use in source and binary forms, with or without
12  * modification, are permitted provided that the following conditions
13  * are met:
14  * 1. Redistributions of source code must retain the above copyright
15  *    notice, this list of conditions and the following disclaimer.
16  * 2. Redistributions in binary form must reproduce the above copyright
17  *    notice, this list of conditions and the following disclaimer in the
18  *    documentation and/or other materials provided with the distribution.
19  *
20  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
21  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
24  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
26  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
27  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
28  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
29  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
30  * SUCH DAMAGE.
31  */
32
33 #include <sys/cdefs.h>
34 __FBSDID("$FreeBSD$");
35
36 #include <sys/param.h>
37 #include <sys/socket.h>
38
39 #include <errno.h>
40 #include <fcntl.h>
41 #include <stdbool.h>
42 #include <stdint.h>
43 #include <stdlib.h>
44 #include <string.h>
45 #include <unistd.h>
46
47 #ifdef HAVE_PJDLOG
48 #include <pjdlog.h>
49 #endif
50
51 #include "common_impl.h"
52 #include "msgio.h"
53
54 #ifndef HAVE_PJDLOG
55 #include <assert.h>
56 #define PJDLOG_ASSERT(...)              assert(__VA_ARGS__)
57 #define PJDLOG_RASSERT(expr, ...)       assert(expr)
58 #define PJDLOG_ABORT(...)               abort()
59 #endif
60
61 #define PKG_MAX_SIZE    (MCLBYTES / CMSG_SPACE(sizeof(int)) - 1)
62
63 static int
64 msghdr_add_fd(struct cmsghdr *cmsg, int fd)
65 {
66
67         PJDLOG_ASSERT(fd >= 0);
68
69         cmsg->cmsg_level = SOL_SOCKET;
70         cmsg->cmsg_type = SCM_RIGHTS;
71         cmsg->cmsg_len = CMSG_LEN(sizeof(fd));
72         bcopy(&fd, CMSG_DATA(cmsg), sizeof(fd));
73
74         return (0);
75 }
76
77 static int
78 msghdr_get_fd(struct cmsghdr *cmsg)
79 {
80         int fd;
81
82         if (cmsg == NULL || cmsg->cmsg_level != SOL_SOCKET ||
83             cmsg->cmsg_type != SCM_RIGHTS ||
84             cmsg->cmsg_len != CMSG_LEN(sizeof(fd))) {
85                 errno = EINVAL;
86                 return (-1);
87         }
88
89         bcopy(CMSG_DATA(cmsg), &fd, sizeof(fd));
90 #ifndef MSG_CMSG_CLOEXEC
91         /*
92          * If the MSG_CMSG_CLOEXEC flag is not available we cannot set the
93          * close-on-exec flag atomically, but we still want to set it for
94          * consistency.
95          */
96         (void) fcntl(fd, F_SETFD, FD_CLOEXEC);
97 #endif
98
99         return (fd);
100 }
101
102 static void
103 fd_wait(int fd, bool doread)
104 {
105         fd_set fds;
106
107         PJDLOG_ASSERT(fd >= 0);
108
109         FD_ZERO(&fds);
110         FD_SET(fd, &fds);
111         (void)select(fd + 1, doread ? &fds : NULL, doread ? NULL : &fds,
112             NULL, NULL);
113 }
114
115 static int
116 msg_recv(int sock, struct msghdr *msg)
117 {
118         int flags;
119
120         PJDLOG_ASSERT(sock >= 0);
121
122 #ifdef MSG_CMSG_CLOEXEC
123         flags = MSG_CMSG_CLOEXEC;
124 #else
125         flags = 0;
126 #endif
127
128         for (;;) {
129                 fd_wait(sock, true);
130                 if (recvmsg(sock, msg, flags) == -1) {
131                         if (errno == EINTR)
132                                 continue;
133                         return (-1);
134                 }
135                 break;
136         }
137
138         return (0);
139 }
140
141 static int
142 msg_send(int sock, const struct msghdr *msg)
143 {
144
145         PJDLOG_ASSERT(sock >= 0);
146
147         for (;;) {
148                 fd_wait(sock, false);
149                 if (sendmsg(sock, msg, 0) == -1) {
150                         if (errno == EINTR)
151                                 continue;
152                         return (-1);
153                 }
154                 break;
155         }
156
157         return (0);
158 }
159
160 int
161 cred_send(int sock)
162 {
163         unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
164         struct msghdr msg;
165         struct cmsghdr *cmsg;
166         struct iovec iov;
167         uint8_t dummy;
168
169         bzero(credbuf, sizeof(credbuf));
170         bzero(&msg, sizeof(msg));
171         bzero(&iov, sizeof(iov));
172
173         /*
174          * XXX: We send one byte along with the control message, because
175          *      setting msg_iov to NULL only works if this is the first
176          *      packet send over the socket. Once we send some data we
177          *      won't be able to send credentials anymore. This is most
178          *      likely a kernel bug.
179          */
180         dummy = 0;
181         iov.iov_base = &dummy;
182         iov.iov_len = sizeof(dummy);
183
184         msg.msg_iov = &iov;
185         msg.msg_iovlen = 1;
186         msg.msg_control = credbuf;
187         msg.msg_controllen = sizeof(credbuf);
188
189         cmsg = CMSG_FIRSTHDR(&msg);
190         cmsg->cmsg_len = CMSG_LEN(sizeof(struct cmsgcred));
191         cmsg->cmsg_level = SOL_SOCKET;
192         cmsg->cmsg_type = SCM_CREDS;
193
194         if (msg_send(sock, &msg) == -1)
195                 return (-1);
196
197         return (0);
198 }
199
200 int
201 cred_recv(int sock, struct cmsgcred *cred)
202 {
203         unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
204         struct msghdr msg;
205         struct cmsghdr *cmsg;
206         struct iovec iov;
207         uint8_t dummy;
208
209         bzero(credbuf, sizeof(credbuf));
210         bzero(&msg, sizeof(msg));
211         bzero(&iov, sizeof(iov));
212
213         iov.iov_base = &dummy;
214         iov.iov_len = sizeof(dummy);
215
216         msg.msg_iov = &iov;
217         msg.msg_iovlen = 1;
218         msg.msg_control = credbuf;
219         msg.msg_controllen = sizeof(credbuf);
220
221         if (msg_recv(sock, &msg) == -1)
222                 return (-1);
223
224         cmsg = CMSG_FIRSTHDR(&msg);
225         if (cmsg == NULL ||
226             cmsg->cmsg_len != CMSG_LEN(sizeof(struct cmsgcred)) ||
227             cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDS) {
228                 errno = EINVAL;
229                 return (-1);
230         }
231         bcopy(CMSG_DATA(cmsg), cred, sizeof(*cred));
232
233         return (0);
234 }
235
236 static int
237 fd_package_send(int sock, const int *fds, size_t nfds)
238 {
239         struct msghdr msg;
240         struct cmsghdr *cmsg;
241         struct iovec iov;
242         unsigned int i;
243         int serrno, ret;
244         uint8_t dummy;
245
246         PJDLOG_ASSERT(sock >= 0);
247         PJDLOG_ASSERT(fds != NULL);
248         PJDLOG_ASSERT(nfds > 0);
249
250         bzero(&msg, sizeof(msg));
251
252         /*
253          * XXX: Look into cred_send function for more details.
254          */
255         dummy = 0;
256         iov.iov_base = &dummy;
257         iov.iov_len = sizeof(dummy);
258
259         msg.msg_iov = &iov;
260         msg.msg_iovlen = 1;
261         msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
262         msg.msg_control = calloc(1, msg.msg_controllen);
263         if (msg.msg_control == NULL)
264                 return (-1);
265
266         ret = -1;
267
268         for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
269             i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
270                 if (msghdr_add_fd(cmsg, fds[i]) == -1)
271                         goto end;
272         }
273
274         if (msg_send(sock, &msg) == -1)
275                 goto end;
276
277         ret = 0;
278 end:
279         serrno = errno;
280         free(msg.msg_control);
281         errno = serrno;
282         return (ret);
283 }
284
285 static int
286 fd_package_recv(int sock, int *fds, size_t nfds)
287 {
288         struct msghdr msg;
289         struct cmsghdr *cmsg;
290         unsigned int i;
291         int serrno, ret;
292         struct iovec iov;
293         uint8_t dummy;
294
295         PJDLOG_ASSERT(sock >= 0);
296         PJDLOG_ASSERT(nfds > 0);
297         PJDLOG_ASSERT(fds != NULL);
298
299         bzero(&msg, sizeof(msg));
300         bzero(&iov, sizeof(iov));
301
302         /*
303          * XXX: Look into cred_send function for more details.
304          */
305         iov.iov_base = &dummy;
306         iov.iov_len = sizeof(dummy);
307
308         msg.msg_iov = &iov;
309         msg.msg_iovlen = 1;
310         msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
311         msg.msg_control = calloc(1, msg.msg_controllen);
312         if (msg.msg_control == NULL)
313                 return (-1);
314
315         ret = -1;
316
317         if (msg_recv(sock, &msg) == -1)
318                 goto end;
319
320         for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
321             i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
322                 fds[i] = msghdr_get_fd(cmsg);
323                 if (fds[i] < 0)
324                         break;
325         }
326
327         if (cmsg != NULL || i < nfds) {
328                 int fd;
329
330                 /*
331                  * We need to close all received descriptors, even if we have
332                  * different control message (eg. SCM_CREDS) in between.
333                  */
334                 for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL;
335                     cmsg = CMSG_NXTHDR(&msg, cmsg)) {
336                         fd = msghdr_get_fd(cmsg);
337                         if (fd >= 0)
338                                 close(fd);
339                 }
340                 errno = EINVAL;
341                 goto end;
342         }
343
344         ret = 0;
345 end:
346         serrno = errno;
347         free(msg.msg_control);
348         errno = serrno;
349         return (ret);
350 }
351
352 int
353 fd_recv(int sock, int *fds, size_t nfds)
354 {
355         unsigned int i, step, j;
356         int ret, serrno;
357
358         if (nfds == 0 || fds == NULL) {
359                 errno = EINVAL;
360                 return (-1);
361         }
362
363         ret = i = step = 0;
364         while (i < nfds) {
365                 if (PKG_MAX_SIZE < nfds - i)
366                         step = PKG_MAX_SIZE;
367                 else
368                         step = nfds - i;
369                 ret = fd_package_recv(sock, fds + i, step);
370                 if (ret != 0) {
371                         /* Close all received descriptors. */
372                         serrno = errno;
373                         for (j = 0; j < i; j++)
374                                 close(fds[j]);
375                         errno = serrno;
376                         break;
377                 }
378                 i += step;
379         }
380
381         return (ret);
382 }
383
384 int
385 fd_send(int sock, const int *fds, size_t nfds)
386 {
387         unsigned int i, step;
388         int ret;
389
390         if (nfds == 0 || fds == NULL) {
391                 errno = EINVAL;
392                 return (-1);
393         }
394
395         ret = i = step = 0;
396         while (i < nfds) {
397                 if (PKG_MAX_SIZE < nfds - i)
398                         step = PKG_MAX_SIZE;
399                 else
400                         step = nfds - i;
401                 ret = fd_package_send(sock, fds + i, step);
402                 if (ret != 0)
403                         break;
404                 i += step;
405         }
406
407         return (ret);
408 }
409
410 int
411 buf_send(int sock, void *buf, size_t size)
412 {
413         ssize_t done;
414         unsigned char *ptr;
415
416         PJDLOG_ASSERT(sock >= 0);
417         PJDLOG_ASSERT(size > 0);
418         PJDLOG_ASSERT(buf != NULL);
419
420         ptr = buf;
421         do {
422                 fd_wait(sock, false);
423                 done = send(sock, ptr, size, 0);
424                 if (done == -1) {
425                         if (errno == EINTR)
426                                 continue;
427                         return (-1);
428                 } else if (done == 0) {
429                         errno = ENOTCONN;
430                         return (-1);
431                 }
432                 size -= done;
433                 ptr += done;
434         } while (size > 0);
435
436         return (0);
437 }
438
439 int
440 buf_recv(int sock, void *buf, size_t size)
441 {
442         ssize_t done;
443         unsigned char *ptr;
444
445         PJDLOG_ASSERT(sock >= 0);
446         PJDLOG_ASSERT(buf != NULL);
447
448         ptr = buf;
449         while (size > 0) {
450                 fd_wait(sock, true);
451                 done = recv(sock, ptr, size, 0);
452                 if (done == -1) {
453                         if (errno == EINTR)
454                                 continue;
455                         return (-1);
456                 } else if (done == 0) {
457                         errno = ENOTCONN;
458                         return (-1);
459                 }
460                 size -= done;
461                 ptr += done;
462         }
463
464         return (0);
465 }