]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - sys/dev/hyperv/hvsock/hv_sock.c
hvsock: Fix a typo in a source code comment
[FreeBSD/FreeBSD.git] / sys / dev / hyperv / hvsock / hv_sock.c
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause
3  *
4  * Copyright (c) 2020 Microsoft Corp.
5  * All rights reserved.
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  * 1. Redistributions of source code must retain the above copyright
11  *    notice unmodified, this list of conditions, and the following
12  *    disclaimer.
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in the
15  *    documentation and/or other materials provided with the distribution.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
18  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
19  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
20  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
21  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
22  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
23  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
24  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
26  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27  */
28
29 #include <sys/param.h>
30 #include <sys/bus.h>
31 #include <sys/domain.h>
32 #include <sys/lock.h>
33 #include <sys/kernel.h>
34 #include <sys/types.h>
35 #include <sys/malloc.h>
36 #include <sys/module.h>
37 #include <sys/mutex.h>
38 #include <sys/proc.h>
39 #include <sys/protosw.h>
40 #include <sys/socket.h>
41 #include <sys/sysctl.h>
42 #include <sys/sysproto.h>
43 #include <sys/systm.h>
44 #include <sys/sockbuf.h>
45 #include <sys/sx.h>
46 #include <sys/uio.h>
47
48 #include <net/vnet.h>
49
50 #include <dev/hyperv/vmbus/vmbus_reg.h>
51
52 #include "hv_sock.h"
53
54 #define HVSOCK_DBG_NONE                 0x0
55 #define HVSOCK_DBG_INFO                 0x1
56 #define HVSOCK_DBG_ERR                  0x2
57 #define HVSOCK_DBG_VERBOSE              0x3
58
59
60 SYSCTL_NODE(_net, OID_AUTO, hvsock, CTLFLAG_RD, 0, "HyperV socket");
61
62 static int hvs_dbg_level;
63 SYSCTL_INT(_net_hvsock, OID_AUTO, hvs_dbg_level, CTLFLAG_RWTUN, &hvs_dbg_level,
64     0, "hyperv socket debug level: 0 = none, 1 = info, 2 = error, 3 = verbose");
65
66
67 #define HVSOCK_DBG(level, ...) do {                                     \
68         if (hvs_dbg_level >= (level))                                   \
69                 printf(__VA_ARGS__);                                    \
70         } while (0)
71
72 MALLOC_DEFINE(M_HVSOCK, "hyperv_socket", "hyperv socket control structures");
73
74 static int hvs_dom_probe(void);
75
76 /* The MTU is 16KB per host side's design */
77 #define HVSOCK_MTU_SIZE         (1024 * 16)
78 #define HVSOCK_SEND_BUF_SZ      (PAGE_SIZE - sizeof(struct vmpipe_proto_header))
79
80 #define HVSOCK_HEADER_LEN       (sizeof(struct hvs_pkt_header))
81
82 #define HVSOCK_PKT_LEN(payload_len)     (HVSOCK_HEADER_LEN + \
83                                          roundup2(payload_len, 8) + \
84                                          sizeof(uint64_t))
85
86 /*
87  * HyperV Transport sockets
88  */
89 static struct protosw hv_socket_protosw = {
90         .pr_type =              SOCK_STREAM,
91         .pr_protocol =          HYPERV_SOCK_PROTO_TRANS,
92         .pr_flags =             PR_CONNREQUIRED,
93         .pr_attach =            hvs_trans_attach,
94         .pr_bind =              hvs_trans_bind,
95         .pr_listen =            hvs_trans_listen,
96         .pr_accept =            hvs_trans_accept,
97         .pr_connect =           hvs_trans_connect,
98         .pr_peeraddr =          hvs_trans_peeraddr,
99         .pr_sockaddr =          hvs_trans_sockaddr,
100         .pr_soreceive =         hvs_trans_soreceive,
101         .pr_sosend =            hvs_trans_sosend,
102         .pr_disconnect =        hvs_trans_disconnect,
103         .pr_close =             hvs_trans_close,
104         .pr_detach =            hvs_trans_detach,
105         .pr_shutdown =          hvs_trans_shutdown,
106         .pr_abort =             hvs_trans_abort,
107 };
108
109 static struct domain            hv_socket_domain = {
110         .dom_family =           AF_HYPERV,
111         .dom_name =             "hyperv",
112         .dom_probe =            hvs_dom_probe,
113         .dom_nprotosw =         1,
114         .dom_protosw =          { &hv_socket_protosw },
115 };
116
117 DOMAIN_SET(hv_socket_);
118
119 #define MAX_PORT                        ((uint32_t)0xFFFFFFFF)
120 #define MIN_PORT                        ((uint32_t)0x0)
121
122 /* 00000000-facb-11e6-bd58-64006a7986d3 */
123 static const struct hyperv_guid srv_id_template = {
124         .hv_guid = {
125             0x00, 0x00, 0x00, 0x00, 0xcb, 0xfa, 0xe6, 0x11,
126             0xbd, 0x58, 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3 }
127 };
128
129 static int              hvsock_br_callback(void *, int, void *);
130 static uint32_t         hvsock_canread_check(struct hvs_pcb *);
131 static uint32_t         hvsock_canwrite_check(struct hvs_pcb *);
132 static int              hvsock_send_data(struct vmbus_channel *chan,
133     struct uio *uio, uint32_t to_write, struct sockbuf *sb);
134
135
136
137 /* Globals */
138 static struct sx                hvs_trans_socks_sx;
139 static struct mtx               hvs_trans_socks_mtx;
140 static LIST_HEAD(, hvs_pcb)     hvs_trans_bound_socks;
141 static LIST_HEAD(, hvs_pcb)     hvs_trans_connected_socks;
142 static uint32_t                 previous_auto_bound_port;
143
144 static void
145 hvsock_print_guid(struct hyperv_guid *guid)
146 {
147         unsigned char *p = (unsigned char *)guid;
148
149         HVSOCK_DBG(HVSOCK_DBG_INFO,
150             "0x%x-0x%x-0x%x-0x%x-0x%x-0x%x-0x%x-0x%x-0x%x-0x%x-0x%x\n",
151             *(unsigned int *)p,
152             *((unsigned short *) &p[4]),
153             *((unsigned short *) &p[6]),
154             p[8], p[9], p[10], p[11], p[12], p[13], p[14], p[15]);
155 }
156
157 static bool
158 is_valid_srv_id(const struct hyperv_guid *id)
159 {
160         return !memcmp(&id->hv_guid[4],
161             &srv_id_template.hv_guid[4], sizeof(struct hyperv_guid) - 4);
162 }
163
164 static unsigned int
165 get_port_by_srv_id(const struct hyperv_guid *srv_id)
166 {
167         return *((const unsigned int *)srv_id);
168 }
169
170 static void
171 set_port_by_srv_id(struct hyperv_guid *srv_id, unsigned int port)
172 {
173         *((unsigned int *)srv_id) = port;
174 }
175
176
177 static void
178 __hvs_remove_pcb_from_list(struct hvs_pcb *pcb, unsigned char list)
179 {
180         struct hvs_pcb *p = NULL;
181
182         HVSOCK_DBG(HVSOCK_DBG_VERBOSE, "%s: pcb is %p\n", __func__, pcb);
183
184         if (!pcb)
185                 return;
186
187         if (list & HVS_LIST_BOUND) {
188                 LIST_FOREACH(p, &hvs_trans_bound_socks, bound_next)
189                         if  (p == pcb)
190                                 LIST_REMOVE(p, bound_next);
191         }
192
193         if (list & HVS_LIST_CONNECTED) {
194                 LIST_FOREACH(p, &hvs_trans_connected_socks, connected_next)
195                         if (p == pcb)
196                                 LIST_REMOVE(pcb, connected_next);
197         }
198 }
199
200 static void
201 __hvs_remove_socket_from_list(struct socket *so, unsigned char list)
202 {
203         struct hvs_pcb *pcb = so2hvspcb(so);
204
205         HVSOCK_DBG(HVSOCK_DBG_VERBOSE, "%s: pcb is %p\n", __func__, pcb);
206
207         __hvs_remove_pcb_from_list(pcb, list);
208 }
209
210 static void
211 __hvs_insert_socket_on_list(struct socket *so, unsigned char list)
212 {
213         struct hvs_pcb *pcb = so2hvspcb(so);
214
215         if (list & HVS_LIST_BOUND)
216                 LIST_INSERT_HEAD(&hvs_trans_bound_socks,
217                    pcb, bound_next);
218
219         if (list & HVS_LIST_CONNECTED)
220                 LIST_INSERT_HEAD(&hvs_trans_connected_socks,
221                    pcb, connected_next);
222 }
223
224 void
225 hvs_remove_socket_from_list(struct socket *so, unsigned char list)
226 {
227         if (!so || !so->so_pcb) {
228                 HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
229                     "%s: socket or so_pcb is null\n", __func__);
230                 return;
231         }
232
233         mtx_lock(&hvs_trans_socks_mtx);
234         __hvs_remove_socket_from_list(so, list);
235         mtx_unlock(&hvs_trans_socks_mtx);
236 }
237
238 static void
239 hvs_insert_socket_on_list(struct socket *so, unsigned char list)
240 {
241         if (!so || !so->so_pcb) {
242                 HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
243                     "%s: socket or so_pcb is null\n", __func__);
244                 return;
245         }
246
247         mtx_lock(&hvs_trans_socks_mtx);
248         __hvs_insert_socket_on_list(so, list);
249         mtx_unlock(&hvs_trans_socks_mtx);
250 }
251
252 static struct socket *
253 __hvs_find_socket_on_list(struct sockaddr_hvs *addr, unsigned char list)
254 {
255         struct hvs_pcb *p = NULL;
256
257         if (list & HVS_LIST_BOUND)
258                 LIST_FOREACH(p, &hvs_trans_bound_socks, bound_next)
259                         if (p->so != NULL &&
260                             addr->hvs_port == p->local_addr.hvs_port)
261                                 return p->so;
262
263         if (list & HVS_LIST_CONNECTED)
264                 LIST_FOREACH(p, &hvs_trans_connected_socks, connected_next)
265                         if (p->so != NULL &&
266                             addr->hvs_port == p->local_addr.hvs_port)
267                                 return p->so;
268
269         return NULL;
270 }
271
272 static struct socket *
273 hvs_find_socket_on_list(struct sockaddr_hvs *addr, unsigned char list)
274 {
275         struct socket *s = NULL;
276
277         mtx_lock(&hvs_trans_socks_mtx);
278         s = __hvs_find_socket_on_list(addr, list);
279         mtx_unlock(&hvs_trans_socks_mtx);
280
281         return s;
282 }
283
284 static inline void
285 hvs_addr_set(struct sockaddr_hvs *addr, unsigned int port)
286 {
287         memset(addr, 0, sizeof(*addr));
288         addr->sa_family = AF_HYPERV;
289         addr->sa_len = sizeof(*addr);
290         addr->hvs_port = port;
291 }
292
293 void
294 hvs_addr_init(struct sockaddr_hvs *addr, const struct hyperv_guid *svr_id)
295 {
296         hvs_addr_set(addr, get_port_by_srv_id(svr_id));
297 }
298
299 int
300 hvs_trans_lock(void)
301 {
302         sx_xlock(&hvs_trans_socks_sx);
303         return (0);
304 }
305
306 void
307 hvs_trans_unlock(void)
308 {
309         sx_xunlock(&hvs_trans_socks_sx);
310 }
311
312 static int
313 hvs_dom_probe(void)
314 {
315
316         /* Don't even give us a chance to attach on non-HyperV. */
317         if (vm_guest != VM_GUEST_HV)
318                 return (ENXIO);
319         return (0);
320 }
321
322 static void
323 hvs_trans_init(void *arg __unused)
324 {
325
326         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
327             "%s: HyperV Socket hvs_trans_init called\n", __func__);
328
329         /* Initialize Globals */
330         previous_auto_bound_port = MAX_PORT;
331         sx_init(&hvs_trans_socks_sx, "hvs_trans_sock_sx");
332         mtx_init(&hvs_trans_socks_mtx,
333             "hvs_trans_socks_mtx", NULL, MTX_DEF);
334         LIST_INIT(&hvs_trans_bound_socks);
335         LIST_INIT(&hvs_trans_connected_socks);
336 }
337 SYSINIT(hvs_trans_init, SI_SUB_PROTO_DOMAIN, SI_ORDER_THIRD,
338     hvs_trans_init, NULL);
339
340 /*
341  * Called in two cases:
342  * 1) When user calls socket();
343  * 2) When we accept new incoming conneciton and call sonewconn().
344  */
345 int
346 hvs_trans_attach(struct socket *so, int proto, struct thread *td)
347 {
348         struct hvs_pcb *pcb = so2hvspcb(so);
349
350         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
351             "%s: HyperV Socket hvs_trans_attach called\n", __func__);
352
353         if (so->so_type != SOCK_STREAM)
354                 return (ESOCKTNOSUPPORT);
355
356         if (proto != 0 && proto != HYPERV_SOCK_PROTO_TRANS)
357                 return (EPROTONOSUPPORT);
358
359         if (pcb != NULL)
360                 return (EISCONN);
361         pcb = malloc(sizeof(struct hvs_pcb), M_HVSOCK, M_NOWAIT | M_ZERO);
362         if (pcb == NULL)
363                 return (ENOMEM);
364
365         pcb->so = so;
366         so->so_pcb = (void *)pcb;
367
368         return (0);
369 }
370
371 void
372 hvs_trans_detach(struct socket *so)
373 {
374         struct hvs_pcb *pcb;
375
376         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
377             "%s: HyperV Socket hvs_trans_detach called\n", __func__);
378
379         (void) hvs_trans_lock();
380         pcb = so2hvspcb(so);
381         if (pcb == NULL) {
382                 hvs_trans_unlock();
383                 return;
384         }
385
386         if (SOLISTENING(so)) {
387                 bzero(pcb, sizeof(*pcb));
388                 free(pcb, M_HVSOCK);
389         }
390
391         so->so_pcb = NULL;
392
393         hvs_trans_unlock();
394 }
395
396 int
397 hvs_trans_bind(struct socket *so, struct sockaddr *addr, struct thread *td)
398 {
399         struct hvs_pcb *pcb = so2hvspcb(so);
400         struct sockaddr_hvs *sa = (struct sockaddr_hvs *) addr;
401         int error = 0;
402
403         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
404             "%s: HyperV Socket hvs_trans_bind called\n", __func__);
405
406         if (sa == NULL) {
407                 return (EINVAL);
408         }
409
410         if (pcb == NULL) {
411                 return (EINVAL);
412         }
413
414         if (sa->sa_family != AF_HYPERV) {
415                 HVSOCK_DBG(HVSOCK_DBG_ERR,
416                     "%s: Not supported, sa_family is %u\n",
417                     __func__, sa->sa_family);
418                 return (EAFNOSUPPORT);
419         }
420         if (sa->sa_len != sizeof(*sa)) {
421                 HVSOCK_DBG(HVSOCK_DBG_ERR,
422                     "%s: Not supported, sa_len is %u\n",
423                     __func__, sa->sa_len);
424                 return (EINVAL);
425         }
426
427         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
428             "%s: binding port = 0x%x\n", __func__, sa->hvs_port);
429
430         mtx_lock(&hvs_trans_socks_mtx);
431         if (__hvs_find_socket_on_list(sa,
432             HVS_LIST_BOUND | HVS_LIST_CONNECTED)) {
433                 error = EADDRINUSE;
434         } else {
435                 /*
436                  * The address is available for us to bind.
437                  * Add socket to the bound list.
438                  */
439                 hvs_addr_set(&pcb->local_addr, sa->hvs_port);
440                 hvs_addr_set(&pcb->remote_addr, HVADDR_PORT_ANY);
441                 __hvs_insert_socket_on_list(so, HVS_LIST_BOUND);
442         }
443         mtx_unlock(&hvs_trans_socks_mtx);
444
445         return (error);
446 }
447
448 int
449 hvs_trans_listen(struct socket *so, int backlog, struct thread *td)
450 {
451         struct hvs_pcb *pcb = so2hvspcb(so);
452         struct socket *bound_so;
453         int error;
454
455         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
456             "%s: HyperV Socket hvs_trans_listen called\n", __func__);
457
458         if (pcb == NULL)
459                 return (EINVAL);
460
461         /* Check if the address is already bound and it was by us. */
462         bound_so = hvs_find_socket_on_list(&pcb->local_addr, HVS_LIST_BOUND);
463         if (bound_so == NULL || bound_so != so) {
464                 HVSOCK_DBG(HVSOCK_DBG_ERR,
465                     "%s: Address not bound or not by us.\n", __func__);
466                 return (EADDRNOTAVAIL);
467         }
468
469         SOCK_LOCK(so);
470         error = solisten_proto_check(so);
471         if (error == 0)
472                 solisten_proto(so, backlog);
473         SOCK_UNLOCK(so);
474
475         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
476             "%s: HyperV Socket listen error = %d\n", __func__, error);
477         return (error);
478 }
479
480 int
481 hvs_trans_accept(struct socket *so, struct sockaddr *sa)
482 {
483         struct hvs_pcb *pcb = so2hvspcb(so);
484
485         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
486             "%s: HyperV Socket hvs_trans_accept called\n", __func__);
487
488         if (pcb == NULL)
489                 return (EINVAL);
490
491         memcpy(sa, &pcb->remote_addr, pcb->remote_addr.sa_len);
492
493         return (0);
494 }
495
496 int
497 hvs_trans_connect(struct socket *so, struct sockaddr *nam, struct thread *td)
498 {
499         struct hvs_pcb *pcb = so2hvspcb(so);
500         struct sockaddr_hvs *raddr = (struct sockaddr_hvs *)nam;
501         bool found_auto_bound_port = false;
502         int i, error = 0;
503
504         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
505             "%s: HyperV Socket hvs_trans_connect called, remote port is %x\n",
506             __func__, raddr->hvs_port);
507
508         if (pcb == NULL)
509                 return (EINVAL);
510
511         /* Verify the remote address */
512         if (raddr == NULL)
513                 return (EINVAL);
514         if (raddr->sa_family != AF_HYPERV)
515                 return (EAFNOSUPPORT);
516         if (raddr->sa_len != sizeof(*raddr))
517                 return (EINVAL);
518
519         mtx_lock(&hvs_trans_socks_mtx);
520         if (so->so_state &
521             (SS_ISCONNECTED|SS_ISDISCONNECTING|SS_ISCONNECTING)) {
522                         HVSOCK_DBG(HVSOCK_DBG_ERR,
523                             "%s: socket connect in progress\n",
524                             __func__);
525                         error = EINPROGRESS;
526                         goto out;
527         }
528
529         /*
530          * Find an available port for us to auto bind the local
531          * address.
532          */
533         hvs_addr_set(&pcb->local_addr, 0);
534
535         for (i = previous_auto_bound_port - 1;
536             i != previous_auto_bound_port; i --) {
537                 if (i == MIN_PORT)
538                         i = MAX_PORT;
539
540                 pcb->local_addr.hvs_port = i;
541
542                 if (__hvs_find_socket_on_list(&pcb->local_addr,
543                     HVS_LIST_BOUND | HVS_LIST_CONNECTED) == NULL) {
544                         found_auto_bound_port = true;
545                         previous_auto_bound_port = i;
546                         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
547                             "%s: found local bound port is %x\n",
548                             __func__, pcb->local_addr.hvs_port);
549                         break;
550                 }
551         }
552
553         if (found_auto_bound_port == true) {
554                 /* Found available port for auto bound, put on list */
555                 __hvs_insert_socket_on_list(so, HVS_LIST_BOUND);
556                 /* Set VM service ID */
557                 pcb->vm_srv_id = srv_id_template;
558                 set_port_by_srv_id(&pcb->vm_srv_id, pcb->local_addr.hvs_port);
559                 /* Set host service ID and remote port */
560                 pcb->host_srv_id = srv_id_template;
561                 set_port_by_srv_id(&pcb->host_srv_id, raddr->hvs_port);
562                 hvs_addr_set(&pcb->remote_addr, raddr->hvs_port);
563
564                 /* Change the socket state to SS_ISCONNECTING */
565                 soisconnecting(so);
566         } else {
567                 HVSOCK_DBG(HVSOCK_DBG_ERR,
568                     "%s: No local port available for auto bound\n",
569                     __func__);
570                 error = EADDRINUSE;
571         }
572
573         HVSOCK_DBG(HVSOCK_DBG_INFO, "Connect vm_srv_id is ");
574         hvsock_print_guid(&pcb->vm_srv_id);
575         HVSOCK_DBG(HVSOCK_DBG_INFO, "Connect host_srv_id is ");
576         hvsock_print_guid(&pcb->host_srv_id);
577
578 out:
579         mtx_unlock(&hvs_trans_socks_mtx);
580
581         if (found_auto_bound_port == true)
582                  vmbus_req_tl_connect(&pcb->vm_srv_id, &pcb->host_srv_id);
583
584         return (error);
585 }
586
587 int
588 hvs_trans_disconnect(struct socket *so)
589 {
590         struct hvs_pcb *pcb;
591
592         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
593             "%s: HyperV Socket hvs_trans_disconnect called\n", __func__);
594
595         (void) hvs_trans_lock();
596         pcb = so2hvspcb(so);
597         if (pcb == NULL) {
598                 hvs_trans_unlock();
599                 return (EINVAL);
600         }
601
602         /* If socket is already disconnected, skip this */
603         if ((so->so_state & SS_ISDISCONNECTED) == 0)
604                 soisdisconnecting(so);
605
606         hvs_trans_unlock();
607
608         return (0);
609 }
610
611 struct hvs_callback_arg {
612         struct uio *uio;
613         struct sockbuf *sb;
614 };
615
616 int
617 hvs_trans_soreceive(struct socket *so, struct sockaddr **paddr,
618     struct uio *uio, struct mbuf **mp0, struct mbuf **controlp, int *flagsp)
619 {
620         struct hvs_pcb *pcb = so2hvspcb(so);
621         struct sockbuf *sb;
622         ssize_t orig_resid;
623         uint32_t canread, to_read;
624         int flags, error = 0;
625         struct hvs_callback_arg cbarg;
626
627         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
628             "%s: HyperV Socket hvs_trans_soreceive called\n", __func__);
629
630         if (so->so_type != SOCK_STREAM)
631                 return (EINVAL);
632         if (pcb == NULL)
633                 return (EINVAL);
634
635         if (flagsp != NULL)
636                 flags = *flagsp &~ MSG_EOR;
637         else
638                 flags = 0;
639
640         if (flags & MSG_PEEK)
641                 return (EOPNOTSUPP);
642
643         /* If no space to copy out anything */
644         if (uio->uio_resid == 0 || uio->uio_rw != UIO_READ)
645                 return (EINVAL);
646
647         orig_resid = uio->uio_resid;
648
649         /* Prevent other readers from entering the socket. */
650         error = SOCK_IO_RECV_LOCK(so, SBLOCKWAIT(flags));
651         if (error) {
652                 HVSOCK_DBG(HVSOCK_DBG_ERR,
653                     "%s: soiolock returned error = %d\n", __func__, error);
654                 return (error);
655         }
656
657         sb = &so->so_rcv;
658         SOCKBUF_LOCK(sb);
659
660         cbarg.uio = uio;
661         cbarg.sb = sb;
662         /*
663          * If the socket is closing, there might still be some data
664          * in rx br to read. However we need to make sure
665          * the channel is still open.
666          */
667         if ((sb->sb_state & SBS_CANTRCVMORE) &&
668             (so->so_state & SS_ISDISCONNECTED)) {
669                 /* Other thread already closed the channel */
670                 error = EPIPE;
671                 goto out;
672         }
673
674         while (true) {
675                 while (uio->uio_resid > 0 &&
676                     (canread = hvsock_canread_check(pcb)) > 0) {
677                         to_read = MIN(canread, uio->uio_resid);
678                         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
679                             "%s: to_read = %u, skip = %u\n", __func__, to_read,
680                             (unsigned int)(sizeof(struct hvs_pkt_header) +
681                             pcb->recv_data_off));
682
683                         error = vmbus_chan_recv_peek_call(pcb->chan, to_read,
684                             sizeof(struct hvs_pkt_header) + pcb->recv_data_off,
685                             hvsock_br_callback, (void *)&cbarg);
686                         /*
687                          * It is possible socket is disconnected becasue
688                          * we released lock in hvsock_br_callback. So we
689                          * need to check the state to make sure it is not
690                          * disconnected.
691                          */
692                         if (error || so->so_state & SS_ISDISCONNECTED) {
693                                 break;
694                         }
695
696                         pcb->recv_data_len -= to_read;
697                         pcb->recv_data_off += to_read;
698                 }
699
700                 if (error)
701                         break;
702
703                 /* Abort if socket has reported problems. */
704                 if (so->so_error) {
705                         if (so->so_error == ESHUTDOWN &&
706                             orig_resid > uio->uio_resid) {
707                                 /*
708                                  * Although we got a FIN, we also received
709                                  * some data in this round. Delivery it
710                                  * to user.
711                                  */
712                                 error = 0;
713                         } else {
714                                 if (so->so_error != ESHUTDOWN)
715                                         error = so->so_error;
716                         }
717
718                         break;
719                 }
720
721                 /* Cannot received more. */
722                 if (sb->sb_state & SBS_CANTRCVMORE)
723                         break;
724
725                 /* We are done if buffer has been filled */
726                 if (uio->uio_resid == 0)
727                         break;
728
729                 if (!(flags & MSG_WAITALL) && orig_resid > uio->uio_resid)
730                         break;
731
732                 /* Buffer ring is empty and we shall not block */
733                 if ((so->so_state & SS_NBIO) ||
734                     (flags & (MSG_DONTWAIT|MSG_NBIO))) {
735                         if (orig_resid == uio->uio_resid) {
736                                 /* We have not read anything */
737                                 error = EAGAIN;
738                         }
739                         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
740                             "%s: non blocked read return, error %d.\n",
741                             __func__, error);
742                         break;
743                 }
744
745                 /*
746                  * Wait and block until (more) data comes in.
747                  * Note: Drops the sockbuf lock during wait.
748                  */
749                 error = sbwait(so, SO_RCV);
750
751                 if (error)
752                         break;
753
754                 HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
755                     "%s: wake up from sbwait, read available is %u\n",
756                     __func__, vmbus_chan_read_available(pcb->chan));
757         }
758
759 out:
760         SOCKBUF_UNLOCK(sb);
761         SOCK_IO_RECV_UNLOCK(so);
762
763         /* We received a FIN in this call */
764         if (so->so_error == ESHUTDOWN) {
765                 if (so->so_snd.sb_state & SBS_CANTSENDMORE) {
766                         /* Send has already closed */
767                         soisdisconnecting(so);
768                 } else {
769                         /* Just close the receive side */
770                         socantrcvmore(so);
771                 }
772         }
773
774         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
775             "%s: returning error = %d, so_error = %d\n",
776             __func__, error, so->so_error);
777
778         return (error);
779 }
780
781 int
782 hvs_trans_sosend(struct socket *so, struct sockaddr *addr, struct uio *uio,
783     struct mbuf *top, struct mbuf *controlp, int flags, struct thread *td)
784 {
785         struct hvs_pcb *pcb = so2hvspcb(so);
786         struct sockbuf *sb;
787         ssize_t orig_resid;
788         uint32_t canwrite, to_write;
789         int error = 0;
790
791         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
792             "%s: HyperV Socket hvs_trans_sosend called, uio_resid = %zd\n",
793             __func__, uio->uio_resid);
794
795         if (so->so_type != SOCK_STREAM)
796                 return (EINVAL);
797         if (pcb == NULL)
798                 return (EINVAL);
799
800         /* If nothing to send */
801         if (uio->uio_resid == 0 || uio->uio_rw != UIO_WRITE)
802                 return (EINVAL);
803
804         orig_resid = uio->uio_resid;
805
806         /* Prevent other writers from entering the socket. */
807         error = SOCK_IO_SEND_LOCK(so, SBLOCKWAIT(flags));
808         if (error) {
809                 HVSOCK_DBG(HVSOCK_DBG_ERR,
810                     "%s: soiolocak returned error = %d\n", __func__, error);
811                 return (error);
812         }
813
814         sb = &so->so_snd;
815         SOCKBUF_LOCK(sb);
816
817         if ((sb->sb_state & SBS_CANTSENDMORE) ||
818             so->so_error == ESHUTDOWN) {
819                 error = EPIPE;
820                 goto out;
821         }
822
823         while (uio->uio_resid > 0) {
824                 canwrite = hvsock_canwrite_check(pcb);
825                 if (canwrite == 0) {
826                         /* We have sent some data */
827                         if (orig_resid > uio->uio_resid)
828                                 break;
829                         /*
830                          * We have not sent any data and it is
831                          * non-blocked io
832                          */
833                         if (so->so_state & SS_NBIO ||
834                             (flags & (MSG_NBIO | MSG_DONTWAIT)) != 0) {
835                                 error = EWOULDBLOCK;
836                                 break;
837                         } else {
838                                 /*
839                                  * We are here because there is no space on
840                                  * send buffer ring. Signal the other side
841                                  * to read and free more space.
842                                  * Sleep wait until space avaiable to send
843                                  * Note: Drops the sockbuf lock during wait.
844                                  */
845                                 error = sbwait(so, SO_SND);
846
847                                 if (error)
848                                         break;
849
850                                 HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
851                                     "%s: wake up from sbwait, space avail on "
852                                     "tx ring is %u\n",
853                                     __func__,
854                                     vmbus_chan_write_available(pcb->chan));
855
856                                 continue;
857                         }
858                 }
859                 to_write = MIN(canwrite, uio->uio_resid);
860                 to_write = MIN(to_write, HVSOCK_SEND_BUF_SZ);
861
862                 HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
863                     "%s: canwrite is %u, to_write = %u\n", __func__,
864                     canwrite, to_write);
865                 error = hvsock_send_data(pcb->chan, uio, to_write, sb);
866
867                 if (error)
868                         break;
869         }
870
871 out:
872         SOCKBUF_UNLOCK(sb);
873         SOCK_IO_SEND_UNLOCK(so);
874
875         return (error);
876 }
877
878 int
879 hvs_trans_peeraddr(struct socket *so, struct sockaddr *sa)
880 {
881         struct hvs_pcb *pcb = so2hvspcb(so);
882
883         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
884             "%s: HyperV Socket hvs_trans_peeraddr called\n", __func__);
885
886         if (pcb == NULL)
887                 return (EINVAL);
888
889         memcpy(sa, &pcb->remote_addr, pcb->remote_addr.sa_len);
890
891         return (0);
892 }
893
894 int
895 hvs_trans_sockaddr(struct socket *so, struct sockaddr *sa)
896 {
897         struct hvs_pcb *pcb = so2hvspcb(so);
898
899         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
900             "%s: HyperV Socket hvs_trans_sockaddr called\n", __func__);
901
902         if (pcb == NULL)
903                 return (EINVAL);
904
905         memcpy(sa, &pcb->local_addr, pcb->local_addr.sa_len);
906
907         return (0);
908 }
909
910 void
911 hvs_trans_close(struct socket *so)
912 {
913         struct hvs_pcb *pcb;
914
915         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
916             "%s: HyperV Socket hvs_trans_close called\n", __func__);
917
918         (void) hvs_trans_lock();
919         pcb = so2hvspcb(so);
920         if (!pcb) {
921                 hvs_trans_unlock();
922                 return;
923         }
924
925         if (so->so_state & SS_ISCONNECTED) {
926                 /* Send a FIN to peer */
927                 HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
928                     "%s: hvs_trans_close sending a FIN to host\n", __func__);
929                 (void) hvsock_send_data(pcb->chan, NULL, 0, NULL);
930         }
931
932         if (so->so_state &
933             (SS_ISCONNECTED|SS_ISCONNECTING|SS_ISDISCONNECTING))
934                 soisdisconnected(so);
935
936         pcb->chan = NULL;
937         pcb->so = NULL;
938
939         if (SOLISTENING(so)) {
940                 mtx_lock(&hvs_trans_socks_mtx);
941                 /* Remove from bound list */
942                 __hvs_remove_socket_from_list(so, HVS_LIST_BOUND);
943                 mtx_unlock(&hvs_trans_socks_mtx);
944         }
945
946         hvs_trans_unlock();
947
948         return;
949 }
950
951 void
952 hvs_trans_abort(struct socket *so)
953 {
954         struct hvs_pcb *pcb = so2hvspcb(so);
955
956         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
957             "%s: HyperV Socket hvs_trans_abort called\n", __func__);
958
959         (void) hvs_trans_lock();
960         if (pcb == NULL) {
961                 hvs_trans_unlock();
962                 return;
963         }
964
965         if (SOLISTENING(so)) {
966                 mtx_lock(&hvs_trans_socks_mtx);
967                 /* Remove from bound list */
968                 __hvs_remove_socket_from_list(so, HVS_LIST_BOUND);
969                 mtx_unlock(&hvs_trans_socks_mtx);
970         }
971
972         if (so->so_state & SS_ISCONNECTED) {
973                 (void) sodisconnect(so);
974         }
975         hvs_trans_unlock();
976
977         return;
978 }
979
980 int
981 hvs_trans_shutdown(struct socket *so)
982 {
983         struct hvs_pcb *pcb = so2hvspcb(so);
984         struct sockbuf *sb;
985
986         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
987             "%s: HyperV Socket hvs_trans_shutdown called\n", __func__);
988
989         if (pcb == NULL)
990                 return (EINVAL);
991
992         /*
993          * Only get called with the shutdown method is SHUT_WR or
994          * SHUT_RDWR.
995          * When the method is SHUT_RD or SHUT_RDWR, the caller
996          * already set the SBS_CANTRCVMORE on receive side socket
997          * buffer.
998          */
999         if ((so->so_rcv.sb_state & SBS_CANTRCVMORE) == 0) {
1000                 /*
1001                  * SHUT_WR only case.
1002                  * Receive side is still open. Just close
1003                  * the send side.
1004                  */
1005                 socantsendmore(so);
1006         } else {
1007                 /* SHUT_RDWR case */
1008                 if (so->so_state & SS_ISCONNECTED) {
1009                         /* Send a FIN to peer */
1010                         sb = &so->so_snd;
1011                         SOCKBUF_LOCK(sb);
1012                         (void) hvsock_send_data(pcb->chan, NULL, 0, sb);
1013                         SOCKBUF_UNLOCK(sb);
1014
1015                         soisdisconnecting(so);
1016                 }
1017         }
1018
1019         return (0);
1020 }
1021
1022 /* In the VM, we support Hyper-V Sockets with AF_HYPERV, and the endpoint is
1023  * <port> (see struct sockaddr_hvs).
1024  *
1025  * On the host, Hyper-V Sockets are supported by Winsock AF_HYPERV:
1026  * https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-
1027  * guide/make-integration-service, and the endpoint is <VmID, ServiceId> with
1028  * the below sockaddr:
1029  *
1030  * struct SOCKADDR_HV
1031  * {
1032  *    ADDRESS_FAMILY Family;
1033  *    USHORT Reserved;
1034  *    GUID VmId;
1035  *    GUID ServiceId;
1036  * };
1037  * Note: VmID is not used by FreeBSD VM and actually it isn't transmitted via
1038  * VMBus, because here it's obvious the host and the VM can easily identify
1039  * each other. Though the VmID is useful on the host, especially in the case
1040  * of Windows container, FreeBSD VM doesn't need it at all.
1041  *
1042  * To be compatible with similar infrastructure in Linux VMs, we have
1043  * to limit the available GUID space of SOCKADDR_HV so that we can create
1044  * a mapping between FreeBSD AF_HYPERV port and SOCKADDR_HV Service GUID.
1045  * The rule of writing Hyper-V Sockets apps on the host and in FreeBSD VM is:
1046  *
1047  ****************************************************************************
1048  * The only valid Service GUIDs, from the perspectives of both the host and *
1049  * FreeBSD VM, that can be connected by the other end, must conform to this *
1050  * format: <port>-facb-11e6-bd58-64006a7986d3.                              *
1051  ****************************************************************************
1052  *
1053  * When we write apps on the host to connect(), the GUID ServiceID is used.
1054  * When we write apps in FreeBSD VM to connect(), we only need to specify the
1055  * port and the driver will form the GUID and use that to request the host.
1056  *
1057  * From the perspective of FreeBSD VM, the remote ephemeral port (i.e. the
1058  * auto-generated remote port for a connect request initiated by the host's
1059  * connect()) is set to HVADDR_PORT_UNKNOWN, which is not realy used on the
1060  * FreeBSD guest.
1061  */
1062
1063 /*
1064  * Older HyperV hosts (vmbus version 'VMBUS_VERSION_WIN10' or before)
1065  * restricts HyperV socket ring buffer size to six 4K pages. Newer
1066  * HyperV hosts doen't have this limit.
1067  */
1068 #define HVS_RINGBUF_RCV_SIZE    (PAGE_SIZE * 6)
1069 #define HVS_RINGBUF_SND_SIZE    (PAGE_SIZE * 6)
1070 #define HVS_RINGBUF_MAX_SIZE    (PAGE_SIZE * 64)
1071
1072 struct hvsock_sc {
1073         device_t                dev;
1074         struct hvs_pcb          *pcb;
1075         struct vmbus_channel    *channel;
1076 };
1077
1078 static bool
1079 hvsock_chan_readable(struct vmbus_channel *chan)
1080 {
1081         uint32_t readable = vmbus_chan_read_available(chan);
1082
1083         return (readable >= HVSOCK_PKT_LEN(0));
1084 }
1085
1086 static void
1087 hvsock_chan_cb(struct vmbus_channel *chan, void *context)
1088 {
1089         struct hvs_pcb *pcb = (struct hvs_pcb *) context;
1090         struct socket *so;
1091         uint32_t canwrite;
1092
1093         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
1094             "%s: host send us a wakeup on rb data, pcb = %p\n",
1095             __func__, pcb);
1096
1097         /*
1098          * Check if the socket is still attached and valid.
1099          * Here we know channel is still open. Need to make
1100          * sure the socket has not been closed or freed.
1101          */
1102         (void) hvs_trans_lock();
1103         so = hsvpcb2so(pcb);
1104
1105         if (pcb->chan != NULL && so != NULL) {
1106                 /*
1107                  * Wake up reader if there are data to read.
1108                  */
1109                 SOCKBUF_LOCK(&(so)->so_rcv);
1110
1111                 HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
1112                     "%s: read available = %u\n", __func__,
1113                     vmbus_chan_read_available(pcb->chan));
1114
1115                 if (hvsock_chan_readable(pcb->chan))
1116                         sorwakeup_locked(so);
1117                 else
1118                         SOCKBUF_UNLOCK(&(so)->so_rcv);
1119
1120                 /*
1121                  * Wake up sender if space becomes available to write.
1122                  */
1123                 SOCKBUF_LOCK(&(so)->so_snd);
1124                 canwrite = hvsock_canwrite_check(pcb);
1125
1126                 HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
1127                     "%s: canwrite = %u\n", __func__, canwrite);
1128
1129                 if (canwrite > 0) {
1130                         sowwakeup_locked(so);
1131                 } else {
1132                         SOCKBUF_UNLOCK(&(so)->so_snd);
1133                 }
1134         }
1135
1136         hvs_trans_unlock();
1137
1138         return;
1139 }
1140
1141 static int
1142 hvsock_br_callback(void *datap, int cplen, void *cbarg)
1143 {
1144         struct hvs_callback_arg *arg = (struct hvs_callback_arg *)cbarg;
1145         struct uio *uio = arg->uio;
1146         struct sockbuf *sb = arg->sb;
1147         int error = 0;
1148
1149         if (cbarg == NULL || datap == NULL)
1150                 return (EINVAL);
1151
1152         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
1153             "%s: called, uio_rw = %s, uio_resid = %zd, cplen = %u, "
1154             "datap = %p\n",
1155             __func__, (uio->uio_rw == UIO_READ) ? "read from br":"write to br",
1156             uio->uio_resid, cplen, datap);
1157
1158         if (sb)
1159                 SOCKBUF_UNLOCK(sb);
1160
1161         error = uiomove(datap, cplen, uio);
1162
1163         if (sb)
1164                 SOCKBUF_LOCK(sb);
1165
1166         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
1167             "%s: after uiomove, uio_resid = %zd, error = %d\n",
1168             __func__, uio->uio_resid, error);
1169
1170         return (error);
1171 }
1172
1173 static int
1174 hvsock_send_data(struct vmbus_channel *chan, struct uio *uio,
1175     uint32_t to_write, struct sockbuf *sb)
1176 {
1177         struct hvs_pkt_header hvs_pkt;
1178         int hvs_pkthlen, hvs_pktlen, pad_pktlen, hlen, error = 0;
1179         uint64_t pad = 0;
1180         struct iovec iov[3];
1181         struct hvs_callback_arg cbarg;
1182
1183         if (chan == NULL)
1184                 return (ENOTCONN);
1185
1186         hlen = sizeof(struct vmbus_chanpkt_hdr);
1187         hvs_pkthlen = sizeof(struct hvs_pkt_header);
1188         hvs_pktlen = hvs_pkthlen + to_write;
1189         pad_pktlen = VMBUS_CHANPKT_TOTLEN(hvs_pktlen);
1190
1191         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
1192             "%s: hlen = %u, hvs_pkthlen = %u, hvs_pktlen = %u, "
1193             "pad_pktlen = %u, data_len = %u\n",
1194             __func__, hlen, hvs_pkthlen, hvs_pktlen, pad_pktlen, to_write);
1195
1196         hvs_pkt.chan_pkt_hdr.cph_type = VMBUS_CHANPKT_TYPE_INBAND;
1197         hvs_pkt.chan_pkt_hdr.cph_flags = 0;
1198         VMBUS_CHANPKT_SETLEN(hvs_pkt.chan_pkt_hdr.cph_hlen, hlen);
1199         VMBUS_CHANPKT_SETLEN(hvs_pkt.chan_pkt_hdr.cph_tlen, pad_pktlen);
1200         hvs_pkt.chan_pkt_hdr.cph_xactid = 0;
1201
1202         hvs_pkt.vmpipe_pkt_hdr.vmpipe_pkt_type = 1;
1203         hvs_pkt.vmpipe_pkt_hdr.vmpipe_data_size = to_write;
1204
1205         cbarg.uio = uio;
1206         cbarg.sb = sb;
1207
1208         if (uio && to_write > 0) {
1209                 iov[0].iov_base = &hvs_pkt;
1210                 iov[0].iov_len = hvs_pkthlen;
1211                 iov[1].iov_base = NULL;
1212                 iov[1].iov_len = to_write;
1213                 iov[2].iov_base = &pad;
1214                 iov[2].iov_len = pad_pktlen - hvs_pktlen;
1215
1216                 error = vmbus_chan_iov_send(chan, iov, 3,
1217                     hvsock_br_callback, &cbarg);
1218         } else {
1219                 if (to_write == 0) {
1220                         iov[0].iov_base = &hvs_pkt;
1221                         iov[0].iov_len = hvs_pkthlen;
1222                         iov[1].iov_base = &pad;
1223                         iov[1].iov_len = pad_pktlen - hvs_pktlen;
1224                         error = vmbus_chan_iov_send(chan, iov, 2, NULL, NULL);
1225                 }
1226         }
1227
1228         if (error) {
1229                 HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
1230                     "%s: error = %d\n", __func__, error);
1231         }
1232
1233         return (error);
1234 }
1235
1236 /*
1237  * Check if we have data on current ring buffer to read
1238  * or not. If not, advance the ring buffer read index to
1239  * next packet. Update the recev_data_len and recev_data_off
1240  * to new value.
1241  * Return the number of bytes can read.
1242  */
1243 static uint32_t
1244 hvsock_canread_check(struct hvs_pcb *pcb)
1245 {
1246         uint32_t advance;
1247         uint32_t tlen, hlen, dlen;
1248         uint32_t bytes_canread = 0;
1249         int error;
1250
1251         if (pcb == NULL || pcb->chan == NULL) {
1252                 pcb->so->so_error = EIO;
1253                 return (0);
1254         }
1255
1256         /* Still have data not read yet on current packet */
1257         if (pcb->recv_data_len > 0)
1258                 return (pcb->recv_data_len);
1259
1260         if (pcb->rb_init)
1261                 advance =
1262                     VMBUS_CHANPKT_GETLEN(pcb->hvs_pkt.chan_pkt_hdr.cph_tlen);
1263         else
1264                 advance = 0;
1265
1266         bytes_canread = vmbus_chan_read_available(pcb->chan);
1267
1268         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
1269             "%s: bytes_canread on br = %u, advance = %u\n",
1270             __func__, bytes_canread, advance);
1271
1272         if (pcb->rb_init && bytes_canread == (advance + sizeof(uint64_t))) {
1273                 /*
1274                  * Nothing to read. Need to advance the rindex before
1275                  * calling sbwait, so host knows to wake us up when data
1276                  * is available to read on rb.
1277                  */
1278                 error = vmbus_chan_recv_idxadv(pcb->chan, advance);
1279                 if (error) {
1280                         HVSOCK_DBG(HVSOCK_DBG_ERR,
1281                             "%s: after calling vmbus_chan_recv_idxadv, "
1282                             "got error = %d\n",  __func__, error);
1283                         return (0);
1284                 } else {
1285                         pcb->rb_init = false;
1286                         pcb->recv_data_len = 0;
1287                         pcb->recv_data_off = 0;
1288                         bytes_canread = vmbus_chan_read_available(pcb->chan);
1289
1290                         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
1291                             "%s: advanced %u bytes, "
1292                             " bytes_canread on br now = %u\n",
1293                             __func__, advance, bytes_canread);
1294
1295                         if (bytes_canread == 0)
1296                                 return (0);
1297                         else
1298                                 advance = 0;
1299                 }
1300         }
1301
1302         if (bytes_canread <
1303             advance + (sizeof(struct hvs_pkt_header) + sizeof(uint64_t)))
1304                 return (0);
1305
1306         error = vmbus_chan_recv_peek(pcb->chan, &pcb->hvs_pkt,
1307             sizeof(struct hvs_pkt_header), advance);
1308
1309         /* Don't have anything to read */
1310         if (error) {
1311                 HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
1312                     "%s: after calling vmbus_chan_recv_peek, got error = %d\n",
1313                     __func__, error);
1314                 return (0);
1315         }
1316
1317         /*
1318          * We just read in a new packet header. Do some sanity checks.
1319          */
1320         tlen = VMBUS_CHANPKT_GETLEN(pcb->hvs_pkt.chan_pkt_hdr.cph_tlen);
1321         hlen = VMBUS_CHANPKT_GETLEN(pcb->hvs_pkt.chan_pkt_hdr.cph_hlen);
1322         dlen = pcb->hvs_pkt.vmpipe_pkt_hdr.vmpipe_data_size;
1323         if (__predict_false(hlen < sizeof(struct vmbus_chanpkt_hdr)) ||
1324             __predict_false(hlen > tlen) ||
1325             __predict_false(tlen < dlen + sizeof(struct hvs_pkt_header))) {
1326                 HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
1327                     "invalid tlen(%u), hlen(%u) or dlen(%u)\n",
1328                     tlen, hlen, dlen);
1329                 pcb->so->so_error = EIO;
1330                 return (0);
1331         }
1332         if (pcb->rb_init == false)
1333                 pcb->rb_init = true;
1334
1335         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
1336             "Got new pkt tlen(%u), hlen(%u) or dlen(%u)\n",
1337             tlen, hlen, dlen);
1338
1339         /* The other side has sent a close FIN */
1340         if (dlen == 0) {
1341                 HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
1342                     "%s: Received FIN from other side\n", __func__);
1343                 /* inform the caller by seting so_error to ESHUTDOWN */
1344                 pcb->so->so_error = ESHUTDOWN;
1345         }
1346
1347         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
1348             "%s: canread on receive ring is %u \n", __func__, dlen);
1349
1350         pcb->recv_data_len = dlen;
1351         pcb->recv_data_off = 0;
1352
1353         return (pcb->recv_data_len);
1354 }
1355
1356 static uint32_t
1357 hvsock_canwrite_check(struct hvs_pcb *pcb)
1358 {
1359         uint32_t writeable;
1360         uint32_t ret;
1361
1362         if (pcb == NULL || pcb->chan == NULL)
1363                 return (0);
1364
1365         writeable = vmbus_chan_write_available(pcb->chan);
1366
1367         /*
1368          * We must always reserve a 0-length-payload packet for the FIN.
1369          */
1370         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
1371             "%s: writeable is %u, should be greater than %ju\n",
1372             __func__, writeable,
1373             (uintmax_t)(HVSOCK_PKT_LEN(1) + HVSOCK_PKT_LEN(0)));
1374
1375         if (writeable < HVSOCK_PKT_LEN(1) + HVSOCK_PKT_LEN(0)) {
1376                 /*
1377                  * The Tx ring seems full.
1378                  */
1379                 return (0);
1380         }
1381
1382         ret = writeable - HVSOCK_PKT_LEN(0) - HVSOCK_PKT_LEN(0);
1383
1384         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
1385             "%s: available size is %u\n", __func__, rounddown2(ret, 8));
1386
1387         return (rounddown2(ret, 8));
1388 }
1389
1390 static void
1391 hvsock_set_chan_pending_send_size(struct vmbus_channel *chan)
1392 {
1393         vmbus_chan_set_pending_send_size(chan,
1394             HVSOCK_PKT_LEN(HVSOCK_SEND_BUF_SZ));
1395 }
1396
1397 static int
1398 hvsock_open_channel(struct vmbus_channel *chan, struct socket *so)
1399 {
1400         unsigned int rcvbuf, sndbuf;
1401         struct hvs_pcb *pcb = so2hvspcb(so);
1402         int ret;
1403
1404         if (vmbus_current_version < VMBUS_VERSION_WIN10_V5) {
1405                 sndbuf = HVS_RINGBUF_SND_SIZE;
1406                 rcvbuf = HVS_RINGBUF_RCV_SIZE;
1407         } else {
1408                 sndbuf = MAX(so->so_snd.sb_hiwat, HVS_RINGBUF_SND_SIZE);
1409                 sndbuf = MIN(sndbuf, HVS_RINGBUF_MAX_SIZE);
1410                 sndbuf = rounddown2(sndbuf, PAGE_SIZE);
1411                 rcvbuf = MAX(so->so_rcv.sb_hiwat, HVS_RINGBUF_RCV_SIZE);
1412                 rcvbuf = MIN(rcvbuf, HVS_RINGBUF_MAX_SIZE);
1413                 rcvbuf = rounddown2(rcvbuf, PAGE_SIZE);
1414         }
1415
1416         /*
1417          * Can only read whatever user provided size of data
1418          * from ring buffer. Turn off batched reading.
1419          */
1420         vmbus_chan_set_readbatch(chan, false);
1421
1422         ret = vmbus_chan_open(chan, sndbuf, rcvbuf, NULL, 0,
1423             hvsock_chan_cb, pcb);
1424
1425         if (ret != 0) {
1426                 HVSOCK_DBG(HVSOCK_DBG_ERR,
1427                     "%s: failed to open hvsock channel, sndbuf = %u, "
1428                     "rcvbuf = %u\n", __func__, sndbuf, rcvbuf);
1429         } else {
1430                 HVSOCK_DBG(HVSOCK_DBG_INFO,
1431                     "%s: hvsock channel opened, sndbuf = %u, i"
1432                     "rcvbuf = %u\n", __func__, sndbuf, rcvbuf);
1433                 /*
1434                  * Se the pending send size so to receive wakeup
1435                  * signals from host when there is enough space on
1436                  * rx buffer ring to write.
1437                  */
1438                 hvsock_set_chan_pending_send_size(chan);
1439         }
1440
1441         return ret;
1442 }
1443
1444 /*
1445  * Guest is listening passively on the socket. Open channel and
1446  * create a new socket for the conneciton.
1447  */
1448 static void
1449 hvsock_open_conn_passive(struct vmbus_channel *chan, struct socket *so,
1450     struct hvsock_sc *sc)
1451 {
1452         struct socket *new_so;
1453         struct hvs_pcb *new_pcb, *pcb;
1454         int error;
1455
1456         /* Do nothing if socket is not listening */
1457         if (!SOLISTENING(so)) {
1458                 HVSOCK_DBG(HVSOCK_DBG_ERR,
1459                     "%s: socket is not a listening one\n", __func__);
1460                 return;
1461         }
1462
1463         /*
1464          * Create a new socket. This will call pru_attach to complete
1465          * the socket initialization and put the new socket onto
1466          * listening socket's sol_incomp list, waiting to be promoted
1467          * to sol_comp list.
1468          * The new socket created has ref count 0. There is no other
1469          * thread that changes the state of this new one at the
1470          * moment, so we don't need to hold its lock while opening
1471          * channel and filling out its pcb information.
1472          */
1473         new_so = sonewconn(so, 0);
1474         if (!new_so)
1475                 HVSOCK_DBG(HVSOCK_DBG_ERR,
1476                     "%s: creating new socket failed\n", __func__);
1477
1478         /*
1479          * Now open the vmbus channel. If it fails, the socket will be
1480          * on the listening socket's sol_incomp queue until it is
1481          * replaced and aborted.
1482          */
1483         error = hvsock_open_channel(chan, new_so);
1484         if (error) {
1485                 new_so->so_error = error;
1486                 return;
1487         }
1488
1489         pcb = so->so_pcb;
1490         new_pcb = new_so->so_pcb;
1491
1492         hvs_addr_set(&(new_pcb->local_addr), pcb->local_addr.hvs_port);
1493         /* Remote port is unknown to guest in this type of conneciton */
1494         hvs_addr_set(&(new_pcb->remote_addr), HVADDR_PORT_UNKNOWN);
1495         new_pcb->chan = chan;
1496         new_pcb->recv_data_len = 0;
1497         new_pcb->recv_data_off = 0;
1498         new_pcb->rb_init = false;
1499
1500         new_pcb->vm_srv_id = *vmbus_chan_guid_type(chan);
1501         new_pcb->host_srv_id = *vmbus_chan_guid_inst(chan);
1502
1503         hvs_insert_socket_on_list(new_so, HVS_LIST_CONNECTED);
1504
1505         sc->pcb = new_pcb;
1506
1507         /*
1508          * Change the socket state to SS_ISCONNECTED. This will promote
1509          * the socket to sol_comp queue and wake up the thread which
1510          * is accepting connection.
1511          */
1512         soisconnected(new_so);
1513 }
1514
1515
1516 /*
1517  * Guest is actively connecting to host.
1518  */
1519 static void
1520 hvsock_open_conn_active(struct vmbus_channel *chan, struct socket *so)
1521 {
1522         struct hvs_pcb *pcb;
1523         int error;
1524
1525         error = hvsock_open_channel(chan, so);
1526         if (error) {
1527                 so->so_error = error;
1528                 return;
1529         }
1530
1531         pcb = so->so_pcb;
1532         pcb->chan = chan;
1533         pcb->recv_data_len = 0;
1534         pcb->recv_data_off = 0;
1535         pcb->rb_init = false;
1536
1537         mtx_lock(&hvs_trans_socks_mtx);
1538         __hvs_remove_socket_from_list(so, HVS_LIST_BOUND);
1539         __hvs_insert_socket_on_list(so, HVS_LIST_CONNECTED);
1540         mtx_unlock(&hvs_trans_socks_mtx);
1541
1542         /*
1543          * Change the socket state to SS_ISCONNECTED. This will wake up
1544          * the thread sleeping in connect call.
1545          */
1546         soisconnected(so);
1547 }
1548
1549 static void
1550 hvsock_open_connection(struct vmbus_channel *chan, struct hvsock_sc *sc)
1551 {
1552         struct hyperv_guid *inst_guid, *type_guid;
1553         bool conn_from_host;
1554         struct sockaddr_hvs addr;
1555         struct socket *so;
1556         struct hvs_pcb *pcb;
1557
1558         type_guid = (struct hyperv_guid *) vmbus_chan_guid_type(chan);
1559         inst_guid = (struct hyperv_guid *) vmbus_chan_guid_inst(chan);
1560         conn_from_host = vmbus_chan_is_hvs_conn_from_host(chan);
1561
1562         HVSOCK_DBG(HVSOCK_DBG_INFO, "type_guid is ");
1563         hvsock_print_guid(type_guid);
1564         HVSOCK_DBG(HVSOCK_DBG_INFO, "inst_guid is ");
1565         hvsock_print_guid(inst_guid);
1566         HVSOCK_DBG(HVSOCK_DBG_INFO, "connection %s host\n",
1567             (conn_from_host == true ) ? "from" : "to");
1568
1569         /*
1570          * The listening port should be in [0, MAX_LISTEN_PORT]
1571          */
1572         if (!is_valid_srv_id(type_guid))
1573                 return;
1574
1575         /*
1576          * There should be a bound socket already created no matter
1577          * it is a passive or active connection.
1578          * For host initiated connection (passive on guest side),
1579          * the  type_guid contains the port which guest is bound and
1580          * listening.
1581          * For the guest initiated connection (active on guest side),
1582          * the inst_guid contains the port that guest has auto bound
1583          * to.
1584          */
1585         hvs_addr_init(&addr, conn_from_host ? type_guid : inst_guid);
1586         so = hvs_find_socket_on_list(&addr, HVS_LIST_BOUND);
1587         if (!so) {
1588                 HVSOCK_DBG(HVSOCK_DBG_ERR,
1589                     "%s: no bound socket found for port %u\n",
1590                     __func__, addr.hvs_port);
1591                 return;
1592         }
1593
1594         if (conn_from_host) {
1595                 hvsock_open_conn_passive(chan, so, sc);
1596         } else {
1597                 (void) hvs_trans_lock();
1598                 pcb = so->so_pcb;
1599                 if (pcb && pcb->so) {
1600                         sc->pcb = so2hvspcb(so);
1601                         hvsock_open_conn_active(chan, so);
1602                 } else {
1603                         HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
1604                             "%s: channel detached before open\n", __func__);
1605                 }
1606                 hvs_trans_unlock();
1607         }
1608
1609 }
1610
1611 static int
1612 hvsock_probe(device_t dev)
1613 {
1614         struct vmbus_channel *channel = vmbus_get_channel(dev);
1615
1616         if (!channel || !vmbus_chan_is_hvs(channel)) {
1617                 HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
1618                     "hvsock_probe called but not a hvsock channel id %u\n",
1619                     vmbus_chan_id(channel));
1620
1621                 return ENXIO;
1622         } else {
1623                 HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
1624                     "hvsock_probe got a hvsock channel id %u\n",
1625                     vmbus_chan_id(channel));
1626
1627                 return BUS_PROBE_DEFAULT;
1628         }
1629 }
1630
1631 static int
1632 hvsock_attach(device_t dev)
1633 {
1634         struct vmbus_channel *channel = vmbus_get_channel(dev);
1635         struct hvsock_sc *sc = (struct hvsock_sc *)device_get_softc(dev);
1636
1637         HVSOCK_DBG(HVSOCK_DBG_VERBOSE, "hvsock_attach called.\n");
1638
1639         hvsock_open_connection(channel, sc);
1640
1641         /*
1642          * Always return success. On error the host will rescind the device
1643          * in 30 seconds and we can do cleanup at that time in
1644          * vmbus_chan_msgproc_chrescind().
1645          */
1646         return (0);
1647 }
1648
1649 static int
1650 hvsock_detach(device_t dev)
1651 {
1652         struct hvsock_sc *sc = (struct hvsock_sc *)device_get_softc(dev);
1653         struct socket *so;
1654         int retry;
1655
1656         if (bootverbose)
1657                 device_printf(dev, "hvsock_detach called.\n");
1658
1659         HVSOCK_DBG(HVSOCK_DBG_VERBOSE, "hvsock_detach called.\n");
1660
1661         if (sc->pcb != NULL) {
1662                 (void) hvs_trans_lock();
1663
1664                 so = hsvpcb2so(sc->pcb);
1665                 if (so) {
1666                         /* Close the connection */
1667                         if (so->so_state &
1668                             (SS_ISCONNECTED|SS_ISCONNECTING|SS_ISDISCONNECTING))
1669                                 soisdisconnected(so);
1670                 }
1671
1672                 mtx_lock(&hvs_trans_socks_mtx);
1673                 __hvs_remove_pcb_from_list(sc->pcb,
1674                     HVS_LIST_BOUND | HVS_LIST_CONNECTED);
1675                 mtx_unlock(&hvs_trans_socks_mtx);
1676
1677                 /*
1678                  * Close channel while no reader and sender are working
1679                  * on the buffer rings.
1680                  */
1681                 if (so) {
1682                         retry = 0;
1683                         while (SOCK_IO_RECV_LOCK(so, 0) == EWOULDBLOCK) {
1684                                 /*
1685                                  * Someone is reading, rx br is busy
1686                                  */
1687                                 soisdisconnected(so);
1688                                 DELAY(500);
1689                                 HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
1690                                     "waiting for rx reader to exit, "
1691                                     "retry = %d\n", retry++);
1692                         }
1693                         retry = 0;
1694                         while (SOCK_IO_SEND_LOCK(so, 0) == EWOULDBLOCK) {
1695                                 /*
1696                                  * Someone is sending, tx br is busy
1697                                  */
1698                                 soisdisconnected(so);
1699                                 DELAY(500);
1700                                 HVSOCK_DBG(HVSOCK_DBG_VERBOSE,
1701                                     "waiting for tx sender to exit, "
1702                                     "retry = %d\n", retry++);
1703                         }
1704                 }
1705
1706
1707                 bzero(sc->pcb, sizeof(struct hvs_pcb));
1708                 free(sc->pcb, M_HVSOCK);
1709                 sc->pcb = NULL;
1710
1711                 if (so) {
1712                         SOCK_IO_RECV_UNLOCK(so);
1713                         SOCK_IO_SEND_UNLOCK(so);
1714                         so->so_pcb = NULL;
1715                 }
1716
1717                 hvs_trans_unlock();
1718         }
1719
1720         vmbus_chan_close(vmbus_get_channel(dev));
1721
1722         return (0);
1723 }
1724
1725 static device_method_t hvsock_methods[] = {
1726         /* Device interface */
1727         DEVMETHOD(device_probe, hvsock_probe),
1728         DEVMETHOD(device_attach, hvsock_attach),
1729         DEVMETHOD(device_detach, hvsock_detach),
1730         DEVMETHOD_END
1731 };
1732
1733 static driver_t hvsock_driver = {
1734         "hv_sock",
1735         hvsock_methods,
1736         sizeof(struct hvsock_sc)
1737 };
1738
1739 DRIVER_MODULE(hvsock, vmbus, hvsock_driver, NULL, NULL);
1740 MODULE_VERSION(hvsock, 1);
1741 MODULE_DEPEND(hvsock, vmbus, 1, 1, 1);