]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - contrib/lib9p/transport/socket.c
Import lib9p 9d5aee77bcc1bf0e79b0a3bfefff5fdf2146283c.
[FreeBSD/FreeBSD.git] / contrib / lib9p / transport / socket.c
1 /*
2  * Copyright 2016 Jakub Klama <jceel@FreeBSD.org>
3  * All rights reserved
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted providing that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
16  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
18  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
20  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
21  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
22  * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
23  * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
24  * POSSIBILITY OF SUCH DAMAGE.
25  *
26  */
27
28 #include <stdlib.h>
29 #include <errno.h>
30 #include <string.h>
31 #include <unistd.h>
32 #include <pthread.h>
33 #include <assert.h>
34 #include <sys/types.h>
35 #ifdef __APPLE__
36 # include "../apple_endian.h"
37 #else
38 # include <sys/endian.h>
39 #endif
40 #include <sys/socket.h>
41 #include <sys/event.h>
42 #include <sys/uio.h>
43 #include <netdb.h>
44 #include "../lib9p.h"
45 #include "../lib9p_impl.h"
46 #include "../log.h"
47 #include "socket.h"
48
49 struct l9p_socket_softc
50 {
51         struct l9p_connection *ls_conn;
52         struct sockaddr ls_sockaddr;
53         socklen_t ls_socklen;
54         pthread_t ls_thread;
55         int ls_fd;
56 };
57
58 static int l9p_socket_readmsg(struct l9p_socket_softc *, void **, size_t *);
59 static int l9p_socket_get_response_buffer(struct l9p_request *,
60     struct iovec *, size_t *, void *);
61 static int l9p_socket_send_response(struct l9p_request *, const struct iovec *,
62     const size_t, const size_t, void *);
63 static void l9p_socket_drop_response(struct l9p_request *, const struct iovec *,
64     size_t, void *);
65 static void *l9p_socket_thread(void *);
66 static ssize_t xread(int, void *, size_t);
67 static ssize_t xwrite(int, void *, size_t);
68
69 int
70 l9p_start_server(struct l9p_server *server, const char *host, const char *port)
71 {
72         struct addrinfo *res, *res0, hints;
73         struct kevent kev[2];
74         struct kevent event[2];
75         int err, kq, i, val, evs, nsockets = 0;
76         int sockets[2];
77
78         memset(&hints, 0, sizeof(hints));
79         hints.ai_family = PF_UNSPEC;
80         hints.ai_socktype = SOCK_STREAM;
81         err = getaddrinfo(host, port, &hints, &res0);
82
83         if (err)
84                 return (-1);
85
86         for (res = res0; res; res = res->ai_next) {
87                 int s = socket(res->ai_family, res->ai_socktype,
88                     res->ai_protocol);
89
90                 val = 1;
91                 setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val));
92
93                 if (s < 0)
94                         continue;
95
96                 if (bind(s, res->ai_addr, res->ai_addrlen) < 0) {
97                         close(s);
98                         continue;
99                 }
100
101                 sockets[nsockets] = s;
102                 EV_SET(&kev[nsockets++], s, EVFILT_READ, EV_ADD | EV_ENABLE, 0,
103                     0, 0);
104                 listen(s, 10);
105         }
106
107         if (nsockets < 1) {
108                 L9P_LOG(L9P_ERROR, "bind(): %s", strerror(errno));
109                 return(-1);
110         }
111
112         kq = kqueue();
113
114         if (kevent(kq, kev, nsockets, NULL, 0, NULL) < 0) {
115                 L9P_LOG(L9P_ERROR, "kevent(): %s", strerror(errno));
116                 return (-1);
117         }
118
119         for (;;) {
120                 evs = kevent(kq, NULL, 0, event, nsockets, NULL);
121                 if (evs < 0) {
122                         if (errno == EINTR)
123                                 continue;
124
125                         L9P_LOG(L9P_ERROR, "kevent(): %s", strerror(errno));
126                         return (-1);
127                 }
128
129                 for (i = 0; i < evs; i++) {
130                         struct sockaddr client_addr;
131                         socklen_t client_addr_len = sizeof(client_addr);
132                         int news = accept((int)event[i].ident, &client_addr,
133                             &client_addr_len);
134
135                         if (news < 0) {
136                                 L9P_LOG(L9P_WARNING, "accept(): %s",
137                                     strerror(errno));
138                                 continue;
139                         }
140
141                         l9p_socket_accept(server, news, &client_addr,
142                             client_addr_len);
143                 }
144         }
145
146 }
147
148 void
149 l9p_socket_accept(struct l9p_server *server, int conn_fd,
150     struct sockaddr *client_addr, socklen_t client_addr_len)
151 {
152         struct l9p_socket_softc *sc;
153         struct l9p_connection *conn;
154         char host[NI_MAXHOST + 1];
155         char serv[NI_MAXSERV + 1];
156         int err;
157
158         err = getnameinfo(client_addr, client_addr_len, host, NI_MAXHOST, serv,
159             NI_MAXSERV, NI_NUMERICHOST | NI_NUMERICSERV);
160
161         if (err != 0) {
162                 L9P_LOG(L9P_WARNING, "cannot look up client name: %s",
163                     gai_strerror(err));
164         } else {
165                 L9P_LOG(L9P_INFO, "new connection from %s:%s", host, serv);
166         }
167
168         if (l9p_connection_init(server, &conn) != 0) {
169                 L9P_LOG(L9P_ERROR, "cannot create new connection");
170                 return;
171         }
172
173         sc = l9p_calloc(1, sizeof(*sc));
174         sc->ls_conn = conn;
175         sc->ls_fd = conn_fd;
176
177         /*
178          * Fill in transport handler functions and aux argument.
179          */
180         conn->lc_lt.lt_aux = sc;
181         conn->lc_lt.lt_get_response_buffer = l9p_socket_get_response_buffer;
182         conn->lc_lt.lt_send_response = l9p_socket_send_response;
183         conn->lc_lt.lt_drop_response = l9p_socket_drop_response;
184
185         err = pthread_create(&sc->ls_thread, NULL, l9p_socket_thread, sc);
186         if (err) {
187                 L9P_LOG(L9P_ERROR,
188                     "pthread_create (for connection from %s:%s): error %s",
189                     host, serv, strerror(err));
190                 l9p_connection_close(sc->ls_conn);
191                 free(sc);
192         }
193 }
194
195 static void *
196 l9p_socket_thread(void *arg)
197 {
198         struct l9p_socket_softc *sc = (struct l9p_socket_softc *)arg;
199         struct iovec iov;
200         void *buf;
201         size_t length;
202
203         for (;;) {
204                 if (l9p_socket_readmsg(sc, &buf, &length) != 0)
205                         break;
206
207                 iov.iov_base = buf;
208                 iov.iov_len = length;
209                 l9p_connection_recv(sc->ls_conn, &iov, 1, NULL);
210                 free(buf);
211         }
212
213         L9P_LOG(L9P_INFO, "connection closed");
214         l9p_connection_close(sc->ls_conn);
215         free(sc);
216         return (NULL);
217 }
218
219 static int
220 l9p_socket_readmsg(struct l9p_socket_softc *sc, void **buf, size_t *size)
221 {
222         uint32_t msize;
223         size_t toread;
224         ssize_t ret;
225         void *buffer;
226         int fd = sc->ls_fd;
227
228         assert(fd > 0);
229
230         buffer = l9p_malloc(sizeof(uint32_t));
231
232         ret = xread(fd, buffer, sizeof(uint32_t));
233         if (ret < 0) {
234                 L9P_LOG(L9P_ERROR, "read(): %s", strerror(errno));
235                 return (-1);
236         }
237
238         if (ret != sizeof(uint32_t)) {
239                 if (ret == 0)
240                         L9P_LOG(L9P_DEBUG, "%p: EOF", (void *)sc->ls_conn);
241                 else
242                         L9P_LOG(L9P_ERROR,
243                             "short read: %zd bytes of %zd expected",
244                             ret, sizeof(uint32_t));
245                 return (-1);
246         }
247
248         msize = le32toh(*(uint32_t *)buffer);
249         toread = msize - sizeof(uint32_t);
250         buffer = l9p_realloc(buffer, msize);
251
252         ret = xread(fd, (char *)buffer + sizeof(uint32_t), toread);
253         if (ret < 0) {
254                 L9P_LOG(L9P_ERROR, "read(): %s", strerror(errno));
255                 return (-1);
256         }
257
258         if (ret != (ssize_t)toread) {
259                 L9P_LOG(L9P_ERROR, "short read: %zd bytes of %zd expected",
260                     ret, toread);
261                 return (-1);
262         }
263
264         *size = msize;
265         *buf = buffer;
266         L9P_LOG(L9P_INFO, "%p: read complete message, buf=%p size=%d",
267             (void *)sc->ls_conn, buffer, msize);
268
269         return (0);
270 }
271
272 static int
273 l9p_socket_get_response_buffer(struct l9p_request *req, struct iovec *iov,
274     size_t *niovp, void *arg __unused)
275 {
276         size_t size = req->lr_conn->lc_msize;
277         void *buf;
278
279         buf = l9p_malloc(size);
280         iov[0].iov_base = buf;
281         iov[0].iov_len = size;
282
283         *niovp = 1;
284         return (0);
285 }
286
287 static int
288 l9p_socket_send_response(struct l9p_request *req __unused,
289     const struct iovec *iov, const size_t niov __unused, const size_t iolen,
290     void *arg)
291 {
292         struct l9p_socket_softc *sc = (struct l9p_socket_softc *)arg;
293
294         assert(sc->ls_fd >= 0);
295
296         L9P_LOG(L9P_DEBUG, "%p: sending reply, buf=%p, size=%d", arg,
297             iov[0].iov_base, iolen);
298
299         if (xwrite(sc->ls_fd, iov[0].iov_base, iolen) != (int)iolen) {
300                 L9P_LOG(L9P_ERROR, "short write: %s", strerror(errno));
301                 return (-1);
302         }
303
304         free(iov[0].iov_base);
305         return (0);
306 }
307
308 static void
309 l9p_socket_drop_response(struct l9p_request *req __unused,
310     const struct iovec *iov, size_t niov __unused, void *arg __unused)
311 {
312
313         L9P_LOG(L9P_DEBUG, "%p: drop buf=%p", arg, iov[0].iov_base);
314         free(iov[0].iov_base);
315 }
316
317 static ssize_t
318 xread(int fd, void *buf, size_t count)
319 {
320         size_t done = 0;
321         ssize_t ret;
322
323         while (done < count) {
324                 ret = read(fd, (char *)buf + done, count - done);
325                 if (ret < 0) {
326                         if (errno == EINTR)
327                                 continue;
328
329                         return (-1);
330                 }
331
332                 if (ret == 0)
333                         return ((ssize_t)done);
334
335                 done += (size_t)ret;
336         }
337
338         return ((ssize_t)done);
339 }
340
341 static ssize_t
342 xwrite(int fd, void *buf, size_t count)
343 {
344         size_t done = 0;
345         ssize_t ret;
346
347         while (done < count) {
348                 ret = write(fd, (char *)buf + done, count - done);
349                 if (ret < 0) {
350                         if (errno == EINTR)
351                                 continue;
352
353                         return (-1);
354                 }
355
356                 if (ret == 0)
357                         return ((ssize_t)done);
358
359                 done += (size_t)ret;
360         }
361
362         return ((ssize_t)done);
363 }