Home | History | Annotate | Line # | Download | only in npf
lpm.c revision 1.4
      1 /*-
      2  * Copyright (c) 2016 Mindaugas Rasiukevicius <rmind at noxt eu>
      3  * All rights reserved.
      4  *
      5  * Redistribution and use in source and binary forms, with or without
      6  * modification, are permitted provided that the following conditions
      7  * are met:
      8  * 1. Redistributions of source code must retain the above copyright
      9  *    notice, this list of conditions and the following disclaimer.
     10  * 2. Redistributions in binary form must reproduce the above copyright
     11  *    notice, this list of conditions and the following disclaimer in the
     12  *    documentation and/or other materials provided with the distribution.
     13  *
     14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
     15  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
     16  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
     17  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
     18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
     19  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
     20  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
     21  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
     22  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
     23  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
     24  * SUCH DAMAGE.
     25  */
     26 
     27 /*
     28  * TODO: Simple linear scan for now (works just well with a few prefixes).
     29  * TBD on a better algorithm.
     30  */
     31 
     32 #if defined(_KERNEL)
     33 #include <sys/cdefs.h>
     34 __KERNEL_RCSID(0, "$NetBSD: lpm.c,v 1.4 2017/06/01 02:45:14 chs Exp $");
     35 
     36 #include <sys/param.h>
     37 #include <sys/types.h>
     38 #include <sys/malloc.h>
     39 #include <sys/kmem.h>
     40 #else
     41 #include <sys/socket.h>
     42 #include <arpa/inet.h>
     43 
     44 #include <stdio.h>
     45 #include <stdlib.h>
     46 #include <stdbool.h>
     47 #include <stddef.h>
     48 #include <string.h>
     49 #include <strings.h>
     50 #include <errno.h>
     51 #include <assert.h>
     52 #define kmem_alloc(a, b) malloc(a)
     53 #define kmem_free(a, b) free(a)
     54 #define kmem_zalloc(a, b) calloc(a, 1)
     55 #endif
     56 
     57 #include "lpm.h"
     58 
     59 #define	LPM_MAX_PREFIX		(128)
     60 #define	LPM_MAX_WORDS		(LPM_MAX_PREFIX >> 5)
     61 #define	LPM_TO_WORDS(x)		((x) >> 2)
     62 #define	LPM_HASH_STEP		(8)
     63 
     64 #ifdef DEBUG
     65 #define	ASSERT	assert
     66 #else
     67 #define	ASSERT
     68 #endif
     69 
     70 typedef struct lpm_ent {
     71 	struct lpm_ent *next;
     72 	void *		val;
     73 	unsigned	len;
     74 	uint8_t		key[];
     75 } lpm_ent_t;
     76 
     77 typedef struct {
     78 	uint32_t	hashsize;
     79 	uint32_t	nitems;
     80 	lpm_ent_t **bucket;
     81 } lpm_hmap_t;
     82 
     83 struct lpm {
     84 	uint32_t	bitmask[LPM_MAX_WORDS];
     85 	void *		defval;
     86 	lpm_hmap_t	prefix[LPM_MAX_PREFIX + 1];
     87 };
     88 
     89 lpm_t *
     90 lpm_create(void)
     91 {
     92 	return kmem_zalloc(sizeof(lpm_t), KM_SLEEP);
     93 }
     94 
     95 void
     96 lpm_clear(lpm_t *lpm, lpm_dtor_t dtor, void *arg)
     97 {
     98 	for (unsigned n = 0; n <= LPM_MAX_PREFIX; n++) {
     99 		lpm_hmap_t *hmap = &lpm->prefix[n];
    100 
    101 		if (!hmap->hashsize) {
    102 			KASSERT(!hmap->bucket);
    103 			continue;
    104 		}
    105 		for (unsigned i = 0; i < hmap->hashsize; i++) {
    106 			lpm_ent_t *entry = hmap->bucket[i];
    107 
    108 			while (entry) {
    109 				lpm_ent_t *next = entry->next;
    110 
    111 				if (dtor) {
    112 					dtor(arg, entry->key,
    113 					    entry->len, entry->val);
    114 				}
    115 				kmem_free(entry,
    116 				    offsetof(lpm_ent_t, key[entry->len]));
    117 				entry = next;
    118 			}
    119 		}
    120 		kmem_free(hmap->bucket, hmap->hashsize * sizeof(lpm_ent_t *));
    121 		hmap->bucket = NULL;
    122 		hmap->hashsize = 0;
    123 		hmap->nitems = 0;
    124 	}
    125 	memset(lpm->bitmask, 0, sizeof(lpm->bitmask));
    126 	lpm->defval = NULL;
    127 }
    128 
    129 void
    130 lpm_destroy(lpm_t *lpm)
    131 {
    132 	lpm_clear(lpm, NULL, NULL);
    133 	kmem_free(lpm, sizeof(*lpm));
    134 }
    135 
    136 /*
    137  * fnv1a_hash: Fowler-Noll-Vo hash function (FNV-1a variant).
    138  */
    139 static uint32_t
    140 fnv1a_hash(const void *buf, size_t len)
    141 {
    142 	uint32_t hash = 2166136261UL;
    143 	const uint8_t *p = buf;
    144 
    145 	while (len--) {
    146 		hash ^= *p++;
    147 		hash *= 16777619U;
    148 	}
    149 	return hash;
    150 }
    151 
    152 static bool
    153 hashmap_rehash(lpm_hmap_t *hmap, uint32_t size)
    154 {
    155 	lpm_ent_t **bucket;
    156 	uint32_t hashsize;
    157 
    158 	for (hashsize = 1; hashsize < size; hashsize <<= 1) {
    159 		continue;
    160 	}
    161 	bucket = kmem_zalloc(hashsize * sizeof(lpm_ent_t *), KM_SLEEP);
    162 	for (unsigned n = 0; n < hmap->hashsize; n++) {
    163 		lpm_ent_t *list = hmap->bucket[n];
    164 
    165 		while (list) {
    166 			lpm_ent_t *entry = list;
    167 			uint32_t hash = fnv1a_hash(entry->key, entry->len);
    168 			const size_t i = hash & (hashsize - 1);
    169 
    170 			list = entry->next;
    171 			entry->next = bucket[i];
    172 			bucket[i] = entry;
    173 		}
    174 	}
    175 	if (hmap->bucket)
    176 		kmem_free(hmap->bucket, hmap->hashsize * sizeof(lpm_ent_t *));
    177 	hmap->bucket = bucket;
    178 	hmap->hashsize = hashsize;
    179 	return true;
    180 }
    181 
    182 static lpm_ent_t *
    183 hashmap_insert(lpm_hmap_t *hmap, const void *key, size_t len)
    184 {
    185 	const uint32_t target = hmap->nitems + LPM_HASH_STEP;
    186 	const size_t entlen = offsetof(lpm_ent_t, key[len]);
    187 	uint32_t hash, i;
    188 	lpm_ent_t *entry;
    189 
    190 	if (hmap->hashsize < target && !hashmap_rehash(hmap, target)) {
    191 		return NULL;
    192 	}
    193 
    194 	hash = fnv1a_hash(key, len);
    195 	i = hash & (hmap->hashsize - 1);
    196 	entry = hmap->bucket[i];
    197 	while (entry) {
    198 		if (entry->len == len && memcmp(entry->key, key, len) == 0) {
    199 			return entry;
    200 		}
    201 		entry = entry->next;
    202 	}
    203 
    204 	entry = kmem_alloc(entlen, KM_SLEEP);
    205 	memcpy(entry->key, key, len);
    206 	entry->next = hmap->bucket[i];
    207 	entry->len = len;
    208 
    209 	hmap->bucket[i] = entry;
    210 	hmap->nitems++;
    211 	return entry;
    212 }
    213 
    214 static lpm_ent_t *
    215 hashmap_lookup(lpm_hmap_t *hmap, const void *key, size_t len)
    216 {
    217 	const uint32_t hash = fnv1a_hash(key, len);
    218 	const uint32_t i = hash & (hmap->hashsize - 1);
    219 	lpm_ent_t *entry = hmap->bucket[i];
    220 
    221 	while (entry) {
    222 		if (entry->len == len && memcmp(entry->key, key, len) == 0) {
    223 			return entry;
    224 		}
    225 		entry = entry->next;
    226 	}
    227 	return NULL;
    228 }
    229 
    230 static int
    231 hashmap_remove(lpm_hmap_t *hmap, const void *key, size_t len)
    232 {
    233 	const uint32_t hash = fnv1a_hash(key, len);
    234 	const uint32_t i = hash & (hmap->hashsize - 1);
    235 	lpm_ent_t *prev = NULL, *entry = hmap->bucket[i];
    236 
    237 	while (entry) {
    238 		if (entry->len == len && memcmp(entry->key, key, len) == 0) {
    239 			if (prev) {
    240 				prev->next = entry->next;
    241 			} else {
    242 				hmap->bucket[i] = entry->next;
    243 			}
    244 			kmem_free(entry, offsetof(lpm_ent_t, key[len]));
    245 			return 0;
    246 		}
    247 		prev = entry;
    248 		entry = entry->next;
    249 	}
    250 	return -1;
    251 }
    252 
    253 /*
    254  * compute_prefix: given the address and prefix length, compute and
    255  * return the address prefix.
    256  */
    257 static inline void
    258 compute_prefix(const unsigned nwords, const uint32_t *addr,
    259     unsigned preflen, uint32_t *prefix)
    260 {
    261 	uint32_t addr2[4];
    262 
    263 	if ((uintptr_t)addr & 3) {
    264 		/* Unaligned address: just copy for now. */
    265 		memcpy(addr2, addr, nwords * 4);
    266 		addr = addr2;
    267 	}
    268 	for (unsigned i = 0; i < nwords; i++) {
    269 		if (preflen == 0) {
    270 			prefix[i] = 0;
    271 			continue;
    272 		}
    273 		if (preflen < 32) {
    274 			uint32_t mask = htonl(0xffffffff << (32 - preflen));
    275 			prefix[i] = addr[i] & mask;
    276 			preflen = 0;
    277 		} else {
    278 			prefix[i] = addr[i];
    279 			preflen -= 32;
    280 		}
    281 	}
    282 }
    283 
    284 /*
    285  * lpm_insert: insert the CIDR into the LPM table.
    286  *
    287  * => Returns zero on success and -1 on failure.
    288  */
    289 int
    290 lpm_insert(lpm_t *lpm, const void *addr,
    291     size_t len, unsigned preflen, void *val)
    292 {
    293 	const unsigned nwords = LPM_TO_WORDS(len);
    294 	uint32_t prefix[LPM_MAX_WORDS];
    295 	lpm_ent_t *entry;
    296 
    297 	if (preflen == 0) {
    298 		/* Default is a special case. */
    299 		lpm->defval = val;
    300 		return 0;
    301 	}
    302 	compute_prefix(nwords, addr, preflen, prefix);
    303 	entry = hashmap_insert(&lpm->prefix[preflen], prefix, len);
    304 	if (entry) {
    305 		const unsigned n = --preflen >> 5;
    306 		lpm->bitmask[n] |= 0x80000000U >> (preflen & 31);
    307 		entry->val = val;
    308 		return 0;
    309 	}
    310 	return -1;
    311 }
    312 
    313 /*
    314  * lpm_remove: remove the specified prefix.
    315  */
    316 int
    317 lpm_remove(lpm_t *lpm, const void *addr, size_t len, unsigned preflen)
    318 {
    319 	const unsigned nwords = LPM_TO_WORDS(len);
    320 	uint32_t prefix[LPM_MAX_WORDS];
    321 
    322 	if (preflen == 0) {
    323 		lpm->defval = NULL;
    324 		return 0;
    325 	}
    326 	compute_prefix(nwords, addr, preflen, prefix);
    327 	return hashmap_remove(&lpm->prefix[preflen], prefix, len);
    328 }
    329 
    330 /*
    331  * lpm_lookup: find the longest matching prefix given the IP address.
    332  *
    333  * => Returns the associated value on success or NULL on failure.
    334  */
    335 void *
    336 lpm_lookup(lpm_t *lpm, const void *addr, size_t len)
    337 {
    338 	const unsigned nwords = LPM_TO_WORDS(len);
    339 	unsigned i, n = nwords;
    340 	uint32_t prefix[LPM_MAX_WORDS];
    341 
    342 	while (n--) {
    343 		uint32_t bitmask = lpm->bitmask[n];
    344 
    345 		while ((i = ffs(bitmask)) != 0) {
    346 			const unsigned preflen = (32 * n) + (32 - --i);
    347 			lpm_hmap_t *hmap = &lpm->prefix[preflen];
    348 			lpm_ent_t *entry;
    349 
    350 			compute_prefix(nwords, addr, preflen, prefix);
    351 			entry = hashmap_lookup(hmap, prefix, len);
    352 			if (entry) {
    353 				return entry->val;
    354 			}
    355 			bitmask &= ~(1U << i);
    356 		}
    357 	}
    358 	return lpm->defval;
    359 }
    360 
    361 #if !defined(_KERNEL)
    362 /*
    363  * lpm_strtobin: convert CIDR string to the binary IP address and mask.
    364  *
    365  * => The address will be in the network byte order.
    366  * => Returns 0 on success or -1 on failure.
    367  */
    368 int
    369 lpm_strtobin(const char *cidr, void *addr, size_t *len, unsigned *preflen)
    370 {
    371 	char *p, buf[INET6_ADDRSTRLEN];
    372 
    373 	strncpy(buf, cidr, sizeof(buf));
    374 	buf[sizeof(buf) - 1] = '\0';
    375 
    376 	if ((p = strchr(buf, '/')) != NULL) {
    377 		const ptrdiff_t off = p - buf;
    378 		*preflen = atoi(&buf[off + 1]);
    379 		buf[off] = '\0';
    380 	} else {
    381 		*preflen = LPM_MAX_PREFIX;
    382 	}
    383 
    384 	if (inet_pton(AF_INET6, buf, addr) == 1) {
    385 		*len = 16;
    386 		return 0;
    387 	}
    388 	if (inet_pton(AF_INET, buf, addr) == 1) {
    389 		if (*preflen == LPM_MAX_PREFIX) {
    390 			*preflen = 32;
    391 		}
    392 		*len = 4;
    393 		return 0;
    394 	}
    395 	return -1;
    396 }
    397 #endif
    398