Home | History | Annotate | Line # | Download | only in npf
lpm.c revision 1.4.12.1
      1 /*-
      2  * Copyright (c) 2016 Mindaugas Rasiukevicius <rmind at noxt eu>
      3  * All rights reserved.
      4  *
      5  * Redistribution and use in source and binary forms, with or without
      6  * modification, are permitted provided that the following conditions
      7  * are met:
      8  * 1. Redistributions of source code must retain the above copyright
      9  *    notice, this list of conditions and the following disclaimer.
     10  * 2. Redistributions in binary form must reproduce the above copyright
     11  *    notice, this list of conditions and the following disclaimer in the
     12  *    documentation and/or other materials provided with the distribution.
     13  *
     14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
     15  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
     16  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
     17  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
     18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
     19  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
     20  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
     21  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
     22  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
     23  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
     24  * SUCH DAMAGE.
     25  */
     26 
     27 /*
     28  * Longest Prefix Match (LPM) library supporting IPv4 and IPv6.
     29  *
     30  * Algorithm:
     31  *
     32  * Each prefix gets its own hash map and all added prefixes are saved
     33  * in a bitmap.  On a lookup, we perform a linear scan of hash maps,
     34  * iterating through the added prefixes only.  Usually, there are only
     35  * a few unique prefixes used and such simple algorithm is very efficient.
     36  * With many IPv6 prefixes, the linear scan might become a bottleneck.
     37  */
     38 
     39 #if defined(_KERNEL)
     40 #include <sys/cdefs.h>
     41 __KERNEL_RCSID(0, "$NetBSD: lpm.c,v 1.4.12.1 2019/06/10 22:09:46 christos Exp $");
     42 
     43 #include <sys/param.h>
     44 #include <sys/types.h>
     45 #include <sys/malloc.h>
     46 #include <sys/kmem.h>
     47 #else
     48 #include <sys/socket.h>
     49 #include <arpa/inet.h>
     50 
     51 #include <stdio.h>
     52 #include <stdlib.h>
     53 #include <stdbool.h>
     54 #include <stddef.h>
     55 #include <string.h>
     56 #include <strings.h>
     57 #include <errno.h>
     58 #include <assert.h>
     59 #define kmem_alloc(a, b) malloc(a)
     60 #define kmem_free(a, b) free(a)
     61 #define kmem_zalloc(a, b) calloc(a, 1)
     62 #endif
     63 
     64 #include "lpm.h"
     65 
     66 #define	LPM_MAX_PREFIX		(128)
     67 #define	LPM_MAX_WORDS		(LPM_MAX_PREFIX >> 5)
     68 #define	LPM_TO_WORDS(x)		((x) >> 2)
     69 #define	LPM_HASH_STEP		(8)
     70 #define	LPM_LEN_IDX(len)	((len) >> 4)
     71 
     72 #ifdef DEBUG
     73 #define	ASSERT			assert
     74 #else
     75 #define	ASSERT(x)
     76 #endif
     77 
     78 typedef struct lpm_ent {
     79 	struct lpm_ent *next;
     80 	void *		val;
     81 	unsigned	len;
     82 	uint8_t		key[];
     83 } lpm_ent_t;
     84 
     85 typedef struct {
     86 	unsigned	hashsize;
     87 	unsigned	nitems;
     88 	lpm_ent_t **	bucket;
     89 } lpm_hmap_t;
     90 
     91 struct lpm {
     92 	uint32_t	bitmask[LPM_MAX_WORDS];
     93 	void *		defvals[2];
     94 	lpm_hmap_t	prefix[LPM_MAX_PREFIX + 1];
     95 };
     96 
     97 static const uint32_t zero_address[LPM_MAX_WORDS];
     98 
     99 lpm_t *
    100 lpm_create(void)
    101 {
    102 	return kmem_zalloc(sizeof(lpm_t), KM_SLEEP);
    103 }
    104 
    105 void
    106 lpm_clear(lpm_t *lpm, lpm_dtor_t dtor, void *arg)
    107 {
    108 	for (unsigned n = 0; n <= LPM_MAX_PREFIX; n++) {
    109 		lpm_hmap_t *hmap = &lpm->prefix[n];
    110 
    111 		if (!hmap->hashsize) {
    112 			KASSERT(!hmap->bucket);
    113 			continue;
    114 		}
    115 		for (unsigned i = 0; i < hmap->hashsize; i++) {
    116 			lpm_ent_t *entry = hmap->bucket[i];
    117 
    118 			while (entry) {
    119 				lpm_ent_t *next = entry->next;
    120 
    121 				if (dtor) {
    122 					dtor(arg, entry->key,
    123 					    entry->len, entry->val);
    124 				}
    125 				kmem_free(entry,
    126 				    offsetof(lpm_ent_t, key[entry->len]));
    127 				entry = next;
    128 			}
    129 		}
    130 		kmem_free(hmap->bucket, hmap->hashsize * sizeof(lpm_ent_t *));
    131 		hmap->bucket = NULL;
    132 		hmap->hashsize = 0;
    133 		hmap->nitems = 0;
    134 	}
    135 	if (dtor) {
    136 		dtor(arg, zero_address, 4, lpm->defvals[0]);
    137 		dtor(arg, zero_address, 16, lpm->defvals[1]);
    138 	}
    139 	memset(lpm->bitmask, 0, sizeof(lpm->bitmask));
    140 	memset(lpm->defvals, 0, sizeof(lpm->defvals));
    141 }
    142 
    143 void
    144 lpm_destroy(lpm_t *lpm)
    145 {
    146 	lpm_clear(lpm, NULL, NULL);
    147 	kmem_free(lpm, sizeof(*lpm));
    148 }
    149 
    150 /*
    151  * fnv1a_hash: Fowler-Noll-Vo hash function (FNV-1a variant).
    152  */
    153 static uint32_t
    154 fnv1a_hash(const void *buf, size_t len)
    155 {
    156 	uint32_t hash = 2166136261UL;
    157 	const uint8_t *p = buf;
    158 
    159 	while (len--) {
    160 		hash ^= *p++;
    161 		hash *= 16777619U;
    162 	}
    163 	return hash;
    164 }
    165 
    166 static bool
    167 hashmap_rehash(lpm_hmap_t *hmap, unsigned size)
    168 {
    169 	lpm_ent_t **bucket;
    170 	unsigned hashsize;
    171 
    172 	for (hashsize = 1; hashsize < size; hashsize <<= 1) {
    173 		continue;
    174 	}
    175 	bucket = kmem_zalloc(hashsize * sizeof(lpm_ent_t *), KM_SLEEP);
    176 	for (unsigned n = 0; n < hmap->hashsize; n++) {
    177 		lpm_ent_t *list = hmap->bucket[n];
    178 
    179 		while (list) {
    180 			lpm_ent_t *entry = list;
    181 			uint32_t hash = fnv1a_hash(entry->key, entry->len);
    182 			const unsigned i = hash & (hashsize - 1);
    183 
    184 			list = entry->next;
    185 			entry->next = bucket[i];
    186 			bucket[i] = entry;
    187 		}
    188 	}
    189 	if (hmap->bucket)
    190 		kmem_free(hmap->bucket, hmap->hashsize * sizeof(lpm_ent_t *));
    191 	hmap->bucket = bucket;
    192 	hmap->hashsize = hashsize;
    193 	return true;
    194 }
    195 
    196 static lpm_ent_t *
    197 hashmap_insert(lpm_hmap_t *hmap, const void *key, size_t len)
    198 {
    199 	const unsigned target = hmap->nitems + LPM_HASH_STEP;
    200 	const size_t entlen = offsetof(lpm_ent_t, key[len]);
    201 	uint32_t hash, i;
    202 	lpm_ent_t *entry;
    203 
    204 	if (hmap->hashsize < target && !hashmap_rehash(hmap, target)) {
    205 		return NULL;
    206 	}
    207 
    208 	hash = fnv1a_hash(key, len);
    209 	i = hash & (hmap->hashsize - 1);
    210 	entry = hmap->bucket[i];
    211 	while (entry) {
    212 		if (entry->len == len && memcmp(entry->key, key, len) == 0) {
    213 			return entry;
    214 		}
    215 		entry = entry->next;
    216 	}
    217 
    218 	if ((entry = kmem_alloc(entlen, KM_SLEEP)) != NULL) {
    219 		memcpy(entry->key, key, len);
    220 		entry->next = hmap->bucket[i];
    221 		entry->len = len;
    222 
    223 		hmap->bucket[i] = entry;
    224 		hmap->nitems++;
    225 	}
    226 	return entry;
    227 }
    228 
    229 static lpm_ent_t *
    230 hashmap_lookup(lpm_hmap_t *hmap, const void *key, size_t len)
    231 {
    232 	const uint32_t hash = fnv1a_hash(key, len);
    233 	const unsigned i = hash & (hmap->hashsize - 1);
    234 	lpm_ent_t *entry;
    235 
    236 	if (hmap->hashsize == 0) {
    237 		return NULL;
    238 	}
    239 	entry = hmap->bucket[i];
    240 
    241 	while (entry) {
    242 		if (entry->len == len && memcmp(entry->key, key, len) == 0) {
    243 			return entry;
    244 		}
    245 		entry = entry->next;
    246 	}
    247 	return NULL;
    248 }
    249 
    250 static int
    251 hashmap_remove(lpm_hmap_t *hmap, const void *key, size_t len)
    252 {
    253 	const uint32_t hash = fnv1a_hash(key, len);
    254 	const unsigned i = hash & (hmap->hashsize - 1);
    255 	lpm_ent_t *prev = NULL, *entry;
    256 
    257 	if (hmap->hashsize == 0) {
    258 		return -1;
    259 	}
    260 	entry = hmap->bucket[i];
    261 
    262 	while (entry) {
    263 		if (entry->len == len && memcmp(entry->key, key, len) == 0) {
    264 			if (prev) {
    265 				prev->next = entry->next;
    266 			} else {
    267 				hmap->bucket[i] = entry->next;
    268 			}
    269 			kmem_free(entry, offsetof(lpm_ent_t, key[len]));
    270 			return 0;
    271 		}
    272 		prev = entry;
    273 		entry = entry->next;
    274 	}
    275 	return -1;
    276 }
    277 
    278 /*
    279  * compute_prefix: given the address and prefix length, compute and
    280  * return the address prefix.
    281  */
    282 static inline void
    283 compute_prefix(const unsigned nwords, const uint32_t *addr,
    284     unsigned preflen, uint32_t *prefix)
    285 {
    286 	uint32_t addr2[4];
    287 
    288 	if ((uintptr_t)addr & 3) {
    289 		/* Unaligned address: just copy for now. */
    290 		memcpy(addr2, addr, nwords * 4);
    291 		addr = addr2;
    292 	}
    293 	for (unsigned i = 0; i < nwords; i++) {
    294 		if (preflen == 0) {
    295 			prefix[i] = 0;
    296 			continue;
    297 		}
    298 		if (preflen < 32) {
    299 			uint32_t mask = htonl(0xffffffff << (32 - preflen));
    300 			prefix[i] = addr[i] & mask;
    301 			preflen = 0;
    302 		} else {
    303 			prefix[i] = addr[i];
    304 			preflen -= 32;
    305 		}
    306 	}
    307 }
    308 
    309 /*
    310  * lpm_insert: insert the CIDR into the LPM table.
    311  *
    312  * => Returns zero on success and -1 on failure.
    313  */
    314 int
    315 lpm_insert(lpm_t *lpm, const void *addr,
    316     size_t len, unsigned preflen, void *val)
    317 {
    318 	const unsigned nwords = LPM_TO_WORDS(len);
    319 	uint32_t prefix[LPM_MAX_WORDS];
    320 	lpm_ent_t *entry;
    321 	KASSERT(len == 4 || len == 16);
    322 
    323 	if (preflen == 0) {
    324 		/* 0-length prefix is a special case. */
    325 		lpm->defvals[LPM_LEN_IDX(len)] = val;
    326 		return 0;
    327 	}
    328 	compute_prefix(nwords, addr, preflen, prefix);
    329 	entry = hashmap_insert(&lpm->prefix[preflen], prefix, len);
    330 	if (entry) {
    331 		const unsigned n = --preflen >> 5;
    332 		lpm->bitmask[n] |= 0x80000000U >> (preflen & 31);
    333 		entry->val = val;
    334 		return 0;
    335 	}
    336 	return -1;
    337 }
    338 
    339 /*
    340  * lpm_remove: remove the specified prefix.
    341  */
    342 int
    343 lpm_remove(lpm_t *lpm, const void *addr, size_t len, unsigned preflen)
    344 {
    345 	const unsigned nwords = LPM_TO_WORDS(len);
    346 	uint32_t prefix[LPM_MAX_WORDS];
    347 	KASSERT(len == 4 || len == 16);
    348 
    349 	if (preflen == 0) {
    350 		lpm->defvals[LPM_LEN_IDX(len)] = NULL;
    351 		return 0;
    352 	}
    353 	compute_prefix(nwords, addr, preflen, prefix);
    354 	return hashmap_remove(&lpm->prefix[preflen], prefix, len);
    355 }
    356 
    357 /*
    358  * lpm_lookup: find the longest matching prefix given the IP address.
    359  *
    360  * => Returns the associated value on success or NULL on failure.
    361  */
    362 void *
    363 lpm_lookup(lpm_t *lpm, const void *addr, size_t len)
    364 {
    365 	const unsigned nwords = LPM_TO_WORDS(len);
    366 	unsigned i, n = nwords;
    367 	uint32_t prefix[LPM_MAX_WORDS];
    368 
    369 	while (n--) {
    370 		uint32_t bitmask = lpm->bitmask[n];
    371 
    372 		while ((i = ffs(bitmask)) != 0) {
    373 			const unsigned preflen = (32 * n) + (32 - --i);
    374 			lpm_hmap_t *hmap = &lpm->prefix[preflen];
    375 			lpm_ent_t *entry;
    376 
    377 			compute_prefix(nwords, addr, preflen, prefix);
    378 			entry = hashmap_lookup(hmap, prefix, len);
    379 			if (entry) {
    380 				return entry->val;
    381 			}
    382 			bitmask &= ~(1U << i);
    383 		}
    384 	}
    385 	return lpm->defvals[LPM_LEN_IDX(len)];
    386 }
    387 
    388 /*
    389  * lpm_lookup_prefix: return the value associated with a prefix
    390  *
    391  * => Returns the associated value on success or NULL on failure.
    392  */
    393 void *
    394 lpm_lookup_prefix(lpm_t *lpm, const void *addr, size_t len, unsigned preflen)
    395 {
    396 	const unsigned nwords = LPM_TO_WORDS(len);
    397 	uint32_t prefix[LPM_MAX_WORDS];
    398 	lpm_ent_t *entry;
    399 	KASSERT(len == 4 || len == 16);
    400 
    401 	if (preflen == 0) {
    402 		return lpm->defvals[LPM_LEN_IDX(len)];
    403 	}
    404 	compute_prefix(nwords, addr, preflen, prefix);
    405 	entry = hashmap_lookup(&lpm->prefix[preflen], prefix, len);
    406 	if (entry) {
    407 		return entry->val;
    408 	}
    409 	return NULL;
    410 }
    411 
    412 #if !defined(_KERNEL)
    413 /*
    414  * lpm_strtobin: convert CIDR string to the binary IP address and mask.
    415  *
    416  * => The address will be in the network byte order.
    417  * => Returns 0 on success or -1 on failure.
    418  */
    419 int
    420 lpm_strtobin(const char *cidr, void *addr, size_t *len, unsigned *preflen)
    421 {
    422 	char *p, buf[INET6_ADDRSTRLEN];
    423 
    424 	strncpy(buf, cidr, sizeof(buf));
    425 	buf[sizeof(buf) - 1] = '\0';
    426 
    427 	if ((p = strchr(buf, '/')) != NULL) {
    428 		const ptrdiff_t off = p - buf;
    429 		*preflen = atoi(&buf[off + 1]);
    430 		buf[off] = '\0';
    431 	} else {
    432 		*preflen = LPM_MAX_PREFIX;
    433 	}
    434 
    435 	if (inet_pton(AF_INET6, buf, addr) == 1) {
    436 		*len = 16;
    437 		return 0;
    438 	}
    439 	if (inet_pton(AF_INET, buf, addr) == 1) {
    440 		if (*preflen == LPM_MAX_PREFIX) {
    441 			*preflen = 32;
    442 		}
    443 		*len = 4;
    444 		return 0;
    445 	}
    446 	return -1;
    447 }
    448 #endif
    449