1 /* $NetBSD: aes_neon_subr.c,v 1.8 2022/06/26 17:52:54 riastradh Exp $ */ 2 3 /*- 4 * Copyright (c) 2020 The NetBSD Foundation, Inc. 5 * All rights reserved. 6 * 7 * Redistribution and use in source and binary forms, with or without 8 * modification, are permitted provided that the following conditions 9 * are met: 10 * 1. Redistributions of source code must retain the above copyright 11 * notice, this list of conditions and the following disclaimer. 12 * 2. Redistributions in binary form must reproduce the above copyright 13 * notice, this list of conditions and the following disclaimer in the 14 * documentation and/or other materials provided with the distribution. 15 * 16 * THIS SOFTWARE IS PROVIDED BY THE NETBSD FOUNDATION, INC. AND CONTRIBUTORS 17 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED 18 * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 19 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR CONTRIBUTORS 20 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 26 * POSSIBILITY OF SUCH DAMAGE. 27 */ 28 29 #include <sys/cdefs.h> 30 __KERNEL_RCSID(1, "$NetBSD: aes_neon_subr.c,v 1.8 2022/06/26 17:52:54 riastradh Exp $"); 31 32 #ifdef _KERNEL 33 #include <sys/systm.h> 34 #include <lib/libkern/libkern.h> 35 #else 36 #include <assert.h> 37 #include <inttypes.h> 38 #include <stdio.h> 39 #include <string.h> 40 #define KASSERT assert 41 #endif 42 43 #include <crypto/aes/arch/arm/aes_neon.h> 44 45 #include "aes_neon_impl.h" 46 47 static inline uint8x16_t 48 loadblock(const void *in) 49 { 50 return vld1q_u8(in); 51 } 52 53 static inline void 54 storeblock(void *out, uint8x16_t block) 55 { 56 vst1q_u8(out, block); 57 } 58 59 void 60 aes_neon_enc(const struct aesenc *enc, const uint8_t in[static 16], 61 uint8_t out[static 16], uint32_t nrounds) 62 { 63 uint8x16_t block; 64 65 block = loadblock(in); 66 block = aes_neon_enc1(enc, block, nrounds); 67 storeblock(out, block); 68 } 69 70 void 71 aes_neon_dec(const struct aesdec *dec, const uint8_t in[static 16], 72 uint8_t out[static 16], uint32_t nrounds) 73 { 74 uint8x16_t block; 75 76 block = loadblock(in); 77 block = aes_neon_dec1(dec, block, nrounds); 78 storeblock(out, block); 79 } 80 81 void 82 aes_neon_cbc_enc(const struct aesenc *enc, const uint8_t in[static 16], 83 uint8_t out[static 16], size_t nbytes, uint8_t iv[static 16], 84 uint32_t nrounds) 85 { 86 uint8x16_t cv; 87 88 KASSERT(nbytes); 89 90 cv = loadblock(iv); 91 for (; nbytes; nbytes -= 16, in += 16, out += 16) { 92 cv ^= loadblock(in); 93 cv = aes_neon_enc1(enc, cv, nrounds); 94 storeblock(out, cv); 95 } 96 storeblock(iv, cv); 97 } 98 99 void 100 aes_neon_cbc_dec(const struct aesdec *dec, const uint8_t in[static 16], 101 uint8_t out[static 16], size_t nbytes, uint8_t iv[static 16], 102 uint32_t nrounds) 103 { 104 uint8x16_t iv0, cv, b; 105 106 KASSERT(nbytes); 107 KASSERT(nbytes % 16 == 0); 108 109 iv0 = loadblock(iv); 110 cv = loadblock(in + nbytes - 16); 111 storeblock(iv, cv); 112 113 if (nbytes % 32) { 114 KASSERT(nbytes % 32 == 16); 115 b = aes_neon_dec1(dec, cv, nrounds); 116 if ((nbytes -= 16) == 0) 117 goto out; 118 cv = loadblock(in + nbytes - 16); 119 storeblock(out + nbytes, cv ^ b); 120 } 121 122 for (;;) { 123 uint8x16x2_t b2; 124 125 KASSERT(nbytes >= 32); 126 127 b2.val[1] = cv; 128 b2.val[0] = cv = loadblock(in + nbytes - 32); 129 b2 = aes_neon_dec2(dec, b2, nrounds); 130 storeblock(out + nbytes - 16, cv ^ b2.val[1]); 131 if ((nbytes -= 32) == 0) { 132 b = b2.val[0]; 133 goto out; 134 } 135 cv = loadblock(in + nbytes - 16); 136 storeblock(out + nbytes, cv ^ b2.val[0]); 137 } 138 139 out: storeblock(out, b ^ iv0); 140 } 141 142 static inline uint8x16_t 143 aes_neon_xts_update(uint8x16_t t8) 144 { 145 const int32x4_t zero = vdupq_n_s32(0); 146 /* (0x87,1,1,1) */ 147 const uint32x4_t carry = vsetq_lane_u32(0x87, vdupq_n_u32(1), 0); 148 int32x4_t t, t_; 149 uint32x4_t mask; 150 151 t = vreinterpretq_s32_u8(t8); 152 mask = vcltq_s32(t, zero); /* -1 if high bit set else 0 */ 153 mask = vextq_u32(mask, mask, 3); /* rotate quarters */ 154 t_ = vshlq_n_s32(t, 1); /* shift */ 155 t_ ^= carry & mask; 156 157 return vreinterpretq_u8_s32(t_); 158 } 159 160 static int 161 aes_neon_xts_update_selftest(void) 162 { 163 static const struct { 164 uint8_t in[16], out[16]; 165 } cases[] = { 166 [0] = { {1,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0}, 167 {2,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0} }, 168 [1] = { {0,0,0,0x80, 0,0,0,0, 0,0,0,0, 0,0,0,0}, 169 {0,0,0,0, 1,0,0,0, 0,0,0,0, 0,0,0,0} }, 170 [2] = { {0,0,0,0, 0,0,0,0x80, 0,0,0,0, 0,0,0,0}, 171 {0,0,0,0, 0,0,0,0, 1,0,0,0, 0,0,0,0} }, 172 [3] = { {0,0,0,0, 0,0,0,0, 0,0,0,0x80, 0,0,0,0}, 173 {0,0,0,0, 0,0,0,0, 0,0,0,0, 1,0,0,0} }, 174 [4] = { {0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0x80}, 175 {0x87,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0} }, 176 [5] = { {0,0,0,0, 0,0,0,0x80, 0,0,0,0, 0,0,0,0x80}, 177 {0x87,0,0,0, 0,0,0,0, 1,0,0,0, 0,0,0,0} }, 178 }; 179 unsigned i; 180 uint8_t t[16]; 181 int result = 0; 182 183 for (i = 0; i < sizeof(cases)/sizeof(cases[0]); i++) { 184 storeblock(t, aes_neon_xts_update(loadblock(cases[i].in))); 185 if (memcmp(t, cases[i].out, 16)) { 186 char buf[3*16 + 1]; 187 unsigned j; 188 189 for (j = 0; j < 16; j++) { 190 snprintf(buf + 3*j, sizeof(buf) - 3*j, 191 " %02hhx", t[j]); 192 } 193 printf("%s %u: %s\n", __func__, i, buf); 194 result = -1; 195 } 196 } 197 198 return result; 199 } 200 201 void 202 aes_neon_xts_enc(const struct aesenc *enc, const uint8_t in[static 16], 203 uint8_t out[static 16], size_t nbytes, uint8_t tweak[static 16], 204 uint32_t nrounds) 205 { 206 uint8x16_t t, b; 207 208 KASSERT(nbytes); 209 KASSERT(nbytes % 16 == 0); 210 211 t = loadblock(tweak); 212 if (nbytes % 32) { 213 KASSERT(nbytes % 32 == 16); 214 b = t ^ loadblock(in); 215 b = aes_neon_enc1(enc, b, nrounds); 216 storeblock(out, t ^ b); 217 t = aes_neon_xts_update(t); 218 nbytes -= 16; 219 in += 16; 220 out += 16; 221 } 222 for (; nbytes; nbytes -= 32, in += 32, out += 32) { 223 uint8x16_t t1; 224 uint8x16x2_t b2; 225 226 t1 = aes_neon_xts_update(t); 227 b2.val[0] = t ^ loadblock(in); 228 b2.val[1] = t1 ^ loadblock(in + 16); 229 b2 = aes_neon_enc2(enc, b2, nrounds); 230 storeblock(out, b2.val[0] ^ t); 231 storeblock(out + 16, b2.val[1] ^ t1); 232 233 t = aes_neon_xts_update(t1); 234 } 235 storeblock(tweak, t); 236 } 237 238 void 239 aes_neon_xts_dec(const struct aesdec *dec, const uint8_t in[static 16], 240 uint8_t out[static 16], size_t nbytes, uint8_t tweak[static 16], 241 uint32_t nrounds) 242 { 243 uint8x16_t t, b; 244 245 KASSERT(nbytes); 246 KASSERT(nbytes % 16 == 0); 247 248 t = loadblock(tweak); 249 if (nbytes % 32) { 250 KASSERT(nbytes % 32 == 16); 251 b = t ^ loadblock(in); 252 b = aes_neon_dec1(dec, b, nrounds); 253 storeblock(out, t ^ b); 254 t = aes_neon_xts_update(t); 255 nbytes -= 16; 256 in += 16; 257 out += 16; 258 } 259 for (; nbytes; nbytes -= 32, in += 32, out += 32) { 260 uint8x16_t t1; 261 uint8x16x2_t b2; 262 263 t1 = aes_neon_xts_update(t); 264 b2.val[0] = t ^ loadblock(in); 265 b2.val[1] = t1 ^ loadblock(in + 16); 266 b2 = aes_neon_dec2(dec, b2, nrounds); 267 storeblock(out, b2.val[0] ^ t); 268 storeblock(out + 16, b2.val[1] ^ t1); 269 270 t = aes_neon_xts_update(t1); 271 } 272 storeblock(tweak, t); 273 } 274 275 void 276 aes_neon_cbcmac_update1(const struct aesenc *enc, const uint8_t in[static 16], 277 size_t nbytes, uint8_t auth0[static 16], uint32_t nrounds) 278 { 279 uint8x16_t auth; 280 281 KASSERT(nbytes); 282 KASSERT(nbytes % 16 == 0); 283 284 auth = loadblock(auth0); 285 for (; nbytes; nbytes -= 16, in += 16) 286 auth = aes_neon_enc1(enc, auth ^ loadblock(in), nrounds); 287 storeblock(auth0, auth); 288 } 289 290 void 291 aes_neon_ccm_enc1(const struct aesenc *enc, const uint8_t in[static 16], 292 uint8_t out[static 16], size_t nbytes, uint8_t authctr[static 32], 293 uint32_t nrounds) 294 { 295 /* (0,0,0,1) */ 296 const uint32x4_t ctr32_inc = vsetq_lane_u32(1, vdupq_n_u32(0), 3); 297 uint8x16_t auth, ptxt, ctr_be; 298 uint32x4_t ctr; 299 300 KASSERT(nbytes); 301 KASSERT(nbytes % 16 == 0); 302 303 auth = loadblock(authctr); 304 ctr_be = loadblock(authctr + 16); 305 ctr = vreinterpretq_u32_u8(vrev32q_u8(ctr_be)); 306 for (; nbytes; nbytes -= 16, in += 16, out += 16) { 307 uint8x16x2_t b2; 308 ptxt = loadblock(in); 309 ctr = vaddq_u32(ctr, ctr32_inc); 310 ctr_be = vrev32q_u8(vreinterpretq_u8_u32(ctr)); 311 312 b2.val[0] = auth ^ ptxt; 313 b2.val[1] = ctr_be; 314 b2 = aes_neon_enc2(enc, b2, nrounds); 315 auth = b2.val[0]; 316 storeblock(out, ptxt ^ b2.val[1]); 317 } 318 storeblock(authctr, auth); 319 storeblock(authctr + 16, ctr_be); 320 } 321 322 void 323 aes_neon_ccm_dec1(const struct aesenc *enc, const uint8_t in[static 16], 324 uint8_t out[static 16], size_t nbytes, uint8_t authctr[static 32], 325 uint32_t nrounds) 326 { 327 /* (0,0,0,1) */ 328 const uint32x4_t ctr32_inc = vsetq_lane_u32(1, vdupq_n_u32(0), 3); 329 uint8x16_t auth, ctr_be, ptxt, pad; 330 uint32x4_t ctr; 331 332 KASSERT(nbytes); 333 KASSERT(nbytes % 16 == 0); 334 335 ctr_be = loadblock(authctr + 16); 336 ctr = vreinterpretq_u32_u8(vrev32q_u8(ctr_be)); 337 ctr = vaddq_u32(ctr, ctr32_inc); 338 ctr_be = vrev32q_u8(vreinterpretq_u8_u32(ctr)); 339 pad = aes_neon_enc1(enc, ctr_be, nrounds); 340 auth = loadblock(authctr); 341 for (;; in += 16, out += 16) { 342 uint8x16x2_t b2; 343 344 ptxt = loadblock(in) ^ pad; 345 auth ^= ptxt; 346 storeblock(out, ptxt); 347 348 if ((nbytes -= 16) == 0) 349 break; 350 351 ctr = vaddq_u32(ctr, ctr32_inc); 352 ctr_be = vrev32q_u8(vreinterpretq_u8_u32(ctr)); 353 b2.val[0] = auth; 354 b2.val[1] = ctr_be; 355 b2 = aes_neon_enc2(enc, b2, nrounds); 356 auth = b2.val[0]; 357 pad = b2.val[1]; 358 } 359 auth = aes_neon_enc1(enc, auth, nrounds); 360 storeblock(authctr, auth); 361 storeblock(authctr + 16, ctr_be); 362 } 363 364 int 365 aes_neon_selftest(void) 366 { 367 368 if (aes_neon_xts_update_selftest()) 369 return -1; 370 371 return 0; 372 } 373