2 * Copyright (c) 2004, 2005 Topspin Communications. All rights reserved.
3 * Copyright (c) 2006 Cisco Systems, Inc. All rights reserved.
5 * This software is available to you under a choice of one of two
6 * licenses. You may choose to be licensed under the terms of the GNU
7 * General Public License (GPL) Version 2, available from the file
8 * COPYING in the main directory of this source tree, or the
9 * OpenIB.org BSD license below:
11 * Redistribution and use in source and binary forms, with or
12 * without modification, are permitted provided that the following
15 * - Redistributions of source code must retain the above
16 * copyright notice, this list of conditions and the following
19 * - Redistributions in binary form must reproduce the above
20 * copyright notice, this list of conditions and the following
21 * disclaimer in the documentation and/or other materials
22 * provided with the distribution.
24 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
25 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
26 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
27 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
28 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
29 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
30 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
54 struct ibv_mem_node *parent;
55 struct ibv_mem_node *left, *right;
60 static struct ibv_mem_node *mm_root;
61 static pthread_mutex_t mm_mutex = PTHREAD_MUTEX_INITIALIZER;
63 static int huge_page_enabled;
66 static unsigned long smaps_page_size(FILE *file)
69 unsigned long size = page_size;
72 while (fgets(buf, sizeof(buf), file) != NULL) {
73 if (!strstr(buf, "KernelPageSize:"))
76 n = sscanf(buf, "%*s %lu", &size);
80 /* page size is printed in Kb */
89 static unsigned long get_page_size(void *base)
91 unsigned long ret = page_size;
97 snprintf(buf, sizeof(buf), "/proc/%d/smaps", pid);
99 file = fopen(buf, "r" STREAM_CLOEXEC);
103 while (fgets(buf, sizeof(buf), file) != NULL) {
105 uintptr_t range_start, range_end;
107 n = sscanf(buf, "%" SCNxPTR "-%" SCNxPTR, &range_start, &range_end);
112 if ((uintptr_t) base >= range_start && (uintptr_t) base < range_end) {
113 ret = smaps_page_size(file);
124 int ibv_fork_init(void)
126 void *tmp, *tmp_aligned;
130 if (getenv("RDMAV_HUGEPAGES_SAFE"))
131 huge_page_enabled = 1;
139 page_size = sysconf(_SC_PAGESIZE);
143 if (posix_memalign(&tmp, page_size, page_size))
146 if (huge_page_enabled) {
147 size = get_page_size(tmp);
148 tmp_aligned = (void *) ((uintptr_t) tmp & ~(size - 1));
154 ret = madvise(tmp_aligned, size, MADV_DONTFORK) ||
155 madvise(tmp_aligned, size, MADV_DOFORK);
162 mm_root = malloc(sizeof *mm_root);
166 mm_root->parent = NULL;
167 mm_root->left = NULL;
168 mm_root->right = NULL;
169 mm_root->color = IBV_BLACK;
171 mm_root->end = UINTPTR_MAX;
177 static struct ibv_mem_node *__mm_prev(struct ibv_mem_node *node)
184 while (node->parent && node == node->parent->left)
193 static struct ibv_mem_node *__mm_next(struct ibv_mem_node *node)
200 while (node->parent && node == node->parent->right)
209 static void __mm_rotate_right(struct ibv_mem_node *node)
211 struct ibv_mem_node *tmp;
215 node->left = tmp->right;
217 node->left->parent = node;
220 if (node->parent->right == node)
221 node->parent->right = tmp;
223 node->parent->left = tmp;
227 tmp->parent = node->parent;
233 static void __mm_rotate_left(struct ibv_mem_node *node)
235 struct ibv_mem_node *tmp;
239 node->right = tmp->left;
241 node->right->parent = node;
244 if (node->parent->right == node)
245 node->parent->right = tmp;
247 node->parent->left = tmp;
251 tmp->parent = node->parent;
258 static int verify(struct ibv_mem_node *node)
265 hl = verify(node->left);
266 hr = verify(node->left);
273 if (node->color == IBV_RED) {
274 if (node->left && node->left->color != IBV_BLACK)
276 if (node->right && node->right->color != IBV_BLACK)
285 static void __mm_add_rebalance(struct ibv_mem_node *node)
287 struct ibv_mem_node *parent, *gp, *uncle;
289 while (node->parent && node->parent->color == IBV_RED) {
290 parent = node->parent;
291 gp = node->parent->parent;
293 if (parent == gp->left) {
296 if (uncle && uncle->color == IBV_RED) {
297 parent->color = IBV_BLACK;
298 uncle->color = IBV_BLACK;
303 if (node == parent->right) {
304 __mm_rotate_left(parent);
306 parent = node->parent;
309 parent->color = IBV_BLACK;
312 __mm_rotate_right(gp);
317 if (uncle && uncle->color == IBV_RED) {
318 parent->color = IBV_BLACK;
319 uncle->color = IBV_BLACK;
324 if (node == parent->left) {
325 __mm_rotate_right(parent);
327 parent = node->parent;
330 parent->color = IBV_BLACK;
333 __mm_rotate_left(gp);
338 mm_root->color = IBV_BLACK;
341 static void __mm_add(struct ibv_mem_node *new)
343 struct ibv_mem_node *node, *parent = NULL;
348 if (node->start < new->start)
354 if (parent->start < new->start)
359 new->parent = parent;
363 new->color = IBV_RED;
364 __mm_add_rebalance(new);
367 static void __mm_remove(struct ibv_mem_node *node)
369 struct ibv_mem_node *child, *parent, *sib, *tmp;
372 if (node->left && node->right) {
377 nodecol = tmp->color;
379 tmp->color = node->color;
381 if (tmp->parent != node) {
382 parent = tmp->parent;
383 parent->right = tmp->left;
385 tmp->left->parent = parent;
387 tmp->left = node->left;
388 node->left->parent = tmp;
392 tmp->right = node->right;
393 node->right->parent = tmp;
395 tmp->parent = node->parent;
397 if (node->parent->left == node)
398 node->parent->left = tmp;
400 node->parent->right = tmp;
404 nodecol = node->color;
406 child = node->left ? node->left : node->right;
407 parent = node->parent;
410 child->parent = parent;
412 if (parent->left == node)
413 parent->left = child;
415 parent->right = child;
422 if (nodecol == IBV_RED)
425 while ((!child || child->color == IBV_BLACK) && child != mm_root) {
426 if (parent->left == child) {
429 if (sib->color == IBV_RED) {
430 parent->color = IBV_RED;
431 sib->color = IBV_BLACK;
432 __mm_rotate_left(parent);
436 if ((!sib->left || sib->left->color == IBV_BLACK) &&
437 (!sib->right || sib->right->color == IBV_BLACK)) {
438 sib->color = IBV_RED;
440 parent = child->parent;
442 if (!sib->right || sib->right->color == IBV_BLACK) {
444 sib->left->color = IBV_BLACK;
445 sib->color = IBV_RED;
446 __mm_rotate_right(sib);
450 sib->color = parent->color;
451 parent->color = IBV_BLACK;
453 sib->right->color = IBV_BLACK;
454 __mm_rotate_left(parent);
461 if (sib->color == IBV_RED) {
462 parent->color = IBV_RED;
463 sib->color = IBV_BLACK;
464 __mm_rotate_right(parent);
468 if ((!sib->left || sib->left->color == IBV_BLACK) &&
469 (!sib->right || sib->right->color == IBV_BLACK)) {
470 sib->color = IBV_RED;
472 parent = child->parent;
474 if (!sib->left || sib->left->color == IBV_BLACK) {
476 sib->right->color = IBV_BLACK;
477 sib->color = IBV_RED;
478 __mm_rotate_left(sib);
482 sib->color = parent->color;
483 parent->color = IBV_BLACK;
485 sib->left->color = IBV_BLACK;
486 __mm_rotate_right(parent);
494 child->color = IBV_BLACK;
497 static struct ibv_mem_node *__mm_find_start(uintptr_t start, uintptr_t end)
499 struct ibv_mem_node *node = mm_root;
502 if (node->start <= start && node->end >= start)
505 if (node->start < start)
514 static struct ibv_mem_node *merge_ranges(struct ibv_mem_node *node,
515 struct ibv_mem_node *prev)
517 prev->end = node->end;
518 prev->refcnt = node->refcnt;
524 static struct ibv_mem_node *split_range(struct ibv_mem_node *node,
527 struct ibv_mem_node *new_node = NULL;
529 new_node = malloc(sizeof *new_node);
532 new_node->start = cut_line;
533 new_node->end = node->end;
534 new_node->refcnt = node->refcnt;
535 node->end = cut_line - 1;
541 static struct ibv_mem_node *get_start_node(uintptr_t start, uintptr_t end,
544 struct ibv_mem_node *node, *tmp = NULL;
546 node = __mm_find_start(start, end);
547 if (node->start < start)
548 node = split_range(node, start);
550 tmp = __mm_prev(node);
551 if (tmp && tmp->refcnt == node->refcnt + inc)
552 node = merge_ranges(node, tmp);
558 * This function is called if madvise() fails to undo merging/splitting
559 * operations performed on the node.
561 static struct ibv_mem_node *undo_node(struct ibv_mem_node *node,
562 uintptr_t start, int inc)
564 struct ibv_mem_node *tmp = NULL;
567 * This condition can be true only if we merged this
568 * node with the previous one, so we need to split them.
570 if (start > node->start) {
571 tmp = split_range(node, start);
579 tmp = __mm_prev(node);
580 if (tmp && tmp->refcnt == node->refcnt)
581 node = merge_ranges(node, tmp);
583 tmp = __mm_next(node);
584 if (tmp && tmp->refcnt == node->refcnt)
585 node = merge_ranges(tmp, node);
590 static int ibv_madvise_range(void *base, size_t size, int advice)
592 uintptr_t start, end;
593 struct ibv_mem_node *node, *tmp;
595 int rolling_back = 0;
597 unsigned long range_page_size;
602 if (huge_page_enabled)
603 range_page_size = get_page_size(base);
605 range_page_size = page_size;
607 start = (uintptr_t) base & ~(range_page_size - 1);
608 end = ((uintptr_t) (base + size + range_page_size - 1) &
609 ~(range_page_size - 1)) - 1;
611 pthread_mutex_lock(&mm_mutex);
613 inc = advice == MADV_DONTFORK ? 1 : -1;
615 node = get_start_node(start, end, inc);
621 while (node && node->start <= end) {
622 if (node->end > end) {
623 if (!split_range(node, end + 1)) {
629 if ((inc == -1 && node->refcnt == 1) ||
630 (inc == 1 && node->refcnt == 0)) {
632 * If this is the first time through the loop,
633 * and we merged this node with the previous
634 * one, then we only want to do the madvise()
635 * on start ... node->end (rather than
636 * starting at node->start).
638 * Otherwise we end up doing madvise() on
639 * bigger region than we're being asked to,
640 * and that may lead to a spurious failure.
642 if (start > node->start)
643 ret = madvise((void *) start, node->end - start + 1,
646 ret = madvise((void *) node->start,
647 node->end - node->start + 1,
650 node = undo_node(node, start, inc);
652 if (rolling_back || !node)
655 /* madvise failed, roll back previous changes */
657 advice = advice == MADV_DONTFORK ?
658 MADV_DOFORK : MADV_DONTFORK;
659 tmp = __mm_prev(node);
660 if (!tmp || start > tmp->end)
668 node = __mm_next(node);
672 tmp = __mm_prev(node);
673 if (tmp && node->refcnt == tmp->refcnt)
674 node = merge_ranges(node, tmp);
681 pthread_mutex_unlock(&mm_mutex);
686 int ibv_dontfork_range(void *base, size_t size)
689 return ibv_madvise_range(base, size, MADV_DONTFORK);
696 int ibv_dofork_range(void *base, size_t size)
699 return ibv_madvise_range(base, size, MADV_DOFORK);