1 #include <openssl/e_os2.h> 2 #include <stddef.h> 3 #include <sys/types.h> 4 #include <string.h> 5 #include <openssl/bn.h> 6 #include <openssl/err.h> 7 #include <openssl/rsaerr.h> 8 #include "internal/numbers.h" 9 #include "internal/constant_time.h" 10 #include "bn_local.h" 11 12 # if BN_BYTES == 8 13 typedef uint64_t limb_t; 14 # if defined(__SIZEOF_INT128__) && __SIZEOF_INT128__ == 16 15 /* nonstandard; implemented by gcc on 64-bit platforms */ 16 typedef __uint128_t limb2_t; 17 # define HAVE_LIMB2_T 18 # endif 19 # define LIMB_BIT_SIZE 64 20 # define LIMB_BYTE_SIZE 8 21 # elif BN_BYTES == 4 22 typedef uint32_t limb_t; 23 typedef uint64_t limb2_t; 24 # define LIMB_BIT_SIZE 32 25 # define LIMB_BYTE_SIZE 4 26 # define HAVE_LIMB2_T 27 # else 28 # error "Not supported" 29 # endif 30 31 /* 32 * For multiplication we're using schoolbook multiplication, 33 * so if we have two numbers, each with 6 "digits" (words) 34 * the multiplication is calculated as follows: 35 * A B C D E F 36 * x I J K L M N 37 * -------------- 38 * N*F 39 * N*E 40 * N*D 41 * N*C 42 * N*B 43 * N*A 44 * M*F 45 * M*E 46 * M*D 47 * M*C 48 * M*B 49 * M*A 50 * L*F 51 * L*E 52 * L*D 53 * L*C 54 * L*B 55 * L*A 56 * K*F 57 * K*E 58 * K*D 59 * K*C 60 * K*B 61 * K*A 62 * J*F 63 * J*E 64 * J*D 65 * J*C 66 * J*B 67 * J*A 68 * I*F 69 * I*E 70 * I*D 71 * I*C 72 * I*B 73 * + I*A 74 * ========================== 75 * N*B N*D N*F 76 * + N*A N*C N*E 77 * + M*B M*D M*F 78 * + M*A M*C M*E 79 * + L*B L*D L*F 80 * + L*A L*C L*E 81 * + K*B K*D K*F 82 * + K*A K*C K*E 83 * + J*B J*D J*F 84 * + J*A J*C J*E 85 * + I*B I*D I*F 86 * + I*A I*C I*E 87 * 88 * 1+1 1+3 1+5 89 * 1+0 1+2 1+4 90 * 0+1 0+3 0+5 91 * 0+0 0+2 0+4 92 * 93 * 0 1 2 3 4 5 6 94 * which requires n^2 multiplications and 2n full length additions 95 * as we can keep every other result of limb multiplication in two separate 96 * limbs 97 */ 98 99 #if defined HAVE_LIMB2_T 100 static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b) 101 { 102 limb2_t t; 103 /* 104 * this is idiomatic code to tell compiler to use the native mul 105 * those three lines will actually compile to single instruction 106 */ 107 108 t = (limb2_t)a * b; 109 *hi = t >> LIMB_BIT_SIZE; 110 *lo = (limb_t)t; 111 } 112 #elif (BN_BYTES == 8) && (defined _MSC_VER) 113 /* https://learn.microsoft.com/en-us/cpp/intrinsics/umul128?view=msvc-170 */ 114 #pragma intrinsic(_umul128) 115 static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b) 116 { 117 *lo = _umul128(a, b, hi); 118 } 119 #else 120 /* 121 * if the compiler doesn't have either a 128bit data type nor a "return 122 * high 64 bits of multiplication" 123 */ 124 static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b) 125 { 126 limb_t a_low = (limb_t)(uint32_t)a; 127 limb_t a_hi = a >> 32; 128 limb_t b_low = (limb_t)(uint32_t)b; 129 limb_t b_hi = b >> 32; 130 131 limb_t p0 = a_low * b_low; 132 limb_t p1 = a_low * b_hi; 133 limb_t p2 = a_hi * b_low; 134 limb_t p3 = a_hi * b_hi; 135 136 uint32_t cy = (uint32_t)(((p0 >> 32) + (uint32_t)p1 + (uint32_t)p2) >> 32); 137 138 *lo = p0 + (p1 << 32) + (p2 << 32); 139 *hi = p3 + (p1 >> 32) + (p2 >> 32) + cy; 140 } 141 #endif 142 143 /* add two limbs with carry in, return carry out */ 144 static ossl_inline limb_t _add_limb(limb_t *ret, limb_t a, limb_t b, limb_t carry) 145 { 146 limb_t carry1, carry2, t; 147 /* 148 * `c = a + b; if (c < a)` is idiomatic code that makes compilers 149 * use add with carry on assembly level 150 */ 151 152 *ret = a + carry; 153 if (*ret < a) 154 carry1 = 1; 155 else 156 carry1 = 0; 157 158 t = *ret; 159 *ret = t + b; 160 if (*ret < t) 161 carry2 = 1; 162 else 163 carry2 = 0; 164 165 return carry1 + carry2; 166 } 167 168 /* 169 * add two numbers of the same size, return overflow 170 * 171 * add a to b, place result in ret; all arrays need to be n limbs long 172 * return overflow from addition (0 or 1) 173 */ 174 static ossl_inline limb_t add(limb_t *ret, limb_t *a, limb_t *b, size_t n) 175 { 176 limb_t c = 0; 177 ossl_ssize_t i; 178 179 for(i = n - 1; i > -1; i--) 180 c = _add_limb(&ret[i], a[i], b[i], c); 181 182 return c; 183 } 184 185 /* 186 * return number of limbs necessary for temporary values 187 * when multiplying numbers n limbs large 188 */ 189 static ossl_inline size_t mul_limb_numb(size_t n) 190 { 191 return 2 * n * 2; 192 } 193 194 /* 195 * multiply two numbers of the same size 196 * 197 * multiply a by b, place result in ret; a and b need to be n limbs long 198 * ret needs to be 2*n limbs long, tmp needs to be mul_limb_numb(n) limbs 199 * long 200 */ 201 static void limb_mul(limb_t *ret, limb_t *a, limb_t *b, size_t n, limb_t *tmp) 202 { 203 limb_t *r_odd, *r_even; 204 size_t i, j, k; 205 206 r_odd = tmp; 207 r_even = &tmp[2 * n]; 208 209 memset(ret, 0, 2 * n * sizeof(limb_t)); 210 211 for (i = 0; i < n; i++) { 212 for (k = 0; k < i + n + 1; k++) { 213 r_even[k] = 0; 214 r_odd[k] = 0; 215 } 216 for (j = 0; j < n; j++) { 217 /* 218 * place results from even and odd limbs in separate arrays so that 219 * we don't have to calculate overflow every time we get individual 220 * limb multiplication result 221 */ 222 if (j % 2 == 0) 223 _mul_limb(&r_even[i + j], &r_even[i + j + 1], a[i], b[j]); 224 else 225 _mul_limb(&r_odd[i + j], &r_odd[i + j + 1], a[i], b[j]); 226 } 227 /* 228 * skip the least significant limbs when adding multiples of 229 * more significant limbs (they're zero anyway) 230 */ 231 add(ret, ret, r_even, n + i + 1); 232 add(ret, ret, r_odd, n + i + 1); 233 } 234 } 235 236 /* modifies the value in place by performing a right shift by one bit */ 237 static ossl_inline void rshift1(limb_t *val, size_t n) 238 { 239 limb_t shift_in = 0, shift_out = 0; 240 size_t i; 241 242 for (i = 0; i < n; i++) { 243 shift_out = val[i] & 1; 244 val[i] = shift_in << (LIMB_BIT_SIZE - 1) | (val[i] >> 1); 245 shift_in = shift_out; 246 } 247 } 248 249 /* extend the LSB of flag to all bits of limb */ 250 static ossl_inline limb_t mk_mask(limb_t flag) 251 { 252 flag |= flag << 1; 253 flag |= flag << 2; 254 flag |= flag << 4; 255 flag |= flag << 8; 256 flag |= flag << 16; 257 #if (LIMB_BYTE_SIZE == 8) 258 flag |= flag << 32; 259 #endif 260 return flag; 261 } 262 263 /* 264 * copy from either a or b to ret based on flag 265 * when flag == 0, then copies from b 266 * when flag == 1, then copies from a 267 */ 268 static ossl_inline void cselect(limb_t flag, limb_t *ret, limb_t *a, limb_t *b, size_t n) 269 { 270 /* 271 * would be more efficient with non volatile mask, but then gcc 272 * generates code with jumps 273 */ 274 volatile limb_t mask; 275 size_t i; 276 277 mask = mk_mask(flag); 278 for (i = 0; i < n; i++) { 279 #if (LIMB_BYTE_SIZE == 8) 280 ret[i] = constant_time_select_64(mask, a[i], b[i]); 281 #else 282 ret[i] = constant_time_select_32(mask, a[i], b[i]); 283 #endif 284 } 285 } 286 287 static limb_t _sub_limb(limb_t *ret, limb_t a, limb_t b, limb_t borrow) 288 { 289 limb_t borrow1, borrow2, t; 290 /* 291 * while it doesn't look constant-time, this is idiomatic code 292 * to tell compilers to use the carry bit from subtraction 293 */ 294 295 *ret = a - borrow; 296 if (*ret > a) 297 borrow1 = 1; 298 else 299 borrow1 = 0; 300 301 t = *ret; 302 *ret = t - b; 303 if (*ret > t) 304 borrow2 = 1; 305 else 306 borrow2 = 0; 307 308 return borrow1 + borrow2; 309 } 310 311 /* 312 * place the result of a - b into ret, return the borrow bit. 313 * All arrays need to be n limbs long 314 */ 315 static limb_t sub(limb_t *ret, limb_t *a, limb_t *b, size_t n) 316 { 317 limb_t borrow = 0; 318 ossl_ssize_t i; 319 320 for (i = n - 1; i > -1; i--) 321 borrow = _sub_limb(&ret[i], a[i], b[i], borrow); 322 323 return borrow; 324 } 325 326 /* return the number of limbs necessary to allocate for the mod() tmp operand */ 327 static ossl_inline size_t mod_limb_numb(size_t anum, size_t modnum) 328 { 329 return (anum + modnum) * 3; 330 } 331 332 /* 333 * calculate a % mod, place the result in ret 334 * size of a is defined by anum, size of ret and mod is modnum, 335 * size of tmp is returned by mod_limb_numb() 336 */ 337 static void mod(limb_t *ret, limb_t *a, size_t anum, limb_t *mod, 338 size_t modnum, limb_t *tmp) 339 { 340 limb_t *atmp, *modtmp, *rettmp; 341 limb_t res; 342 size_t i; 343 344 memset(tmp, 0, mod_limb_numb(anum, modnum) * LIMB_BYTE_SIZE); 345 346 atmp = tmp; 347 modtmp = &tmp[anum + modnum]; 348 rettmp = &tmp[(anum + modnum) * 2]; 349 350 for (i = modnum; i <modnum + anum; i++) 351 atmp[i] = a[i-modnum]; 352 353 for (i = 0; i < modnum; i++) 354 modtmp[i] = mod[i]; 355 356 for (i = 0; i < anum * LIMB_BIT_SIZE; i++) { 357 rshift1(modtmp, anum + modnum); 358 res = sub(rettmp, atmp, modtmp, anum+modnum); 359 cselect(res, atmp, atmp, rettmp, anum+modnum); 360 } 361 362 memcpy(ret, &atmp[anum], sizeof(limb_t) * modnum); 363 } 364 365 /* necessary size of tmp for a _mul_add_limb() call with provided anum */ 366 static ossl_inline size_t _mul_add_limb_numb(size_t anum) 367 { 368 return 2 * (anum + 1); 369 } 370 371 /* multiply a by m, add to ret, return carry */ 372 static limb_t _mul_add_limb(limb_t *ret, limb_t *a, size_t anum, 373 limb_t m, limb_t *tmp) 374 { 375 limb_t carry = 0; 376 limb_t *r_odd, *r_even; 377 size_t i; 378 379 memset(tmp, 0, sizeof(limb_t) * (anum + 1) * 2); 380 381 r_odd = tmp; 382 r_even = &tmp[anum + 1]; 383 384 for (i = 0; i < anum; i++) { 385 /* 386 * place the results from even and odd limbs in separate arrays 387 * so that we have to worry about carry just once 388 */ 389 if (i % 2 == 0) 390 _mul_limb(&r_even[i], &r_even[i + 1], a[i], m); 391 else 392 _mul_limb(&r_odd[i], &r_odd[i + 1], a[i], m); 393 } 394 /* assert: add() carry here will be equal zero */ 395 add(r_even, r_even, r_odd, anum + 1); 396 /* 397 * while here it will not overflow as the max value from multiplication 398 * is -2 while max overflow from addition is 1, so the max value of 399 * carry is -1 (i.e. max int) 400 */ 401 carry = add(ret, ret, &r_even[1], anum) + r_even[0]; 402 403 return carry; 404 } 405 406 static ossl_inline size_t mod_montgomery_limb_numb(size_t modnum) 407 { 408 return modnum * 2 + _mul_add_limb_numb(modnum); 409 } 410 411 /* 412 * calculate a % mod, place result in ret 413 * assumes that a is in Montgomery form with the R (Montgomery modulus) being 414 * smallest power of two big enough to fit mod and that's also a power 415 * of the count of number of bits in limb_t (B). 416 * For calculation, we also need n', such that mod * n' == -1 mod B. 417 * anum must be <= 2 * modnum 418 * ret needs to be modnum words long 419 * tmp needs to be mod_montgomery_limb_numb(modnum) limbs long 420 */ 421 static void mod_montgomery(limb_t *ret, limb_t *a, size_t anum, limb_t *mod, 422 size_t modnum, limb_t ni0, limb_t *tmp) 423 { 424 limb_t carry, v; 425 limb_t *res, *rp, *tmp2; 426 ossl_ssize_t i; 427 428 res = tmp; 429 /* 430 * for intermediate result we need an integer twice as long as modulus 431 * but keep the input in the least significant limbs 432 */ 433 memset(res, 0, sizeof(limb_t) * (modnum * 2)); 434 memcpy(&res[modnum * 2 - anum], a, sizeof(limb_t) * anum); 435 rp = &res[modnum]; 436 tmp2 = &res[modnum * 2]; 437 438 carry = 0; 439 440 /* add multiples of the modulus to the value until R divides it cleanly */ 441 for (i = modnum; i > 0; i--, rp--) { 442 v = _mul_add_limb(rp, mod, modnum, rp[modnum - 1] * ni0, tmp2); 443 v = v + carry + rp[-1]; 444 carry |= (v != rp[-1]); 445 carry &= (v <= rp[-1]); 446 rp[-1] = v; 447 } 448 449 /* perform the final reduction by mod... */ 450 carry -= sub(ret, rp, mod, modnum); 451 452 /* ...conditionally */ 453 cselect(carry, ret, rp, ret, modnum); 454 } 455 456 /* allocated buffer should be freed afterwards */ 457 static void BN_to_limb(const BIGNUM *bn, limb_t *buf, size_t limbs) 458 { 459 int i; 460 int real_limbs = (BN_num_bytes(bn) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE; 461 limb_t *ptr = buf + (limbs - real_limbs); 462 463 for (i = 0; i < real_limbs; i++) 464 ptr[i] = bn->d[real_limbs - i - 1]; 465 } 466 467 #if LIMB_BYTE_SIZE == 8 468 static ossl_inline uint64_t be64(uint64_t host) 469 { 470 const union { 471 long one; 472 char little; 473 } is_endian = { 1 }; 474 475 if (is_endian.little) { 476 uint64_t big = 0; 477 478 big |= (host & 0xff00000000000000) >> 56; 479 big |= (host & 0x00ff000000000000) >> 40; 480 big |= (host & 0x0000ff0000000000) >> 24; 481 big |= (host & 0x000000ff00000000) >> 8; 482 big |= (host & 0x00000000ff000000) << 8; 483 big |= (host & 0x0000000000ff0000) << 24; 484 big |= (host & 0x000000000000ff00) << 40; 485 big |= (host & 0x00000000000000ff) << 56; 486 return big; 487 } else { 488 return host; 489 } 490 } 491 492 #else 493 /* Not all platforms have htobe32(). */ 494 static ossl_inline uint32_t be32(uint32_t host) 495 { 496 const union { 497 long one; 498 char little; 499 } is_endian = { 1 }; 500 501 if (is_endian.little) { 502 uint32_t big = 0; 503 504 big |= (host & 0xff000000) >> 24; 505 big |= (host & 0x00ff0000) >> 8; 506 big |= (host & 0x0000ff00) << 8; 507 big |= (host & 0x000000ff) << 24; 508 return big; 509 } else { 510 return host; 511 } 512 } 513 #endif 514 515 /* 516 * We assume that intermediate, possible_arg2, blinding, and ctx are used 517 * similar to BN_BLINDING_invert_ex() arguments. 518 * to_mod is RSA modulus. 519 * buf and num is the serialization buffer and its length. 520 * 521 * Here we use classic/Montgomery multiplication and modulo. After the calculation finished 522 * we serialize the new structure instead of BIGNUMs taking endianness into account. 523 */ 524 int ossl_bn_rsa_do_unblind(const BIGNUM *intermediate, 525 const BN_BLINDING *blinding, 526 const BIGNUM *possible_arg2, 527 const BIGNUM *to_mod, BN_CTX *ctx, 528 unsigned char *buf, int num) 529 { 530 limb_t *l_im = NULL, *l_mul = NULL, *l_mod = NULL; 531 limb_t *l_ret = NULL, *l_tmp = NULL, l_buf; 532 size_t l_im_count = 0, l_mul_count = 0, l_size = 0, l_mod_count = 0; 533 size_t l_tmp_count = 0; 534 int ret = 0; 535 size_t i; 536 unsigned char *tmp; 537 const BIGNUM *arg1 = intermediate; 538 const BIGNUM *arg2 = (possible_arg2 == NULL) ? blinding->Ai : possible_arg2; 539 540 l_im_count = (BN_num_bytes(arg1) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE; 541 l_mul_count = (BN_num_bytes(arg2) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE; 542 l_mod_count = (BN_num_bytes(to_mod) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE; 543 544 l_size = l_im_count > l_mul_count ? l_im_count : l_mul_count; 545 l_im = OPENSSL_zalloc(l_size * LIMB_BYTE_SIZE); 546 l_mul = OPENSSL_zalloc(l_size * LIMB_BYTE_SIZE); 547 l_mod = OPENSSL_zalloc(l_mod_count * LIMB_BYTE_SIZE); 548 549 if ((l_im == NULL) || (l_mul == NULL) || (l_mod == NULL)) 550 goto err; 551 552 BN_to_limb(arg1, l_im, l_size); 553 BN_to_limb(arg2, l_mul, l_size); 554 BN_to_limb(to_mod, l_mod, l_mod_count); 555 556 l_ret = OPENSSL_malloc(2 * l_size * LIMB_BYTE_SIZE); 557 558 if (blinding->m_ctx != NULL) { 559 l_tmp_count = mul_limb_numb(l_size) > mod_montgomery_limb_numb(l_mod_count) ? 560 mul_limb_numb(l_size) : mod_montgomery_limb_numb(l_mod_count); 561 l_tmp = OPENSSL_malloc(l_tmp_count * LIMB_BYTE_SIZE); 562 } else { 563 l_tmp_count = mul_limb_numb(l_size) > mod_limb_numb(2 * l_size, l_mod_count) ? 564 mul_limb_numb(l_size) : mod_limb_numb(2 * l_size, l_mod_count); 565 l_tmp = OPENSSL_malloc(l_tmp_count * LIMB_BYTE_SIZE); 566 } 567 568 if ((l_ret == NULL) || (l_tmp == NULL)) 569 goto err; 570 571 if (blinding->m_ctx != NULL) { 572 limb_mul(l_ret, l_im, l_mul, l_size, l_tmp); 573 mod_montgomery(l_ret, l_ret, 2 * l_size, l_mod, l_mod_count, 574 blinding->m_ctx->n0[0], l_tmp); 575 } else { 576 limb_mul(l_ret, l_im, l_mul, l_size, l_tmp); 577 mod(l_ret, l_ret, 2 * l_size, l_mod, l_mod_count, l_tmp); 578 } 579 580 /* modulus size in bytes can be equal to num but after limbs conversion it becomes bigger */ 581 if (num < BN_num_bytes(to_mod)) { 582 BNerr(BN_F_OSSL_BN_RSA_DO_UNBLIND, ERR_R_PASSED_INVALID_ARGUMENT); 583 goto err; 584 } 585 586 memset(buf, 0, num); 587 tmp = buf + num - BN_num_bytes(to_mod); 588 for (i = 0; i < l_mod_count; i++) { 589 #if LIMB_BYTE_SIZE == 8 590 l_buf = be64(l_ret[i]); 591 #else 592 l_buf = be32(l_ret[i]); 593 #endif 594 if (i == 0) { 595 int delta = LIMB_BYTE_SIZE - ((l_mod_count * LIMB_BYTE_SIZE) - num); 596 597 memcpy(tmp, ((char *)&l_buf) + LIMB_BYTE_SIZE - delta, delta); 598 tmp += delta; 599 } else { 600 memcpy(tmp, &l_buf, LIMB_BYTE_SIZE); 601 tmp += LIMB_BYTE_SIZE; 602 } 603 } 604 ret = num; 605 606 err: 607 OPENSSL_free(l_im); 608 OPENSSL_free(l_mul); 609 OPENSSL_free(l_mod); 610 OPENSSL_free(l_tmp); 611 OPENSSL_free(l_ret); 612 613 return ret; 614 } 615