Home | History | Annotate | Line # | Download | only in npf
lpm.c revision 1.3
      1 /*-
      2  * Copyright (c) 2016 Mindaugas Rasiukevicius <rmind at noxt eu>
      3  * All rights reserved.
      4  *
      5  * Redistribution and use in source and binary forms, with or without
      6  * modification, are permitted provided that the following conditions
      7  * are met:
      8  * 1. Redistributions of source code must retain the above copyright
      9  *    notice, this list of conditions and the following disclaimer.
     10  * 2. Redistributions in binary form must reproduce the above copyright
     11  *    notice, this list of conditions and the following disclaimer in the
     12  *    documentation and/or other materials provided with the distribution.
     13  *
     14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
     15  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
     16  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
     17  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
     18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
     19  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
     20  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
     21  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
     22  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
     23  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
     24  * SUCH DAMAGE.
     25  */
     26 
     27 /*
     28  * TODO: Simple linear scan for now (works just well with a few prefixes).
     29  * TBD on a better algorithm.
     30  */
     31 
     32 #if defined(_KERNEL)
     33 #include <sys/cdefs.h>
     34 __KERNEL_RCSID(0, "$NetBSD: lpm.c,v 1.3 2016/12/26 21:16:06 rmind Exp $");
     35 
     36 #include <sys/param.h>
     37 #include <sys/types.h>
     38 #include <sys/malloc.h>
     39 #include <sys/kmem.h>
     40 #else
     41 #include <sys/socket.h>
     42 #include <arpa/inet.h>
     43 
     44 #include <stdio.h>
     45 #include <stdlib.h>
     46 #include <stdbool.h>
     47 #include <stddef.h>
     48 #include <string.h>
     49 #include <strings.h>
     50 #include <errno.h>
     51 #include <assert.h>
     52 #define kmem_alloc(a, b) malloc(a)
     53 #define kmem_free(a, b) free(a)
     54 #define kmem_zalloc(a, b) calloc(a, 1)
     55 #endif
     56 
     57 #include "lpm.h"
     58 
     59 #define	LPM_MAX_PREFIX		(128)
     60 #define	LPM_MAX_WORDS		(LPM_MAX_PREFIX >> 5)
     61 #define	LPM_TO_WORDS(x)		((x) >> 2)
     62 #define	LPM_HASH_STEP		(8)
     63 
     64 #ifdef DEBUG
     65 #define	ASSERT	assert
     66 #else
     67 #define	ASSERT
     68 #endif
     69 
     70 typedef struct lpm_ent {
     71 	struct lpm_ent *next;
     72 	void *		val;
     73 	unsigned	len;
     74 	uint8_t		key[];
     75 } lpm_ent_t;
     76 
     77 typedef struct {
     78 	uint32_t	hashsize;
     79 	uint32_t	nitems;
     80 	lpm_ent_t **bucket;
     81 } lpm_hmap_t;
     82 
     83 struct lpm {
     84 	uint32_t	bitmask[LPM_MAX_WORDS];
     85 	void *		defval;
     86 	lpm_hmap_t	prefix[LPM_MAX_PREFIX + 1];
     87 };
     88 
     89 lpm_t *
     90 lpm_create(void)
     91 {
     92 	return kmem_zalloc(sizeof(lpm_t), KM_SLEEP);
     93 }
     94 
     95 void
     96 lpm_clear(lpm_t *lpm, lpm_dtor_t dtor, void *arg)
     97 {
     98 	for (unsigned n = 0; n <= LPM_MAX_PREFIX; n++) {
     99 		lpm_hmap_t *hmap = &lpm->prefix[n];
    100 
    101 		if (!hmap->hashsize) {
    102 			KASSERT(!hmap->bucket);
    103 			continue;
    104 		}
    105 		for (unsigned i = 0; i < hmap->hashsize; i++) {
    106 			lpm_ent_t *entry = hmap->bucket[i];
    107 
    108 			while (entry) {
    109 				lpm_ent_t *next = entry->next;
    110 
    111 				if (dtor) {
    112 					dtor(arg, entry->key,
    113 					    entry->len, entry->val);
    114 				}
    115 				kmem_free(entry,
    116 				    offsetof(lpm_ent_t, key[entry->len]));
    117 				entry = next;
    118 			}
    119 		}
    120 		kmem_free(hmap->bucket, hmap->hashsize * sizeof(lpm_ent_t *));
    121 		hmap->bucket = NULL;
    122 		hmap->hashsize = 0;
    123 		hmap->nitems = 0;
    124 	}
    125 	memset(lpm->bitmask, 0, sizeof(lpm->bitmask));
    126 	lpm->defval = NULL;
    127 }
    128 
    129 void
    130 lpm_destroy(lpm_t *lpm)
    131 {
    132 	lpm_clear(lpm, NULL, NULL);
    133 	kmem_free(lpm, sizeof(*lpm));
    134 }
    135 
    136 /*
    137  * fnv1a_hash: Fowler-Noll-Vo hash function (FNV-1a variant).
    138  */
    139 static uint32_t
    140 fnv1a_hash(const void *buf, size_t len)
    141 {
    142 	uint32_t hash = 2166136261UL;
    143 	const uint8_t *p = buf;
    144 
    145 	while (len--) {
    146 		hash ^= *p++;
    147 		hash *= 16777619U;
    148 	}
    149 	return hash;
    150 }
    151 
    152 static bool
    153 hashmap_rehash(lpm_hmap_t *hmap, uint32_t size)
    154 {
    155 	lpm_ent_t **bucket;
    156 	uint32_t hashsize;
    157 
    158 	for (hashsize = 1; hashsize < size; hashsize <<= 1) {
    159 		continue;
    160 	}
    161 	bucket = kmem_zalloc(hashsize * sizeof(lpm_ent_t *), KM_SLEEP);
    162 	if (bucket == NULL)
    163 		return false;
    164 	for (unsigned n = 0; n < hmap->hashsize; n++) {
    165 		lpm_ent_t *list = hmap->bucket[n];
    166 
    167 		while (list) {
    168 			lpm_ent_t *entry = list;
    169 			uint32_t hash = fnv1a_hash(entry->key, entry->len);
    170 			const size_t i = hash & (hashsize - 1);
    171 
    172 			list = entry->next;
    173 			entry->next = bucket[i];
    174 			bucket[i] = entry;
    175 		}
    176 	}
    177 	if (hmap->bucket)
    178 		kmem_free(hmap->bucket, hmap->hashsize * sizeof(lpm_ent_t *));
    179 	hmap->bucket = bucket;
    180 	hmap->hashsize = hashsize;
    181 	return true;
    182 }
    183 
    184 static lpm_ent_t *
    185 hashmap_insert(lpm_hmap_t *hmap, const void *key, size_t len)
    186 {
    187 	const uint32_t target = hmap->nitems + LPM_HASH_STEP;
    188 	const size_t entlen = offsetof(lpm_ent_t, key[len]);
    189 	uint32_t hash, i;
    190 	lpm_ent_t *entry;
    191 
    192 	if (hmap->hashsize < target && !hashmap_rehash(hmap, target)) {
    193 		return NULL;
    194 	}
    195 
    196 	hash = fnv1a_hash(key, len);
    197 	i = hash & (hmap->hashsize - 1);
    198 	entry = hmap->bucket[i];
    199 	while (entry) {
    200 		if (entry->len == len && memcmp(entry->key, key, len) == 0) {
    201 			return entry;
    202 		}
    203 		entry = entry->next;
    204 	}
    205 
    206 	if ((entry = kmem_alloc(entlen, KM_SLEEP)) == NULL)
    207 		return NULL;
    208 
    209 	memcpy(entry->key, key, len);
    210 	entry->next = hmap->bucket[i];
    211 	entry->len = len;
    212 
    213 	hmap->bucket[i] = entry;
    214 	hmap->nitems++;
    215 	return entry;
    216 }
    217 
    218 static lpm_ent_t *
    219 hashmap_lookup(lpm_hmap_t *hmap, const void *key, size_t len)
    220 {
    221 	const uint32_t hash = fnv1a_hash(key, len);
    222 	const uint32_t i = hash & (hmap->hashsize - 1);
    223 	lpm_ent_t *entry = hmap->bucket[i];
    224 
    225 	while (entry) {
    226 		if (entry->len == len && memcmp(entry->key, key, len) == 0) {
    227 			return entry;
    228 		}
    229 		entry = entry->next;
    230 	}
    231 	return NULL;
    232 }
    233 
    234 static int
    235 hashmap_remove(lpm_hmap_t *hmap, const void *key, size_t len)
    236 {
    237 	const uint32_t hash = fnv1a_hash(key, len);
    238 	const uint32_t i = hash & (hmap->hashsize - 1);
    239 	lpm_ent_t *prev = NULL, *entry = hmap->bucket[i];
    240 
    241 	while (entry) {
    242 		if (entry->len == len && memcmp(entry->key, key, len) == 0) {
    243 			if (prev) {
    244 				prev->next = entry->next;
    245 			} else {
    246 				hmap->bucket[i] = entry->next;
    247 			}
    248 			kmem_free(entry, offsetof(lpm_ent_t, key[len]));
    249 			return 0;
    250 		}
    251 		prev = entry;
    252 		entry = entry->next;
    253 	}
    254 	return -1;
    255 }
    256 
    257 /*
    258  * compute_prefix: given the address and prefix length, compute and
    259  * return the address prefix.
    260  */
    261 static inline void
    262 compute_prefix(const unsigned nwords, const uint32_t *addr,
    263     unsigned preflen, uint32_t *prefix)
    264 {
    265 	uint32_t addr2[4];
    266 
    267 	if ((uintptr_t)addr & 3) {
    268 		/* Unaligned address: just copy for now. */
    269 		memcpy(addr2, addr, nwords * 4);
    270 		addr = addr2;
    271 	}
    272 	for (unsigned i = 0; i < nwords; i++) {
    273 		if (preflen == 0) {
    274 			prefix[i] = 0;
    275 			continue;
    276 		}
    277 		if (preflen < 32) {
    278 			uint32_t mask = htonl(0xffffffff << (32 - preflen));
    279 			prefix[i] = addr[i] & mask;
    280 			preflen = 0;
    281 		} else {
    282 			prefix[i] = addr[i];
    283 			preflen -= 32;
    284 		}
    285 	}
    286 }
    287 
    288 /*
    289  * lpm_insert: insert the CIDR into the LPM table.
    290  *
    291  * => Returns zero on success and -1 on failure.
    292  */
    293 int
    294 lpm_insert(lpm_t *lpm, const void *addr,
    295     size_t len, unsigned preflen, void *val)
    296 {
    297 	const unsigned nwords = LPM_TO_WORDS(len);
    298 	uint32_t prefix[LPM_MAX_WORDS];
    299 	lpm_ent_t *entry;
    300 
    301 	if (preflen == 0) {
    302 		/* Default is a special case. */
    303 		lpm->defval = val;
    304 		return 0;
    305 	}
    306 	compute_prefix(nwords, addr, preflen, prefix);
    307 	entry = hashmap_insert(&lpm->prefix[preflen], prefix, len);
    308 	if (entry) {
    309 		const unsigned n = --preflen >> 5;
    310 		lpm->bitmask[n] |= 0x80000000U >> (preflen & 31);
    311 		entry->val = val;
    312 		return 0;
    313 	}
    314 	return -1;
    315 }
    316 
    317 /*
    318  * lpm_remove: remove the specified prefix.
    319  */
    320 int
    321 lpm_remove(lpm_t *lpm, const void *addr, size_t len, unsigned preflen)
    322 {
    323 	const unsigned nwords = LPM_TO_WORDS(len);
    324 	uint32_t prefix[LPM_MAX_WORDS];
    325 
    326 	if (preflen == 0) {
    327 		lpm->defval = NULL;
    328 		return 0;
    329 	}
    330 	compute_prefix(nwords, addr, preflen, prefix);
    331 	return hashmap_remove(&lpm->prefix[preflen], prefix, len);
    332 }
    333 
    334 /*
    335  * lpm_lookup: find the longest matching prefix given the IP address.
    336  *
    337  * => Returns the associated value on success or NULL on failure.
    338  */
    339 void *
    340 lpm_lookup(lpm_t *lpm, const void *addr, size_t len)
    341 {
    342 	const unsigned nwords = LPM_TO_WORDS(len);
    343 	unsigned i, n = nwords;
    344 	uint32_t prefix[LPM_MAX_WORDS];
    345 
    346 	while (n--) {
    347 		uint32_t bitmask = lpm->bitmask[n];
    348 
    349 		while ((i = ffs(bitmask)) != 0) {
    350 			const unsigned preflen = (32 * n) + (32 - --i);
    351 			lpm_hmap_t *hmap = &lpm->prefix[preflen];
    352 			lpm_ent_t *entry;
    353 
    354 			compute_prefix(nwords, addr, preflen, prefix);
    355 			entry = hashmap_lookup(hmap, prefix, len);
    356 			if (entry) {
    357 				return entry->val;
    358 			}
    359 			bitmask &= ~(1U << i);
    360 		}
    361 	}
    362 	return lpm->defval;
    363 }
    364 
    365 #if !defined(_KERNEL)
    366 /*
    367  * lpm_strtobin: convert CIDR string to the binary IP address and mask.
    368  *
    369  * => The address will be in the network byte order.
    370  * => Returns 0 on success or -1 on failure.
    371  */
    372 int
    373 lpm_strtobin(const char *cidr, void *addr, size_t *len, unsigned *preflen)
    374 {
    375 	char *p, buf[INET6_ADDRSTRLEN];
    376 
    377 	strncpy(buf, cidr, sizeof(buf));
    378 	buf[sizeof(buf) - 1] = '\0';
    379 
    380 	if ((p = strchr(buf, '/')) != NULL) {
    381 		const ptrdiff_t off = p - buf;
    382 		*preflen = atoi(&buf[off + 1]);
    383 		buf[off] = '\0';
    384 	} else {
    385 		*preflen = LPM_MAX_PREFIX;
    386 	}
    387 
    388 	if (inet_pton(AF_INET6, buf, addr) == 1) {
    389 		*len = 16;
    390 		return 0;
    391 	}
    392 	if (inet_pton(AF_INET, buf, addr) == 1) {
    393 		if (*preflen == LPM_MAX_PREFIX) {
    394 			*preflen = 32;
    395 		}
    396 		*len = 4;
    397 		return 0;
    398 	}
    399 	return -1;
    400 }
    401 #endif
    402