Home | History | Annotate | Line # | Download | only in npf
      1 /*-
      2  * Copyright (c) 2016 Mindaugas Rasiukevicius <rmind at noxt eu>
      3  * All rights reserved.
      4  *
      5  * Redistribution and use in source and binary forms, with or without
      6  * modification, are permitted provided that the following conditions
      7  * are met:
      8  * 1. Redistributions of source code must retain the above copyright
      9  *    notice, this list of conditions and the following disclaimer.
     10  * 2. Redistributions in binary form must reproduce the above copyright
     11  *    notice, this list of conditions and the following disclaimer in the
     12  *    documentation and/or other materials provided with the distribution.
     13  *
     14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
     15  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
     16  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
     17  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
     18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
     19  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
     20  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
     21  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
     22  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
     23  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
     24  * SUCH DAMAGE.
     25  */
     26 
     27 /*
     28  * Longest Prefix Match (LPM) library supporting IPv4 and IPv6.
     29  *
     30  * Algorithm:
     31  *
     32  * Each prefix gets its own hash map and all added prefixes are saved
     33  * in a bitmap.  On a lookup, we perform a linear scan of hash maps,
     34  * iterating through the added prefixes only.  Usually, there are only
     35  * a few unique prefixes used and such simple algorithm is very efficient.
     36  * With many IPv6 prefixes, the linear scan might become a bottleneck.
     37  */
     38 
     39 #if defined(_KERNEL)
     40 #include <sys/cdefs.h>
     41 __KERNEL_RCSID(0, "$NetBSD: lpm.c,v 1.6 2019/06/12 14:36:32 christos Exp $");
     42 
     43 #include <sys/param.h>
     44 #include <sys/types.h>
     45 #include <sys/malloc.h>
     46 #include <sys/kmem.h>
     47 #else
     48 #include <sys/socket.h>
     49 #include <arpa/inet.h>
     50 
     51 #include <stdio.h>
     52 #include <stdlib.h>
     53 #include <stdbool.h>
     54 #include <stddef.h>
     55 #include <string.h>
     56 #include <strings.h>
     57 #include <errno.h>
     58 #include <assert.h>
     59 #define kmem_alloc(a, b) malloc(a)
     60 #define kmem_free(a, b) free(a)
     61 #define kmem_zalloc(a, b) calloc(a, 1)
     62 #endif
     63 
     64 #include "lpm.h"
     65 
     66 #define	LPM_MAX_PREFIX		(128)
     67 #define	LPM_MAX_WORDS		(LPM_MAX_PREFIX >> 5)
     68 #define	LPM_TO_WORDS(x)		((x) >> 2)
     69 #define	LPM_HASH_STEP		(8)
     70 #define	LPM_LEN_IDX(len)	((len) >> 4)
     71 
     72 #ifdef DEBUG
     73 #define	ASSERT			assert
     74 #else
     75 #define	ASSERT(x)
     76 #endif
     77 
     78 typedef struct lpm_ent {
     79 	struct lpm_ent *next;
     80 	void *		val;
     81 	unsigned	len;
     82 	uint8_t		key[];
     83 } lpm_ent_t;
     84 
     85 typedef struct {
     86 	unsigned	hashsize;
     87 	unsigned	nitems;
     88 	lpm_ent_t **	bucket;
     89 } lpm_hmap_t;
     90 
     91 struct lpm {
     92 	uint32_t	bitmask[LPM_MAX_WORDS];
     93 	int		flags;
     94 	void *		defvals[2];
     95 	lpm_hmap_t	prefix[LPM_MAX_PREFIX + 1];
     96 };
     97 
     98 static const uint32_t zero_address[LPM_MAX_WORDS];
     99 
    100 lpm_t *
    101 lpm_create(int flags)
    102 {
    103 	lpm_t *lpm = kmem_zalloc(sizeof(*lpm), KM_SLEEP);
    104 	lpm->flags = flags;
    105 	return lpm;
    106 }
    107 
    108 void
    109 lpm_clear(lpm_t *lpm, lpm_dtor_t dtor, void *arg)
    110 {
    111 	for (unsigned n = 0; n <= LPM_MAX_PREFIX; n++) {
    112 		lpm_hmap_t *hmap = &lpm->prefix[n];
    113 
    114 		if (!hmap->hashsize) {
    115 			KASSERT(!hmap->bucket);
    116 			continue;
    117 		}
    118 		for (unsigned i = 0; i < hmap->hashsize; i++) {
    119 			lpm_ent_t *entry = hmap->bucket[i];
    120 
    121 			while (entry) {
    122 				lpm_ent_t *next = entry->next;
    123 
    124 				if (dtor) {
    125 					dtor(arg, entry->key,
    126 					    entry->len, entry->val);
    127 				}
    128 				kmem_free(entry,
    129 				    offsetof(lpm_ent_t, key[entry->len]));
    130 				entry = next;
    131 			}
    132 		}
    133 		kmem_free(hmap->bucket, hmap->hashsize * sizeof(lpm_ent_t *));
    134 		hmap->bucket = NULL;
    135 		hmap->hashsize = 0;
    136 		hmap->nitems = 0;
    137 	}
    138 	if (dtor) {
    139 		dtor(arg, zero_address, 4, lpm->defvals[0]);
    140 		dtor(arg, zero_address, 16, lpm->defvals[1]);
    141 	}
    142 	memset(lpm->bitmask, 0, sizeof(lpm->bitmask));
    143 	memset(lpm->defvals, 0, sizeof(lpm->defvals));
    144 }
    145 
    146 void
    147 lpm_destroy(lpm_t *lpm)
    148 {
    149 	lpm_clear(lpm, NULL, NULL);
    150 	kmem_free(lpm, sizeof(*lpm));
    151 }
    152 
    153 /*
    154  * fnv1a_hash: Fowler-Noll-Vo hash function (FNV-1a variant).
    155  */
    156 static uint32_t
    157 fnv1a_hash(const void *buf, size_t len)
    158 {
    159 	uint32_t hash = 2166136261UL;
    160 	const uint8_t *p = buf;
    161 
    162 	while (len--) {
    163 		hash ^= *p++;
    164 		hash *= 16777619U;
    165 	}
    166 	return hash;
    167 }
    168 
    169 static bool
    170 hashmap_rehash(lpm_hmap_t *hmap, unsigned size, int flags)
    171 {
    172 	lpm_ent_t **bucket;
    173 	unsigned hashsize;
    174 
    175 	for (hashsize = 1; hashsize < size; hashsize <<= 1) {
    176 		continue;
    177 	}
    178 	bucket = kmem_zalloc(hashsize * sizeof(lpm_ent_t *), flags);
    179 	if (bucket == NULL)
    180 		return false;
    181 	for (unsigned n = 0; n < hmap->hashsize; n++) {
    182 		lpm_ent_t *list = hmap->bucket[n];
    183 
    184 		while (list) {
    185 			lpm_ent_t *entry = list;
    186 			uint32_t hash = fnv1a_hash(entry->key, entry->len);
    187 			const unsigned i = hash & (hashsize - 1);
    188 
    189 			list = entry->next;
    190 			entry->next = bucket[i];
    191 			bucket[i] = entry;
    192 		}
    193 	}
    194 	if (hmap->bucket)
    195 		kmem_free(hmap->bucket, hmap->hashsize * sizeof(lpm_ent_t *));
    196 	hmap->bucket = bucket;
    197 	hmap->hashsize = hashsize;
    198 	return true;
    199 }
    200 
    201 static lpm_ent_t *
    202 hashmap_insert(lpm_hmap_t *hmap, const void *key, size_t len, int flags)
    203 {
    204 	const unsigned target = hmap->nitems + LPM_HASH_STEP;
    205 	const size_t entlen = offsetof(lpm_ent_t, key[len]);
    206 	uint32_t hash, i;
    207 	lpm_ent_t *entry;
    208 
    209 	if (hmap->hashsize < target && !hashmap_rehash(hmap, target, flags)) {
    210 		return NULL;
    211 	}
    212 
    213 	hash = fnv1a_hash(key, len);
    214 	i = hash & (hmap->hashsize - 1);
    215 	entry = hmap->bucket[i];
    216 	while (entry) {
    217 		if (entry->len == len && memcmp(entry->key, key, len) == 0) {
    218 			return entry;
    219 		}
    220 		entry = entry->next;
    221 	}
    222 
    223 	if ((entry = kmem_alloc(entlen, flags)) != NULL) {
    224 		memcpy(entry->key, key, len);
    225 		entry->next = hmap->bucket[i];
    226 		entry->len = len;
    227 
    228 		hmap->bucket[i] = entry;
    229 		hmap->nitems++;
    230 	}
    231 	return entry;
    232 }
    233 
    234 static lpm_ent_t *
    235 hashmap_lookup(lpm_hmap_t *hmap, const void *key, size_t len)
    236 {
    237 	const uint32_t hash = fnv1a_hash(key, len);
    238 	const unsigned i = hash & (hmap->hashsize - 1);
    239 	lpm_ent_t *entry;
    240 
    241 	if (hmap->hashsize == 0) {
    242 		return NULL;
    243 	}
    244 	entry = hmap->bucket[i];
    245 
    246 	while (entry) {
    247 		if (entry->len == len && memcmp(entry->key, key, len) == 0) {
    248 			return entry;
    249 		}
    250 		entry = entry->next;
    251 	}
    252 	return NULL;
    253 }
    254 
    255 static int
    256 hashmap_remove(lpm_hmap_t *hmap, const void *key, size_t len)
    257 {
    258 	const uint32_t hash = fnv1a_hash(key, len);
    259 	const unsigned i = hash & (hmap->hashsize - 1);
    260 	lpm_ent_t *prev = NULL, *entry;
    261 
    262 	if (hmap->hashsize == 0) {
    263 		return -1;
    264 	}
    265 	entry = hmap->bucket[i];
    266 
    267 	while (entry) {
    268 		if (entry->len == len && memcmp(entry->key, key, len) == 0) {
    269 			if (prev) {
    270 				prev->next = entry->next;
    271 			} else {
    272 				hmap->bucket[i] = entry->next;
    273 			}
    274 			kmem_free(entry, offsetof(lpm_ent_t, key[len]));
    275 			return 0;
    276 		}
    277 		prev = entry;
    278 		entry = entry->next;
    279 	}
    280 	return -1;
    281 }
    282 
    283 /*
    284  * compute_prefix: given the address and prefix length, compute and
    285  * return the address prefix.
    286  */
    287 static inline void
    288 compute_prefix(const unsigned nwords, const uint32_t *addr,
    289     unsigned preflen, uint32_t *prefix)
    290 {
    291 	uint32_t addr2[4];
    292 
    293 	if ((uintptr_t)addr & 3) {
    294 		/* Unaligned address: just copy for now. */
    295 		memcpy(addr2, addr, nwords * 4);
    296 		addr = addr2;
    297 	}
    298 	for (unsigned i = 0; i < nwords; i++) {
    299 		if (preflen == 0) {
    300 			prefix[i] = 0;
    301 			continue;
    302 		}
    303 		if (preflen < 32) {
    304 			uint32_t mask = htonl(0xffffffff << (32 - preflen));
    305 			prefix[i] = addr[i] & mask;
    306 			preflen = 0;
    307 		} else {
    308 			prefix[i] = addr[i];
    309 			preflen -= 32;
    310 		}
    311 	}
    312 }
    313 
    314 /*
    315  * lpm_insert: insert the CIDR into the LPM table.
    316  *
    317  * => Returns zero on success and -1 on failure.
    318  */
    319 int
    320 lpm_insert(lpm_t *lpm, const void *addr,
    321     size_t len, unsigned preflen, void *val)
    322 {
    323 	const unsigned nwords = LPM_TO_WORDS(len);
    324 	uint32_t prefix[LPM_MAX_WORDS];
    325 	lpm_ent_t *entry;
    326 	KASSERT(len == 4 || len == 16);
    327 
    328 	if (preflen == 0) {
    329 		/* 0-length prefix is a special case. */
    330 		lpm->defvals[LPM_LEN_IDX(len)] = val;
    331 		return 0;
    332 	}
    333 	compute_prefix(nwords, addr, preflen, prefix);
    334 	entry = hashmap_insert(&lpm->prefix[preflen], prefix, len, lpm->flags);
    335 	if (entry) {
    336 		const unsigned n = --preflen >> 5;
    337 		lpm->bitmask[n] |= 0x80000000U >> (preflen & 31);
    338 		entry->val = val;
    339 		return 0;
    340 	}
    341 	return -1;
    342 }
    343 
    344 /*
    345  * lpm_remove: remove the specified prefix.
    346  */
    347 int
    348 lpm_remove(lpm_t *lpm, const void *addr, size_t len, unsigned preflen)
    349 {
    350 	const unsigned nwords = LPM_TO_WORDS(len);
    351 	uint32_t prefix[LPM_MAX_WORDS];
    352 	KASSERT(len == 4 || len == 16);
    353 
    354 	if (preflen == 0) {
    355 		lpm->defvals[LPM_LEN_IDX(len)] = NULL;
    356 		return 0;
    357 	}
    358 	compute_prefix(nwords, addr, preflen, prefix);
    359 	return hashmap_remove(&lpm->prefix[preflen], prefix, len);
    360 }
    361 
    362 /*
    363  * lpm_lookup: find the longest matching prefix given the IP address.
    364  *
    365  * => Returns the associated value on success or NULL on failure.
    366  */
    367 void *
    368 lpm_lookup(lpm_t *lpm, const void *addr, size_t len)
    369 {
    370 	const unsigned nwords = LPM_TO_WORDS(len);
    371 	unsigned i, n = nwords;
    372 	uint32_t prefix[LPM_MAX_WORDS];
    373 
    374 	while (n--) {
    375 		uint32_t bitmask = lpm->bitmask[n];
    376 
    377 		while ((i = ffs(bitmask)) != 0) {
    378 			const unsigned preflen = (32 * n) + (32 - --i);
    379 			lpm_hmap_t *hmap = &lpm->prefix[preflen];
    380 			lpm_ent_t *entry;
    381 
    382 			compute_prefix(nwords, addr, preflen, prefix);
    383 			entry = hashmap_lookup(hmap, prefix, len);
    384 			if (entry) {
    385 				return entry->val;
    386 			}
    387 			bitmask &= ~(1U << i);
    388 		}
    389 	}
    390 	return lpm->defvals[LPM_LEN_IDX(len)];
    391 }
    392 
    393 /*
    394  * lpm_lookup_prefix: return the value associated with a prefix
    395  *
    396  * => Returns the associated value on success or NULL on failure.
    397  */
    398 void *
    399 lpm_lookup_prefix(lpm_t *lpm, const void *addr, size_t len, unsigned preflen)
    400 {
    401 	const unsigned nwords = LPM_TO_WORDS(len);
    402 	uint32_t prefix[LPM_MAX_WORDS];
    403 	lpm_ent_t *entry;
    404 	KASSERT(len == 4 || len == 16);
    405 
    406 	if (preflen == 0) {
    407 		return lpm->defvals[LPM_LEN_IDX(len)];
    408 	}
    409 	compute_prefix(nwords, addr, preflen, prefix);
    410 	entry = hashmap_lookup(&lpm->prefix[preflen], prefix, len);
    411 	if (entry) {
    412 		return entry->val;
    413 	}
    414 	return NULL;
    415 }
    416 
    417 #if !defined(_KERNEL)
    418 /*
    419  * lpm_strtobin: convert CIDR string to the binary IP address and mask.
    420  *
    421  * => The address will be in the network byte order.
    422  * => Returns 0 on success or -1 on failure.
    423  */
    424 int
    425 lpm_strtobin(const char *cidr, void *addr, size_t *len, unsigned *preflen)
    426 {
    427 	char *p, buf[INET6_ADDRSTRLEN];
    428 
    429 	strncpy(buf, cidr, sizeof(buf));
    430 	buf[sizeof(buf) - 1] = '\0';
    431 
    432 	if ((p = strchr(buf, '/')) != NULL) {
    433 		const ptrdiff_t off = p - buf;
    434 		*preflen = atoi(&buf[off + 1]);
    435 		buf[off] = '\0';
    436 	} else {
    437 		*preflen = LPM_MAX_PREFIX;
    438 	}
    439 
    440 	if (inet_pton(AF_INET6, buf, addr) == 1) {
    441 		*len = 16;
    442 		return 0;
    443 	}
    444 	if (inet_pton(AF_INET, buf, addr) == 1) {
    445 		if (*preflen == LPM_MAX_PREFIX) {
    446 			*preflen = 32;
    447 		}
    448 		*len = 4;
    449 		return 0;
    450 	}
    451 	return -1;
    452 }
    453 #endif
    454