]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - contrib/ofed/librdmacm/rsocket.c
MFV r337195: 9454 ::zfs_blkstats should count embedded blocks
[FreeBSD/FreeBSD.git] / contrib / ofed / librdmacm / rsocket.c
1 /*
2  * Copyright (c) 2008-2014 Intel Corporation.  All rights reserved.
3  *
4  * This software is available to you under a choice of one of two
5  * licenses.  You may choose to be licensed under the terms of the GNU
6  * General Public License (GPL) Version 2, available from the file
7  * COPYING in the main directory of this source tree, or the
8  * OpenIB.org BSD license below:
9  *
10  *     Redistribution and use in source and binary forms, with or
11  *     without modification, are permitted provided that the following
12  *     conditions are met:
13  *
14  *      - Redistributions of source code must retain the above
15  *        copyright notice, this list of conditions and the following
16  *        disclaimer.
17  *
18  *      - Redistributions in binary form must reproduce the above
19  *        copyright notice, this list of conditions and the following
20  *        disclaimer in the documentation and/or other materials
21  *        provided with the distribution.
22  *
23  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
27  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
28  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
29  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30  * SOFTWARE.
31  *
32  */
33 #define _GNU_SOURCE
34 #include <config.h>
35
36 #include <sys/types.h>
37 #include <sys/socket.h>
38 #include <sys/time.h>
39 #include <infiniband/endian.h>
40 #include <stdarg.h>
41 #include <netdb.h>
42 #include <unistd.h>
43 #include <fcntl.h>
44 #include <stdio.h>
45 #include <stddef.h>
46 #include <string.h>
47 #include <netinet/tcp.h>
48 #include <sys/epoll.h>
49 #include <search.h>
50 #include <byteswap.h>
51 #include <util/compiler.h>
52
53 #include <rdma/rdma_cma.h>
54 #include <rdma/rdma_verbs.h>
55 #include <rdma/rsocket.h>
56 #include "cma.h"
57 #include "indexer.h"
58
59 #define RS_OLAP_START_SIZE 2048
60 #define RS_MAX_TRANSFER 65536
61 #define RS_SNDLOWAT 2048
62 #define RS_QP_MIN_SIZE 16
63 #define RS_QP_MAX_SIZE 0xFFFE
64 #define RS_QP_CTRL_SIZE 4       /* must be power of 2 */
65 #define RS_CONN_RETRIES 6
66 #define RS_SGL_SIZE 2
67 static struct index_map idm;
68 static pthread_mutex_t mut = PTHREAD_MUTEX_INITIALIZER;
69
70 struct rsocket;
71
72 enum {
73         RS_SVC_NOOP,
74         RS_SVC_ADD_DGRAM,
75         RS_SVC_REM_DGRAM,
76         RS_SVC_ADD_KEEPALIVE,
77         RS_SVC_REM_KEEPALIVE,
78         RS_SVC_MOD_KEEPALIVE
79 };
80
81 struct rs_svc_msg {
82         uint32_t cmd;
83         uint32_t status;
84         struct rsocket *rs;
85 };
86
87 struct rs_svc {
88         pthread_t id;
89         int sock[2];
90         int cnt;
91         int size;
92         int context_size;
93         void *(*run)(void *svc);
94         struct rsocket **rss;
95         void *contexts;
96 };
97
98 static struct pollfd *udp_svc_fds;
99 static void *udp_svc_run(void *arg);
100 static struct rs_svc udp_svc = {
101         .context_size = sizeof(*udp_svc_fds),
102         .run = udp_svc_run
103 };
104 static uint32_t *tcp_svc_timeouts;
105 static void *tcp_svc_run(void *arg);
106 static struct rs_svc tcp_svc = {
107         .context_size = sizeof(*tcp_svc_timeouts),
108         .run = tcp_svc_run
109 };
110
111 static uint16_t def_iomap_size = 0;
112 static uint16_t def_inline = 64;
113 static uint16_t def_sqsize = 384;
114 static uint16_t def_rqsize = 384;
115 static uint32_t def_mem = (1 << 17);
116 static uint32_t def_wmem = (1 << 17);
117 static uint32_t polling_time = 10;
118
119 /*
120  * Immediate data format is determined by the upper bits
121  * bit 31: message type, 0 - data, 1 - control
122  * bit 30: buffers updated, 0 - target, 1 - direct-receive
123  * bit 29: more data, 0 - end of transfer, 1 - more data available
124  *
125  * for data transfers:
126  * bits [28:0]: bytes transferred
127  * for control messages:
128  * SGL, CTRL
129  * bits [28-0]: receive credits granted
130  * IOMAP_SGL
131  * bits [28-16]: reserved, bits [15-0]: index
132  */
133
134 enum {
135         RS_OP_DATA,
136         RS_OP_RSVD_DATA_MORE,
137         RS_OP_WRITE, /* opcode is not transmitted over the network */
138         RS_OP_RSVD_DRA_MORE,
139         RS_OP_SGL,
140         RS_OP_RSVD,
141         RS_OP_IOMAP_SGL,
142         RS_OP_CTRL
143 };
144 #define rs_msg_set(op, data)  ((op << 29) | (uint32_t) (data))
145 #define rs_msg_op(imm_data)   (imm_data >> 29)
146 #define rs_msg_data(imm_data) (imm_data & 0x1FFFFFFF)
147 #define RS_MSG_SIZE           sizeof(uint32_t)
148
149 #define RS_WR_ID_FLAG_RECV (((uint64_t) 1) << 63)
150 #define RS_WR_ID_FLAG_MSG_SEND (((uint64_t) 1) << 62) /* See RS_OPT_MSG_SEND */
151 #define rs_send_wr_id(data) ((uint64_t) data)
152 #define rs_recv_wr_id(data) (RS_WR_ID_FLAG_RECV | (uint64_t) data)
153 #define rs_wr_is_recv(wr_id) (wr_id & RS_WR_ID_FLAG_RECV)
154 #define rs_wr_is_msg_send(wr_id) (wr_id & RS_WR_ID_FLAG_MSG_SEND)
155 #define rs_wr_data(wr_id) ((uint32_t) wr_id)
156
157 enum {
158         RS_CTRL_DISCONNECT,
159         RS_CTRL_KEEPALIVE,
160         RS_CTRL_SHUTDOWN
161 };
162
163 struct rs_msg {
164         uint32_t op;
165         uint32_t data;
166 };
167
168 struct ds_qp;
169
170 struct ds_rmsg {
171         struct ds_qp    *qp;
172         uint32_t        offset;
173         uint32_t        length;
174 };
175
176 struct ds_smsg {
177         struct ds_smsg  *next;
178 };
179
180 struct rs_sge {
181         uint64_t addr;
182         uint32_t key;
183         uint32_t length;
184 };
185
186 struct rs_iomap {
187         uint64_t offset;
188         struct rs_sge sge;
189 };
190
191 struct rs_iomap_mr {
192         uint64_t offset;
193         struct ibv_mr *mr;
194         dlist_entry entry;
195         _Atomic(int) refcnt;
196         int index;      /* -1 if mapping is local and not in iomap_list */
197 };
198
199 #define RS_MAX_CTRL_MSG    (sizeof(struct rs_sge))
200 #define rs_host_is_net()   (__BYTE_ORDER == __BIG_ENDIAN)
201 #define RS_CONN_FLAG_NET   (1 << 0)
202 #define RS_CONN_FLAG_IOMAP (1 << 1)
203
204 struct rs_conn_data {
205         uint8_t           version;
206         uint8_t           flags;
207         __be16            credits;
208         uint8_t           reserved[3];
209         uint8_t           target_iomap_size;
210         struct rs_sge     target_sgl;
211         struct rs_sge     data_buf;
212 };
213
214 struct rs_conn_private_data {
215         union {
216                 struct rs_conn_data             conn_data;
217                 struct {
218                         struct ib_connect_hdr   ib_hdr;
219                         struct rs_conn_data     conn_data;
220                 } af_ib;
221         };
222 };
223
224 /*
225  * rsocket states are ordered as passive, connecting, connected, disconnected.
226  */
227 enum rs_state {
228         rs_init,
229         rs_bound           =                0x0001,
230         rs_listening       =                0x0002,
231         rs_opening         =                0x0004,
232         rs_resolving_addr  = rs_opening |   0x0010,
233         rs_resolving_route = rs_opening |   0x0020,
234         rs_connecting      = rs_opening |   0x0040,
235         rs_accepting       = rs_opening |   0x0080,
236         rs_connected       =                0x0100,
237         rs_writable        =                0x0200,
238         rs_readable        =                0x0400,
239         rs_connect_rdwr    = rs_connected | rs_readable | rs_writable,
240         rs_connect_error   =                0x0800,
241         rs_disconnected    =                0x1000,
242         rs_error           =                0x2000,
243 };
244
245 #define RS_OPT_SWAP_SGL   (1 << 0)
246 /*
247  * iWarp does not support RDMA write with immediate data.  For iWarp, we
248  * transfer rsocket messages as inline sends.
249  */
250 #define RS_OPT_MSG_SEND   (1 << 1)
251 #define RS_OPT_SVC_ACTIVE (1 << 2)
252
253 union socket_addr {
254         struct sockaddr         sa;
255         struct sockaddr_in      sin;
256         struct sockaddr_in6     sin6;
257 };
258
259 struct ds_header {
260         uint8_t           version;
261         uint8_t           length;
262         __be16            port;
263         union {
264                 __be32  ipv4;
265                 struct {
266                         __be32 flowinfo;
267                         uint8_t  addr[16];
268                 } ipv6;
269         } addr;
270 };
271
272 #define DS_IPV4_HDR_LEN  8
273 #define DS_IPV6_HDR_LEN 24
274
275 struct ds_dest {
276         union socket_addr addr; /* must be first */
277         struct ds_qp      *qp;
278         struct ibv_ah     *ah;
279         uint32_t           qpn;
280 };
281
282 struct ds_qp {
283         dlist_entry       list;
284         struct rsocket    *rs;
285         struct rdma_cm_id *cm_id;
286         struct ds_header  hdr;
287         struct ds_dest    dest;
288
289         struct ibv_mr     *smr;
290         struct ibv_mr     *rmr;
291         uint8_t           *rbuf;
292
293         int               cq_armed;
294 };
295
296 struct rsocket {
297         int               type;
298         int               index;
299         fastlock_t        slock;
300         fastlock_t        rlock;
301         fastlock_t        cq_lock;
302         fastlock_t        cq_wait_lock;
303         fastlock_t        map_lock; /* acquire slock first if needed */
304
305         union {
306                 /* data stream */
307                 struct {
308                         struct rdma_cm_id *cm_id;
309                         uint64_t          tcp_opts;
310                         unsigned int      keepalive_time;
311
312                         unsigned int      ctrl_seqno;
313                         unsigned int      ctrl_max_seqno;
314                         uint16_t          sseq_no;
315                         uint16_t          sseq_comp;
316                         uint16_t          rseq_no;
317                         uint16_t          rseq_comp;
318
319                         int               remote_sge;
320                         struct rs_sge     remote_sgl;
321                         struct rs_sge     remote_iomap;
322
323                         struct ibv_mr     *target_mr;
324                         int               target_sge;
325                         int               target_iomap_size;
326                         void              *target_buffer_list;
327                         volatile struct rs_sge    *target_sgl;
328                         struct rs_iomap   *target_iomap;
329
330                         int               rbuf_msg_index;
331                         int               rbuf_bytes_avail;
332                         int               rbuf_free_offset;
333                         int               rbuf_offset;
334                         struct ibv_mr     *rmr;
335                         uint8_t           *rbuf;
336
337                         int               sbuf_bytes_avail;
338                         struct ibv_mr     *smr;
339                         struct ibv_sge    ssgl[2];
340                 };
341                 /* datagram */
342                 struct {
343                         struct ds_qp      *qp_list;
344                         void              *dest_map;
345                         struct ds_dest    *conn_dest;
346
347                         int               udp_sock;
348                         int               epfd;
349                         int               rqe_avail;
350                         struct ds_smsg    *smsg_free;
351                 };
352         };
353
354         int               opts;
355         int               fd_flags;
356         uint64_t          so_opts;
357         uint64_t          ipv6_opts;
358         void              *optval;
359         size_t            optlen;
360         int               state;
361         int               cq_armed;
362         int               retries;
363         int               err;
364
365         int               sqe_avail;
366         uint32_t          sbuf_size;
367         uint16_t          sq_size;
368         uint16_t          sq_inline;
369
370         uint32_t          rbuf_size;
371         uint16_t          rq_size;
372         int               rmsg_head;
373         int               rmsg_tail;
374         union {
375                 struct rs_msg     *rmsg;
376                 struct ds_rmsg    *dmsg;
377         };
378
379         uint8_t           *sbuf;
380         struct rs_iomap_mr *remote_iomappings;
381         dlist_entry       iomap_list;
382         dlist_entry       iomap_queue;
383         int               iomap_pending;
384         int               unack_cqe;
385 };
386
387 #define DS_UDP_TAG 0x55555555
388
389 struct ds_udp_header {
390         __be32            tag;
391         uint8_t           version;
392         uint8_t           op;
393         uint8_t           length;
394         uint8_t           reserved;
395         __be32            qpn;  /* lower 8-bits reserved */
396         union {
397                 __be32   ipv4;
398                 uint8_t  ipv6[16];
399         } addr;
400 };
401
402 #define DS_UDP_IPV4_HDR_LEN 16
403 #define DS_UDP_IPV6_HDR_LEN 28
404
405 #define ds_next_qp(qp) container_of((qp)->list.next, struct ds_qp, list)
406
407 static void write_all(int fd, const void *msg, size_t len)
408 {
409         // FIXME: if fd is a socket this really needs to handle EINTR and other conditions.
410         ssize_t rc = write(fd, msg, len);
411         assert(rc == len);
412 }
413
414 static void read_all(int fd, void *msg, size_t len)
415 {
416         // FIXME: if fd is a socket this really needs to handle EINTR and other conditions.
417         ssize_t rc = read(fd, msg, len);
418         assert(rc == len);
419 }
420
421 static void ds_insert_qp(struct rsocket *rs, struct ds_qp *qp)
422 {
423         if (!rs->qp_list)
424                 dlist_init(&qp->list);
425         else
426                 dlist_insert_head(&qp->list, &rs->qp_list->list);
427         rs->qp_list = qp;
428 }
429
430 static void ds_remove_qp(struct rsocket *rs, struct ds_qp *qp)
431 {
432         if (qp->list.next != &qp->list) {
433                 rs->qp_list = ds_next_qp(qp);
434                 dlist_remove(&qp->list);
435         } else {
436                 rs->qp_list = NULL;
437         }
438 }
439
440 static int rs_notify_svc(struct rs_svc *svc, struct rsocket *rs, int cmd)
441 {
442         struct rs_svc_msg msg;
443         int ret;
444
445         pthread_mutex_lock(&mut);
446         if (!svc->cnt) {
447                 ret = socketpair(AF_UNIX, SOCK_STREAM, 0, svc->sock);
448                 if (ret)
449                         goto unlock;
450
451                 ret = pthread_create(&svc->id, NULL, svc->run, svc);
452                 if (ret) {
453                         ret = ERR(ret);
454                         goto closepair;
455                 }
456         }
457
458         msg.cmd = cmd;
459         msg.status = EINVAL;
460         msg.rs = rs;
461         write_all(svc->sock[0], &msg, sizeof msg);
462         read_all(svc->sock[0], &msg, sizeof msg);
463         ret = rdma_seterrno(msg.status);
464         if (svc->cnt)
465                 goto unlock;
466
467         pthread_join(svc->id, NULL);
468 closepair:
469         close(svc->sock[0]);
470         close(svc->sock[1]);
471 unlock:
472         pthread_mutex_unlock(&mut);
473         return ret;
474 }
475
476 static int ds_compare_addr(const void *dst1, const void *dst2)
477 {
478         const struct sockaddr *sa1, *sa2;
479         size_t len;
480
481         sa1 = (const struct sockaddr *) dst1;
482         sa2 = (const struct sockaddr *) dst2;
483
484         len = (sa1->sa_family == AF_INET6 && sa2->sa_family == AF_INET6) ?
485               sizeof(struct sockaddr_in6) : sizeof(struct sockaddr_in);
486         return memcmp(dst1, dst2, len);
487 }
488
489 static int rs_value_to_scale(int value, int bits)
490 {
491         return value <= (1 << (bits - 1)) ?
492                value : (1 << (bits - 1)) | (value >> bits);
493 }
494
495 static int rs_scale_to_value(int value, int bits)
496 {
497         return value <= (1 << (bits - 1)) ?
498                value : (value & ~(1 << (bits - 1))) << bits;
499 }
500
501 /* gcc > ~5 will not allow (void)fscanf to suppress -Wunused-result, but this
502    will do it.  In this case ignoring the result is OK (but horribly
503    unfriendly to user) since the library has a sane default. */
504 #define failable_fscanf(f, fmt, ...)                                           \
505         {                                                                      \
506                 int rc = fscanf(f, fmt, __VA_ARGS__);                          \
507                 (void) rc;                                                     \
508         }
509
510 static void rs_configure(void)
511 {
512         FILE *f;
513         static int init;
514
515         if (init)
516                 return;
517
518         pthread_mutex_lock(&mut);
519         if (init)
520                 goto out;
521
522         if (ucma_init())
523                 goto out;
524         ucma_ib_init();
525
526         if ((f = fopen(RS_CONF_DIR "/polling_time", "r"))) {
527                 failable_fscanf(f, "%u", &polling_time);
528                 fclose(f);
529         }
530
531         if ((f = fopen(RS_CONF_DIR "/inline_default", "r"))) {
532                 failable_fscanf(f, "%hu", &def_inline);
533                 fclose(f);
534         }
535
536         if ((f = fopen(RS_CONF_DIR "/sqsize_default", "r"))) {
537                 failable_fscanf(f, "%hu", &def_sqsize);
538                 fclose(f);
539         }
540
541         if ((f = fopen(RS_CONF_DIR "/rqsize_default", "r"))) {
542                 failable_fscanf(f, "%hu", &def_rqsize);
543                 fclose(f);
544         }
545
546         if ((f = fopen(RS_CONF_DIR "/mem_default", "r"))) {
547                 failable_fscanf(f, "%u", &def_mem);
548                 fclose(f);
549
550                 if (def_mem < 1)
551                         def_mem = 1;
552         }
553
554         if ((f = fopen(RS_CONF_DIR "/wmem_default", "r"))) {
555                 failable_fscanf(f, "%u", &def_wmem);
556                 fclose(f);
557                 if (def_wmem < RS_SNDLOWAT)
558                         def_wmem = RS_SNDLOWAT << 1;
559         }
560
561         if ((f = fopen(RS_CONF_DIR "/iomap_size", "r"))) {
562                 failable_fscanf(f, "%hu", &def_iomap_size);
563                 fclose(f);
564
565                 /* round to supported values */
566                 def_iomap_size = (uint8_t) rs_value_to_scale(
567                         (uint16_t) rs_scale_to_value(def_iomap_size, 8), 8);
568         }
569         init = 1;
570 out:
571         pthread_mutex_unlock(&mut);
572 }
573
574 static int rs_insert(struct rsocket *rs, int index)
575 {
576         pthread_mutex_lock(&mut);
577         rs->index = idm_set(&idm, index, rs);
578         pthread_mutex_unlock(&mut);
579         return rs->index;
580 }
581
582 static void rs_remove(struct rsocket *rs)
583 {
584         pthread_mutex_lock(&mut);
585         idm_clear(&idm, rs->index);
586         pthread_mutex_unlock(&mut);
587 }
588
589 /* We only inherit from listening sockets */
590 static struct rsocket *rs_alloc(struct rsocket *inherited_rs, int type)
591 {
592         struct rsocket *rs;
593
594         rs = calloc(1, sizeof(*rs));
595         if (!rs)
596                 return NULL;
597
598         rs->type = type;
599         rs->index = -1;
600         if (type == SOCK_DGRAM) {
601                 rs->udp_sock = -1;
602                 rs->epfd = -1;
603         }
604
605         if (inherited_rs) {
606                 rs->sbuf_size = inherited_rs->sbuf_size;
607                 rs->rbuf_size = inherited_rs->rbuf_size;
608                 rs->sq_inline = inherited_rs->sq_inline;
609                 rs->sq_size = inherited_rs->sq_size;
610                 rs->rq_size = inherited_rs->rq_size;
611                 if (type == SOCK_STREAM) {
612                         rs->ctrl_max_seqno = inherited_rs->ctrl_max_seqno;
613                         rs->target_iomap_size = inherited_rs->target_iomap_size;
614                 }
615         } else {
616                 rs->sbuf_size = def_wmem;
617                 rs->rbuf_size = def_mem;
618                 rs->sq_inline = def_inline;
619                 rs->sq_size = def_sqsize;
620                 rs->rq_size = def_rqsize;
621                 if (type == SOCK_STREAM) {
622                         rs->ctrl_max_seqno = RS_QP_CTRL_SIZE;
623                         rs->target_iomap_size = def_iomap_size;
624                 }
625         }
626         fastlock_init(&rs->slock);
627         fastlock_init(&rs->rlock);
628         fastlock_init(&rs->cq_lock);
629         fastlock_init(&rs->cq_wait_lock);
630         fastlock_init(&rs->map_lock);
631         dlist_init(&rs->iomap_list);
632         dlist_init(&rs->iomap_queue);
633         return rs;
634 }
635
636 static int rs_set_nonblocking(struct rsocket *rs, int arg)
637 {
638         struct ds_qp *qp;
639         int ret = 0;
640
641         if (rs->type == SOCK_STREAM) {
642                 if (rs->cm_id->recv_cq_channel)
643                         ret = fcntl(rs->cm_id->recv_cq_channel->fd, F_SETFL, arg);
644
645                 if (!ret && rs->state < rs_connected)
646                         ret = fcntl(rs->cm_id->channel->fd, F_SETFL, arg);
647         } else {
648                 ret = fcntl(rs->epfd, F_SETFL, arg);
649                 if (!ret && rs->qp_list) {
650                         qp = rs->qp_list;
651                         do {
652                                 ret = fcntl(qp->cm_id->recv_cq_channel->fd,
653                                             F_SETFL, arg);
654                                 qp = ds_next_qp(qp);
655                         } while (qp != rs->qp_list && !ret);
656                 }
657         }
658
659         return ret;
660 }
661
662 static void rs_set_qp_size(struct rsocket *rs)
663 {
664         uint16_t max_size;
665
666         max_size = min(ucma_max_qpsize(rs->cm_id), RS_QP_MAX_SIZE);
667
668         if (rs->sq_size > max_size)
669                 rs->sq_size = max_size;
670         else if (rs->sq_size < RS_QP_MIN_SIZE)
671                 rs->sq_size = RS_QP_MIN_SIZE;
672
673         if (rs->rq_size > max_size)
674                 rs->rq_size = max_size;
675         else if (rs->rq_size < RS_QP_MIN_SIZE)
676                 rs->rq_size = RS_QP_MIN_SIZE;
677 }
678
679 static void ds_set_qp_size(struct rsocket *rs)
680 {
681         uint16_t max_size;
682
683         max_size = min(ucma_max_qpsize(NULL), RS_QP_MAX_SIZE);
684
685         if (rs->sq_size > max_size)
686                 rs->sq_size = max_size;
687         if (rs->rq_size > max_size)
688                 rs->rq_size = max_size;
689
690         if (rs->rq_size > (rs->rbuf_size / RS_SNDLOWAT))
691                 rs->rq_size = rs->rbuf_size / RS_SNDLOWAT;
692         else
693                 rs->rbuf_size = rs->rq_size * RS_SNDLOWAT;
694
695         if (rs->sq_size > (rs->sbuf_size / RS_SNDLOWAT))
696                 rs->sq_size = rs->sbuf_size / RS_SNDLOWAT;
697         else
698                 rs->sbuf_size = rs->sq_size * RS_SNDLOWAT;
699 }
700
701 static int rs_init_bufs(struct rsocket *rs)
702 {
703         uint32_t total_rbuf_size, total_sbuf_size;
704         size_t len;
705
706         rs->rmsg = calloc(rs->rq_size + 1, sizeof(*rs->rmsg));
707         if (!rs->rmsg)
708                 return ERR(ENOMEM);
709
710         total_sbuf_size = rs->sbuf_size;
711         if (rs->sq_inline < RS_MAX_CTRL_MSG)
712                 total_sbuf_size += RS_MAX_CTRL_MSG * RS_QP_CTRL_SIZE;
713         rs->sbuf = calloc(total_sbuf_size, 1);
714         if (!rs->sbuf)
715                 return ERR(ENOMEM);
716
717         rs->smr = rdma_reg_msgs(rs->cm_id, rs->sbuf, total_sbuf_size);
718         if (!rs->smr)
719                 return -1;
720
721         len = sizeof(*rs->target_sgl) * RS_SGL_SIZE +
722               sizeof(*rs->target_iomap) * rs->target_iomap_size;
723         rs->target_buffer_list = malloc(len);
724         if (!rs->target_buffer_list)
725                 return ERR(ENOMEM);
726
727         rs->target_mr = rdma_reg_write(rs->cm_id, rs->target_buffer_list, len);
728         if (!rs->target_mr)
729                 return -1;
730
731         memset(rs->target_buffer_list, 0, len);
732         rs->target_sgl = rs->target_buffer_list;
733         if (rs->target_iomap_size)
734                 rs->target_iomap = (struct rs_iomap *) (rs->target_sgl + RS_SGL_SIZE);
735
736         total_rbuf_size = rs->rbuf_size;
737         if (rs->opts & RS_OPT_MSG_SEND)
738                 total_rbuf_size += rs->rq_size * RS_MSG_SIZE;
739         rs->rbuf = calloc(total_rbuf_size, 1);
740         if (!rs->rbuf)
741                 return ERR(ENOMEM);
742
743         rs->rmr = rdma_reg_write(rs->cm_id, rs->rbuf, total_rbuf_size);
744         if (!rs->rmr)
745                 return -1;
746
747         rs->ssgl[0].addr = rs->ssgl[1].addr = (uintptr_t) rs->sbuf;
748         rs->sbuf_bytes_avail = rs->sbuf_size;
749         rs->ssgl[0].lkey = rs->ssgl[1].lkey = rs->smr->lkey;
750
751         rs->rbuf_free_offset = rs->rbuf_size >> 1;
752         rs->rbuf_bytes_avail = rs->rbuf_size >> 1;
753         rs->sqe_avail = rs->sq_size - rs->ctrl_max_seqno;
754         rs->rseq_comp = rs->rq_size >> 1;
755         return 0;
756 }
757
758 static int ds_init_bufs(struct ds_qp *qp)
759 {
760         qp->rbuf = calloc(qp->rs->rbuf_size + sizeof(struct ibv_grh), 1);
761         if (!qp->rbuf)
762                 return ERR(ENOMEM);
763
764         qp->smr = rdma_reg_msgs(qp->cm_id, qp->rs->sbuf, qp->rs->sbuf_size);
765         if (!qp->smr)
766                 return -1;
767
768         qp->rmr = rdma_reg_msgs(qp->cm_id, qp->rbuf, qp->rs->rbuf_size +
769                                                      sizeof(struct ibv_grh));
770         if (!qp->rmr)
771                 return -1;
772
773         return 0;
774 }
775
776 /*
777  * If a user is waiting on a datagram rsocket through poll or select, then
778  * we need the first completion to generate an event on the related epoll fd
779  * in order to signal the user.  We arm the CQ on creation for this purpose
780  */
781 static int rs_create_cq(struct rsocket *rs, struct rdma_cm_id *cm_id)
782 {
783         cm_id->recv_cq_channel = ibv_create_comp_channel(cm_id->verbs);
784         if (!cm_id->recv_cq_channel)
785                 return -1;
786
787         cm_id->recv_cq = ibv_create_cq(cm_id->verbs, rs->sq_size + rs->rq_size,
788                                        cm_id, cm_id->recv_cq_channel, 0);
789         if (!cm_id->recv_cq)
790                 goto err1;
791
792         if (rs->fd_flags & O_NONBLOCK) {
793                 if (fcntl(cm_id->recv_cq_channel->fd, F_SETFL, O_NONBLOCK))
794                         goto err2;
795         }
796
797         ibv_req_notify_cq(cm_id->recv_cq, 0);
798         cm_id->send_cq_channel = cm_id->recv_cq_channel;
799         cm_id->send_cq = cm_id->recv_cq;
800         return 0;
801
802 err2:
803         ibv_destroy_cq(cm_id->recv_cq);
804         cm_id->recv_cq = NULL;
805 err1:
806         ibv_destroy_comp_channel(cm_id->recv_cq_channel);
807         cm_id->recv_cq_channel = NULL;
808         return -1;
809 }
810
811 static inline int rs_post_recv(struct rsocket *rs)
812 {
813         struct ibv_recv_wr wr, *bad;
814         struct ibv_sge sge;
815
816         wr.next = NULL;
817         if (!(rs->opts & RS_OPT_MSG_SEND)) {
818                 wr.wr_id = rs_recv_wr_id(0);
819                 wr.sg_list = NULL;
820                 wr.num_sge = 0;
821         } else {
822                 wr.wr_id = rs_recv_wr_id(rs->rbuf_msg_index);
823                 sge.addr = (uintptr_t) rs->rbuf + rs->rbuf_size +
824                            (rs->rbuf_msg_index * RS_MSG_SIZE);
825                 sge.length = RS_MSG_SIZE;
826                 sge.lkey = rs->rmr->lkey;
827
828                 wr.sg_list = &sge;
829                 wr.num_sge = 1;
830                 if(++rs->rbuf_msg_index == rs->rq_size)
831                         rs->rbuf_msg_index = 0;
832         }
833
834         return rdma_seterrno(ibv_post_recv(rs->cm_id->qp, &wr, &bad));
835 }
836
837 static inline int ds_post_recv(struct rsocket *rs, struct ds_qp *qp, uint32_t offset)
838 {
839         struct ibv_recv_wr wr, *bad;
840         struct ibv_sge sge[2];
841
842         sge[0].addr = (uintptr_t) qp->rbuf + rs->rbuf_size;
843         sge[0].length = sizeof(struct ibv_grh);
844         sge[0].lkey = qp->rmr->lkey;
845         sge[1].addr = (uintptr_t) qp->rbuf + offset;
846         sge[1].length = RS_SNDLOWAT;
847         sge[1].lkey = qp->rmr->lkey;
848
849         wr.wr_id = rs_recv_wr_id(offset);
850         wr.next = NULL;
851         wr.sg_list = sge;
852         wr.num_sge = 2;
853
854         return rdma_seterrno(ibv_post_recv(qp->cm_id->qp, &wr, &bad));
855 }
856
857 static int rs_create_ep(struct rsocket *rs)
858 {
859         struct ibv_qp_init_attr qp_attr;
860         int i, ret;
861
862         rs_set_qp_size(rs);
863         if (rs->cm_id->verbs->device->transport_type == IBV_TRANSPORT_IWARP)
864                 rs->opts |= RS_OPT_MSG_SEND;
865         ret = rs_create_cq(rs, rs->cm_id);
866         if (ret)
867                 return ret;
868
869         memset(&qp_attr, 0, sizeof qp_attr);
870         qp_attr.qp_context = rs;
871         qp_attr.send_cq = rs->cm_id->send_cq;
872         qp_attr.recv_cq = rs->cm_id->recv_cq;
873         qp_attr.qp_type = IBV_QPT_RC;
874         qp_attr.sq_sig_all = 1;
875         qp_attr.cap.max_send_wr = rs->sq_size;
876         qp_attr.cap.max_recv_wr = rs->rq_size;
877         qp_attr.cap.max_send_sge = 2;
878         qp_attr.cap.max_recv_sge = 1;
879         qp_attr.cap.max_inline_data = rs->sq_inline;
880
881         ret = rdma_create_qp(rs->cm_id, NULL, &qp_attr);
882         if (ret)
883                 return ret;
884
885         rs->sq_inline = qp_attr.cap.max_inline_data;
886         if ((rs->opts & RS_OPT_MSG_SEND) && (rs->sq_inline < RS_MSG_SIZE))
887                 return ERR(ENOTSUP);
888
889         ret = rs_init_bufs(rs);
890         if (ret)
891                 return ret;
892
893         for (i = 0; i < rs->rq_size; i++) {
894                 ret = rs_post_recv(rs);
895                 if (ret)
896                         return ret;
897         }
898         return 0;
899 }
900
901 static void rs_release_iomap_mr(struct rs_iomap_mr *iomr)
902 {
903         if (atomic_fetch_sub(&iomr->refcnt, 1) != 1)
904                 return;
905
906         dlist_remove(&iomr->entry);
907         ibv_dereg_mr(iomr->mr);
908         if (iomr->index >= 0)
909                 iomr->mr = NULL;
910         else
911                 free(iomr);
912 }
913
914 static void rs_free_iomappings(struct rsocket *rs)
915 {
916         struct rs_iomap_mr *iomr;
917
918         while (!dlist_empty(&rs->iomap_list)) {
919                 iomr = container_of(rs->iomap_list.next,
920                                     struct rs_iomap_mr, entry);
921                 riounmap(rs->index, iomr->mr->addr, iomr->mr->length);
922         }
923         while (!dlist_empty(&rs->iomap_queue)) {
924                 iomr = container_of(rs->iomap_queue.next,
925                                     struct rs_iomap_mr, entry);
926                 riounmap(rs->index, iomr->mr->addr, iomr->mr->length);
927         }
928 }
929
930 static void ds_free_qp(struct ds_qp *qp)
931 {
932         if (qp->smr)
933                 rdma_dereg_mr(qp->smr);
934
935         if (qp->rbuf) {
936                 if (qp->rmr)
937                         rdma_dereg_mr(qp->rmr);
938                 free(qp->rbuf);
939         }
940
941         if (qp->cm_id) {
942                 if (qp->cm_id->qp) {
943                         tdelete(&qp->dest.addr, &qp->rs->dest_map, ds_compare_addr);
944                         epoll_ctl(qp->rs->epfd, EPOLL_CTL_DEL,
945                                   qp->cm_id->recv_cq_channel->fd, NULL);
946                         rdma_destroy_qp(qp->cm_id);
947                 }
948                 rdma_destroy_id(qp->cm_id);
949         }
950
951         free(qp);
952 }
953
954 static void ds_free(struct rsocket *rs)
955 {
956         struct ds_qp *qp;
957
958         if (rs->udp_sock >= 0)
959                 close(rs->udp_sock);
960
961         if (rs->index >= 0)
962                 rs_remove(rs);
963
964         if (rs->dmsg)
965                 free(rs->dmsg);
966
967         while ((qp = rs->qp_list)) {
968                 ds_remove_qp(rs, qp);
969                 ds_free_qp(qp);
970         }
971
972         if (rs->epfd >= 0)
973                 close(rs->epfd);
974
975         if (rs->sbuf)
976                 free(rs->sbuf);
977
978         tdestroy(rs->dest_map, free);
979         fastlock_destroy(&rs->map_lock);
980         fastlock_destroy(&rs->cq_wait_lock);
981         fastlock_destroy(&rs->cq_lock);
982         fastlock_destroy(&rs->rlock);
983         fastlock_destroy(&rs->slock);
984         free(rs);
985 }
986
987 static void rs_free(struct rsocket *rs)
988 {
989         if (rs->type == SOCK_DGRAM) {
990                 ds_free(rs);
991                 return;
992         }
993
994         if (rs->rmsg)
995                 free(rs->rmsg);
996
997         if (rs->sbuf) {
998                 if (rs->smr)
999                         rdma_dereg_mr(rs->smr);
1000                 free(rs->sbuf);
1001         }
1002
1003         if (rs->rbuf) {
1004                 if (rs->rmr)
1005                         rdma_dereg_mr(rs->rmr);
1006                 free(rs->rbuf);
1007         }
1008
1009         if (rs->target_buffer_list) {
1010                 if (rs->target_mr)
1011                         rdma_dereg_mr(rs->target_mr);
1012                 free(rs->target_buffer_list);
1013         }
1014
1015         if (rs->cm_id) {
1016                 rs_free_iomappings(rs);
1017                 if (rs->cm_id->qp) {
1018                         ibv_ack_cq_events(rs->cm_id->recv_cq, rs->unack_cqe);
1019                         rdma_destroy_qp(rs->cm_id);
1020                 }
1021                 rdma_destroy_id(rs->cm_id);
1022         }
1023
1024         if (rs->index >= 0)
1025                 rs_remove(rs);
1026
1027         fastlock_destroy(&rs->map_lock);
1028         fastlock_destroy(&rs->cq_wait_lock);
1029         fastlock_destroy(&rs->cq_lock);
1030         fastlock_destroy(&rs->rlock);
1031         fastlock_destroy(&rs->slock);
1032         free(rs);
1033 }
1034
1035 static size_t rs_conn_data_offset(struct rsocket *rs)
1036 {
1037         return (rs->cm_id->route.addr.src_addr.sa_family == AF_IB) ?
1038                 sizeof(struct ib_connect_hdr) : 0;
1039 }
1040
1041 static void rs_format_conn_data(struct rsocket *rs, struct rs_conn_data *conn)
1042 {
1043         conn->version = 1;
1044         conn->flags = RS_CONN_FLAG_IOMAP |
1045                       (rs_host_is_net() ? RS_CONN_FLAG_NET : 0);
1046         conn->credits = htobe16(rs->rq_size);
1047         memset(conn->reserved, 0, sizeof conn->reserved);
1048         conn->target_iomap_size = (uint8_t) rs_value_to_scale(rs->target_iomap_size, 8);
1049
1050         conn->target_sgl.addr = (__force uint64_t)htobe64((uintptr_t) rs->target_sgl);
1051         conn->target_sgl.length = (__force uint32_t)htobe32(RS_SGL_SIZE);
1052         conn->target_sgl.key = (__force uint32_t)htobe32(rs->target_mr->rkey);
1053
1054         conn->data_buf.addr = (__force uint64_t)htobe64((uintptr_t) rs->rbuf);
1055         conn->data_buf.length = (__force uint32_t)htobe32(rs->rbuf_size >> 1);
1056         conn->data_buf.key = (__force uint32_t)htobe32(rs->rmr->rkey);
1057 }
1058
1059 static void rs_save_conn_data(struct rsocket *rs, struct rs_conn_data *conn)
1060 {
1061         rs->remote_sgl.addr = be64toh((__force __be64)conn->target_sgl.addr);
1062         rs->remote_sgl.length = be32toh((__force __be32)conn->target_sgl.length);
1063         rs->remote_sgl.key = be32toh((__force __be32)conn->target_sgl.key);
1064         rs->remote_sge = 1;
1065         if ((rs_host_is_net() && !(conn->flags & RS_CONN_FLAG_NET)) ||
1066             (!rs_host_is_net() && (conn->flags & RS_CONN_FLAG_NET)))
1067                 rs->opts = RS_OPT_SWAP_SGL;
1068
1069         if (conn->flags & RS_CONN_FLAG_IOMAP) {
1070                 rs->remote_iomap.addr = rs->remote_sgl.addr +
1071                                         sizeof(rs->remote_sgl) * rs->remote_sgl.length;
1072                 rs->remote_iomap.length = rs_scale_to_value(conn->target_iomap_size, 8);
1073                 rs->remote_iomap.key = rs->remote_sgl.key;
1074         }
1075
1076         rs->target_sgl[0].addr = be64toh((__force __be64)conn->data_buf.addr);
1077         rs->target_sgl[0].length = be32toh((__force __be32)conn->data_buf.length);
1078         rs->target_sgl[0].key = be32toh((__force __be32)conn->data_buf.key);
1079
1080         rs->sseq_comp = be16toh(conn->credits);
1081 }
1082
1083 static int ds_init(struct rsocket *rs, int domain)
1084 {
1085         rs->udp_sock = socket(domain, SOCK_DGRAM, 0);
1086         if (rs->udp_sock < 0)
1087                 return rs->udp_sock;
1088
1089         rs->epfd = epoll_create(2);
1090         if (rs->epfd < 0)
1091                 return rs->epfd;
1092
1093         return 0;
1094 }
1095
1096 static int ds_init_ep(struct rsocket *rs)
1097 {
1098         struct ds_smsg *msg;
1099         int i, ret;
1100
1101         ds_set_qp_size(rs);
1102
1103         rs->sbuf = calloc(rs->sq_size, RS_SNDLOWAT);
1104         if (!rs->sbuf)
1105                 return ERR(ENOMEM);
1106
1107         rs->dmsg = calloc(rs->rq_size + 1, sizeof(*rs->dmsg));
1108         if (!rs->dmsg)
1109                 return ERR(ENOMEM);
1110
1111         rs->sqe_avail = rs->sq_size;
1112         rs->rqe_avail = rs->rq_size;
1113
1114         rs->smsg_free = (struct ds_smsg *) rs->sbuf;
1115         msg = rs->smsg_free;
1116         for (i = 0; i < rs->sq_size - 1; i++) {
1117                 msg->next = (void *) msg + RS_SNDLOWAT;
1118                 msg = msg->next;
1119         }
1120         msg->next = NULL;
1121
1122         ret = rs_notify_svc(&udp_svc, rs, RS_SVC_ADD_DGRAM);
1123         if (ret)
1124                 return ret;
1125
1126         rs->state = rs_readable | rs_writable;
1127         return 0;
1128 }
1129
1130 int rsocket(int domain, int type, int protocol)
1131 {
1132         struct rsocket *rs;
1133         int index, ret;
1134
1135         if ((domain != AF_INET && domain != AF_INET6 && domain != AF_IB) ||
1136             ((type != SOCK_STREAM) && (type != SOCK_DGRAM)) ||
1137             (type == SOCK_STREAM && protocol && protocol != IPPROTO_TCP) ||
1138             (type == SOCK_DGRAM && protocol && protocol != IPPROTO_UDP))
1139                 return ERR(ENOTSUP);
1140
1141         rs_configure();
1142         rs = rs_alloc(NULL, type);
1143         if (!rs)
1144                 return ERR(ENOMEM);
1145
1146         if (type == SOCK_STREAM) {
1147                 ret = rdma_create_id(NULL, &rs->cm_id, rs, RDMA_PS_TCP);
1148                 if (ret)
1149                         goto err;
1150
1151                 rs->cm_id->route.addr.src_addr.sa_family = domain;
1152                 index = rs->cm_id->channel->fd;
1153         } else {
1154                 ret = ds_init(rs, domain);
1155                 if (ret)
1156                         goto err;
1157
1158                 index = rs->udp_sock;
1159         }
1160
1161         ret = rs_insert(rs, index);
1162         if (ret < 0)
1163                 goto err;
1164
1165         return rs->index;
1166
1167 err:
1168         rs_free(rs);
1169         return ret;
1170 }
1171
1172 int rbind(int socket, const struct sockaddr *addr, socklen_t addrlen)
1173 {
1174         struct rsocket *rs;
1175         int ret;
1176
1177         rs = idm_lookup(&idm, socket);
1178         if (!rs)
1179                 return ERR(EBADF);
1180         if (rs->type == SOCK_STREAM) {
1181                 ret = rdma_bind_addr(rs->cm_id, (struct sockaddr *) addr);
1182                 if (!ret)
1183                         rs->state = rs_bound;
1184         } else {
1185                 if (rs->state == rs_init) {
1186                         ret = ds_init_ep(rs);
1187                         if (ret)
1188                                 return ret;
1189                 }
1190                 ret = bind(rs->udp_sock, addr, addrlen);
1191         }
1192         return ret;
1193 }
1194
1195 int rlisten(int socket, int backlog)
1196 {
1197         struct rsocket *rs;
1198         int ret;
1199
1200         rs = idm_lookup(&idm, socket);
1201         if (!rs)
1202                 return ERR(EBADF);
1203
1204         if (rs->state != rs_listening) {
1205                 ret = rdma_listen(rs->cm_id, backlog);
1206                 if (!ret)
1207                         rs->state = rs_listening;
1208         } else {
1209                 ret = 0;
1210         }
1211         return ret;
1212 }
1213
1214 /*
1215  * Nonblocking is usually not inherited between sockets, but we need to
1216  * inherit it here to establish the connection only.  This is needed to
1217  * prevent rdma_accept from blocking until the remote side finishes
1218  * establishing the connection.  If we were to allow rdma_accept to block,
1219  * then a single thread cannot establish a connection with itself, or
1220  * two threads which try to connect to each other can deadlock trying to
1221  * form a connection.
1222  *
1223  * Data transfers on the new socket remain blocking unless the user
1224  * specifies otherwise through rfcntl.
1225  */
1226 int raccept(int socket, struct sockaddr *addr, socklen_t *addrlen)
1227 {
1228         struct rsocket *rs, *new_rs;
1229         struct rdma_conn_param param;
1230         struct rs_conn_data *creq, cresp;
1231         int ret;
1232
1233         rs = idm_lookup(&idm, socket);
1234         if (!rs)
1235                 return ERR(EBADF);
1236         new_rs = rs_alloc(rs, rs->type);
1237         if (!new_rs)
1238                 return ERR(ENOMEM);
1239
1240         ret = rdma_get_request(rs->cm_id, &new_rs->cm_id);
1241         if (ret)
1242                 goto err;
1243
1244         ret = rs_insert(new_rs, new_rs->cm_id->channel->fd);
1245         if (ret < 0)
1246                 goto err;
1247
1248         creq = (struct rs_conn_data *)
1249                (new_rs->cm_id->event->param.conn.private_data + rs_conn_data_offset(rs));
1250         if (creq->version != 1) {
1251                 ret = ERR(ENOTSUP);
1252                 goto err;
1253         }
1254
1255         if (rs->fd_flags & O_NONBLOCK)
1256                 fcntl(new_rs->cm_id->channel->fd, F_SETFL, O_NONBLOCK);
1257
1258         ret = rs_create_ep(new_rs);
1259         if (ret)
1260                 goto err;
1261
1262         rs_save_conn_data(new_rs, creq);
1263         param = new_rs->cm_id->event->param.conn;
1264         rs_format_conn_data(new_rs, &cresp);
1265         param.private_data = &cresp;
1266         param.private_data_len = sizeof cresp;
1267         ret = rdma_accept(new_rs->cm_id, &param);
1268         if (!ret)
1269                 new_rs->state = rs_connect_rdwr;
1270         else if (errno == EAGAIN || errno == EWOULDBLOCK)
1271                 new_rs->state = rs_accepting;
1272         else
1273                 goto err;
1274
1275         if (addr && addrlen)
1276                 rgetpeername(new_rs->index, addr, addrlen);
1277         return new_rs->index;
1278
1279 err:
1280         rs_free(new_rs);
1281         return ret;
1282 }
1283
1284 static int rs_do_connect(struct rsocket *rs)
1285 {
1286         struct rdma_conn_param param;
1287         struct rs_conn_private_data cdata;
1288         struct rs_conn_data *creq, *cresp;
1289         int to, ret;
1290
1291         switch (rs->state) {
1292         case rs_init:
1293         case rs_bound:
1294 resolve_addr:
1295                 to = 1000 << rs->retries++;
1296                 ret = rdma_resolve_addr(rs->cm_id, NULL,
1297                                         &rs->cm_id->route.addr.dst_addr, to);
1298                 if (!ret)
1299                         goto resolve_route;
1300                 if (errno == EAGAIN || errno == EWOULDBLOCK)
1301                         rs->state = rs_resolving_addr;
1302                 break;
1303         case rs_resolving_addr:
1304                 ret = ucma_complete(rs->cm_id);
1305                 if (ret) {
1306                         if (errno == ETIMEDOUT && rs->retries <= RS_CONN_RETRIES)
1307                                 goto resolve_addr;
1308                         break;
1309                 }
1310
1311                 rs->retries = 0;
1312 resolve_route:
1313                 to = 1000 << rs->retries++;
1314                 if (rs->optval) {
1315                         ret = rdma_set_option(rs->cm_id,  RDMA_OPTION_IB,
1316                                               RDMA_OPTION_IB_PATH, rs->optval,
1317                                               rs->optlen);
1318                         free(rs->optval);
1319                         rs->optval = NULL;
1320                         if (!ret) {
1321                                 rs->state = rs_resolving_route;
1322                                 goto resolving_route;
1323                         }
1324                 } else {
1325                         ret = rdma_resolve_route(rs->cm_id, to);
1326                         if (!ret)
1327                                 goto do_connect;
1328                 }
1329                 if (errno == EAGAIN || errno == EWOULDBLOCK)
1330                         rs->state = rs_resolving_route;
1331                 break;
1332         case rs_resolving_route:
1333 resolving_route:
1334                 ret = ucma_complete(rs->cm_id);
1335                 if (ret) {
1336                         if (errno == ETIMEDOUT && rs->retries <= RS_CONN_RETRIES)
1337                                 goto resolve_route;
1338                         break;
1339                 }
1340 do_connect:
1341                 ret = rs_create_ep(rs);
1342                 if (ret)
1343                         break;
1344
1345                 memset(&param, 0, sizeof param);
1346                 creq = (void *) &cdata + rs_conn_data_offset(rs);
1347                 rs_format_conn_data(rs, creq);
1348                 param.private_data = (void *) creq - rs_conn_data_offset(rs);
1349                 param.private_data_len = sizeof(*creq) + rs_conn_data_offset(rs);
1350                 param.flow_control = 1;
1351                 param.retry_count = 7;
1352                 param.rnr_retry_count = 7;
1353                 /* work-around: iWarp issues RDMA read during connection */
1354                 if (rs->opts & RS_OPT_MSG_SEND)
1355                         param.initiator_depth = 1;
1356                 rs->retries = 0;
1357
1358                 ret = rdma_connect(rs->cm_id, &param);
1359                 if (!ret)
1360                         goto connected;
1361                 if (errno == EAGAIN || errno == EWOULDBLOCK)
1362                         rs->state = rs_connecting;
1363                 break;
1364         case rs_connecting:
1365                 ret = ucma_complete(rs->cm_id);
1366                 if (ret)
1367                         break;
1368 connected:
1369                 cresp = (struct rs_conn_data *) rs->cm_id->event->param.conn.private_data;
1370                 if (cresp->version != 1) {
1371                         ret = ERR(ENOTSUP);
1372                         break;
1373                 }
1374
1375                 rs_save_conn_data(rs, cresp);
1376                 rs->state = rs_connect_rdwr;
1377                 break;
1378         case rs_accepting:
1379                 if (!(rs->fd_flags & O_NONBLOCK))
1380                         fcntl(rs->cm_id->channel->fd, F_SETFL, 0);
1381
1382                 ret = ucma_complete(rs->cm_id);
1383                 if (ret)
1384                         break;
1385
1386                 rs->state = rs_connect_rdwr;
1387                 break;
1388         default:
1389                 ret = ERR(EINVAL);
1390                 break;
1391         }
1392
1393         if (ret) {
1394                 if (errno == EAGAIN || errno == EWOULDBLOCK) {
1395                         errno = EINPROGRESS;
1396                 } else {
1397                         rs->state = rs_connect_error;
1398                         rs->err = errno;
1399                 }
1400         }
1401         return ret;
1402 }
1403
1404 static int rs_any_addr(const union socket_addr *addr)
1405 {
1406         if (addr->sa.sa_family == AF_INET) {
1407                 return (addr->sin.sin_addr.s_addr == htobe32(INADDR_ANY) ||
1408                         addr->sin.sin_addr.s_addr == htobe32(INADDR_LOOPBACK));
1409         } else {
1410                 return (!memcmp(&addr->sin6.sin6_addr, &in6addr_any, 16) ||
1411                         !memcmp(&addr->sin6.sin6_addr, &in6addr_loopback, 16));
1412         }
1413 }
1414
1415 static int ds_get_src_addr(struct rsocket *rs,
1416                            const struct sockaddr *dest_addr, socklen_t dest_len,
1417                            union socket_addr *src_addr, socklen_t *src_len)
1418 {
1419         int sock, ret;
1420         __be16 port;
1421
1422         *src_len = sizeof(*src_addr);
1423         ret = getsockname(rs->udp_sock, &src_addr->sa, src_len);
1424         if (ret || !rs_any_addr(src_addr))
1425                 return ret;
1426
1427         port = src_addr->sin.sin_port;
1428         sock = socket(dest_addr->sa_family, SOCK_DGRAM, 0);
1429         if (sock < 0)
1430                 return sock;
1431
1432         ret = connect(sock, dest_addr, dest_len);
1433         if (ret)
1434                 goto out;
1435
1436         *src_len = sizeof(*src_addr);
1437         ret = getsockname(sock, &src_addr->sa, src_len);
1438         src_addr->sin.sin_port = port;
1439 out:
1440         close(sock);
1441         return ret;
1442 }
1443
1444 static void ds_format_hdr(struct ds_header *hdr, union socket_addr *addr)
1445 {
1446         if (addr->sa.sa_family == AF_INET) {
1447                 hdr->version = 4;
1448                 hdr->length = DS_IPV4_HDR_LEN;
1449                 hdr->port = addr->sin.sin_port;
1450                 hdr->addr.ipv4 = addr->sin.sin_addr.s_addr;
1451         } else {
1452                 hdr->version = 6;
1453                 hdr->length = DS_IPV6_HDR_LEN;
1454                 hdr->port = addr->sin6.sin6_port;
1455                 hdr->addr.ipv6.flowinfo= addr->sin6.sin6_flowinfo;
1456                 memcpy(&hdr->addr.ipv6.addr, &addr->sin6.sin6_addr, 16);
1457         }
1458 }
1459
1460 static int ds_add_qp_dest(struct ds_qp *qp, union socket_addr *addr,
1461                           socklen_t addrlen)
1462 {
1463         struct ibv_port_attr port_attr;
1464         struct ibv_ah_attr attr;
1465         int ret;
1466
1467         memcpy(&qp->dest.addr, addr, addrlen);
1468         qp->dest.qp = qp;
1469         qp->dest.qpn = qp->cm_id->qp->qp_num;
1470
1471         ret = ibv_query_port(qp->cm_id->verbs, qp->cm_id->port_num, &port_attr);
1472         if (ret)
1473                 return ret;
1474
1475         memset(&attr, 0, sizeof attr);
1476         attr.dlid = port_attr.lid;
1477         attr.port_num = qp->cm_id->port_num;
1478         qp->dest.ah = ibv_create_ah(qp->cm_id->pd, &attr);
1479         if (!qp->dest.ah)
1480                 return ERR(ENOMEM);
1481
1482         tsearch(&qp->dest.addr, &qp->rs->dest_map, ds_compare_addr);
1483         return 0;
1484 }
1485
1486 static int ds_create_qp(struct rsocket *rs, union socket_addr *src_addr,
1487                         socklen_t addrlen, struct ds_qp **new_qp)
1488 {
1489         struct ds_qp *qp;
1490         struct ibv_qp_init_attr qp_attr;
1491         struct epoll_event event;
1492         int i, ret;
1493
1494         qp = calloc(1, sizeof(*qp));
1495         if (!qp)
1496                 return ERR(ENOMEM);
1497
1498         qp->rs = rs;
1499         ret = rdma_create_id(NULL, &qp->cm_id, qp, RDMA_PS_UDP);
1500         if (ret)
1501                 goto err;
1502
1503         ds_format_hdr(&qp->hdr, src_addr);
1504         ret = rdma_bind_addr(qp->cm_id, &src_addr->sa);
1505         if (ret)
1506                 goto err;
1507
1508         ret = ds_init_bufs(qp);
1509         if (ret)
1510                 goto err;
1511
1512         ret = rs_create_cq(rs, qp->cm_id);
1513         if (ret)
1514                 goto err;
1515
1516         memset(&qp_attr, 0, sizeof qp_attr);
1517         qp_attr.qp_context = qp;
1518         qp_attr.send_cq = qp->cm_id->send_cq;
1519         qp_attr.recv_cq = qp->cm_id->recv_cq;
1520         qp_attr.qp_type = IBV_QPT_UD;
1521         qp_attr.sq_sig_all = 1;
1522         qp_attr.cap.max_send_wr = rs->sq_size;
1523         qp_attr.cap.max_recv_wr = rs->rq_size;
1524         qp_attr.cap.max_send_sge = 1;
1525         qp_attr.cap.max_recv_sge = 2;
1526         qp_attr.cap.max_inline_data = rs->sq_inline;
1527         ret = rdma_create_qp(qp->cm_id, NULL, &qp_attr);
1528         if (ret)
1529                 goto err;
1530
1531         rs->sq_inline = qp_attr.cap.max_inline_data;
1532         ret = ds_add_qp_dest(qp, src_addr, addrlen);
1533         if (ret)
1534                 goto err;
1535
1536         event.events = EPOLLIN;
1537         event.data.ptr = qp;
1538         ret = epoll_ctl(rs->epfd,  EPOLL_CTL_ADD,
1539                         qp->cm_id->recv_cq_channel->fd, &event);
1540         if (ret)
1541                 goto err;
1542
1543         for (i = 0; i < rs->rq_size; i++) {
1544                 ret = ds_post_recv(rs, qp, i * RS_SNDLOWAT);
1545                 if (ret)
1546                         goto err;
1547         }
1548
1549         ds_insert_qp(rs, qp);
1550         *new_qp = qp;
1551         return 0;
1552 err:
1553         ds_free_qp(qp);
1554         return ret;
1555 }
1556
1557 static int ds_get_qp(struct rsocket *rs, union socket_addr *src_addr,
1558                      socklen_t addrlen, struct ds_qp **qp)
1559 {
1560         if (rs->qp_list) {
1561                 *qp = rs->qp_list;
1562                 do {
1563                         if (!ds_compare_addr(rdma_get_local_addr((*qp)->cm_id),
1564                                              src_addr))
1565                                 return 0;
1566
1567                         *qp = ds_next_qp(*qp);
1568                 } while (*qp != rs->qp_list);
1569         }
1570
1571         return ds_create_qp(rs, src_addr, addrlen, qp);
1572 }
1573
1574 static int ds_get_dest(struct rsocket *rs, const struct sockaddr *addr,
1575                        socklen_t addrlen, struct ds_dest **dest)
1576 {
1577         union socket_addr src_addr;
1578         socklen_t src_len;
1579         struct ds_qp *qp;
1580         struct ds_dest **tdest, *new_dest;
1581         int ret = 0;
1582
1583         fastlock_acquire(&rs->map_lock);
1584         tdest = tfind(addr, &rs->dest_map, ds_compare_addr);
1585         if (tdest)
1586                 goto found;
1587
1588         ret = ds_get_src_addr(rs, addr, addrlen, &src_addr, &src_len);
1589         if (ret)
1590                 goto out;
1591
1592         ret = ds_get_qp(rs, &src_addr, src_len, &qp);
1593         if (ret)
1594                 goto out;
1595
1596         tdest = tfind(addr, &rs->dest_map, ds_compare_addr);
1597         if (!tdest) {
1598                 new_dest = calloc(1, sizeof(*new_dest));
1599                 if (!new_dest) {
1600                         ret = ERR(ENOMEM);
1601                         goto out;
1602                 }
1603
1604                 memcpy(&new_dest->addr, addr, addrlen);
1605                 new_dest->qp = qp;
1606                 tdest = tsearch(&new_dest->addr, &rs->dest_map, ds_compare_addr);
1607         }
1608
1609 found:
1610         *dest = *tdest;
1611 out:
1612         fastlock_release(&rs->map_lock);
1613         return ret;
1614 }
1615
1616 int rconnect(int socket, const struct sockaddr *addr, socklen_t addrlen)
1617 {
1618         struct rsocket *rs;
1619         int ret;
1620
1621         rs = idm_lookup(&idm, socket);
1622         if (!rs)
1623                 return ERR(EBADF);
1624         if (rs->type == SOCK_STREAM) {
1625                 memcpy(&rs->cm_id->route.addr.dst_addr, addr, addrlen);
1626                 ret = rs_do_connect(rs);
1627         } else {
1628                 if (rs->state == rs_init) {
1629                         ret = ds_init_ep(rs);
1630                         if (ret)
1631                                 return ret;
1632                 }
1633
1634                 fastlock_acquire(&rs->slock);
1635                 ret = connect(rs->udp_sock, addr, addrlen);
1636                 if (!ret)
1637                         ret = ds_get_dest(rs, addr, addrlen, &rs->conn_dest);
1638                 fastlock_release(&rs->slock);
1639         }
1640         return ret;
1641 }
1642
1643 static void *rs_get_ctrl_buf(struct rsocket *rs)
1644 {
1645         return rs->sbuf + rs->sbuf_size +
1646                 RS_MAX_CTRL_MSG * (rs->ctrl_seqno & (RS_QP_CTRL_SIZE - 1));
1647 }
1648
1649 static int rs_post_msg(struct rsocket *rs, uint32_t msg)
1650 {
1651         struct ibv_send_wr wr, *bad;
1652         struct ibv_sge sge;
1653
1654         wr.wr_id = rs_send_wr_id(msg);
1655         wr.next = NULL;
1656         if (!(rs->opts & RS_OPT_MSG_SEND)) {
1657                 wr.sg_list = NULL;
1658                 wr.num_sge = 0;
1659                 wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
1660                 wr.send_flags = 0;
1661                 wr.imm_data = htobe32(msg);
1662         } else {
1663                 sge.addr = (uintptr_t) &msg;
1664                 sge.lkey = 0;
1665                 sge.length = sizeof msg;
1666                 wr.sg_list = &sge;
1667                 wr.num_sge = 1;
1668                 wr.opcode = IBV_WR_SEND;
1669                 wr.send_flags = IBV_SEND_INLINE;
1670         }
1671
1672         return rdma_seterrno(ibv_post_send(rs->cm_id->qp, &wr, &bad));
1673 }
1674
1675 static int rs_post_write(struct rsocket *rs,
1676                          struct ibv_sge *sgl, int nsge,
1677                          uint32_t wr_data, int flags,
1678                          uint64_t addr, uint32_t rkey)
1679 {
1680         struct ibv_send_wr wr, *bad;
1681
1682         wr.wr_id = rs_send_wr_id(wr_data);
1683         wr.next = NULL;
1684         wr.sg_list = sgl;
1685         wr.num_sge = nsge;
1686         wr.opcode = IBV_WR_RDMA_WRITE;
1687         wr.send_flags = flags;
1688         wr.wr.rdma.remote_addr = addr;
1689         wr.wr.rdma.rkey = rkey;
1690
1691         return rdma_seterrno(ibv_post_send(rs->cm_id->qp, &wr, &bad));
1692 }
1693
1694 static int rs_post_write_msg(struct rsocket *rs,
1695                          struct ibv_sge *sgl, int nsge,
1696                          uint32_t msg, int flags,
1697                          uint64_t addr, uint32_t rkey)
1698 {
1699         struct ibv_send_wr wr, *bad;
1700         struct ibv_sge sge;
1701         int ret;
1702
1703         wr.next = NULL;
1704         if (!(rs->opts & RS_OPT_MSG_SEND)) {
1705                 wr.wr_id = rs_send_wr_id(msg);
1706                 wr.sg_list = sgl;
1707                 wr.num_sge = nsge;
1708                 wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
1709                 wr.send_flags = flags;
1710                 wr.imm_data = htobe32(msg);
1711                 wr.wr.rdma.remote_addr = addr;
1712                 wr.wr.rdma.rkey = rkey;
1713
1714                 return rdma_seterrno(ibv_post_send(rs->cm_id->qp, &wr, &bad));
1715         } else {
1716                 ret = rs_post_write(rs, sgl, nsge, msg, flags, addr, rkey);
1717                 if (!ret) {
1718                         wr.wr_id = rs_send_wr_id(rs_msg_set(rs_msg_op(msg), 0)) |
1719                                    RS_WR_ID_FLAG_MSG_SEND;
1720                         sge.addr = (uintptr_t) &msg;
1721                         sge.lkey = 0;
1722                         sge.length = sizeof msg;
1723                         wr.sg_list = &sge;
1724                         wr.num_sge = 1;
1725                         wr.opcode = IBV_WR_SEND;
1726                         wr.send_flags = IBV_SEND_INLINE;
1727
1728                         ret = rdma_seterrno(ibv_post_send(rs->cm_id->qp, &wr, &bad));
1729                 }
1730                 return ret;
1731         }
1732 }
1733
1734 static int ds_post_send(struct rsocket *rs, struct ibv_sge *sge,
1735                         uint32_t wr_data)
1736 {
1737         struct ibv_send_wr wr, *bad;
1738
1739         wr.wr_id = rs_send_wr_id(wr_data);
1740         wr.next = NULL;
1741         wr.sg_list = sge;
1742         wr.num_sge = 1;
1743         wr.opcode = IBV_WR_SEND;
1744         wr.send_flags = (sge->length <= rs->sq_inline) ? IBV_SEND_INLINE : 0;
1745         wr.wr.ud.ah = rs->conn_dest->ah;
1746         wr.wr.ud.remote_qpn = rs->conn_dest->qpn;
1747         wr.wr.ud.remote_qkey = RDMA_UDP_QKEY;
1748
1749         return rdma_seterrno(ibv_post_send(rs->conn_dest->qp->cm_id->qp, &wr, &bad));
1750 }
1751
1752 /*
1753  * Update target SGE before sending data.  Otherwise the remote side may
1754  * update the entry before we do.
1755  */
1756 static int rs_write_data(struct rsocket *rs,
1757                          struct ibv_sge *sgl, int nsge,
1758                          uint32_t length, int flags)
1759 {
1760         uint64_t addr;
1761         uint32_t rkey;
1762
1763         rs->sseq_no++;
1764         rs->sqe_avail--;
1765         if (rs->opts & RS_OPT_MSG_SEND)
1766                 rs->sqe_avail--;
1767         rs->sbuf_bytes_avail -= length;
1768
1769         addr = rs->target_sgl[rs->target_sge].addr;
1770         rkey = rs->target_sgl[rs->target_sge].key;
1771
1772         rs->target_sgl[rs->target_sge].addr += length;
1773         rs->target_sgl[rs->target_sge].length -= length;
1774
1775         if (!rs->target_sgl[rs->target_sge].length) {
1776                 if (++rs->target_sge == RS_SGL_SIZE)
1777                         rs->target_sge = 0;
1778         }
1779
1780         return rs_post_write_msg(rs, sgl, nsge, rs_msg_set(RS_OP_DATA, length),
1781                                  flags, addr, rkey);
1782 }
1783
1784 static int rs_write_direct(struct rsocket *rs, struct rs_iomap *iom, uint64_t offset,
1785                            struct ibv_sge *sgl, int nsge, uint32_t length, int flags)
1786 {
1787         uint64_t addr;
1788
1789         rs->sqe_avail--;
1790         rs->sbuf_bytes_avail -= length;
1791
1792         addr = iom->sge.addr + offset - iom->offset;
1793         return rs_post_write(rs, sgl, nsge, rs_msg_set(RS_OP_WRITE, length),
1794                              flags, addr, iom->sge.key);
1795 }
1796
1797 static int rs_write_iomap(struct rsocket *rs, struct rs_iomap_mr *iomr,
1798                           struct ibv_sge *sgl, int nsge, int flags)
1799 {
1800         uint64_t addr;
1801
1802         rs->sseq_no++;
1803         rs->sqe_avail--;
1804         if (rs->opts & RS_OPT_MSG_SEND)
1805                 rs->sqe_avail--;
1806         rs->sbuf_bytes_avail -= sizeof(struct rs_iomap);
1807
1808         addr = rs->remote_iomap.addr + iomr->index * sizeof(struct rs_iomap);
1809         return rs_post_write_msg(rs, sgl, nsge, rs_msg_set(RS_OP_IOMAP_SGL, iomr->index),
1810                                  flags, addr, rs->remote_iomap.key);
1811 }
1812
1813 static uint32_t rs_sbuf_left(struct rsocket *rs)
1814 {
1815         return (uint32_t) (((uint64_t) (uintptr_t) &rs->sbuf[rs->sbuf_size]) -
1816                            rs->ssgl[0].addr);
1817 }
1818
1819 static void rs_send_credits(struct rsocket *rs)
1820 {
1821         struct ibv_sge ibsge;
1822         struct rs_sge sge, *sge_buf;
1823         int flags;
1824
1825         rs->ctrl_seqno++;
1826         rs->rseq_comp = rs->rseq_no + (rs->rq_size >> 1);
1827         if (rs->rbuf_bytes_avail >= (rs->rbuf_size >> 1)) {
1828                 if (rs->opts & RS_OPT_MSG_SEND)
1829                         rs->ctrl_seqno++;
1830
1831                 if (!(rs->opts & RS_OPT_SWAP_SGL)) {
1832                         sge.addr = (uintptr_t) &rs->rbuf[rs->rbuf_free_offset];
1833                         sge.key = rs->rmr->rkey;
1834                         sge.length = rs->rbuf_size >> 1;
1835                 } else {
1836                         sge.addr = bswap_64((uintptr_t) &rs->rbuf[rs->rbuf_free_offset]);
1837                         sge.key = bswap_32(rs->rmr->rkey);
1838                         sge.length = bswap_32(rs->rbuf_size >> 1);
1839                 }
1840
1841                 if (rs->sq_inline < sizeof sge) {
1842                         sge_buf = rs_get_ctrl_buf(rs);
1843                         memcpy(sge_buf, &sge, sizeof sge);
1844                         ibsge.addr = (uintptr_t) sge_buf;
1845                         ibsge.lkey = rs->smr->lkey;
1846                         flags = 0;
1847                 } else {
1848                         ibsge.addr = (uintptr_t) &sge;
1849                         ibsge.lkey = 0;
1850                         flags = IBV_SEND_INLINE;
1851                 }
1852                 ibsge.length = sizeof(sge);
1853
1854                 rs_post_write_msg(rs, &ibsge, 1,
1855                         rs_msg_set(RS_OP_SGL, rs->rseq_no + rs->rq_size), flags,
1856                         rs->remote_sgl.addr + rs->remote_sge * sizeof(struct rs_sge),
1857                         rs->remote_sgl.key);
1858
1859                 rs->rbuf_bytes_avail -= rs->rbuf_size >> 1;
1860                 rs->rbuf_free_offset += rs->rbuf_size >> 1;
1861                 if (rs->rbuf_free_offset >= rs->rbuf_size)
1862                         rs->rbuf_free_offset = 0;
1863                 if (++rs->remote_sge == rs->remote_sgl.length)
1864                         rs->remote_sge = 0;
1865         } else {
1866                 rs_post_msg(rs, rs_msg_set(RS_OP_SGL, rs->rseq_no + rs->rq_size));
1867         }
1868 }
1869
1870 static inline int rs_ctrl_avail(struct rsocket *rs)
1871 {
1872         return rs->ctrl_seqno != rs->ctrl_max_seqno;
1873 }
1874
1875 /* Protocols that do not support RDMA write with immediate may require 2 msgs */
1876 static inline int rs_2ctrl_avail(struct rsocket *rs)
1877 {
1878         return (int)((rs->ctrl_seqno + 1) - rs->ctrl_max_seqno) < 0;
1879 }
1880
1881 static int rs_give_credits(struct rsocket *rs)
1882 {
1883         if (!(rs->opts & RS_OPT_MSG_SEND)) {
1884                 return ((rs->rbuf_bytes_avail >= (rs->rbuf_size >> 1)) ||
1885                         ((short) ((short) rs->rseq_no - (short) rs->rseq_comp) >= 0)) &&
1886                        rs_ctrl_avail(rs) && (rs->state & rs_connected);
1887         } else {
1888                 return ((rs->rbuf_bytes_avail >= (rs->rbuf_size >> 1)) ||
1889                         ((short) ((short) rs->rseq_no - (short) rs->rseq_comp) >= 0)) &&
1890                        rs_2ctrl_avail(rs) && (rs->state & rs_connected);
1891         }
1892 }
1893
1894 static void rs_update_credits(struct rsocket *rs)
1895 {
1896         if (rs_give_credits(rs))
1897                 rs_send_credits(rs);
1898 }
1899
1900 static int rs_poll_cq(struct rsocket *rs)
1901 {
1902         struct ibv_wc wc;
1903         uint32_t msg;
1904         int ret, rcnt = 0;
1905
1906         while ((ret = ibv_poll_cq(rs->cm_id->recv_cq, 1, &wc)) > 0) {
1907                 if (rs_wr_is_recv(wc.wr_id)) {
1908                         if (wc.status != IBV_WC_SUCCESS)
1909                                 continue;
1910                         rcnt++;
1911
1912                         if (wc.wc_flags & IBV_WC_WITH_IMM) {
1913                                 msg = be32toh(wc.imm_data);
1914                         } else {
1915                                 msg = ((uint32_t *) (rs->rbuf + rs->rbuf_size))
1916                                         [rs_wr_data(wc.wr_id)];
1917
1918                         }
1919                         switch (rs_msg_op(msg)) {
1920                         case RS_OP_SGL:
1921                                 rs->sseq_comp = (uint16_t) rs_msg_data(msg);
1922                                 break;
1923                         case RS_OP_IOMAP_SGL:
1924                                 /* The iomap was updated, that's nice to know. */
1925                                 break;
1926                         case RS_OP_CTRL:
1927                                 if (rs_msg_data(msg) == RS_CTRL_DISCONNECT) {
1928                                         rs->state = rs_disconnected;
1929                                         return 0;
1930                                 } else if (rs_msg_data(msg) == RS_CTRL_SHUTDOWN) {
1931                                         if (rs->state & rs_writable) {
1932                                                 rs->state &= ~rs_readable;
1933                                         } else {
1934                                                 rs->state = rs_disconnected;
1935                                                 return 0;
1936                                         }
1937                                 }
1938                                 break;
1939                         case RS_OP_WRITE:
1940                                 /* We really shouldn't be here. */
1941                                 break;
1942                         default:
1943                                 rs->rmsg[rs->rmsg_tail].op = rs_msg_op(msg);
1944                                 rs->rmsg[rs->rmsg_tail].data = rs_msg_data(msg);
1945                                 if (++rs->rmsg_tail == rs->rq_size + 1)
1946                                         rs->rmsg_tail = 0;
1947                                 break;
1948                         }
1949                 } else {
1950                         switch  (rs_msg_op(rs_wr_data(wc.wr_id))) {
1951                         case RS_OP_SGL:
1952                                 rs->ctrl_max_seqno++;
1953                                 break;
1954                         case RS_OP_CTRL:
1955                                 rs->ctrl_max_seqno++;
1956                                 if (rs_msg_data(rs_wr_data(wc.wr_id)) == RS_CTRL_DISCONNECT)
1957                                         rs->state = rs_disconnected;
1958                                 break;
1959                         case RS_OP_IOMAP_SGL:
1960                                 rs->sqe_avail++;
1961                                 if (!rs_wr_is_msg_send(wc.wr_id))
1962                                         rs->sbuf_bytes_avail += sizeof(struct rs_iomap);
1963                                 break;
1964                         default:
1965                                 rs->sqe_avail++;
1966                                 rs->sbuf_bytes_avail += rs_msg_data(rs_wr_data(wc.wr_id));
1967                                 break;
1968                         }
1969                         if (wc.status != IBV_WC_SUCCESS && (rs->state & rs_connected)) {
1970                                 rs->state = rs_error;
1971                                 rs->err = EIO;
1972                         }
1973                 }
1974         }
1975
1976         if (rs->state & rs_connected) {
1977                 while (!ret && rcnt--)
1978                         ret = rs_post_recv(rs);
1979
1980                 if (ret) {
1981                         rs->state = rs_error;
1982                         rs->err = errno;
1983                 }
1984         }
1985         return ret;
1986 }
1987
1988 static int rs_get_cq_event(struct rsocket *rs)
1989 {
1990         struct ibv_cq *cq;
1991         void *context;
1992         int ret;
1993
1994         if (!rs->cq_armed)
1995                 return 0;
1996
1997         ret = ibv_get_cq_event(rs->cm_id->recv_cq_channel, &cq, &context);
1998         if (!ret) {
1999                 if (++rs->unack_cqe >= rs->sq_size + rs->rq_size) {
2000                         ibv_ack_cq_events(rs->cm_id->recv_cq, rs->unack_cqe);
2001                         rs->unack_cqe = 0;
2002                 }
2003                 rs->cq_armed = 0;
2004         } else if (!(errno == EAGAIN || errno == EINTR)) {
2005                 rs->state = rs_error;
2006         }
2007
2008         return ret;
2009 }
2010
2011 /*
2012  * Although we serialize rsend and rrecv calls with respect to themselves,
2013  * both calls may run simultaneously and need to poll the CQ for completions.
2014  * We need to serialize access to the CQ, but rsend and rrecv need to
2015  * allow each other to make forward progress.
2016  *
2017  * For example, rsend may need to wait for credits from the remote side,
2018  * which could be stalled until the remote process calls rrecv.  This should
2019  * not block rrecv from receiving data from the remote side however.
2020  *
2021  * We handle this by using two locks.  The cq_lock protects against polling
2022  * the CQ and processing completions.  The cq_wait_lock serializes access to
2023  * waiting on the CQ.
2024  */
2025 static int rs_process_cq(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs))
2026 {
2027         int ret;
2028
2029         fastlock_acquire(&rs->cq_lock);
2030         do {
2031                 rs_update_credits(rs);
2032                 ret = rs_poll_cq(rs);
2033                 if (test(rs)) {
2034                         ret = 0;
2035                         break;
2036                 } else if (ret) {
2037                         break;
2038                 } else if (nonblock) {
2039                         ret = ERR(EWOULDBLOCK);
2040                 } else if (!rs->cq_armed) {
2041                         ibv_req_notify_cq(rs->cm_id->recv_cq, 0);
2042                         rs->cq_armed = 1;
2043                 } else {
2044                         rs_update_credits(rs);
2045                         fastlock_acquire(&rs->cq_wait_lock);
2046                         fastlock_release(&rs->cq_lock);
2047
2048                         ret = rs_get_cq_event(rs);
2049                         fastlock_release(&rs->cq_wait_lock);
2050                         fastlock_acquire(&rs->cq_lock);
2051                 }
2052         } while (!ret);
2053
2054         rs_update_credits(rs);
2055         fastlock_release(&rs->cq_lock);
2056         return ret;
2057 }
2058
2059 static int rs_get_comp(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs))
2060 {
2061         struct timeval s, e;
2062         uint32_t poll_time = 0;
2063         int ret;
2064
2065         do {
2066                 ret = rs_process_cq(rs, 1, test);
2067                 if (!ret || nonblock || errno != EWOULDBLOCK)
2068                         return ret;
2069
2070                 if (!poll_time)
2071                         gettimeofday(&s, NULL);
2072
2073                 gettimeofday(&e, NULL);
2074                 poll_time = (e.tv_sec - s.tv_sec) * 1000000 +
2075                             (e.tv_usec - s.tv_usec) + 1;
2076         } while (poll_time <= polling_time);
2077
2078         ret = rs_process_cq(rs, 0, test);
2079         return ret;
2080 }
2081
2082 static int ds_valid_recv(struct ds_qp *qp, struct ibv_wc *wc)
2083 {
2084         struct ds_header *hdr;
2085
2086         hdr = (struct ds_header *) (qp->rbuf + rs_wr_data(wc->wr_id));
2087         return ((wc->byte_len >= sizeof(struct ibv_grh) + DS_IPV4_HDR_LEN) &&
2088                 ((hdr->version == 4 && hdr->length == DS_IPV4_HDR_LEN) ||
2089                  (hdr->version == 6 && hdr->length == DS_IPV6_HDR_LEN)));
2090 }
2091
2092 /*
2093  * Poll all CQs associated with a datagram rsocket.  We need to drop any
2094  * received messages that we do not have room to store.  To limit drops,
2095  * we only poll if we have room to store the receive or we need a send
2096  * buffer.  To ensure fairness, we poll the CQs round robin, remembering
2097  * where we left off.
2098  */
2099 static void ds_poll_cqs(struct rsocket *rs)
2100 {
2101         struct ds_qp *qp;
2102         struct ds_smsg *smsg;
2103         struct ds_rmsg *rmsg;
2104         struct ibv_wc wc;
2105         int ret, cnt;
2106
2107         if (!(qp = rs->qp_list))
2108                 return;
2109
2110         do {
2111                 cnt = 0;
2112                 do {
2113                         ret = ibv_poll_cq(qp->cm_id->recv_cq, 1, &wc);
2114                         if (ret <= 0) {
2115                                 qp = ds_next_qp(qp);
2116                                 continue;
2117                         }
2118
2119                         if (rs_wr_is_recv(wc.wr_id)) {
2120                                 if (rs->rqe_avail && wc.status == IBV_WC_SUCCESS &&
2121                                     ds_valid_recv(qp, &wc)) {
2122                                         rs->rqe_avail--;
2123                                         rmsg = &rs->dmsg[rs->rmsg_tail];
2124                                         rmsg->qp = qp;
2125                                         rmsg->offset = rs_wr_data(wc.wr_id);
2126                                         rmsg->length = wc.byte_len - sizeof(struct ibv_grh);
2127                                         if (++rs->rmsg_tail == rs->rq_size + 1)
2128                                                 rs->rmsg_tail = 0;
2129                                 } else {
2130                                         ds_post_recv(rs, qp, rs_wr_data(wc.wr_id));
2131                                 }
2132                         } else {
2133                                 smsg = (struct ds_smsg *) (rs->sbuf + rs_wr_data(wc.wr_id));
2134                                 smsg->next = rs->smsg_free;
2135                                 rs->smsg_free = smsg;
2136                                 rs->sqe_avail++;
2137                         }
2138
2139                         qp = ds_next_qp(qp);
2140                         if (!rs->rqe_avail && rs->sqe_avail) {
2141                                 rs->qp_list = qp;
2142                                 return;
2143                         }
2144                         cnt++;
2145                 } while (qp != rs->qp_list);
2146         } while (cnt);
2147 }
2148
2149 static void ds_req_notify_cqs(struct rsocket *rs)
2150 {
2151         struct ds_qp *qp;
2152
2153         if (!(qp = rs->qp_list))
2154                 return;
2155
2156         do {
2157                 if (!qp->cq_armed) {
2158                         ibv_req_notify_cq(qp->cm_id->recv_cq, 0);
2159                         qp->cq_armed = 1;
2160                 }
2161                 qp = ds_next_qp(qp);
2162         } while (qp != rs->qp_list);
2163 }
2164
2165 static int ds_get_cq_event(struct rsocket *rs)
2166 {
2167         struct epoll_event event;
2168         struct ds_qp *qp;
2169         struct ibv_cq *cq;
2170         void *context;
2171         int ret;
2172
2173         if (!rs->cq_armed)
2174                 return 0;
2175
2176         ret = epoll_wait(rs->epfd, &event, 1, -1);
2177         if (ret <= 0)
2178                 return ret;
2179
2180         qp = event.data.ptr;
2181         ret = ibv_get_cq_event(qp->cm_id->recv_cq_channel, &cq, &context);
2182         if (!ret) {
2183                 ibv_ack_cq_events(qp->cm_id->recv_cq, 1);
2184                 qp->cq_armed = 0;
2185                 rs->cq_armed = 0;
2186         }
2187
2188         return ret;
2189 }
2190
2191 static int ds_process_cqs(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs))
2192 {
2193         int ret = 0;
2194
2195         fastlock_acquire(&rs->cq_lock);
2196         do {
2197                 ds_poll_cqs(rs);
2198                 if (test(rs)) {
2199                         ret = 0;
2200                         break;
2201                 } else if (nonblock) {
2202                         ret = ERR(EWOULDBLOCK);
2203                 } else if (!rs->cq_armed) {
2204                         ds_req_notify_cqs(rs);
2205                         rs->cq_armed = 1;
2206                 } else {
2207                         fastlock_acquire(&rs->cq_wait_lock);
2208                         fastlock_release(&rs->cq_lock);
2209
2210                         ret = ds_get_cq_event(rs);
2211                         fastlock_release(&rs->cq_wait_lock);
2212                         fastlock_acquire(&rs->cq_lock);
2213                 }
2214         } while (!ret);
2215
2216         fastlock_release(&rs->cq_lock);
2217         return ret;
2218 }
2219
2220 static int ds_get_comp(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs))
2221 {
2222         struct timeval s, e;
2223         uint32_t poll_time = 0;
2224         int ret;
2225
2226         do {
2227                 ret = ds_process_cqs(rs, 1, test);
2228                 if (!ret || nonblock || errno != EWOULDBLOCK)
2229                         return ret;
2230
2231                 if (!poll_time)
2232                         gettimeofday(&s, NULL);
2233
2234                 gettimeofday(&e, NULL);
2235                 poll_time = (e.tv_sec - s.tv_sec) * 1000000 +
2236                             (e.tv_usec - s.tv_usec) + 1;
2237         } while (poll_time <= polling_time);
2238
2239         ret = ds_process_cqs(rs, 0, test);
2240         return ret;
2241 }
2242
2243 static int rs_nonblocking(struct rsocket *rs, int flags)
2244 {
2245         return (rs->fd_flags & O_NONBLOCK) || (flags & MSG_DONTWAIT);
2246 }
2247
2248 static int rs_is_cq_armed(struct rsocket *rs)
2249 {
2250         return rs->cq_armed;
2251 }
2252
2253 static int rs_poll_all(struct rsocket *rs)
2254 {
2255         return 1;
2256 }
2257
2258 /*
2259  * We use hardware flow control to prevent over running the remote
2260  * receive queue.  However, data transfers still require space in
2261  * the remote rmsg queue, or we risk losing notification that data
2262  * has been transfered.
2263  *
2264  * Be careful with race conditions in the check below.  The target SGL
2265  * may be updated by a remote RDMA write.
2266  */
2267 static int rs_can_send(struct rsocket *rs)
2268 {
2269         if (!(rs->opts & RS_OPT_MSG_SEND)) {
2270                 return rs->sqe_avail && (rs->sbuf_bytes_avail >= RS_SNDLOWAT) &&
2271                        (rs->sseq_no != rs->sseq_comp) &&
2272                        (rs->target_sgl[rs->target_sge].length != 0);
2273         } else {
2274                 return (rs->sqe_avail >= 2) && (rs->sbuf_bytes_avail >= RS_SNDLOWAT) &&
2275                        (rs->sseq_no != rs->sseq_comp) &&
2276                        (rs->target_sgl[rs->target_sge].length != 0);
2277         }
2278 }
2279
2280 static int ds_can_send(struct rsocket *rs)
2281 {
2282         return rs->sqe_avail;
2283 }
2284
2285 static int ds_all_sends_done(struct rsocket *rs)
2286 {
2287         return rs->sqe_avail == rs->sq_size;
2288 }
2289
2290 static int rs_conn_can_send(struct rsocket *rs)
2291 {
2292         return rs_can_send(rs) || !(rs->state & rs_writable);
2293 }
2294
2295 static int rs_conn_can_send_ctrl(struct rsocket *rs)
2296 {
2297         return rs_ctrl_avail(rs) || !(rs->state & rs_connected);
2298 }
2299
2300 static int rs_have_rdata(struct rsocket *rs)
2301 {
2302         return (rs->rmsg_head != rs->rmsg_tail);
2303 }
2304
2305 static int rs_conn_have_rdata(struct rsocket *rs)
2306 {
2307         return rs_have_rdata(rs) || !(rs->state & rs_readable);
2308 }
2309
2310 static int rs_conn_all_sends_done(struct rsocket *rs)
2311 {
2312         return ((((int) rs->ctrl_max_seqno) - ((int) rs->ctrl_seqno)) +
2313                 rs->sqe_avail == rs->sq_size) ||
2314                !(rs->state & rs_connected);
2315 }
2316
2317 static void ds_set_src(struct sockaddr *addr, socklen_t *addrlen,
2318                        struct ds_header *hdr)
2319 {
2320         union socket_addr sa;
2321
2322         memset(&sa, 0, sizeof sa);
2323         if (hdr->version == 4) {
2324                 if (*addrlen > sizeof(sa.sin))
2325                         *addrlen = sizeof(sa.sin);
2326
2327                 sa.sin.sin_family = AF_INET;
2328                 sa.sin.sin_port = hdr->port;
2329                 sa.sin.sin_addr.s_addr =  hdr->addr.ipv4;
2330         } else {
2331                 if (*addrlen > sizeof(sa.sin6))
2332                         *addrlen = sizeof(sa.sin6);
2333
2334                 sa.sin6.sin6_family = AF_INET6;
2335                 sa.sin6.sin6_port = hdr->port;
2336                 sa.sin6.sin6_flowinfo = hdr->addr.ipv6.flowinfo;
2337                 memcpy(&sa.sin6.sin6_addr, &hdr->addr.ipv6.addr, 16);
2338         }
2339         memcpy(addr, &sa, *addrlen);
2340 }
2341
2342 static ssize_t ds_recvfrom(struct rsocket *rs, void *buf, size_t len, int flags,
2343                            struct sockaddr *src_addr, socklen_t *addrlen)
2344 {
2345         struct ds_rmsg *rmsg;
2346         struct ds_header *hdr;
2347         int ret;
2348
2349         if (!(rs->state & rs_readable))
2350                 return ERR(EINVAL);
2351
2352         if (!rs_have_rdata(rs)) {
2353                 ret = ds_get_comp(rs, rs_nonblocking(rs, flags),
2354                                   rs_have_rdata);
2355                 if (ret)
2356                         return ret;
2357         }
2358
2359         rmsg = &rs->dmsg[rs->rmsg_head];
2360         hdr = (struct ds_header *) (rmsg->qp->rbuf + rmsg->offset);
2361         if (len > rmsg->length - hdr->length)
2362                 len = rmsg->length - hdr->length;
2363
2364         memcpy(buf, (void *) hdr + hdr->length, len);
2365         if (addrlen)
2366                 ds_set_src(src_addr, addrlen, hdr);
2367
2368         if (!(flags & MSG_PEEK)) {
2369                 ds_post_recv(rs, rmsg->qp, rmsg->offset);
2370                 if (++rs->rmsg_head == rs->rq_size + 1)
2371                         rs->rmsg_head = 0;
2372                 rs->rqe_avail++;
2373         }
2374
2375         return len;
2376 }
2377
2378 static ssize_t rs_peek(struct rsocket *rs, void *buf, size_t len)
2379 {
2380         size_t left = len;
2381         uint32_t end_size, rsize;
2382         int rmsg_head, rbuf_offset;
2383
2384         rmsg_head = rs->rmsg_head;
2385         rbuf_offset = rs->rbuf_offset;
2386
2387         for (; left && (rmsg_head != rs->rmsg_tail); left -= rsize) {
2388                 if (left < rs->rmsg[rmsg_head].data) {
2389                         rsize = left;
2390                 } else {
2391                         rsize = rs->rmsg[rmsg_head].data;
2392                         if (++rmsg_head == rs->rq_size + 1)
2393                                 rmsg_head = 0;
2394                 }
2395
2396                 end_size = rs->rbuf_size - rbuf_offset;
2397                 if (rsize > end_size) {
2398                         memcpy(buf, &rs->rbuf[rbuf_offset], end_size);
2399                         rbuf_offset = 0;
2400                         buf += end_size;
2401                         rsize -= end_size;
2402                         left -= end_size;
2403                 }
2404                 memcpy(buf, &rs->rbuf[rbuf_offset], rsize);
2405                 rbuf_offset += rsize;
2406                 buf += rsize;
2407         }
2408
2409         return len - left;
2410 }
2411
2412 /*
2413  * Continue to receive any queued data even if the remote side has disconnected.
2414  */
2415 ssize_t rrecv(int socket, void *buf, size_t len, int flags)
2416 {
2417         struct rsocket *rs;
2418         size_t left = len;
2419         uint32_t end_size, rsize;
2420         int ret = 0;
2421
2422         rs = idm_at(&idm, socket);
2423         if (rs->type == SOCK_DGRAM) {
2424                 fastlock_acquire(&rs->rlock);
2425                 ret = ds_recvfrom(rs, buf, len, flags, NULL, NULL);
2426                 fastlock_release(&rs->rlock);
2427                 return ret;
2428         }
2429
2430         if (rs->state & rs_opening) {
2431                 ret = rs_do_connect(rs);
2432                 if (ret) {
2433                         if (errno == EINPROGRESS)
2434                                 errno = EAGAIN;
2435                         return ret;
2436                 }
2437         }
2438         fastlock_acquire(&rs->rlock);
2439         do {
2440                 if (!rs_have_rdata(rs)) {
2441                         ret = rs_get_comp(rs, rs_nonblocking(rs, flags),
2442                                           rs_conn_have_rdata);
2443                         if (ret)
2444                                 break;
2445                 }
2446
2447                 if (flags & MSG_PEEK) {
2448                         left = len - rs_peek(rs, buf, left);
2449                         break;
2450                 }
2451
2452                 for (; left && rs_have_rdata(rs); left -= rsize) {
2453                         if (left < rs->rmsg[rs->rmsg_head].data) {
2454                                 rsize = left;
2455                                 rs->rmsg[rs->rmsg_head].data -= left;
2456                         } else {
2457                                 rs->rseq_no++;
2458                                 rsize = rs->rmsg[rs->rmsg_head].data;
2459                                 if (++rs->rmsg_head == rs->rq_size + 1)
2460                                         rs->rmsg_head = 0;
2461                         }
2462
2463                         end_size = rs->rbuf_size - rs->rbuf_offset;
2464                         if (rsize > end_size) {
2465                                 memcpy(buf, &rs->rbuf[rs->rbuf_offset], end_size);
2466                                 rs->rbuf_offset = 0;
2467                                 buf += end_size;
2468                                 rsize -= end_size;
2469                                 left -= end_size;
2470                                 rs->rbuf_bytes_avail += end_size;
2471                         }
2472                         memcpy(buf, &rs->rbuf[rs->rbuf_offset], rsize);
2473                         rs->rbuf_offset += rsize;
2474                         buf += rsize;
2475                         rs->rbuf_bytes_avail += rsize;
2476                 }
2477
2478         } while (left && (flags & MSG_WAITALL) && (rs->state & rs_readable));
2479
2480         fastlock_release(&rs->rlock);
2481         return (ret && left == len) ? ret : len - left;
2482 }
2483
2484 ssize_t rrecvfrom(int socket, void *buf, size_t len, int flags,
2485                   struct sockaddr *src_addr, socklen_t *addrlen)
2486 {
2487         struct rsocket *rs;
2488         int ret;
2489
2490         rs = idm_at(&idm, socket);
2491         if (rs->type == SOCK_DGRAM) {
2492                 fastlock_acquire(&rs->rlock);
2493                 ret = ds_recvfrom(rs, buf, len, flags, src_addr, addrlen);
2494                 fastlock_release(&rs->rlock);
2495                 return ret;
2496         }
2497
2498         ret = rrecv(socket, buf, len, flags);
2499         if (ret > 0 && src_addr)
2500                 rgetpeername(socket, src_addr, addrlen);
2501
2502         return ret;
2503 }
2504
2505 /*
2506  * Simple, straightforward implementation for now that only tries to fill
2507  * in the first vector.
2508  */
2509 static ssize_t rrecvv(int socket, const struct iovec *iov, int iovcnt, int flags)
2510 {
2511         return rrecv(socket, iov[0].iov_base, iov[0].iov_len, flags);
2512 }
2513
2514 ssize_t rrecvmsg(int socket, struct msghdr *msg, int flags)
2515 {
2516         if (msg->msg_control && msg->msg_controllen)
2517                 return ERR(ENOTSUP);
2518
2519         return rrecvv(socket, msg->msg_iov, (int) msg->msg_iovlen, msg->msg_flags);
2520 }
2521
2522 ssize_t rread(int socket, void *buf, size_t count)
2523 {
2524         return rrecv(socket, buf, count, 0);
2525 }
2526
2527 ssize_t rreadv(int socket, const struct iovec *iov, int iovcnt)
2528 {
2529         return rrecvv(socket, iov, iovcnt, 0);
2530 }
2531
2532 static int rs_send_iomaps(struct rsocket *rs, int flags)
2533 {
2534         struct rs_iomap_mr *iomr;
2535         struct ibv_sge sge;
2536         struct rs_iomap iom;
2537         int ret;
2538
2539         fastlock_acquire(&rs->map_lock);
2540         while (!dlist_empty(&rs->iomap_queue)) {
2541                 if (!rs_can_send(rs)) {
2542                         ret = rs_get_comp(rs, rs_nonblocking(rs, flags),
2543                                           rs_conn_can_send);
2544                         if (ret)
2545                                 break;
2546                         if (!(rs->state & rs_writable)) {
2547                                 ret = ERR(ECONNRESET);
2548                                 break;
2549                         }
2550                 }
2551
2552                 iomr = container_of(rs->iomap_queue.next, struct rs_iomap_mr, entry);
2553                 if (!(rs->opts & RS_OPT_SWAP_SGL)) {
2554                         iom.offset = iomr->offset;
2555                         iom.sge.addr = (uintptr_t) iomr->mr->addr;
2556                         iom.sge.length = iomr->mr->length;
2557                         iom.sge.key = iomr->mr->rkey;
2558                 } else {
2559                         iom.offset = bswap_64(iomr->offset);
2560                         iom.sge.addr = bswap_64((uintptr_t) iomr->mr->addr);
2561                         iom.sge.length = bswap_32(iomr->mr->length);
2562                         iom.sge.key = bswap_32(iomr->mr->rkey);
2563                 }
2564
2565                 if (rs->sq_inline >= sizeof iom) {
2566                         sge.addr = (uintptr_t) &iom;
2567                         sge.length = sizeof iom;
2568                         sge.lkey = 0;
2569                         ret = rs_write_iomap(rs, iomr, &sge, 1, IBV_SEND_INLINE);
2570                 } else if (rs_sbuf_left(rs) >= sizeof iom) {
2571                         memcpy((void *) (uintptr_t) rs->ssgl[0].addr, &iom, sizeof iom);
2572                         rs->ssgl[0].length = sizeof iom;
2573                         ret = rs_write_iomap(rs, iomr, rs->ssgl, 1, 0);
2574                         if (rs_sbuf_left(rs) > sizeof iom)
2575                                 rs->ssgl[0].addr += sizeof iom;
2576                         else
2577                                 rs->ssgl[0].addr = (uintptr_t) rs->sbuf;
2578                 } else {
2579                         rs->ssgl[0].length = rs_sbuf_left(rs);
2580                         memcpy((void *) (uintptr_t) rs->ssgl[0].addr, &iom,
2581                                 rs->ssgl[0].length);
2582                         rs->ssgl[1].length = sizeof iom - rs->ssgl[0].length;
2583                         memcpy(rs->sbuf, ((void *) &iom) + rs->ssgl[0].length,
2584                                rs->ssgl[1].length);
2585                         ret = rs_write_iomap(rs, iomr, rs->ssgl, 2, 0);
2586                         rs->ssgl[0].addr = (uintptr_t) rs->sbuf + rs->ssgl[1].length;
2587                 }
2588                 dlist_remove(&iomr->entry);
2589                 dlist_insert_tail(&iomr->entry, &rs->iomap_list);
2590                 if (ret)
2591                         break;
2592         }
2593
2594         rs->iomap_pending = !dlist_empty(&rs->iomap_queue);
2595         fastlock_release(&rs->map_lock);
2596         return ret;
2597 }
2598
2599 static ssize_t ds_sendv_udp(struct rsocket *rs, const struct iovec *iov,
2600                             int iovcnt, int flags, uint8_t op)
2601 {
2602         struct ds_udp_header hdr;
2603         struct msghdr msg;
2604         struct iovec miov[8];
2605         ssize_t ret;
2606
2607         if (iovcnt > 8)
2608                 return ERR(ENOTSUP);
2609
2610         hdr.tag = htobe32(DS_UDP_TAG);
2611         hdr.version = rs->conn_dest->qp->hdr.version;
2612         hdr.op = op;
2613         hdr.reserved = 0;
2614         hdr.qpn = htobe32(rs->conn_dest->qp->cm_id->qp->qp_num & 0xFFFFFF);
2615         if (rs->conn_dest->qp->hdr.version == 4) {
2616                 hdr.length = DS_UDP_IPV4_HDR_LEN;
2617                 hdr.addr.ipv4 = rs->conn_dest->qp->hdr.addr.ipv4;
2618         } else {
2619                 hdr.length = DS_UDP_IPV6_HDR_LEN;
2620                 memcpy(hdr.addr.ipv6, &rs->conn_dest->qp->hdr.addr.ipv6, 16);
2621         }
2622
2623         miov[0].iov_base = &hdr;
2624         miov[0].iov_len = hdr.length;
2625         if (iov && iovcnt)
2626                 memcpy(&miov[1], iov, sizeof(*iov) * iovcnt);
2627
2628         memset(&msg, 0, sizeof msg);
2629         msg.msg_name = &rs->conn_dest->addr;
2630         msg.msg_namelen = ucma_addrlen(&rs->conn_dest->addr.sa);
2631         msg.msg_iov = miov;
2632         msg.msg_iovlen = iovcnt + 1;
2633         ret = sendmsg(rs->udp_sock, &msg, flags);
2634         return ret > 0 ? ret - hdr.length : ret;
2635 }
2636
2637 static ssize_t ds_send_udp(struct rsocket *rs, const void *buf, size_t len,
2638                            int flags, uint8_t op)
2639 {
2640         struct iovec iov;
2641         if (buf && len) {
2642                 iov.iov_base = (void *) buf;
2643                 iov.iov_len = len;
2644                 return ds_sendv_udp(rs, &iov, 1, flags, op);
2645         } else {
2646                 return ds_sendv_udp(rs, NULL, 0, flags, op);
2647         }
2648 }
2649
2650 static ssize_t dsend(struct rsocket *rs, const void *buf, size_t len, int flags)
2651 {
2652         struct ds_smsg *msg;
2653         struct ibv_sge sge;
2654         uint64_t offset;
2655         int ret = 0;
2656
2657         if (!rs->conn_dest->ah)
2658                 return ds_send_udp(rs, buf, len, flags, RS_OP_DATA);
2659
2660         if (!ds_can_send(rs)) {
2661                 ret = ds_get_comp(rs, rs_nonblocking(rs, flags), ds_can_send);
2662                 if (ret)
2663                         return ret;
2664         }
2665
2666         msg = rs->smsg_free;
2667         rs->smsg_free = msg->next;
2668         rs->sqe_avail--;
2669
2670         memcpy((void *) msg, &rs->conn_dest->qp->hdr, rs->conn_dest->qp->hdr.length);
2671         memcpy((void *) msg + rs->conn_dest->qp->hdr.length, buf, len);
2672         sge.addr = (uintptr_t) msg;
2673         sge.length = rs->conn_dest->qp->hdr.length + len;
2674         sge.lkey = rs->conn_dest->qp->smr->lkey;
2675         offset = (uint8_t *) msg - rs->sbuf;
2676
2677         ret = ds_post_send(rs, &sge, offset);
2678         return ret ? ret : len;
2679 }
2680
2681 /*
2682  * We overlap sending the data, by posting a small work request immediately,
2683  * then increasing the size of the send on each iteration.
2684  */
2685 ssize_t rsend(int socket, const void *buf, size_t len, int flags)
2686 {
2687         struct rsocket *rs;
2688         struct ibv_sge sge;
2689         size_t left = len;
2690         uint32_t xfer_size, olen = RS_OLAP_START_SIZE;
2691         int ret = 0;
2692
2693         rs = idm_at(&idm, socket);
2694         if (rs->type == SOCK_DGRAM) {
2695                 fastlock_acquire(&rs->slock);
2696                 ret = dsend(rs, buf, len, flags);
2697                 fastlock_release(&rs->slock);
2698                 return ret;
2699         }
2700
2701         if (rs->state & rs_opening) {
2702                 ret = rs_do_connect(rs);
2703                 if (ret) {
2704                         if (errno == EINPROGRESS)
2705                                 errno = EAGAIN;
2706                         return ret;
2707                 }
2708         }
2709
2710         fastlock_acquire(&rs->slock);
2711         if (rs->iomap_pending) {
2712                 ret = rs_send_iomaps(rs, flags);
2713                 if (ret)
2714                         goto out;
2715         }
2716         for (; left; left -= xfer_size, buf += xfer_size) {
2717                 if (!rs_can_send(rs)) {
2718                         ret = rs_get_comp(rs, rs_nonblocking(rs, flags),
2719                                           rs_conn_can_send);
2720                         if (ret)
2721                                 break;
2722                         if (!(rs->state & rs_writable)) {
2723                                 ret = ERR(ECONNRESET);
2724                                 break;
2725                         }
2726                 }
2727
2728                 if (olen < left) {
2729                         xfer_size = olen;
2730                         if (olen < RS_MAX_TRANSFER)
2731                                 olen <<= 1;
2732                 } else {
2733                         xfer_size = left;
2734                 }
2735
2736                 if (xfer_size > rs->sbuf_bytes_avail)
2737                         xfer_size = rs->sbuf_bytes_avail;
2738                 if (xfer_size > rs->target_sgl[rs->target_sge].length)
2739                         xfer_size = rs->target_sgl[rs->target_sge].length;
2740
2741                 if (xfer_size <= rs->sq_inline) {
2742                         sge.addr = (uintptr_t) buf;
2743                         sge.length = xfer_size;
2744                         sge.lkey = 0;
2745                         ret = rs_write_data(rs, &sge, 1, xfer_size, IBV_SEND_INLINE);
2746                 } else if (xfer_size <= rs_sbuf_left(rs)) {
2747                         memcpy((void *) (uintptr_t) rs->ssgl[0].addr, buf, xfer_size);
2748                         rs->ssgl[0].length = xfer_size;
2749                         ret = rs_write_data(rs, rs->ssgl, 1, xfer_size, 0);
2750                         if (xfer_size < rs_sbuf_left(rs))
2751                                 rs->ssgl[0].addr += xfer_size;
2752                         else
2753                                 rs->ssgl[0].addr = (uintptr_t) rs->sbuf;
2754                 } else {
2755                         rs->ssgl[0].length = rs_sbuf_left(rs);
2756                         memcpy((void *) (uintptr_t) rs->ssgl[0].addr, buf,
2757                                 rs->ssgl[0].length);
2758                         rs->ssgl[1].length = xfer_size - rs->ssgl[0].length;
2759                         memcpy(rs->sbuf, buf + rs->ssgl[0].length, rs->ssgl[1].length);
2760                         ret = rs_write_data(rs, rs->ssgl, 2, xfer_size, 0);
2761                         rs->ssgl[0].addr = (uintptr_t) rs->sbuf + rs->ssgl[1].length;
2762                 }
2763                 if (ret)
2764                         break;
2765         }
2766 out:
2767         fastlock_release(&rs->slock);
2768
2769         return (ret && left == len) ? ret : len - left;
2770 }
2771
2772 ssize_t rsendto(int socket, const void *buf, size_t len, int flags,
2773                 const struct sockaddr *dest_addr, socklen_t addrlen)
2774 {
2775         struct rsocket *rs;
2776         int ret;
2777
2778         rs = idm_at(&idm, socket);
2779         if (rs->type == SOCK_STREAM) {
2780                 if (dest_addr || addrlen)
2781                         return ERR(EISCONN);
2782
2783                 return rsend(socket, buf, len, flags);
2784         }
2785
2786         if (rs->state == rs_init) {
2787                 ret = ds_init_ep(rs);
2788                 if (ret)
2789                         return ret;
2790         }
2791
2792         fastlock_acquire(&rs->slock);
2793         if (!rs->conn_dest || ds_compare_addr(dest_addr, &rs->conn_dest->addr)) {
2794                 ret = ds_get_dest(rs, dest_addr, addrlen, &rs->conn_dest);
2795                 if (ret)
2796                         goto out;
2797         }
2798
2799         ret = dsend(rs, buf, len, flags);
2800 out:
2801         fastlock_release(&rs->slock);
2802         return ret;
2803 }
2804
2805 static void rs_copy_iov(void *dst, const struct iovec **iov, size_t *offset, size_t len)
2806 {
2807         size_t size;
2808
2809         while (len) {
2810                 size = (*iov)->iov_len - *offset;
2811                 if (size > len) {
2812                         memcpy (dst, (*iov)->iov_base + *offset, len);
2813                         *offset += len;
2814                         break;
2815                 }
2816
2817                 memcpy(dst, (*iov)->iov_base + *offset, size);
2818                 len -= size;
2819                 dst += size;
2820                 (*iov)++;
2821                 *offset = 0;
2822         }
2823 }
2824
2825 static ssize_t rsendv(int socket, const struct iovec *iov, int iovcnt, int flags)
2826 {
2827         struct rsocket *rs;
2828         const struct iovec *cur_iov;
2829         size_t left, len, offset = 0;
2830         uint32_t xfer_size, olen = RS_OLAP_START_SIZE;
2831         int i, ret = 0;
2832
2833         rs = idm_at(&idm, socket);
2834         if (rs->state & rs_opening) {
2835                 ret = rs_do_connect(rs);
2836                 if (ret) {
2837                         if (errno == EINPROGRESS)
2838                                 errno = EAGAIN;
2839                         return ret;
2840                 }
2841         }
2842
2843         cur_iov = iov;
2844         len = iov[0].iov_len;
2845         for (i = 1; i < iovcnt; i++)
2846                 len += iov[i].iov_len;
2847         left = len;
2848
2849         fastlock_acquire(&rs->slock);
2850         if (rs->iomap_pending) {
2851                 ret = rs_send_iomaps(rs, flags);
2852                 if (ret)
2853                         goto out;
2854         }
2855         for (; left; left -= xfer_size) {
2856                 if (!rs_can_send(rs)) {
2857                         ret = rs_get_comp(rs, rs_nonblocking(rs, flags),
2858                                           rs_conn_can_send);
2859                         if (ret)
2860                                 break;
2861                         if (!(rs->state & rs_writable)) {
2862                                 ret = ERR(ECONNRESET);
2863                                 break;
2864                         }
2865                 }
2866
2867                 if (olen < left) {
2868                         xfer_size = olen;
2869                         if (olen < RS_MAX_TRANSFER)
2870                                 olen <<= 1;
2871                 } else {
2872                         xfer_size = left;
2873                 }
2874
2875                 if (xfer_size > rs->sbuf_bytes_avail)
2876                         xfer_size = rs->sbuf_bytes_avail;
2877                 if (xfer_size > rs->target_sgl[rs->target_sge].length)
2878                         xfer_size = rs->target_sgl[rs->target_sge].length;
2879
2880                 if (xfer_size <= rs_sbuf_left(rs)) {
2881                         rs_copy_iov((void *) (uintptr_t) rs->ssgl[0].addr,
2882                                     &cur_iov, &offset, xfer_size);
2883                         rs->ssgl[0].length = xfer_size;
2884                         ret = rs_write_data(rs, rs->ssgl, 1, xfer_size,
2885                                             xfer_size <= rs->sq_inline ? IBV_SEND_INLINE : 0);
2886                         if (xfer_size < rs_sbuf_left(rs))
2887                                 rs->ssgl[0].addr += xfer_size;
2888                         else
2889                                 rs->ssgl[0].addr = (uintptr_t) rs->sbuf;
2890                 } else {
2891                         rs->ssgl[0].length = rs_sbuf_left(rs);
2892                         rs_copy_iov((void *) (uintptr_t) rs->ssgl[0].addr, &cur_iov,
2893                                     &offset, rs->ssgl[0].length);
2894                         rs->ssgl[1].length = xfer_size - rs->ssgl[0].length;
2895                         rs_copy_iov(rs->sbuf, &cur_iov, &offset, rs->ssgl[1].length);
2896                         ret = rs_write_data(rs, rs->ssgl, 2, xfer_size,
2897                                             xfer_size <= rs->sq_inline ? IBV_SEND_INLINE : 0);
2898                         rs->ssgl[0].addr = (uintptr_t) rs->sbuf + rs->ssgl[1].length;
2899                 }
2900                 if (ret)
2901                         break;
2902         }
2903 out:
2904         fastlock_release(&rs->slock);
2905
2906         return (ret && left == len) ? ret : len - left;
2907 }
2908
2909 ssize_t rsendmsg(int socket, const struct msghdr *msg, int flags)
2910 {
2911         if (msg->msg_control && msg->msg_controllen)
2912                 return ERR(ENOTSUP);
2913
2914         return rsendv(socket, msg->msg_iov, (int) msg->msg_iovlen, flags);
2915 }
2916
2917 ssize_t rwrite(int socket, const void *buf, size_t count)
2918 {
2919         return rsend(socket, buf, count, 0);
2920 }
2921
2922 ssize_t rwritev(int socket, const struct iovec *iov, int iovcnt)
2923 {
2924         return rsendv(socket, iov, iovcnt, 0);
2925 }
2926
2927 static struct pollfd *rs_fds_alloc(nfds_t nfds)
2928 {
2929         static __thread struct pollfd *rfds;
2930         static __thread nfds_t rnfds;
2931
2932         if (nfds > rnfds) {
2933                 if (rfds)
2934                         free(rfds);
2935
2936                 rfds = malloc(sizeof(*rfds) * nfds);
2937                 rnfds = rfds ? nfds : 0;
2938         }
2939
2940         return rfds;
2941 }
2942
2943 static int rs_poll_rs(struct rsocket *rs, int events,
2944                       int nonblock, int (*test)(struct rsocket *rs))
2945 {
2946         struct pollfd fds;
2947         short revents;
2948         int ret;
2949
2950 check_cq:
2951         if ((rs->type == SOCK_STREAM) && ((rs->state & rs_connected) ||
2952              (rs->state == rs_disconnected) || (rs->state & rs_error))) {
2953                 rs_process_cq(rs, nonblock, test);
2954
2955                 revents = 0;
2956                 if ((events & POLLIN) && rs_conn_have_rdata(rs))
2957                         revents |= POLLIN;
2958                 if ((events & POLLOUT) && rs_can_send(rs))
2959                         revents |= POLLOUT;
2960                 if (!(rs->state & rs_connected)) {
2961                         if (rs->state == rs_disconnected)
2962                                 revents |= POLLHUP;
2963                         else
2964                                 revents |= POLLERR;
2965                 }
2966
2967                 return revents;
2968         } else if (rs->type == SOCK_DGRAM) {
2969                 ds_process_cqs(rs, nonblock, test);
2970
2971                 revents = 0;
2972                 if ((events & POLLIN) && rs_have_rdata(rs))
2973                         revents |= POLLIN;
2974                 if ((events & POLLOUT) && ds_can_send(rs))
2975                         revents |= POLLOUT;
2976
2977                 return revents;
2978         }
2979
2980         if (rs->state == rs_listening) {
2981                 fds.fd = rs->cm_id->channel->fd;
2982                 fds.events = events;
2983                 fds.revents = 0;
2984                 poll(&fds, 1, 0);
2985                 return fds.revents;
2986         }
2987
2988         if (rs->state & rs_opening) {
2989                 ret = rs_do_connect(rs);
2990                 if (ret && (errno == EINPROGRESS)) {
2991                         errno = 0;
2992                 } else {
2993                         goto check_cq;
2994                 }
2995         }
2996
2997         if (rs->state == rs_connect_error) {
2998                 revents = 0;
2999                 if (events & POLLOUT)
3000                         revents |= POLLOUT;
3001                 if (events & POLLIN)
3002                         revents |= POLLIN;
3003                 revents |= POLLERR;
3004                 return revents;
3005         }
3006
3007         return 0;
3008 }
3009
3010 static int rs_poll_check(struct pollfd *fds, nfds_t nfds)
3011 {
3012         struct rsocket *rs;
3013         int i, cnt = 0;
3014
3015         for (i = 0; i < nfds; i++) {
3016                 rs = idm_lookup(&idm, fds[i].fd);
3017                 if (rs)
3018                         fds[i].revents = rs_poll_rs(rs, fds[i].events, 1, rs_poll_all);
3019                 else
3020                         poll(&fds[i], 1, 0);
3021
3022                 if (fds[i].revents)
3023                         cnt++;
3024         }
3025         return cnt;
3026 }
3027
3028 static int rs_poll_arm(struct pollfd *rfds, struct pollfd *fds, nfds_t nfds)
3029 {
3030         struct rsocket *rs;
3031         int i;
3032
3033         for (i = 0; i < nfds; i++) {
3034                 rs = idm_lookup(&idm, fds[i].fd);
3035                 if (rs) {
3036                         fds[i].revents = rs_poll_rs(rs, fds[i].events, 0, rs_is_cq_armed);
3037                         if (fds[i].revents)
3038                                 return 1;
3039
3040                         if (rs->type == SOCK_STREAM) {
3041                                 if (rs->state >= rs_connected)
3042                                         rfds[i].fd = rs->cm_id->recv_cq_channel->fd;
3043                                 else
3044                                         rfds[i].fd = rs->cm_id->channel->fd;
3045                         } else {
3046                                 rfds[i].fd = rs->epfd;
3047                         }
3048                         rfds[i].events = POLLIN;
3049                 } else {
3050                         rfds[i].fd = fds[i].fd;
3051                         rfds[i].events = fds[i].events;
3052                 }
3053                 rfds[i].revents = 0;
3054         }
3055         return 0;
3056 }
3057
3058 static int rs_poll_events(struct pollfd *rfds, struct pollfd *fds, nfds_t nfds)
3059 {
3060         struct rsocket *rs;
3061         int i, cnt = 0;
3062
3063         for (i = 0; i < nfds; i++) {
3064                 if (!rfds[i].revents)
3065                         continue;
3066
3067                 rs = idm_lookup(&idm, fds[i].fd);
3068                 if (rs) {
3069                         fastlock_acquire(&rs->cq_wait_lock);
3070                         if (rs->type == SOCK_STREAM)
3071                                 rs_get_cq_event(rs);
3072                         else
3073                                 ds_get_cq_event(rs);
3074                         fastlock_release(&rs->cq_wait_lock);
3075                         fds[i].revents = rs_poll_rs(rs, fds[i].events, 1, rs_poll_all);
3076                 } else {
3077                         fds[i].revents = rfds[i].revents;
3078                 }
3079                 if (fds[i].revents)
3080                         cnt++;
3081         }
3082         return cnt;
3083 }
3084
3085 /*
3086  * We need to poll *all* fd's that the user specifies at least once.
3087  * Note that we may receive events on an rsocket that may not be reported
3088  * to the user (e.g. connection events or credit updates).  Process those
3089  * events, then return to polling until we find ones of interest.
3090  */
3091 int rpoll(struct pollfd *fds, nfds_t nfds, int timeout)
3092 {
3093         struct timeval s, e;
3094         struct pollfd *rfds;
3095         uint32_t poll_time = 0;
3096         int ret;
3097
3098         do {
3099                 ret = rs_poll_check(fds, nfds);
3100                 if (ret || !timeout)
3101                         return ret;
3102
3103                 if (!poll_time)
3104                         gettimeofday(&s, NULL);
3105
3106                 gettimeofday(&e, NULL);
3107                 poll_time = (e.tv_sec - s.tv_sec) * 1000000 +
3108                             (e.tv_usec - s.tv_usec) + 1;
3109         } while (poll_time <= polling_time);
3110
3111         rfds = rs_fds_alloc(nfds);
3112         if (!rfds)
3113                 return ERR(ENOMEM);
3114
3115         do {
3116                 ret = rs_poll_arm(rfds, fds, nfds);
3117                 if (ret)
3118                         break;
3119
3120                 ret = poll(rfds, nfds, timeout);
3121                 if (ret <= 0)
3122                         break;
3123
3124                 ret = rs_poll_events(rfds, fds, nfds);
3125         } while (!ret);
3126
3127         return ret;
3128 }
3129
3130 static struct pollfd *
3131 rs_select_to_poll(int *nfds, fd_set *readfds, fd_set *writefds, fd_set *exceptfds)
3132 {
3133         struct pollfd *fds;
3134         int fd, i = 0;
3135
3136         fds = calloc(*nfds, sizeof(*fds));
3137         if (!fds)
3138                 return NULL;
3139
3140         for (fd = 0; fd < *nfds; fd++) {
3141                 if (readfds && FD_ISSET(fd, readfds)) {
3142                         fds[i].fd = fd;
3143                         fds[i].events = POLLIN;
3144                 }
3145
3146                 if (writefds && FD_ISSET(fd, writefds)) {
3147                         fds[i].fd = fd;
3148                         fds[i].events |= POLLOUT;
3149                 }
3150
3151                 if (exceptfds && FD_ISSET(fd, exceptfds))
3152                         fds[i].fd = fd;
3153
3154                 if (fds[i].fd)
3155                         i++;
3156         }
3157
3158         *nfds = i;
3159         return fds;
3160 }
3161
3162 static int
3163 rs_poll_to_select(int nfds, struct pollfd *fds, fd_set *readfds,
3164                   fd_set *writefds, fd_set *exceptfds)
3165 {
3166         int i, cnt = 0;
3167
3168         for (i = 0; i < nfds; i++) {
3169                 if (readfds && (fds[i].revents & (POLLIN | POLLHUP))) {
3170                         FD_SET(fds[i].fd, readfds);
3171                         cnt++;
3172                 }
3173
3174                 if (writefds && (fds[i].revents & POLLOUT)) {
3175                         FD_SET(fds[i].fd, writefds);
3176                         cnt++;
3177                 }
3178
3179                 if (exceptfds && (fds[i].revents & ~(POLLIN | POLLOUT))) {
3180                         FD_SET(fds[i].fd, exceptfds);
3181                         cnt++;
3182                 }
3183         }
3184         return cnt;
3185 }
3186
3187 static int rs_convert_timeout(struct timeval *timeout)
3188 {
3189         return !timeout ? -1 :
3190                 timeout->tv_sec * 1000 + timeout->tv_usec / 1000;
3191 }
3192
3193 int rselect(int nfds, fd_set *readfds, fd_set *writefds,
3194             fd_set *exceptfds, struct timeval *timeout)
3195 {
3196         struct pollfd *fds;
3197         int ret;
3198
3199         fds = rs_select_to_poll(&nfds, readfds, writefds, exceptfds);
3200         if (!fds)
3201                 return ERR(ENOMEM);
3202
3203         ret = rpoll(fds, nfds, rs_convert_timeout(timeout));
3204
3205         if (readfds)
3206                 FD_ZERO(readfds);
3207         if (writefds)
3208                 FD_ZERO(writefds);
3209         if (exceptfds)
3210                 FD_ZERO(exceptfds);
3211
3212         if (ret > 0)
3213                 ret = rs_poll_to_select(nfds, fds, readfds, writefds, exceptfds);
3214
3215         free(fds);
3216         return ret;
3217 }
3218
3219 /*
3220  * For graceful disconnect, notify the remote side that we're
3221  * disconnecting and wait until all outstanding sends complete, provided
3222  * that the remote side has not sent a disconnect message.
3223  */
3224 int rshutdown(int socket, int how)
3225 {
3226         struct rsocket *rs;
3227         int ctrl, ret = 0;
3228
3229         rs = idm_lookup(&idm, socket);
3230         if (!rs)
3231                 return ERR(EBADF);
3232         if (rs->opts & RS_OPT_SVC_ACTIVE)
3233                 rs_notify_svc(&tcp_svc, rs, RS_SVC_REM_KEEPALIVE);
3234
3235         if (rs->fd_flags & O_NONBLOCK)
3236                 rs_set_nonblocking(rs, 0);
3237
3238         if (rs->state & rs_connected) {
3239                 if (how == SHUT_RDWR) {
3240                         ctrl = RS_CTRL_DISCONNECT;
3241                         rs->state &= ~(rs_readable | rs_writable);
3242                 } else if (how == SHUT_WR) {
3243                         rs->state &= ~rs_writable;
3244                         ctrl = (rs->state & rs_readable) ?
3245                                 RS_CTRL_SHUTDOWN : RS_CTRL_DISCONNECT;
3246                 } else {
3247                         rs->state &= ~rs_readable;
3248                         if (rs->state & rs_writable)
3249                                 goto out;
3250                         ctrl = RS_CTRL_DISCONNECT;
3251                 }
3252                 if (!rs_ctrl_avail(rs)) {
3253                         ret = rs_process_cq(rs, 0, rs_conn_can_send_ctrl);
3254                         if (ret)
3255                                 goto out;
3256                 }
3257
3258                 if ((rs->state & rs_connected) && rs_ctrl_avail(rs)) {
3259                         rs->ctrl_seqno++;
3260                         ret = rs_post_msg(rs, rs_msg_set(RS_OP_CTRL, ctrl));
3261                 }
3262         }
3263
3264         if (rs->state & rs_connected)
3265                 rs_process_cq(rs, 0, rs_conn_all_sends_done);
3266
3267 out:
3268         if ((rs->fd_flags & O_NONBLOCK) && (rs->state & rs_connected))
3269                 rs_set_nonblocking(rs, rs->fd_flags);
3270
3271         if (rs->state & rs_disconnected) {
3272                 /* Generate event by flushing receives to unblock rpoll */
3273                 ibv_req_notify_cq(rs->cm_id->recv_cq, 0);
3274                 ucma_shutdown(rs->cm_id);
3275         }
3276
3277         return ret;
3278 }
3279
3280 static void ds_shutdown(struct rsocket *rs)
3281 {
3282         if (rs->opts & RS_OPT_SVC_ACTIVE)
3283                 rs_notify_svc(&udp_svc, rs, RS_SVC_REM_DGRAM);
3284
3285         if (rs->fd_flags & O_NONBLOCK)
3286                 rs_set_nonblocking(rs, 0);
3287
3288         rs->state &= ~(rs_readable | rs_writable);
3289         ds_process_cqs(rs, 0, ds_all_sends_done);
3290
3291         if (rs->fd_flags & O_NONBLOCK)
3292                 rs_set_nonblocking(rs, rs->fd_flags);
3293 }
3294
3295 int rclose(int socket)
3296 {
3297         struct rsocket *rs;
3298
3299         rs = idm_lookup(&idm, socket);
3300         if (!rs)
3301                 return EBADF;
3302         if (rs->type == SOCK_STREAM) {
3303                 if (rs->state & rs_connected)
3304                         rshutdown(socket, SHUT_RDWR);
3305                 else if (rs->opts & RS_OPT_SVC_ACTIVE)
3306                         rs_notify_svc(&tcp_svc, rs, RS_SVC_REM_KEEPALIVE);
3307         } else {
3308                 ds_shutdown(rs);
3309         }
3310
3311         rs_free(rs);
3312         return 0;
3313 }
3314
3315 static void rs_copy_addr(struct sockaddr *dst, struct sockaddr *src, socklen_t *len)
3316 {
3317         socklen_t size;
3318
3319         if (src->sa_family == AF_INET) {
3320                 size = min_t(socklen_t, *len, sizeof(struct sockaddr_in));
3321                 *len = sizeof(struct sockaddr_in);
3322         } else {
3323                 size = min_t(socklen_t, *len, sizeof(struct sockaddr_in6));
3324                 *len = sizeof(struct sockaddr_in6);
3325         }
3326         memcpy(dst, src, size);
3327 }
3328
3329 int rgetpeername(int socket, struct sockaddr *addr, socklen_t *addrlen)
3330 {
3331         struct rsocket *rs;
3332
3333         rs = idm_lookup(&idm, socket);
3334         if (!rs)
3335                 return ERR(EBADF);
3336         if (rs->type == SOCK_STREAM) {
3337                 rs_copy_addr(addr, rdma_get_peer_addr(rs->cm_id), addrlen);
3338                 return 0;
3339         } else {
3340                 return getpeername(rs->udp_sock, addr, addrlen);
3341         }
3342 }
3343
3344 int rgetsockname(int socket, struct sockaddr *addr, socklen_t *addrlen)
3345 {
3346         struct rsocket *rs;
3347
3348         rs = idm_lookup(&idm, socket);
3349         if (!rs)
3350                 return ERR(EBADF);
3351         if (rs->type == SOCK_STREAM) {
3352                 rs_copy_addr(addr, rdma_get_local_addr(rs->cm_id), addrlen);
3353                 return 0;
3354         } else {
3355                 return getsockname(rs->udp_sock, addr, addrlen);
3356         }
3357 }
3358
3359 static int rs_set_keepalive(struct rsocket *rs, int on)
3360 {
3361         FILE *f;
3362         int ret;
3363
3364         if ((on && (rs->opts & RS_OPT_SVC_ACTIVE)) ||
3365             (!on && !(rs->opts & RS_OPT_SVC_ACTIVE)))
3366                 return 0;
3367
3368         if (on) {
3369                 if (!rs->keepalive_time) {
3370                         if ((f = fopen("/proc/sys/net/ipv4/tcp_keepalive_time", "r"))) {
3371                                 if (fscanf(f, "%u", &rs->keepalive_time) != 1)
3372                                         rs->keepalive_time = 7200;
3373                                 fclose(f);
3374                         } else {
3375                                 rs->keepalive_time = 7200;
3376                         }
3377                 }
3378                 ret = rs_notify_svc(&tcp_svc, rs, RS_SVC_ADD_KEEPALIVE);
3379         } else {
3380                 ret = rs_notify_svc(&tcp_svc, rs, RS_SVC_REM_KEEPALIVE);
3381         }
3382
3383         return ret;
3384 }
3385
3386 int rsetsockopt(int socket, int level, int optname,
3387                 const void *optval, socklen_t optlen)
3388 {
3389         struct rsocket *rs;
3390         int ret, opt_on = 0;
3391         uint64_t *opts = NULL;
3392
3393         ret = ERR(ENOTSUP);
3394         rs = idm_lookup(&idm, socket);
3395         if (!rs)
3396                 return ERR(EBADF);
3397         if (rs->type == SOCK_DGRAM && level != SOL_RDMA) {
3398                 ret = setsockopt(rs->udp_sock, level, optname, optval, optlen);
3399                 if (ret)
3400                         return ret;
3401         }
3402
3403         switch (level) {
3404         case SOL_SOCKET:
3405                 opts = &rs->so_opts;
3406                 switch (optname) {
3407                 case SO_REUSEADDR:
3408                         if (rs->type == SOCK_STREAM) {
3409                                 ret = rdma_set_option(rs->cm_id, RDMA_OPTION_ID,
3410                                                       RDMA_OPTION_ID_REUSEADDR,
3411                                                       (void *) optval, optlen);
3412                                 if (ret && ((errno == ENOSYS) || ((rs->state != rs_init) &&
3413                                     rs->cm_id->context &&
3414                                     (rs->cm_id->verbs->device->transport_type == IBV_TRANSPORT_IB))))
3415                                         ret = 0;
3416                         }
3417                         opt_on = *(int *) optval;
3418                         break;
3419                 case SO_RCVBUF:
3420                         if ((rs->type == SOCK_STREAM && !rs->rbuf) ||
3421                             (rs->type == SOCK_DGRAM && !rs->qp_list))
3422                                 rs->rbuf_size = (*(uint32_t *) optval) << 1;
3423                         ret = 0;
3424                         break;
3425                 case SO_SNDBUF:
3426                         if (!rs->sbuf)
3427                                 rs->sbuf_size = (*(uint32_t *) optval) << 1;
3428                         if (rs->sbuf_size < RS_SNDLOWAT)
3429                                 rs->sbuf_size = RS_SNDLOWAT << 1;
3430                         ret = 0;
3431                         break;
3432                 case SO_LINGER:
3433                         /* Invert value so default so_opt = 0 is on */
3434                         opt_on =  !((struct linger *) optval)->l_onoff;
3435                         ret = 0;
3436                         break;
3437                 case SO_KEEPALIVE:
3438                         ret = rs_set_keepalive(rs, *(int *) optval);
3439                         opt_on = rs->opts & RS_OPT_SVC_ACTIVE;
3440                         break;
3441                 case SO_OOBINLINE:
3442                         opt_on = *(int *) optval;
3443                         ret = 0;
3444                         break;
3445                 default:
3446                         break;
3447                 }
3448                 break;
3449         case IPPROTO_TCP:
3450                 opts = &rs->tcp_opts;
3451                 switch (optname) {
3452                 case TCP_KEEPCNT:
3453                 case TCP_KEEPINTVL:
3454                         ret = 0;   /* N/A - we're using a reliable connection */
3455                         break;
3456                 case TCP_KEEPIDLE:
3457                         if (*(int *) optval <= 0) {
3458                                 ret = ERR(EINVAL);
3459                                 break;
3460                         }
3461                         rs->keepalive_time = *(int *) optval;
3462                         ret = (rs->opts & RS_OPT_SVC_ACTIVE) ?
3463                               rs_notify_svc(&tcp_svc, rs, RS_SVC_MOD_KEEPALIVE) : 0;
3464                         break;
3465                 case TCP_NODELAY:
3466                         opt_on = *(int *) optval;
3467                         ret = 0;
3468                         break;
3469                 case TCP_MAXSEG:
3470                         ret = 0;
3471                         break;
3472                 default:
3473                         break;
3474                 }
3475                 break;
3476         case IPPROTO_IPV6:
3477                 opts = &rs->ipv6_opts;
3478                 switch (optname) {
3479                 case IPV6_V6ONLY:
3480                         if (rs->type == SOCK_STREAM) {
3481                                 ret = rdma_set_option(rs->cm_id, RDMA_OPTION_ID,
3482                                                       RDMA_OPTION_ID_AFONLY,
3483                                                       (void *) optval, optlen);
3484                         }
3485                         opt_on = *(int *) optval;
3486                         break;
3487                 default:
3488                         break;
3489                 }
3490                 break;
3491         case SOL_RDMA:
3492                 if (rs->state >= rs_opening) {
3493                         ret = ERR(EINVAL);
3494                         break;
3495                 }
3496
3497                 switch (optname) {
3498                 case RDMA_SQSIZE:
3499                         rs->sq_size = min_t(uint32_t, (*(uint32_t *)optval),
3500                                             RS_QP_MAX_SIZE);
3501                         ret = 0;
3502                         break;
3503                 case RDMA_RQSIZE:
3504                         rs->rq_size = min_t(uint32_t, (*(uint32_t *)optval),
3505                                             RS_QP_MAX_SIZE);
3506                         ret = 0;
3507                         break;
3508                 case RDMA_INLINE:
3509                         rs->sq_inline = min_t(uint32_t, *(uint32_t *)optval,
3510                                               RS_QP_MAX_SIZE);
3511                         ret = 0;
3512                         break;
3513                 case RDMA_IOMAPSIZE:
3514                         rs->target_iomap_size = (uint16_t) rs_scale_to_value(
3515                                 (uint8_t) rs_value_to_scale(*(int *) optval, 8), 8);
3516                         ret = 0;
3517                         break;
3518                 case RDMA_ROUTE:
3519                         if ((rs->optval = malloc(optlen))) {
3520                                 memcpy(rs->optval, optval, optlen);
3521                                 rs->optlen = optlen;
3522                                 ret = 0;
3523                         } else {
3524                                 ret = ERR(ENOMEM);
3525                         }
3526                         break;
3527                 default:
3528                         break;
3529                 }
3530                 break;
3531         default:
3532                 break;
3533         }
3534
3535         if (!ret && opts) {
3536                 if (opt_on)
3537                         *opts |= (1 << optname);
3538                 else
3539                         *opts &= ~(1 << optname);
3540         }
3541
3542         return ret;
3543 }
3544
3545 static void rs_convert_sa_path(struct ibv_sa_path_rec *sa_path,
3546                                struct ibv_path_data *path_data)
3547 {
3548         uint32_t fl_hop;
3549
3550         memset(path_data, 0, sizeof(*path_data));
3551         path_data->path.dgid = sa_path->dgid;
3552         path_data->path.sgid = sa_path->sgid;
3553         path_data->path.dlid = sa_path->dlid;
3554         path_data->path.slid = sa_path->slid;
3555         fl_hop = be32toh(sa_path->flow_label) << 8;
3556         path_data->path.flowlabel_hoplimit = htobe32(fl_hop | sa_path->hop_limit);
3557         path_data->path.tclass = sa_path->traffic_class;
3558         path_data->path.reversible_numpath = sa_path->reversible << 7 | 1;
3559         path_data->path.pkey = sa_path->pkey;
3560         path_data->path.qosclass_sl = htobe16(sa_path->sl);
3561         path_data->path.mtu = sa_path->mtu | 2 << 6;    /* exactly */
3562         path_data->path.rate = sa_path->rate | 2 << 6;
3563         path_data->path.packetlifetime = sa_path->packet_life_time | 2 << 6;
3564         path_data->flags= sa_path->preference;
3565 }
3566
3567 int rgetsockopt(int socket, int level, int optname,
3568                 void *optval, socklen_t *optlen)
3569 {
3570         struct rsocket *rs;
3571         void *opt;
3572         struct ibv_sa_path_rec *path_rec;
3573         struct ibv_path_data path_data;
3574         socklen_t len;
3575         int ret = 0;
3576         int num_paths;
3577
3578         rs = idm_lookup(&idm, socket);
3579         if (!rs)
3580                 return ERR(EBADF);
3581         switch (level) {
3582         case SOL_SOCKET:
3583                 switch (optname) {
3584                 case SO_REUSEADDR:
3585                 case SO_KEEPALIVE:
3586                 case SO_OOBINLINE:
3587                         *((int *) optval) = !!(rs->so_opts & (1 << optname));
3588                         *optlen = sizeof(int);
3589                         break;
3590                 case SO_RCVBUF:
3591                         *((int *) optval) = rs->rbuf_size;
3592                         *optlen = sizeof(int);
3593                         break;
3594                 case SO_SNDBUF:
3595                         *((int *) optval) = rs->sbuf_size;
3596                         *optlen = sizeof(int);
3597                         break;
3598                 case SO_LINGER:
3599                         /* Value is inverted so default so_opt = 0 is on */
3600                         ((struct linger *) optval)->l_onoff =
3601                                         !(rs->so_opts & (1 << optname));
3602                         ((struct linger *) optval)->l_linger = 0;
3603                         *optlen = sizeof(struct linger);
3604                         break;
3605                 case SO_ERROR:
3606                         *((int *) optval) = rs->err;
3607                         *optlen = sizeof(int);
3608                         rs->err = 0;
3609                         break;
3610                 default:
3611                         ret = ENOTSUP;
3612                         break;
3613                 }
3614                 break;
3615         case IPPROTO_TCP:
3616                 switch (optname) {
3617                 case TCP_KEEPCNT:
3618                 case TCP_KEEPINTVL:
3619                         *((int *) optval) = 1;   /* N/A */
3620                         break;
3621                 case TCP_KEEPIDLE:
3622                         *((int *) optval) = (int) rs->keepalive_time;
3623                         *optlen = sizeof(int);
3624                         break;
3625                 case TCP_NODELAY:
3626                         *((int *) optval) = !!(rs->tcp_opts & (1 << optname));
3627                         *optlen = sizeof(int);
3628                         break;
3629                 case TCP_MAXSEG:
3630                         *((int *) optval) = (rs->cm_id && rs->cm_id->route.num_paths) ?
3631                                             1 << (7 + rs->cm_id->route.path_rec->mtu) :
3632                                             2048;
3633                         *optlen = sizeof(int);
3634                         break;
3635                 default:
3636                         ret = ENOTSUP;
3637                         break;
3638                 }
3639                 break;
3640         case IPPROTO_IPV6:
3641                 switch (optname) {
3642                 case IPV6_V6ONLY:
3643                         *((int *) optval) = !!(rs->ipv6_opts & (1 << optname));
3644                         *optlen = sizeof(int);
3645                         break;
3646                 default:
3647                         ret = ENOTSUP;
3648                         break;
3649                 }
3650                 break;
3651         case SOL_RDMA:
3652                 switch (optname) {
3653                 case RDMA_SQSIZE:
3654                         *((int *) optval) = rs->sq_size;
3655                         *optlen = sizeof(int);
3656                         break;
3657                 case RDMA_RQSIZE:
3658                         *((int *) optval) = rs->rq_size;
3659                         *optlen = sizeof(int);
3660                         break;
3661                 case RDMA_INLINE:
3662                         *((int *) optval) = rs->sq_inline;
3663                         *optlen = sizeof(int);
3664                         break;
3665                 case RDMA_IOMAPSIZE:
3666                         *((int *) optval) = rs->target_iomap_size;
3667                         *optlen = sizeof(int);
3668                         break;
3669                 case RDMA_ROUTE:
3670                         if (rs->optval) {
3671                                 if (*optlen < rs->optlen) {
3672                                         ret = EINVAL;
3673                                 } else {
3674                                         memcpy(rs->optval, optval, rs->optlen);
3675                                         *optlen = rs->optlen;
3676                                 }
3677                         } else {
3678                                 if (*optlen < sizeof(path_data)) {
3679                                         ret = EINVAL;
3680                                 } else {
3681                                         len = 0;
3682                                         opt = optval;
3683                                         path_rec = rs->cm_id->route.path_rec;
3684                                         num_paths = 0;
3685                                         while (len + sizeof(path_data) <= *optlen &&
3686                                                num_paths < rs->cm_id->route.num_paths) {
3687                                                 rs_convert_sa_path(path_rec, &path_data);
3688                                                 memcpy(opt, &path_data, sizeof(path_data));
3689                                                 len += sizeof(path_data);
3690                                                 opt += sizeof(path_data);
3691                                                 path_rec++;
3692                                                 num_paths++;
3693                                         }
3694                                         *optlen = len;
3695                                         ret = 0;
3696                                 }
3697                         }
3698                         break;
3699                 default:
3700                         ret = ENOTSUP;
3701                         break;
3702                 }
3703                 break;
3704         default:
3705                 ret = ENOTSUP;
3706                 break;
3707         }
3708
3709         return rdma_seterrno(ret);
3710 }
3711
3712 int rfcntl(int socket, int cmd, ... /* arg */ )
3713 {
3714         struct rsocket *rs;
3715         va_list args;
3716         int param;
3717         int ret = 0;
3718
3719         rs = idm_lookup(&idm, socket);
3720         if (!rs)
3721                 return ERR(EBADF);
3722         va_start(args, cmd);
3723         switch (cmd) {
3724         case F_GETFL:
3725                 ret = rs->fd_flags;
3726                 break;
3727         case F_SETFL:
3728                 param = va_arg(args, int);
3729                 if ((rs->fd_flags & O_NONBLOCK) != (param & O_NONBLOCK))
3730                         ret = rs_set_nonblocking(rs, param & O_NONBLOCK);
3731
3732                 if (!ret)
3733                         rs->fd_flags = param;
3734                 break;
3735         default:
3736                 ret = ERR(ENOTSUP);
3737                 break;
3738         }
3739         va_end(args);
3740         return ret;
3741 }
3742
3743 static struct rs_iomap_mr *rs_get_iomap_mr(struct rsocket *rs)
3744 {
3745         int i;
3746
3747         if (!rs->remote_iomappings) {
3748                 rs->remote_iomappings = calloc(rs->remote_iomap.length,
3749                                                sizeof(*rs->remote_iomappings));
3750                 if (!rs->remote_iomappings)
3751                         return NULL;
3752
3753                 for (i = 0; i < rs->remote_iomap.length; i++)
3754                         rs->remote_iomappings[i].index = i;
3755         }
3756
3757         for (i = 0; i < rs->remote_iomap.length; i++) {
3758                 if (!rs->remote_iomappings[i].mr)
3759                         return &rs->remote_iomappings[i];
3760         }
3761         return NULL;
3762 }
3763
3764 /*
3765  * If an offset is given, we map to it.  If offset is -1, then we map the
3766  * offset to the address of buf.  We do not check for conflicts, which must
3767  * be fixed at some point.
3768  */
3769 off_t riomap(int socket, void *buf, size_t len, int prot, int flags, off_t offset)
3770 {
3771         struct rsocket *rs;
3772         struct rs_iomap_mr *iomr;
3773         int access = IBV_ACCESS_LOCAL_WRITE;
3774
3775         rs = idm_at(&idm, socket);
3776         if (!rs->cm_id->pd || (prot & ~(PROT_WRITE | PROT_NONE)))
3777                 return ERR(EINVAL);
3778
3779         fastlock_acquire(&rs->map_lock);
3780         if (prot & PROT_WRITE) {
3781                 iomr = rs_get_iomap_mr(rs);
3782                 access |= IBV_ACCESS_REMOTE_WRITE;
3783         } else {
3784                 iomr = calloc(1, sizeof(*iomr));
3785                 iomr->index = -1;
3786         }
3787         if (!iomr) {
3788                 offset = ERR(ENOMEM);
3789                 goto out;
3790         }
3791
3792         iomr->mr = ibv_reg_mr(rs->cm_id->pd, buf, len, access);
3793         if (!iomr->mr) {
3794                 if (iomr->index < 0)
3795                         free(iomr);
3796                 offset = -1;
3797                 goto out;
3798         }
3799
3800         if (offset == -1)
3801                 offset = (uintptr_t) buf;
3802         iomr->offset = offset;
3803         atomic_store(&iomr->refcnt, 1);
3804
3805         if (iomr->index >= 0) {
3806                 dlist_insert_tail(&iomr->entry, &rs->iomap_queue);
3807                 rs->iomap_pending = 1;
3808         } else {
3809                 dlist_insert_tail(&iomr->entry, &rs->iomap_list);
3810         }
3811 out:
3812         fastlock_release(&rs->map_lock);
3813         return offset;
3814 }
3815
3816 int riounmap(int socket, void *buf, size_t len)
3817 {
3818         struct rsocket *rs;
3819         struct rs_iomap_mr *iomr;
3820         dlist_entry *entry;
3821         int ret = 0;
3822
3823         rs = idm_at(&idm, socket);
3824         fastlock_acquire(&rs->map_lock);
3825
3826         for (entry = rs->iomap_list.next; entry != &rs->iomap_list;
3827              entry = entry->next) {
3828                 iomr = container_of(entry, struct rs_iomap_mr, entry);
3829                 if (iomr->mr->addr == buf && iomr->mr->length == len) {
3830                         rs_release_iomap_mr(iomr);
3831                         goto out;
3832                 }
3833         }
3834
3835         for (entry = rs->iomap_queue.next; entry != &rs->iomap_queue;
3836              entry = entry->next) {
3837                 iomr = container_of(entry, struct rs_iomap_mr, entry);
3838                 if (iomr->mr->addr == buf && iomr->mr->length == len) {
3839                         rs_release_iomap_mr(iomr);
3840                         goto out;
3841                 }
3842         }
3843         ret = ERR(EINVAL);
3844 out:
3845         fastlock_release(&rs->map_lock);
3846         return ret;
3847 }
3848
3849 static struct rs_iomap *rs_find_iomap(struct rsocket *rs, off_t offset)
3850 {
3851         int i;
3852
3853         for (i = 0; i < rs->target_iomap_size; i++) {
3854                 if (offset >= rs->target_iomap[i].offset &&
3855                     offset < rs->target_iomap[i].offset + rs->target_iomap[i].sge.length)
3856                         return &rs->target_iomap[i];
3857         }
3858         return NULL;
3859 }
3860
3861 size_t riowrite(int socket, const void *buf, size_t count, off_t offset, int flags)
3862 {
3863         struct rsocket *rs;
3864         struct rs_iomap *iom = NULL;
3865         struct ibv_sge sge;
3866         size_t left = count;
3867         uint32_t xfer_size, olen = RS_OLAP_START_SIZE;
3868         int ret = 0;
3869
3870         rs = idm_at(&idm, socket);
3871         fastlock_acquire(&rs->slock);
3872         if (rs->iomap_pending) {
3873                 ret = rs_send_iomaps(rs, flags);
3874                 if (ret)
3875                         goto out;
3876         }
3877         for (; left; left -= xfer_size, buf += xfer_size, offset += xfer_size) {
3878                 if (!iom || offset > iom->offset + iom->sge.length) {
3879                         iom = rs_find_iomap(rs, offset);
3880                         if (!iom)
3881                                 break;
3882                 }
3883
3884                 if (!rs_can_send(rs)) {
3885                         ret = rs_get_comp(rs, rs_nonblocking(rs, flags),
3886                                           rs_conn_can_send);
3887                         if (ret)
3888                                 break;
3889                         if (!(rs->state & rs_writable)) {
3890                                 ret = ERR(ECONNRESET);
3891                                 break;
3892                         }
3893                 }
3894
3895                 if (olen < left) {
3896                         xfer_size = olen;
3897                         if (olen < RS_MAX_TRANSFER)
3898                                 olen <<= 1;
3899                 } else {
3900                         xfer_size = left;
3901                 }
3902
3903                 if (xfer_size > rs->sbuf_bytes_avail)
3904                         xfer_size = rs->sbuf_bytes_avail;
3905                 if (xfer_size > iom->offset + iom->sge.length - offset)
3906                         xfer_size = iom->offset + iom->sge.length - offset;
3907
3908                 if (xfer_size <= rs->sq_inline) {
3909                         sge.addr = (uintptr_t) buf;
3910                         sge.length = xfer_size;
3911                         sge.lkey = 0;
3912                         ret = rs_write_direct(rs, iom, offset, &sge, 1,
3913                                               xfer_size, IBV_SEND_INLINE);
3914                 } else if (xfer_size <= rs_sbuf_left(rs)) {
3915                         memcpy((void *) (uintptr_t) rs->ssgl[0].addr, buf, xfer_size);
3916                         rs->ssgl[0].length = xfer_size;
3917                         ret = rs_write_direct(rs, iom, offset, rs->ssgl, 1, xfer_size, 0);
3918                         if (xfer_size < rs_sbuf_left(rs))
3919                                 rs->ssgl[0].addr += xfer_size;
3920                         else
3921                                 rs->ssgl[0].addr = (uintptr_t) rs->sbuf;
3922                 } else {
3923                         rs->ssgl[0].length = rs_sbuf_left(rs);
3924                         memcpy((void *) (uintptr_t) rs->ssgl[0].addr, buf,
3925                                 rs->ssgl[0].length);
3926                         rs->ssgl[1].length = xfer_size - rs->ssgl[0].length;
3927                         memcpy(rs->sbuf, buf + rs->ssgl[0].length, rs->ssgl[1].length);
3928                         ret = rs_write_direct(rs, iom, offset, rs->ssgl, 2, xfer_size, 0);
3929                         rs->ssgl[0].addr = (uintptr_t) rs->sbuf + rs->ssgl[1].length;
3930                 }
3931                 if (ret)
3932                         break;
3933         }
3934 out:
3935         fastlock_release(&rs->slock);
3936
3937         return (ret && left == count) ? ret : count - left;
3938 }
3939
3940 /****************************************************************************
3941  * Service Processing Threads
3942  ****************************************************************************/
3943
3944 static int rs_svc_grow_sets(struct rs_svc *svc, int grow_size)
3945 {
3946         struct rsocket **rss;
3947         void *set, *contexts;
3948
3949         set = calloc(svc->size + grow_size, sizeof(*rss) + svc->context_size);
3950         if (!set)
3951                 return ENOMEM;
3952
3953         svc->size += grow_size;
3954         rss = set;
3955         contexts = set + sizeof(*rss) * svc->size;
3956         if (svc->cnt) {
3957                 memcpy(rss, svc->rss, sizeof(*rss) * (svc->cnt + 1));
3958                 memcpy(contexts, svc->contexts, svc->context_size * (svc->cnt + 1));
3959         }
3960
3961         free(svc->rss);
3962         svc->rss = rss;
3963         svc->contexts = contexts;
3964         return 0;
3965 }
3966
3967 /*
3968  * Index 0 is reserved for the service's communication socket.
3969  */
3970 static int rs_svc_add_rs(struct rs_svc *svc, struct rsocket *rs)
3971 {
3972         int ret;
3973
3974         if (svc->cnt >= svc->size - 1) {
3975                 ret = rs_svc_grow_sets(svc, 4);
3976                 if (ret)
3977                         return ret;
3978         }
3979
3980         svc->rss[++svc->cnt] = rs;
3981         return 0;
3982 }
3983
3984 static int rs_svc_index(struct rs_svc *svc, struct rsocket *rs)
3985 {
3986         int i;
3987
3988         for (i = 1; i <= svc->cnt; i++) {
3989                 if (svc->rss[i] == rs)
3990                         return i;
3991         }
3992         return -1;
3993 }
3994
3995 static int rs_svc_rm_rs(struct rs_svc *svc, struct rsocket *rs)
3996 {
3997         int i;
3998
3999         if ((i = rs_svc_index(svc, rs)) >= 0) {
4000                 svc->rss[i] = svc->rss[svc->cnt];
4001                 memcpy(svc->contexts + i * svc->context_size,
4002                        svc->contexts + svc->cnt * svc->context_size,
4003                        svc->context_size);
4004                 svc->cnt--;
4005                 return 0;
4006         }
4007         return EBADF;
4008 }
4009
4010 static void udp_svc_process_sock(struct rs_svc *svc)
4011 {
4012         struct rs_svc_msg msg;
4013
4014         read_all(svc->sock[1], &msg, sizeof msg);
4015         switch (msg.cmd) {
4016         case RS_SVC_ADD_DGRAM:
4017                 msg.status = rs_svc_add_rs(svc, msg.rs);
4018                 if (!msg.status) {
4019                         msg.rs->opts |= RS_OPT_SVC_ACTIVE;
4020                         udp_svc_fds = svc->contexts;
4021                         udp_svc_fds[svc->cnt].fd = msg.rs->udp_sock;
4022                         udp_svc_fds[svc->cnt].events = POLLIN;
4023                         udp_svc_fds[svc->cnt].revents = 0;
4024                 }
4025                 break;
4026         case RS_SVC_REM_DGRAM:
4027                 msg.status = rs_svc_rm_rs(svc, msg.rs);
4028                 if (!msg.status)
4029                         msg.rs->opts &= ~RS_OPT_SVC_ACTIVE;
4030                 break;
4031         case RS_SVC_NOOP:
4032                 msg.status = 0;
4033                 break;
4034         default:
4035                 break;
4036         }
4037
4038         write_all(svc->sock[1], &msg, sizeof msg);
4039 }
4040
4041 static uint8_t udp_svc_sgid_index(struct ds_dest *dest, union ibv_gid *sgid)
4042 {
4043         union ibv_gid gid;
4044         int i;
4045
4046         for (i = 0; i < 16; i++) {
4047                 ibv_query_gid(dest->qp->cm_id->verbs, dest->qp->cm_id->port_num,
4048                               i, &gid);
4049                 if (!memcmp(sgid, &gid, sizeof gid))
4050                         return i;
4051         }
4052         return 0;
4053 }
4054
4055 static uint8_t udp_svc_path_bits(struct ds_dest *dest)
4056 {
4057         struct ibv_port_attr attr;
4058
4059         if (!ibv_query_port(dest->qp->cm_id->verbs, dest->qp->cm_id->port_num, &attr))
4060                 return (uint8_t) ((1 << attr.lmc) - 1);
4061         return 0x7f;
4062 }
4063
4064 static void udp_svc_create_ah(struct rsocket *rs, struct ds_dest *dest, uint32_t qpn)
4065 {
4066         union socket_addr saddr;
4067         struct rdma_cm_id *id;
4068         struct ibv_ah_attr attr;
4069         int ret;
4070
4071         if (dest->ah) {
4072                 fastlock_acquire(&rs->slock);
4073                 ibv_destroy_ah(dest->ah);
4074                 dest->ah = NULL;
4075                 fastlock_release(&rs->slock);
4076         }
4077
4078         ret = rdma_create_id(NULL, &id, NULL, dest->qp->cm_id->ps);
4079         if  (ret)
4080                 return;
4081
4082         memcpy(&saddr, rdma_get_local_addr(dest->qp->cm_id),
4083                ucma_addrlen(rdma_get_local_addr(dest->qp->cm_id)));
4084         if (saddr.sa.sa_family == AF_INET)
4085                 saddr.sin.sin_port = 0;
4086         else
4087                 saddr.sin6.sin6_port = 0;
4088         ret = rdma_resolve_addr(id, &saddr.sa, &dest->addr.sa, 2000);
4089         if (ret)
4090                 goto out;
4091
4092         ret = rdma_resolve_route(id, 2000);
4093         if (ret)
4094                 goto out;
4095
4096         memset(&attr, 0, sizeof attr);
4097         if (id->route.path_rec->hop_limit > 1) {
4098                 attr.is_global = 1;
4099                 attr.grh.dgid = id->route.path_rec->dgid;
4100                 attr.grh.flow_label = be32toh(id->route.path_rec->flow_label);
4101                 attr.grh.sgid_index = udp_svc_sgid_index(dest, &id->route.path_rec->sgid);
4102                 attr.grh.hop_limit = id->route.path_rec->hop_limit;
4103                 attr.grh.traffic_class = id->route.path_rec->traffic_class;
4104         }
4105         attr.dlid = be16toh(id->route.path_rec->dlid);
4106         attr.sl = id->route.path_rec->sl;
4107         attr.src_path_bits = be16toh(id->route.path_rec->slid) & udp_svc_path_bits(dest);
4108         attr.static_rate = id->route.path_rec->rate;
4109         attr.port_num  = id->port_num;
4110
4111         fastlock_acquire(&rs->slock);
4112         dest->qpn = qpn;
4113         dest->ah = ibv_create_ah(dest->qp->cm_id->pd, &attr);
4114         fastlock_release(&rs->slock);
4115 out:
4116         rdma_destroy_id(id);
4117 }
4118
4119 static int udp_svc_valid_udp_hdr(struct ds_udp_header *udp_hdr,
4120                                  union socket_addr *addr)
4121 {
4122         return (udp_hdr->tag == htobe32(DS_UDP_TAG)) &&
4123                 ((udp_hdr->version == 4 && addr->sa.sa_family == AF_INET &&
4124                   udp_hdr->length == DS_UDP_IPV4_HDR_LEN) ||
4125                  (udp_hdr->version == 6 && addr->sa.sa_family == AF_INET6 &&
4126                   udp_hdr->length == DS_UDP_IPV6_HDR_LEN));
4127 }
4128
4129 static void udp_svc_forward(struct rsocket *rs, void *buf, size_t len,
4130                             union socket_addr *src)
4131 {
4132         struct ds_header hdr;
4133         struct ds_smsg *msg;
4134         struct ibv_sge sge;
4135         uint64_t offset;
4136
4137         if (!ds_can_send(rs)) {
4138                 if (ds_get_comp(rs, 0, ds_can_send))
4139                         return;
4140         }
4141
4142         msg = rs->smsg_free;
4143         rs->smsg_free = msg->next;
4144         rs->sqe_avail--;
4145
4146         ds_format_hdr(&hdr, src);
4147         memcpy((void *) msg, &hdr, hdr.length);
4148         memcpy((void *) msg + hdr.length, buf, len);
4149         sge.addr = (uintptr_t) msg;
4150         sge.length = hdr.length + len;
4151         sge.lkey = rs->conn_dest->qp->smr->lkey;
4152         offset = (uint8_t *) msg - rs->sbuf;
4153
4154         ds_post_send(rs, &sge, offset);
4155 }
4156
4157 static void udp_svc_process_rs(struct rsocket *rs)
4158 {
4159         static uint8_t buf[RS_SNDLOWAT];
4160         struct ds_dest *dest, *cur_dest;
4161         struct ds_udp_header *udp_hdr;
4162         union socket_addr addr;
4163         socklen_t addrlen = sizeof addr;
4164         int len, ret;
4165         uint32_t qpn;
4166
4167         ret = recvfrom(rs->udp_sock, buf, sizeof buf, 0, &addr.sa, &addrlen);
4168         if (ret < DS_UDP_IPV4_HDR_LEN)
4169                 return;
4170
4171         udp_hdr = (struct ds_udp_header *) buf;
4172         if (!udp_svc_valid_udp_hdr(udp_hdr, &addr))
4173                 return;
4174
4175         len = ret - udp_hdr->length;
4176         qpn = be32toh(udp_hdr->qpn) & 0xFFFFFF;
4177
4178         udp_hdr->tag = (__force __be32)be32toh(udp_hdr->tag);
4179         udp_hdr->qpn = (__force __be32)qpn;
4180
4181         ret = ds_get_dest(rs, &addr.sa, addrlen, &dest);
4182         if (ret)
4183                 return;
4184
4185         if (udp_hdr->op == RS_OP_DATA) {
4186                 fastlock_acquire(&rs->slock);
4187                 cur_dest = rs->conn_dest;
4188                 rs->conn_dest = dest;
4189                 ds_send_udp(rs, NULL, 0, 0, RS_OP_CTRL);
4190                 rs->conn_dest = cur_dest;
4191                 fastlock_release(&rs->slock);
4192         }
4193
4194         if (!dest->ah || (dest->qpn != qpn))
4195                 udp_svc_create_ah(rs, dest, qpn);
4196
4197         /* to do: handle when dest local ip address doesn't match udp ip */
4198         if (udp_hdr->op == RS_OP_DATA) {
4199                 fastlock_acquire(&rs->slock);
4200                 cur_dest = rs->conn_dest;
4201                 rs->conn_dest = &dest->qp->dest;
4202                 udp_svc_forward(rs, buf + udp_hdr->length, len, &addr);
4203                 rs->conn_dest = cur_dest;
4204                 fastlock_release(&rs->slock);
4205         }
4206 }
4207
4208 static void *udp_svc_run(void *arg)
4209 {
4210         struct rs_svc *svc = arg;
4211         struct rs_svc_msg msg;
4212         int i, ret;
4213
4214         ret = rs_svc_grow_sets(svc, 4);
4215         if (ret) {
4216                 msg.status = ret;
4217                 write_all(svc->sock[1], &msg, sizeof msg);
4218                 return (void *) (uintptr_t) ret;
4219         }
4220
4221         udp_svc_fds = svc->contexts;
4222         udp_svc_fds[0].fd = svc->sock[1];
4223         udp_svc_fds[0].events = POLLIN;
4224         do {
4225                 for (i = 0; i <= svc->cnt; i++)
4226                         udp_svc_fds[i].revents = 0;
4227
4228                 poll(udp_svc_fds, svc->cnt + 1, -1);
4229                 if (udp_svc_fds[0].revents)
4230                         udp_svc_process_sock(svc);
4231
4232                 for (i = 1; i <= svc->cnt; i++) {
4233                         if (udp_svc_fds[i].revents)
4234                                 udp_svc_process_rs(svc->rss[i]);
4235                 }
4236         } while (svc->cnt >= 1);
4237
4238         return NULL;
4239 }
4240
4241 static uint32_t rs_get_time(void)
4242 {
4243         struct timeval now;
4244
4245         memset(&now, 0, sizeof now);
4246         gettimeofday(&now, NULL);
4247         return (uint32_t) now.tv_sec;
4248 }
4249
4250 static void tcp_svc_process_sock(struct rs_svc *svc)
4251 {
4252         struct rs_svc_msg msg;
4253         int i;
4254
4255         read_all(svc->sock[1], &msg, sizeof msg);
4256         switch (msg.cmd) {
4257         case RS_SVC_ADD_KEEPALIVE:
4258                 msg.status = rs_svc_add_rs(svc, msg.rs);
4259                 if (!msg.status) {
4260                         msg.rs->opts |= RS_OPT_SVC_ACTIVE;
4261                         tcp_svc_timeouts = svc->contexts;
4262                         tcp_svc_timeouts[svc->cnt] = rs_get_time() +
4263                                                      msg.rs->keepalive_time;
4264                 }
4265                 break;
4266         case RS_SVC_REM_KEEPALIVE:
4267                 msg.status = rs_svc_rm_rs(svc, msg.rs);
4268                 if (!msg.status)
4269                         msg.rs->opts &= ~RS_OPT_SVC_ACTIVE;
4270                 break;
4271         case RS_SVC_MOD_KEEPALIVE:
4272                 i = rs_svc_index(svc, msg.rs);
4273                 if (i >= 0) {
4274                         tcp_svc_timeouts[i] = rs_get_time() + msg.rs->keepalive_time;
4275                         msg.status = 0;
4276                 } else {
4277                         msg.status = EBADF;
4278                 }
4279                 break;
4280         case RS_SVC_NOOP:
4281                 msg.status = 0;
4282                 break;
4283         default:
4284                 break;
4285         }
4286         write_all(svc->sock[1], &msg, sizeof msg);
4287 }
4288
4289 /*
4290  * Send a 0 byte RDMA write with immediate as keep-alive message.
4291  * This avoids the need for the receive side to do any acknowledgment.
4292  */
4293 static void tcp_svc_send_keepalive(struct rsocket *rs)
4294 {
4295         fastlock_acquire(&rs->cq_lock);
4296         if (rs_ctrl_avail(rs) && (rs->state & rs_connected)) {
4297                 rs->ctrl_seqno++;
4298                 rs_post_write(rs, NULL, 0, rs_msg_set(RS_OP_CTRL, RS_CTRL_KEEPALIVE),
4299                               0, (uintptr_t) NULL, (uintptr_t) NULL);
4300         }
4301         fastlock_release(&rs->cq_lock);
4302 }       
4303
4304 static void *tcp_svc_run(void *arg)
4305 {
4306         struct rs_svc *svc = arg;
4307         struct rs_svc_msg msg;
4308         struct pollfd fds;
4309         uint32_t now, next_timeout;
4310         int i, ret, timeout;
4311
4312         ret = rs_svc_grow_sets(svc, 16);
4313         if (ret) {
4314                 msg.status = ret;
4315                 write_all(svc->sock[1], &msg, sizeof msg);
4316                 return (void *) (uintptr_t) ret;
4317         }
4318
4319         tcp_svc_timeouts = svc->contexts;
4320         fds.fd = svc->sock[1];
4321         fds.events = POLLIN;
4322         timeout = -1;
4323         do {
4324                 poll(&fds, 1, timeout * 1000);
4325                 if (fds.revents)
4326                         tcp_svc_process_sock(svc);
4327
4328                 now = rs_get_time();
4329                 next_timeout = ~0;
4330                 for (i = 1; i <= svc->cnt; i++) {
4331                         if (tcp_svc_timeouts[i] <= now) {
4332                                 tcp_svc_send_keepalive(svc->rss[i]);
4333                                 tcp_svc_timeouts[i] =
4334                                         now + svc->rss[i]->keepalive_time;
4335                         }
4336                         if (tcp_svc_timeouts[i] < next_timeout)
4337                                 next_timeout = tcp_svc_timeouts[i];
4338                 }
4339                 timeout = (int) (next_timeout - now);
4340         } while (svc->cnt >= 1);
4341
4342         return NULL;
4343 }