1 /* 2 Name: imath.c 3 Purpose: Arbitrary precision integer arithmetic routines. 4 Author: M. J. Fromberger 5 6 Copyright (C) 2002-2007 Michael J. Fromberger, All Rights Reserved. 7 8 Permission is hereby granted, free of charge, to any person obtaining a copy 9 of this software and associated documentation files (the "Software"), to deal 10 in the Software without restriction, including without limitation the rights 11 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 copies of the Software, and to permit persons to whom the Software is 13 furnished to do so, subject to the following conditions: 14 15 The above copyright notice and this permission notice shall be included in 16 all copies or substantial portions of the Software. 17 18 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 SOFTWARE. 25 */ 26 27 #include "imath.h" 28 29 #include <assert.h> 30 #include <ctype.h> 31 #include <stdlib.h> 32 #include <string.h> 33 34 const mp_result MP_OK = 0; /* no error, all is well */ 35 const mp_result MP_FALSE = 0; /* boolean false */ 36 const mp_result MP_TRUE = -1; /* boolean true */ 37 const mp_result MP_MEMORY = -2; /* out of memory */ 38 const mp_result MP_RANGE = -3; /* argument out of range */ 39 const mp_result MP_UNDEF = -4; /* result undefined */ 40 const mp_result MP_TRUNC = -5; /* output truncated */ 41 const mp_result MP_BADARG = -6; /* invalid null argument */ 42 const mp_result MP_MINERR = -6; 43 44 const mp_sign MP_NEG = 1; /* value is strictly negative */ 45 const mp_sign MP_ZPOS = 0; /* value is non-negative */ 46 47 static const char *s_unknown_err = "unknown result code"; 48 static const char *s_error_msg[] = {"error code 0", "boolean true", 49 "out of memory", "argument out of range", 50 "result undefined", "output truncated", 51 "invalid argument", NULL}; 52 53 /* The ith entry of this table gives the value of log_i(2). 54 55 An integer value n requires ceil(log_i(n)) digits to be represented 56 in base i. Since it is easy to compute lg(n), by counting bits, we 57 can compute log_i(n) = lg(n) * log_i(2). 58 59 The use of this table eliminates a dependency upon linkage against 60 the standard math libraries. 61 62 If MP_MAX_RADIX is increased, this table should be expanded too. 63 */ 64 static const double s_log2[] = { 65 0.000000000, 0.000000000, 1.000000000, 0.630929754, /* (D)(D) 2 3 */ 66 0.500000000, 0.430676558, 0.386852807, 0.356207187, /* 4 5 6 7 */ 67 0.333333333, 0.315464877, 0.301029996, 0.289064826, /* 8 9 10 11 */ 68 0.278942946, 0.270238154, 0.262649535, 0.255958025, /* 12 13 14 15 */ 69 0.250000000, 0.244650542, 0.239812467, 0.235408913, /* 16 17 18 19 */ 70 0.231378213, 0.227670249, 0.224243824, 0.221064729, /* 20 21 22 23 */ 71 0.218104292, 0.215338279, 0.212746054, 0.210309918, /* 24 25 26 27 */ 72 0.208014598, 0.205846832, 0.203795047, 0.201849087, /* 28 29 30 31 */ 73 0.200000000, 0.198239863, 0.196561632, 0.194959022, /* 32 33 34 35 */ 74 0.193426404, /* 36 */ 75 }; 76 77 /* Return the number of digits needed to represent a static value */ 78 #define MP_VALUE_DIGITS(V) \ 79 ((sizeof(V) + (sizeof(mp_digit) - 1)) / sizeof(mp_digit)) 80 81 /* Round precision P to nearest word boundary */ 82 static inline mp_size s_round_prec(mp_size P) { return 2 * ((P + 1) / 2); } 83 84 /* Set array P of S digits to zero */ 85 static inline void ZERO(mp_digit *P, mp_size S) { 86 mp_size i__ = S * sizeof(mp_digit); 87 mp_digit *p__ = P; 88 memset(p__, 0, i__); 89 } 90 91 /* Copy S digits from array P to array Q */ 92 static inline void COPY(mp_digit *P, mp_digit *Q, mp_size S) { 93 mp_size i__ = S * sizeof(mp_digit); 94 mp_digit *p__ = P; 95 mp_digit *q__ = Q; 96 memcpy(q__, p__, i__); 97 } 98 99 /* Reverse N elements of unsigned char in A. */ 100 static inline void REV(unsigned char *A, int N) { 101 unsigned char *u_ = A; 102 unsigned char *v_ = u_ + N - 1; 103 while (u_ < v_) { 104 unsigned char xch = *u_; 105 *u_++ = *v_; 106 *v_-- = xch; 107 } 108 } 109 110 /* Strip leading zeroes from z_ in-place. */ 111 static inline void CLAMP(mp_int z_) { 112 mp_size uz_ = MP_USED(z_); 113 mp_digit *dz_ = MP_DIGITS(z_) + uz_ - 1; 114 while (uz_ > 1 && (*dz_-- == 0)) --uz_; 115 z_->used = uz_; 116 } 117 118 /* Select min/max. */ 119 static inline int MIN(int A, int B) { return (B < A ? B : A); } 120 static inline mp_size MAX(mp_size A, mp_size B) { return (B > A ? B : A); } 121 122 /* Exchange lvalues A and B of type T, e.g. 123 SWAP(int, x, y) where x and y are variables of type int. */ 124 #define SWAP(T, A, B) \ 125 do { \ 126 T t_ = (A); \ 127 A = (B); \ 128 B = t_; \ 129 } while (0) 130 131 /* Declare a block of N temporary mpz_t values. 132 These values are initialized to zero. 133 You must add CLEANUP_TEMP() at the end of the function. 134 Use TEMP(i) to access a pointer to the ith value. 135 */ 136 #define DECLARE_TEMP(N) \ 137 struct { \ 138 mpz_t value[(N)]; \ 139 int len; \ 140 mp_result err; \ 141 } temp_ = { \ 142 .len = (N), \ 143 .err = MP_OK, \ 144 }; \ 145 do { \ 146 for (int i = 0; i < temp_.len; i++) { \ 147 mp_int_init(TEMP(i)); \ 148 } \ 149 } while (0) 150 151 /* Clear all allocated temp values. */ 152 #define CLEANUP_TEMP() \ 153 CLEANUP: \ 154 do { \ 155 for (int i = 0; i < temp_.len; i++) { \ 156 mp_int_clear(TEMP(i)); \ 157 } \ 158 if (temp_.err != MP_OK) { \ 159 return temp_.err; \ 160 } \ 161 } while (0) 162 163 /* A pointer to the kth temp value. */ 164 #define TEMP(K) (temp_.value + (K)) 165 166 /* Evaluate E, an expression of type mp_result expected to return MP_OK. If 167 the value is not MP_OK, the error is cached and control resumes at the 168 cleanup handler, which returns it. 169 */ 170 #define REQUIRE(E) \ 171 do { \ 172 temp_.err = (E); \ 173 if (temp_.err != MP_OK) goto CLEANUP; \ 174 } while (0) 175 176 /* Compare value to zero. */ 177 static inline int CMPZ(mp_int Z) { 178 if (Z->used == 1 && Z->digits[0] == 0) return 0; 179 return (Z->sign == MP_NEG) ? -1 : 1; 180 } 181 182 static inline mp_word UPPER_HALF(mp_word W) { return (W >> MP_DIGIT_BIT); } 183 static inline mp_digit LOWER_HALF(mp_word W) { return (mp_digit)(W); } 184 185 /* Report whether the highest-order bit of W is 1. */ 186 static inline bool HIGH_BIT_SET(mp_word W) { 187 return (W >> (MP_WORD_BIT - 1)) != 0; 188 } 189 190 /* Report whether adding W + V will carry out. */ 191 static inline bool ADD_WILL_OVERFLOW(mp_word W, mp_word V) { 192 return ((MP_WORD_MAX - V) < W); 193 } 194 195 /* Default number of digits allocated to a new mp_int */ 196 static mp_size default_precision = 8; 197 198 void mp_int_default_precision(mp_size size) { 199 assert(size > 0); 200 default_precision = size; 201 } 202 203 /* Minimum number of digits to invoke recursive multiply */ 204 static mp_size multiply_threshold = 32; 205 206 void mp_int_multiply_threshold(mp_size thresh) { 207 assert(thresh >= sizeof(mp_word)); 208 multiply_threshold = thresh; 209 } 210 211 /* Allocate a buffer of (at least) num digits, or return 212 NULL if that couldn't be done. */ 213 static mp_digit *s_alloc(mp_size num); 214 215 /* Release a buffer of digits allocated by s_alloc(). */ 216 static void s_free(void *ptr); 217 218 /* Insure that z has at least min digits allocated, resizing if 219 necessary. Returns true if successful, false if out of memory. */ 220 static bool s_pad(mp_int z, mp_size min); 221 222 /* Ensure Z has at least N digits allocated. */ 223 static inline mp_result GROW(mp_int Z, mp_size N) { 224 return s_pad(Z, N) ? MP_OK : MP_MEMORY; 225 } 226 227 /* Fill in a "fake" mp_int on the stack with a given value */ 228 static void s_fake(mp_int z, mp_small value, mp_digit vbuf[]); 229 static void s_ufake(mp_int z, mp_usmall value, mp_digit vbuf[]); 230 231 /* Compare two runs of digits of given length, returns <0, 0, >0 */ 232 static int s_cdig(mp_digit *da, mp_digit *db, mp_size len); 233 234 /* Pack the unsigned digits of v into array t */ 235 static int s_uvpack(mp_usmall v, mp_digit t[]); 236 237 /* Compare magnitudes of a and b, returns <0, 0, >0 */ 238 static int s_ucmp(mp_int a, mp_int b); 239 240 /* Compare magnitudes of a and v, returns <0, 0, >0 */ 241 static int s_vcmp(mp_int a, mp_small v); 242 static int s_uvcmp(mp_int a, mp_usmall uv); 243 244 /* Unsigned magnitude addition; assumes dc is big enough. 245 Carry out is returned (no memory allocated). */ 246 static mp_digit s_uadd(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a, 247 mp_size size_b); 248 249 /* Unsigned magnitude subtraction. Assumes dc is big enough. */ 250 static void s_usub(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a, 251 mp_size size_b); 252 253 /* Unsigned recursive multiplication. Assumes dc is big enough. */ 254 static int s_kmul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a, 255 mp_size size_b); 256 257 /* Unsigned magnitude multiplication. Assumes dc is big enough. */ 258 static void s_umul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a, 259 mp_size size_b); 260 261 /* Unsigned recursive squaring. Assumes dc is big enough. */ 262 static int s_ksqr(mp_digit *da, mp_digit *dc, mp_size size_a); 263 264 /* Unsigned magnitude squaring. Assumes dc is big enough. */ 265 static void s_usqr(mp_digit *da, mp_digit *dc, mp_size size_a); 266 267 /* Single digit addition. Assumes a is big enough. */ 268 static void s_dadd(mp_int a, mp_digit b); 269 270 /* Single digit multiplication. Assumes a is big enough. */ 271 static void s_dmul(mp_int a, mp_digit b); 272 273 /* Single digit multiplication on buffers; assumes dc is big enough. */ 274 static void s_dbmul(mp_digit *da, mp_digit b, mp_digit *dc, mp_size size_a); 275 276 /* Single digit division. Replaces a with the quotient, 277 returns the remainder. */ 278 static mp_digit s_ddiv(mp_int a, mp_digit b); 279 280 /* Quick division by a power of 2, replaces z (no allocation) */ 281 static void s_qdiv(mp_int z, mp_size p2); 282 283 /* Quick remainder by a power of 2, replaces z (no allocation) */ 284 static void s_qmod(mp_int z, mp_size p2); 285 286 /* Quick multiplication by a power of 2, replaces z. 287 Allocates if necessary; returns false in case this fails. */ 288 static int s_qmul(mp_int z, mp_size p2); 289 290 /* Quick subtraction from a power of 2, replaces z. 291 Allocates if necessary; returns false in case this fails. */ 292 static int s_qsub(mp_int z, mp_size p2); 293 294 /* Return maximum k such that 2^k divides z. */ 295 static int s_dp2k(mp_int z); 296 297 /* Return k >= 0 such that z = 2^k, or -1 if there is no such k. */ 298 static int s_isp2(mp_int z); 299 300 /* Set z to 2^k. May allocate; returns false in case this fails. */ 301 static int s_2expt(mp_int z, mp_small k); 302 303 /* Normalize a and b for division, returns normalization constant */ 304 static int s_norm(mp_int a, mp_int b); 305 306 /* Compute constant mu for Barrett reduction, given modulus m, result 307 replaces z, m is untouched. */ 308 static mp_result s_brmu(mp_int z, mp_int m); 309 310 /* Reduce a modulo m, using Barrett's algorithm. */ 311 static int s_reduce(mp_int x, mp_int m, mp_int mu, mp_int q1, mp_int q2); 312 313 /* Modular exponentiation, using Barrett reduction */ 314 static mp_result s_embar(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c); 315 316 /* Unsigned magnitude division. Assumes |a| > |b|. Allocates temporaries; 317 overwrites a with quotient, b with remainder. */ 318 static mp_result s_udiv_knuth(mp_int a, mp_int b); 319 320 /* Compute the number of digits in radix r required to represent the given 321 value. Does not account for sign flags, terminators, etc. */ 322 static int s_outlen(mp_int z, mp_size r); 323 324 /* Guess how many digits of precision will be needed to represent a radix r 325 value of the specified number of digits. Returns a value guaranteed to be 326 no smaller than the actual number required. */ 327 static mp_size s_inlen(int len, mp_size r); 328 329 /* Convert a character to a digit value in radix r, or 330 -1 if out of range */ 331 static int s_ch2val(char c, int r); 332 333 /* Convert a digit value to a character */ 334 static char s_val2ch(int v, int caps); 335 336 /* Take 2's complement of a buffer in place */ 337 static void s_2comp(unsigned char *buf, int len); 338 339 /* Convert a value to binary, ignoring sign. On input, *limpos is the bound on 340 how many bytes should be written to buf; on output, *limpos is set to the 341 number of bytes actually written. */ 342 static mp_result s_tobin(mp_int z, unsigned char *buf, int *limpos, int pad); 343 344 /* Multiply X by Y into Z, ignoring signs. Requires that Z have enough storage 345 preallocated to hold the result. */ 346 static inline void UMUL(mp_int X, mp_int Y, mp_int Z) { 347 mp_size ua_ = MP_USED(X); 348 mp_size ub_ = MP_USED(Y); 349 mp_size o_ = ua_ + ub_; 350 ZERO(MP_DIGITS(Z), o_); 351 (void)s_kmul(MP_DIGITS(X), MP_DIGITS(Y), MP_DIGITS(Z), ua_, ub_); 352 Z->used = o_; 353 CLAMP(Z); 354 } 355 356 /* Square X into Z. Requires that Z have enough storage to hold the result. */ 357 static inline void USQR(mp_int X, mp_int Z) { 358 mp_size ua_ = MP_USED(X); 359 mp_size o_ = ua_ + ua_; 360 ZERO(MP_DIGITS(Z), o_); 361 (void)s_ksqr(MP_DIGITS(X), MP_DIGITS(Z), ua_); 362 Z->used = o_; 363 CLAMP(Z); 364 } 365 366 mp_result mp_int_init(mp_int z) { 367 if (z == NULL) return MP_BADARG; 368 369 z->single = 0; 370 z->digits = &(z->single); 371 z->alloc = 1; 372 z->used = 1; 373 z->sign = MP_ZPOS; 374 375 return MP_OK; 376 } 377 378 mp_int mp_int_alloc(void) { 379 mp_int out = malloc(sizeof(mpz_t)); 380 381 if (out != NULL) mp_int_init(out); 382 383 return out; 384 } 385 386 mp_result mp_int_init_size(mp_int z, mp_size prec) { 387 assert(z != NULL); 388 389 if (prec == 0) { 390 prec = default_precision; 391 } else if (prec == 1) { 392 return mp_int_init(z); 393 } else { 394 prec = s_round_prec(prec); 395 } 396 397 z->digits = s_alloc(prec); 398 if (MP_DIGITS(z) == NULL) return MP_MEMORY; 399 400 z->digits[0] = 0; 401 z->used = 1; 402 z->alloc = prec; 403 z->sign = MP_ZPOS; 404 405 return MP_OK; 406 } 407 408 mp_result mp_int_init_copy(mp_int z, mp_int old) { 409 assert(z != NULL && old != NULL); 410 411 mp_size uold = MP_USED(old); 412 if (uold == 1) { 413 mp_int_init(z); 414 } else { 415 mp_size target = MAX(uold, default_precision); 416 mp_result res = mp_int_init_size(z, target); 417 if (res != MP_OK) return res; 418 } 419 420 z->used = uold; 421 z->sign = old->sign; 422 COPY(MP_DIGITS(old), MP_DIGITS(z), uold); 423 424 return MP_OK; 425 } 426 427 mp_result mp_int_init_value(mp_int z, mp_small value) { 428 mpz_t vtmp; 429 mp_digit vbuf[MP_VALUE_DIGITS(value)]; 430 431 s_fake(&vtmp, value, vbuf); 432 return mp_int_init_copy(z, &vtmp); 433 } 434 435 mp_result mp_int_init_uvalue(mp_int z, mp_usmall uvalue) { 436 mpz_t vtmp; 437 mp_digit vbuf[MP_VALUE_DIGITS(uvalue)]; 438 439 s_ufake(&vtmp, uvalue, vbuf); 440 return mp_int_init_copy(z, &vtmp); 441 } 442 443 mp_result mp_int_set_value(mp_int z, mp_small value) { 444 mpz_t vtmp; 445 mp_digit vbuf[MP_VALUE_DIGITS(value)]; 446 447 s_fake(&vtmp, value, vbuf); 448 return mp_int_copy(&vtmp, z); 449 } 450 451 mp_result mp_int_set_uvalue(mp_int z, mp_usmall uvalue) { 452 mpz_t vtmp; 453 mp_digit vbuf[MP_VALUE_DIGITS(uvalue)]; 454 455 s_ufake(&vtmp, uvalue, vbuf); 456 return mp_int_copy(&vtmp, z); 457 } 458 459 void mp_int_clear(mp_int z) { 460 if (z == NULL) return; 461 462 if (MP_DIGITS(z) != NULL) { 463 if (MP_DIGITS(z) != &(z->single)) s_free(MP_DIGITS(z)); 464 465 z->digits = NULL; 466 } 467 } 468 469 void mp_int_free(mp_int z) { 470 assert(z != NULL); 471 472 mp_int_clear(z); 473 free(z); /* note: NOT s_free() */ 474 } 475 476 mp_result mp_int_copy(mp_int a, mp_int c) { 477 assert(a != NULL && c != NULL); 478 479 if (a != c) { 480 mp_size ua = MP_USED(a); 481 mp_digit *da, *dc; 482 483 if (!s_pad(c, ua)) return MP_MEMORY; 484 485 da = MP_DIGITS(a); 486 dc = MP_DIGITS(c); 487 COPY(da, dc, ua); 488 489 c->used = ua; 490 c->sign = a->sign; 491 } 492 493 return MP_OK; 494 } 495 496 void mp_int_swap(mp_int a, mp_int c) { 497 if (a != c) { 498 mpz_t tmp = *a; 499 500 *a = *c; 501 *c = tmp; 502 503 if (MP_DIGITS(a) == &(c->single)) a->digits = &(a->single); 504 if (MP_DIGITS(c) == &(a->single)) c->digits = &(c->single); 505 } 506 } 507 508 void mp_int_zero(mp_int z) { 509 assert(z != NULL); 510 511 z->digits[0] = 0; 512 z->used = 1; 513 z->sign = MP_ZPOS; 514 } 515 516 mp_result mp_int_abs(mp_int a, mp_int c) { 517 assert(a != NULL && c != NULL); 518 519 mp_result res; 520 if ((res = mp_int_copy(a, c)) != MP_OK) return res; 521 522 c->sign = MP_ZPOS; 523 return MP_OK; 524 } 525 526 mp_result mp_int_neg(mp_int a, mp_int c) { 527 assert(a != NULL && c != NULL); 528 529 mp_result res; 530 if ((res = mp_int_copy(a, c)) != MP_OK) return res; 531 532 if (CMPZ(c) != 0) c->sign = 1 - MP_SIGN(a); 533 534 return MP_OK; 535 } 536 537 mp_result mp_int_add(mp_int a, mp_int b, mp_int c) { 538 assert(a != NULL && b != NULL && c != NULL); 539 540 mp_size ua = MP_USED(a); 541 mp_size ub = MP_USED(b); 542 mp_size max = MAX(ua, ub); 543 544 if (MP_SIGN(a) == MP_SIGN(b)) { 545 /* Same sign -- add magnitudes, preserve sign of addends */ 546 if (!s_pad(c, max)) return MP_MEMORY; 547 548 mp_digit carry = s_uadd(MP_DIGITS(a), MP_DIGITS(b), MP_DIGITS(c), ua, ub); 549 mp_size uc = max; 550 551 if (carry) { 552 if (!s_pad(c, max + 1)) return MP_MEMORY; 553 554 c->digits[max] = carry; 555 ++uc; 556 } 557 558 c->used = uc; 559 c->sign = a->sign; 560 561 } else { 562 /* Different signs -- subtract magnitudes, preserve sign of greater */ 563 int cmp = s_ucmp(a, b); /* magnitude comparison, sign ignored */ 564 565 /* Set x to max(a, b), y to min(a, b) to simplify later code. 566 A special case yields zero for equal magnitudes. 567 */ 568 mp_int x, y; 569 if (cmp == 0) { 570 mp_int_zero(c); 571 return MP_OK; 572 } else if (cmp < 0) { 573 x = b; 574 y = a; 575 } else { 576 x = a; 577 y = b; 578 } 579 580 if (!s_pad(c, MP_USED(x))) return MP_MEMORY; 581 582 /* Subtract smaller from larger */ 583 s_usub(MP_DIGITS(x), MP_DIGITS(y), MP_DIGITS(c), MP_USED(x), MP_USED(y)); 584 c->used = x->used; 585 CLAMP(c); 586 587 /* Give result the sign of the larger */ 588 c->sign = x->sign; 589 } 590 591 return MP_OK; 592 } 593 594 mp_result mp_int_add_value(mp_int a, mp_small value, mp_int c) { 595 mpz_t vtmp; 596 mp_digit vbuf[MP_VALUE_DIGITS(value)]; 597 598 s_fake(&vtmp, value, vbuf); 599 600 return mp_int_add(a, &vtmp, c); 601 } 602 603 mp_result mp_int_sub(mp_int a, mp_int b, mp_int c) { 604 assert(a != NULL && b != NULL && c != NULL); 605 606 mp_size ua = MP_USED(a); 607 mp_size ub = MP_USED(b); 608 mp_size max = MAX(ua, ub); 609 610 if (MP_SIGN(a) != MP_SIGN(b)) { 611 /* Different signs -- add magnitudes and keep sign of a */ 612 if (!s_pad(c, max)) return MP_MEMORY; 613 614 mp_digit carry = s_uadd(MP_DIGITS(a), MP_DIGITS(b), MP_DIGITS(c), ua, ub); 615 mp_size uc = max; 616 617 if (carry) { 618 if (!s_pad(c, max + 1)) return MP_MEMORY; 619 620 c->digits[max] = carry; 621 ++uc; 622 } 623 624 c->used = uc; 625 c->sign = a->sign; 626 627 } else { 628 /* Same signs -- subtract magnitudes */ 629 if (!s_pad(c, max)) return MP_MEMORY; 630 mp_int x, y; 631 mp_sign osign; 632 633 int cmp = s_ucmp(a, b); 634 if (cmp >= 0) { 635 x = a; 636 y = b; 637 osign = MP_ZPOS; 638 } else { 639 x = b; 640 y = a; 641 osign = MP_NEG; 642 } 643 644 if (MP_SIGN(a) == MP_NEG && cmp != 0) osign = 1 - osign; 645 646 s_usub(MP_DIGITS(x), MP_DIGITS(y), MP_DIGITS(c), MP_USED(x), MP_USED(y)); 647 c->used = x->used; 648 CLAMP(c); 649 650 c->sign = osign; 651 } 652 653 return MP_OK; 654 } 655 656 mp_result mp_int_sub_value(mp_int a, mp_small value, mp_int c) { 657 mpz_t vtmp; 658 mp_digit vbuf[MP_VALUE_DIGITS(value)]; 659 660 s_fake(&vtmp, value, vbuf); 661 662 return mp_int_sub(a, &vtmp, c); 663 } 664 665 mp_result mp_int_mul(mp_int a, mp_int b, mp_int c) { 666 assert(a != NULL && b != NULL && c != NULL); 667 668 /* If either input is zero, we can shortcut multiplication */ 669 if (mp_int_compare_zero(a) == 0 || mp_int_compare_zero(b) == 0) { 670 mp_int_zero(c); 671 return MP_OK; 672 } 673 674 /* Output is positive if inputs have same sign, otherwise negative */ 675 mp_sign osign = (MP_SIGN(a) == MP_SIGN(b)) ? MP_ZPOS : MP_NEG; 676 677 /* If the output is not identical to any of the inputs, we'll write the 678 results directly; otherwise, allocate a temporary space. */ 679 mp_size ua = MP_USED(a); 680 mp_size ub = MP_USED(b); 681 mp_size osize = MAX(ua, ub); 682 osize = 4 * ((osize + 1) / 2); 683 684 mp_digit *out; 685 mp_size p = 0; 686 if (c == a || c == b) { 687 p = MAX(s_round_prec(osize), default_precision); 688 689 if ((out = s_alloc(p)) == NULL) return MP_MEMORY; 690 } else { 691 if (!s_pad(c, osize)) return MP_MEMORY; 692 693 out = MP_DIGITS(c); 694 } 695 ZERO(out, osize); 696 697 if (!s_kmul(MP_DIGITS(a), MP_DIGITS(b), out, ua, ub)) return MP_MEMORY; 698 699 /* If we allocated a new buffer, get rid of whatever memory c was already 700 using, and fix up its fields to reflect that. 701 */ 702 if (out != MP_DIGITS(c)) { 703 if ((void *)MP_DIGITS(c) != (void *)c) s_free(MP_DIGITS(c)); 704 c->digits = out; 705 c->alloc = p; 706 } 707 708 c->used = osize; /* might not be true, but we'll fix it ... */ 709 CLAMP(c); /* ... right here */ 710 c->sign = osign; 711 712 return MP_OK; 713 } 714 715 mp_result mp_int_mul_value(mp_int a, mp_small value, mp_int c) { 716 mpz_t vtmp; 717 mp_digit vbuf[MP_VALUE_DIGITS(value)]; 718 719 s_fake(&vtmp, value, vbuf); 720 721 return mp_int_mul(a, &vtmp, c); 722 } 723 724 mp_result mp_int_mul_pow2(mp_int a, mp_small p2, mp_int c) { 725 assert(a != NULL && c != NULL && p2 >= 0); 726 727 mp_result res = mp_int_copy(a, c); 728 if (res != MP_OK) return res; 729 730 if (s_qmul(c, (mp_size)p2)) { 731 return MP_OK; 732 } else { 733 return MP_MEMORY; 734 } 735 } 736 737 mp_result mp_int_sqr(mp_int a, mp_int c) { 738 assert(a != NULL && c != NULL); 739 740 /* Get a temporary buffer big enough to hold the result */ 741 mp_size osize = (mp_size)4 * ((MP_USED(a) + 1) / 2); 742 mp_size p = 0; 743 mp_digit *out; 744 if (a == c) { 745 p = s_round_prec(osize); 746 p = MAX(p, default_precision); 747 748 if ((out = s_alloc(p)) == NULL) return MP_MEMORY; 749 } else { 750 if (!s_pad(c, osize)) return MP_MEMORY; 751 752 out = MP_DIGITS(c); 753 } 754 ZERO(out, osize); 755 756 s_ksqr(MP_DIGITS(a), out, MP_USED(a)); 757 758 /* Get rid of whatever memory c was already using, and fix up its fields to 759 reflect the new digit array it's using 760 */ 761 if (out != MP_DIGITS(c)) { 762 if ((void *)MP_DIGITS(c) != (void *)c) s_free(MP_DIGITS(c)); 763 c->digits = out; 764 c->alloc = p; 765 } 766 767 c->used = osize; /* might not be true, but we'll fix it ... */ 768 CLAMP(c); /* ... right here */ 769 c->sign = MP_ZPOS; 770 771 return MP_OK; 772 } 773 774 mp_result mp_int_div(mp_int a, mp_int b, mp_int q, mp_int r) { 775 assert(a != NULL && b != NULL && q != r); 776 777 int cmp; 778 mp_result res = MP_OK; 779 mp_int qout, rout; 780 mp_sign sa = MP_SIGN(a); 781 mp_sign sb = MP_SIGN(b); 782 if (CMPZ(b) == 0) { 783 return MP_UNDEF; 784 } else if ((cmp = s_ucmp(a, b)) < 0) { 785 /* If |a| < |b|, no division is required: 786 q = 0, r = a 787 */ 788 if (r && (res = mp_int_copy(a, r)) != MP_OK) return res; 789 790 if (q) mp_int_zero(q); 791 792 return MP_OK; 793 } else if (cmp == 0) { 794 /* If |a| = |b|, no division is required: 795 q = 1 or -1, r = 0 796 */ 797 if (r) mp_int_zero(r); 798 799 if (q) { 800 mp_int_zero(q); 801 q->digits[0] = 1; 802 803 if (sa != sb) q->sign = MP_NEG; 804 } 805 806 return MP_OK; 807 } 808 809 /* When |a| > |b|, real division is required. We need someplace to store 810 quotient and remainder, but q and r are allowed to be NULL or to overlap 811 with the inputs. 812 */ 813 DECLARE_TEMP(2); 814 int lg; 815 if ((lg = s_isp2(b)) < 0) { 816 if (q && b != q) { 817 REQUIRE(mp_int_copy(a, q)); 818 qout = q; 819 } else { 820 REQUIRE(mp_int_copy(a, TEMP(0))); 821 qout = TEMP(0); 822 } 823 824 if (r && a != r) { 825 REQUIRE(mp_int_copy(b, r)); 826 rout = r; 827 } else { 828 REQUIRE(mp_int_copy(b, TEMP(1))); 829 rout = TEMP(1); 830 } 831 832 REQUIRE(s_udiv_knuth(qout, rout)); 833 } else { 834 if (q) REQUIRE(mp_int_copy(a, q)); 835 if (r) REQUIRE(mp_int_copy(a, r)); 836 837 if (q) s_qdiv(q, (mp_size)lg); 838 qout = q; 839 if (r) s_qmod(r, (mp_size)lg); 840 rout = r; 841 } 842 843 /* Recompute signs for output */ 844 if (rout) { 845 rout->sign = sa; 846 if (CMPZ(rout) == 0) rout->sign = MP_ZPOS; 847 } 848 if (qout) { 849 qout->sign = (sa == sb) ? MP_ZPOS : MP_NEG; 850 if (CMPZ(qout) == 0) qout->sign = MP_ZPOS; 851 } 852 853 if (q) REQUIRE(mp_int_copy(qout, q)); 854 if (r) REQUIRE(mp_int_copy(rout, r)); 855 CLEANUP_TEMP(); 856 return res; 857 } 858 859 mp_result mp_int_mod(mp_int a, mp_int m, mp_int c) { 860 DECLARE_TEMP(1); 861 mp_int out = (m == c) ? TEMP(0) : c; 862 REQUIRE(mp_int_div(a, m, NULL, out)); 863 if (CMPZ(out) < 0) { 864 REQUIRE(mp_int_add(out, m, c)); 865 } else { 866 REQUIRE(mp_int_copy(out, c)); 867 } 868 CLEANUP_TEMP(); 869 return MP_OK; 870 } 871 872 mp_result mp_int_div_value(mp_int a, mp_small value, mp_int q, mp_small *r) { 873 mpz_t vtmp; 874 mp_digit vbuf[MP_VALUE_DIGITS(value)]; 875 s_fake(&vtmp, value, vbuf); 876 877 DECLARE_TEMP(1); 878 REQUIRE(mp_int_div(a, &vtmp, q, TEMP(0))); 879 880 if (r) (void)mp_int_to_int(TEMP(0), r); /* can't fail */ 881 882 CLEANUP_TEMP(); 883 return MP_OK; 884 } 885 886 mp_result mp_int_div_pow2(mp_int a, mp_small p2, mp_int q, mp_int r) { 887 assert(a != NULL && p2 >= 0 && q != r); 888 889 mp_result res = MP_OK; 890 if (q != NULL && (res = mp_int_copy(a, q)) == MP_OK) { 891 s_qdiv(q, (mp_size)p2); 892 } 893 894 if (res == MP_OK && r != NULL && (res = mp_int_copy(a, r)) == MP_OK) { 895 s_qmod(r, (mp_size)p2); 896 } 897 898 return res; 899 } 900 901 mp_result mp_int_expt(mp_int a, mp_small b, mp_int c) { 902 assert(c != NULL); 903 if (b < 0) return MP_RANGE; 904 905 DECLARE_TEMP(1); 906 REQUIRE(mp_int_copy(a, TEMP(0))); 907 908 (void)mp_int_set_value(c, 1); 909 unsigned int v = labs(b); 910 while (v != 0) { 911 if (v & 1) { 912 REQUIRE(mp_int_mul(c, TEMP(0), c)); 913 } 914 915 v >>= 1; 916 if (v == 0) break; 917 918 REQUIRE(mp_int_sqr(TEMP(0), TEMP(0))); 919 } 920 921 CLEANUP_TEMP(); 922 return MP_OK; 923 } 924 925 mp_result mp_int_expt_value(mp_small a, mp_small b, mp_int c) { 926 assert(c != NULL); 927 if (b < 0) return MP_RANGE; 928 929 DECLARE_TEMP(1); 930 REQUIRE(mp_int_set_value(TEMP(0), a)); 931 932 (void)mp_int_set_value(c, 1); 933 unsigned int v = labs(b); 934 while (v != 0) { 935 if (v & 1) { 936 REQUIRE(mp_int_mul(c, TEMP(0), c)); 937 } 938 939 v >>= 1; 940 if (v == 0) break; 941 942 REQUIRE(mp_int_sqr(TEMP(0), TEMP(0))); 943 } 944 945 CLEANUP_TEMP(); 946 return MP_OK; 947 } 948 949 mp_result mp_int_expt_full(mp_int a, mp_int b, mp_int c) { 950 assert(a != NULL && b != NULL && c != NULL); 951 if (MP_SIGN(b) == MP_NEG) return MP_RANGE; 952 953 DECLARE_TEMP(1); 954 REQUIRE(mp_int_copy(a, TEMP(0))); 955 956 (void)mp_int_set_value(c, 1); 957 for (unsigned ix = 0; ix < MP_USED(b); ++ix) { 958 mp_digit d = b->digits[ix]; 959 960 for (unsigned jx = 0; jx < MP_DIGIT_BIT; ++jx) { 961 if (d & 1) { 962 REQUIRE(mp_int_mul(c, TEMP(0), c)); 963 } 964 965 d >>= 1; 966 if (d == 0 && ix + 1 == MP_USED(b)) break; 967 REQUIRE(mp_int_sqr(TEMP(0), TEMP(0))); 968 } 969 } 970 971 CLEANUP_TEMP(); 972 return MP_OK; 973 } 974 975 int mp_int_compare(mp_int a, mp_int b) { 976 assert(a != NULL && b != NULL); 977 978 mp_sign sa = MP_SIGN(a); 979 if (sa == MP_SIGN(b)) { 980 int cmp = s_ucmp(a, b); 981 982 /* If they're both zero or positive, the normal comparison applies; if both 983 negative, the sense is reversed. */ 984 if (sa == MP_ZPOS) { 985 return cmp; 986 } else { 987 return -cmp; 988 } 989 } else if (sa == MP_ZPOS) { 990 return 1; 991 } else { 992 return -1; 993 } 994 } 995 996 int mp_int_compare_unsigned(mp_int a, mp_int b) { 997 assert(a != NULL && b != NULL); 998 999 return s_ucmp(a, b); 1000 } 1001 1002 int mp_int_compare_zero(mp_int z) { 1003 assert(z != NULL); 1004 1005 if (MP_USED(z) == 1 && z->digits[0] == 0) { 1006 return 0; 1007 } else if (MP_SIGN(z) == MP_ZPOS) { 1008 return 1; 1009 } else { 1010 return -1; 1011 } 1012 } 1013 1014 int mp_int_compare_value(mp_int z, mp_small value) { 1015 assert(z != NULL); 1016 1017 mp_sign vsign = (value < 0) ? MP_NEG : MP_ZPOS; 1018 if (vsign == MP_SIGN(z)) { 1019 int cmp = s_vcmp(z, value); 1020 1021 return (vsign == MP_ZPOS) ? cmp : -cmp; 1022 } else { 1023 return (value < 0) ? 1 : -1; 1024 } 1025 } 1026 1027 int mp_int_compare_uvalue(mp_int z, mp_usmall uv) { 1028 assert(z != NULL); 1029 1030 if (MP_SIGN(z) == MP_NEG) { 1031 return -1; 1032 } else { 1033 return s_uvcmp(z, uv); 1034 } 1035 } 1036 1037 mp_result mp_int_exptmod(mp_int a, mp_int b, mp_int m, mp_int c) { 1038 assert(a != NULL && b != NULL && c != NULL && m != NULL); 1039 1040 /* Zero moduli and negative exponents are not considered. */ 1041 if (CMPZ(m) == 0) return MP_UNDEF; 1042 if (CMPZ(b) < 0) return MP_RANGE; 1043 1044 mp_size um = MP_USED(m); 1045 DECLARE_TEMP(3); 1046 REQUIRE(GROW(TEMP(0), 2 * um)); 1047 REQUIRE(GROW(TEMP(1), 2 * um)); 1048 1049 mp_int s; 1050 if (c == b || c == m) { 1051 REQUIRE(GROW(TEMP(2), 2 * um)); 1052 s = TEMP(2); 1053 } else { 1054 s = c; 1055 } 1056 1057 REQUIRE(mp_int_mod(a, m, TEMP(0))); 1058 REQUIRE(s_brmu(TEMP(1), m)); 1059 REQUIRE(s_embar(TEMP(0), b, m, TEMP(1), s)); 1060 REQUIRE(mp_int_copy(s, c)); 1061 1062 CLEANUP_TEMP(); 1063 return MP_OK; 1064 } 1065 1066 mp_result mp_int_exptmod_evalue(mp_int a, mp_small value, mp_int m, mp_int c) { 1067 mpz_t vtmp; 1068 mp_digit vbuf[MP_VALUE_DIGITS(value)]; 1069 1070 s_fake(&vtmp, value, vbuf); 1071 1072 return mp_int_exptmod(a, &vtmp, m, c); 1073 } 1074 1075 mp_result mp_int_exptmod_bvalue(mp_small value, mp_int b, mp_int m, mp_int c) { 1076 mpz_t vtmp; 1077 mp_digit vbuf[MP_VALUE_DIGITS(value)]; 1078 1079 s_fake(&vtmp, value, vbuf); 1080 1081 return mp_int_exptmod(&vtmp, b, m, c); 1082 } 1083 1084 mp_result mp_int_exptmod_known(mp_int a, mp_int b, mp_int m, mp_int mu, 1085 mp_int c) { 1086 assert(a && b && m && c); 1087 1088 /* Zero moduli and negative exponents are not considered. */ 1089 if (CMPZ(m) == 0) return MP_UNDEF; 1090 if (CMPZ(b) < 0) return MP_RANGE; 1091 1092 DECLARE_TEMP(2); 1093 mp_size um = MP_USED(m); 1094 REQUIRE(GROW(TEMP(0), 2 * um)); 1095 1096 mp_int s; 1097 if (c == b || c == m) { 1098 REQUIRE(GROW(TEMP(1), 2 * um)); 1099 s = TEMP(1); 1100 } else { 1101 s = c; 1102 } 1103 1104 REQUIRE(mp_int_mod(a, m, TEMP(0))); 1105 REQUIRE(s_embar(TEMP(0), b, m, mu, s)); 1106 REQUIRE(mp_int_copy(s, c)); 1107 1108 CLEANUP_TEMP(); 1109 return MP_OK; 1110 } 1111 1112 mp_result mp_int_redux_const(mp_int m, mp_int c) { 1113 assert(m != NULL && c != NULL && m != c); 1114 1115 return s_brmu(c, m); 1116 } 1117 1118 mp_result mp_int_invmod(mp_int a, mp_int m, mp_int c) { 1119 assert(a != NULL && m != NULL && c != NULL); 1120 1121 if (CMPZ(a) == 0 || CMPZ(m) <= 0) return MP_RANGE; 1122 1123 DECLARE_TEMP(2); 1124 1125 REQUIRE(mp_int_egcd(a, m, TEMP(0), TEMP(1), NULL)); 1126 1127 if (mp_int_compare_value(TEMP(0), 1) != 0) { 1128 REQUIRE(MP_UNDEF); 1129 } 1130 1131 /* It is first necessary to constrain the value to the proper range */ 1132 REQUIRE(mp_int_mod(TEMP(1), m, TEMP(1))); 1133 1134 /* Now, if 'a' was originally negative, the value we have is actually the 1135 magnitude of the negative representative; to get the positive value we 1136 have to subtract from the modulus. Otherwise, the value is okay as it 1137 stands. 1138 */ 1139 if (MP_SIGN(a) == MP_NEG) { 1140 REQUIRE(mp_int_sub(m, TEMP(1), c)); 1141 } else { 1142 REQUIRE(mp_int_copy(TEMP(1), c)); 1143 } 1144 1145 CLEANUP_TEMP(); 1146 return MP_OK; 1147 } 1148 1149 /* Binary GCD algorithm due to Josef Stein, 1961 */ 1150 mp_result mp_int_gcd(mp_int a, mp_int b, mp_int c) { 1151 assert(a != NULL && b != NULL && c != NULL); 1152 1153 int ca = CMPZ(a); 1154 int cb = CMPZ(b); 1155 if (ca == 0 && cb == 0) { 1156 return MP_UNDEF; 1157 } else if (ca == 0) { 1158 return mp_int_abs(b, c); 1159 } else if (cb == 0) { 1160 return mp_int_abs(a, c); 1161 } 1162 1163 DECLARE_TEMP(3); 1164 REQUIRE(mp_int_copy(a, TEMP(0))); 1165 REQUIRE(mp_int_copy(b, TEMP(1))); 1166 1167 TEMP(0)->sign = MP_ZPOS; 1168 TEMP(1)->sign = MP_ZPOS; 1169 1170 int k = 0; 1171 { /* Divide out common factors of 2 from u and v */ 1172 int div2_u = s_dp2k(TEMP(0)); 1173 int div2_v = s_dp2k(TEMP(1)); 1174 1175 k = MIN(div2_u, div2_v); 1176 s_qdiv(TEMP(0), (mp_size)k); 1177 s_qdiv(TEMP(1), (mp_size)k); 1178 } 1179 1180 if (mp_int_is_odd(TEMP(0))) { 1181 REQUIRE(mp_int_neg(TEMP(1), TEMP(2))); 1182 } else { 1183 REQUIRE(mp_int_copy(TEMP(0), TEMP(2))); 1184 } 1185 1186 for (;;) { 1187 s_qdiv(TEMP(2), s_dp2k(TEMP(2))); 1188 1189 if (CMPZ(TEMP(2)) > 0) { 1190 REQUIRE(mp_int_copy(TEMP(2), TEMP(0))); 1191 } else { 1192 REQUIRE(mp_int_neg(TEMP(2), TEMP(1))); 1193 } 1194 1195 REQUIRE(mp_int_sub(TEMP(0), TEMP(1), TEMP(2))); 1196 1197 if (CMPZ(TEMP(2)) == 0) break; 1198 } 1199 1200 REQUIRE(mp_int_abs(TEMP(0), c)); 1201 if (!s_qmul(c, (mp_size)k)) REQUIRE(MP_MEMORY); 1202 1203 CLEANUP_TEMP(); 1204 return MP_OK; 1205 } 1206 1207 /* This is the binary GCD algorithm again, but this time we keep track of the 1208 elementary matrix operations as we go, so we can get values x and y 1209 satisfying c = ax + by. 1210 */ 1211 mp_result mp_int_egcd(mp_int a, mp_int b, mp_int c, mp_int x, mp_int y) { 1212 assert(a != NULL && b != NULL && c != NULL && (x != NULL || y != NULL)); 1213 1214 mp_result res = MP_OK; 1215 int ca = CMPZ(a); 1216 int cb = CMPZ(b); 1217 if (ca == 0 && cb == 0) { 1218 return MP_UNDEF; 1219 } else if (ca == 0) { 1220 if ((res = mp_int_abs(b, c)) != MP_OK) return res; 1221 mp_int_zero(x); 1222 (void)mp_int_set_value(y, 1); 1223 return MP_OK; 1224 } else if (cb == 0) { 1225 if ((res = mp_int_abs(a, c)) != MP_OK) return res; 1226 (void)mp_int_set_value(x, 1); 1227 mp_int_zero(y); 1228 return MP_OK; 1229 } 1230 1231 /* Initialize temporaries: 1232 A:0, B:1, C:2, D:3, u:4, v:5, ou:6, ov:7 */ 1233 DECLARE_TEMP(8); 1234 REQUIRE(mp_int_set_value(TEMP(0), 1)); 1235 REQUIRE(mp_int_set_value(TEMP(3), 1)); 1236 REQUIRE(mp_int_copy(a, TEMP(4))); 1237 REQUIRE(mp_int_copy(b, TEMP(5))); 1238 1239 /* We will work with absolute values here */ 1240 TEMP(4)->sign = MP_ZPOS; 1241 TEMP(5)->sign = MP_ZPOS; 1242 1243 int k = 0; 1244 { /* Divide out common factors of 2 from u and v */ 1245 int div2_u = s_dp2k(TEMP(4)), div2_v = s_dp2k(TEMP(5)); 1246 1247 k = MIN(div2_u, div2_v); 1248 s_qdiv(TEMP(4), k); 1249 s_qdiv(TEMP(5), k); 1250 } 1251 1252 REQUIRE(mp_int_copy(TEMP(4), TEMP(6))); 1253 REQUIRE(mp_int_copy(TEMP(5), TEMP(7))); 1254 1255 for (;;) { 1256 while (mp_int_is_even(TEMP(4))) { 1257 s_qdiv(TEMP(4), 1); 1258 1259 if (mp_int_is_odd(TEMP(0)) || mp_int_is_odd(TEMP(1))) { 1260 REQUIRE(mp_int_add(TEMP(0), TEMP(7), TEMP(0))); 1261 REQUIRE(mp_int_sub(TEMP(1), TEMP(6), TEMP(1))); 1262 } 1263 1264 s_qdiv(TEMP(0), 1); 1265 s_qdiv(TEMP(1), 1); 1266 } 1267 1268 while (mp_int_is_even(TEMP(5))) { 1269 s_qdiv(TEMP(5), 1); 1270 1271 if (mp_int_is_odd(TEMP(2)) || mp_int_is_odd(TEMP(3))) { 1272 REQUIRE(mp_int_add(TEMP(2), TEMP(7), TEMP(2))); 1273 REQUIRE(mp_int_sub(TEMP(3), TEMP(6), TEMP(3))); 1274 } 1275 1276 s_qdiv(TEMP(2), 1); 1277 s_qdiv(TEMP(3), 1); 1278 } 1279 1280 if (mp_int_compare(TEMP(4), TEMP(5)) >= 0) { 1281 REQUIRE(mp_int_sub(TEMP(4), TEMP(5), TEMP(4))); 1282 REQUIRE(mp_int_sub(TEMP(0), TEMP(2), TEMP(0))); 1283 REQUIRE(mp_int_sub(TEMP(1), TEMP(3), TEMP(1))); 1284 } else { 1285 REQUIRE(mp_int_sub(TEMP(5), TEMP(4), TEMP(5))); 1286 REQUIRE(mp_int_sub(TEMP(2), TEMP(0), TEMP(2))); 1287 REQUIRE(mp_int_sub(TEMP(3), TEMP(1), TEMP(3))); 1288 } 1289 1290 if (CMPZ(TEMP(4)) == 0) { 1291 if (x) REQUIRE(mp_int_copy(TEMP(2), x)); 1292 if (y) REQUIRE(mp_int_copy(TEMP(3), y)); 1293 if (c) { 1294 if (!s_qmul(TEMP(5), k)) { 1295 REQUIRE(MP_MEMORY); 1296 } 1297 REQUIRE(mp_int_copy(TEMP(5), c)); 1298 } 1299 1300 break; 1301 } 1302 } 1303 1304 CLEANUP_TEMP(); 1305 return MP_OK; 1306 } 1307 1308 mp_result mp_int_lcm(mp_int a, mp_int b, mp_int c) { 1309 assert(a != NULL && b != NULL && c != NULL); 1310 1311 /* Since a * b = gcd(a, b) * lcm(a, b), we can compute 1312 lcm(a, b) = (a / gcd(a, b)) * b. 1313 1314 This formulation insures everything works even if the input 1315 variables share space. 1316 */ 1317 DECLARE_TEMP(1); 1318 REQUIRE(mp_int_gcd(a, b, TEMP(0))); 1319 REQUIRE(mp_int_div(a, TEMP(0), TEMP(0), NULL)); 1320 REQUIRE(mp_int_mul(TEMP(0), b, TEMP(0))); 1321 REQUIRE(mp_int_copy(TEMP(0), c)); 1322 1323 CLEANUP_TEMP(); 1324 return MP_OK; 1325 } 1326 1327 bool mp_int_divisible_value(mp_int a, mp_small v) { 1328 mp_small rem = 0; 1329 1330 if (mp_int_div_value(a, v, NULL, &rem) != MP_OK) { 1331 return false; 1332 } 1333 return rem == 0; 1334 } 1335 1336 int mp_int_is_pow2(mp_int z) { 1337 assert(z != NULL); 1338 1339 return s_isp2(z); 1340 } 1341 1342 /* Implementation of Newton's root finding method, based loosely on a patch 1343 contributed by Hal Finkel <half (at) halssoftware.com> 1344 modified by M. J. Fromberger. 1345 */ 1346 mp_result mp_int_root(mp_int a, mp_small b, mp_int c) { 1347 assert(a != NULL && c != NULL && b > 0); 1348 1349 if (b == 1) { 1350 return mp_int_copy(a, c); 1351 } 1352 bool flips = false; 1353 if (MP_SIGN(a) == MP_NEG) { 1354 if (b % 2 == 0) { 1355 return MP_UNDEF; /* root does not exist for negative a with even b */ 1356 } else { 1357 flips = true; 1358 } 1359 } 1360 1361 DECLARE_TEMP(5); 1362 REQUIRE(mp_int_copy(a, TEMP(0))); 1363 REQUIRE(mp_int_copy(a, TEMP(1))); 1364 TEMP(0)->sign = MP_ZPOS; 1365 TEMP(1)->sign = MP_ZPOS; 1366 1367 for (;;) { 1368 REQUIRE(mp_int_expt(TEMP(1), b, TEMP(2))); 1369 1370 if (mp_int_compare_unsigned(TEMP(2), TEMP(0)) <= 0) break; 1371 1372 REQUIRE(mp_int_sub(TEMP(2), TEMP(0), TEMP(2))); 1373 REQUIRE(mp_int_expt(TEMP(1), b - 1, TEMP(3))); 1374 REQUIRE(mp_int_mul_value(TEMP(3), b, TEMP(3))); 1375 REQUIRE(mp_int_div(TEMP(2), TEMP(3), TEMP(4), NULL)); 1376 REQUIRE(mp_int_sub(TEMP(1), TEMP(4), TEMP(4))); 1377 1378 if (mp_int_compare_unsigned(TEMP(1), TEMP(4)) == 0) { 1379 REQUIRE(mp_int_sub_value(TEMP(4), 1, TEMP(4))); 1380 } 1381 REQUIRE(mp_int_copy(TEMP(4), TEMP(1))); 1382 } 1383 1384 REQUIRE(mp_int_copy(TEMP(1), c)); 1385 1386 /* If the original value of a was negative, flip the output sign. */ 1387 if (flips) (void)mp_int_neg(c, c); /* cannot fail */ 1388 1389 CLEANUP_TEMP(); 1390 return MP_OK; 1391 } 1392 1393 mp_result mp_int_to_int(mp_int z, mp_small *out) { 1394 assert(z != NULL); 1395 1396 /* Make sure the value is representable as a small integer */ 1397 mp_sign sz = MP_SIGN(z); 1398 if ((sz == MP_ZPOS && mp_int_compare_value(z, MP_SMALL_MAX) > 0) || 1399 mp_int_compare_value(z, MP_SMALL_MIN) < 0) { 1400 return MP_RANGE; 1401 } 1402 1403 mp_usmall uz = MP_USED(z); 1404 mp_digit *dz = MP_DIGITS(z) + uz - 1; 1405 mp_small uv = 0; 1406 while (uz > 0) { 1407 uv <<= MP_DIGIT_BIT / 2; 1408 uv = (uv << (MP_DIGIT_BIT / 2)) | *dz--; 1409 --uz; 1410 } 1411 1412 if (out) *out = (mp_small)((sz == MP_NEG) ? -uv : uv); 1413 1414 return MP_OK; 1415 } 1416 1417 mp_result mp_int_to_uint(mp_int z, mp_usmall *out) { 1418 assert(z != NULL); 1419 1420 /* Make sure the value is representable as an unsigned small integer */ 1421 mp_size sz = MP_SIGN(z); 1422 if (sz == MP_NEG || mp_int_compare_uvalue(z, MP_USMALL_MAX) > 0) { 1423 return MP_RANGE; 1424 } 1425 1426 mp_size uz = MP_USED(z); 1427 mp_digit *dz = MP_DIGITS(z) + uz - 1; 1428 mp_usmall uv = 0; 1429 1430 while (uz > 0) { 1431 uv <<= MP_DIGIT_BIT / 2; 1432 uv = (uv << (MP_DIGIT_BIT / 2)) | *dz--; 1433 --uz; 1434 } 1435 1436 if (out) *out = uv; 1437 1438 return MP_OK; 1439 } 1440 1441 mp_result mp_int_to_string(mp_int z, mp_size radix, char *str, int limit) { 1442 assert(z != NULL && str != NULL && limit >= 2); 1443 assert(radix >= MP_MIN_RADIX && radix <= MP_MAX_RADIX); 1444 1445 int cmp = 0; 1446 if (CMPZ(z) == 0) { 1447 *str++ = s_val2ch(0, 1); 1448 } else { 1449 mp_result res; 1450 mpz_t tmp; 1451 char *h, *t; 1452 1453 if ((res = mp_int_init_copy(&tmp, z)) != MP_OK) return res; 1454 1455 if (MP_SIGN(z) == MP_NEG) { 1456 *str++ = '-'; 1457 --limit; 1458 } 1459 h = str; 1460 1461 /* Generate digits in reverse order until finished or limit reached */ 1462 for (/* */; limit > 0; --limit) { 1463 mp_digit d; 1464 1465 if ((cmp = CMPZ(&tmp)) == 0) break; 1466 1467 d = s_ddiv(&tmp, (mp_digit)radix); 1468 *str++ = s_val2ch(d, 1); 1469 } 1470 t = str - 1; 1471 1472 /* Put digits back in correct output order */ 1473 while (h < t) { 1474 char tc = *h; 1475 *h++ = *t; 1476 *t-- = tc; 1477 } 1478 1479 mp_int_clear(&tmp); 1480 } 1481 1482 *str = '\0'; 1483 if (cmp == 0) { 1484 return MP_OK; 1485 } else { 1486 return MP_TRUNC; 1487 } 1488 } 1489 1490 mp_result mp_int_string_len(mp_int z, mp_size radix) { 1491 assert(z != NULL); 1492 assert(radix >= MP_MIN_RADIX && radix <= MP_MAX_RADIX); 1493 1494 int len = s_outlen(z, radix) + 1; /* for terminator */ 1495 1496 /* Allow for sign marker on negatives */ 1497 if (MP_SIGN(z) == MP_NEG) len += 1; 1498 1499 return len; 1500 } 1501 1502 /* Read zero-terminated string into z */ 1503 mp_result mp_int_read_string(mp_int z, mp_size radix, const char *str) { 1504 return mp_int_read_cstring(z, radix, str, NULL); 1505 } 1506 1507 mp_result mp_int_read_cstring(mp_int z, mp_size radix, const char *str, 1508 char **end) { 1509 assert(z != NULL && str != NULL); 1510 assert(radix >= MP_MIN_RADIX && radix <= MP_MAX_RADIX); 1511 1512 /* Skip leading whitespace */ 1513 while (isspace((unsigned char)*str)) ++str; 1514 1515 /* Handle leading sign tag (+/-, positive default) */ 1516 switch (*str) { 1517 case '-': 1518 z->sign = MP_NEG; 1519 ++str; 1520 break; 1521 case '+': 1522 ++str; /* fallthrough */ 1523 default: 1524 z->sign = MP_ZPOS; 1525 break; 1526 } 1527 1528 /* Skip leading zeroes */ 1529 int ch; 1530 while ((ch = s_ch2val(*str, radix)) == 0) ++str; 1531 1532 /* Make sure there is enough space for the value */ 1533 if (!s_pad(z, s_inlen(strlen(str), radix))) return MP_MEMORY; 1534 1535 z->used = 1; 1536 z->digits[0] = 0; 1537 1538 while (*str != '\0' && ((ch = s_ch2val(*str, radix)) >= 0)) { 1539 s_dmul(z, (mp_digit)radix); 1540 s_dadd(z, (mp_digit)ch); 1541 ++str; 1542 } 1543 1544 CLAMP(z); 1545 1546 /* Override sign for zero, even if negative specified. */ 1547 if (CMPZ(z) == 0) z->sign = MP_ZPOS; 1548 1549 if (end != NULL) *end = (char *)str; 1550 1551 /* Return a truncation error if the string has unprocessed characters 1552 remaining, so the caller can tell if the whole string was done */ 1553 if (*str != '\0') { 1554 return MP_TRUNC; 1555 } else { 1556 return MP_OK; 1557 } 1558 } 1559 1560 mp_result mp_int_count_bits(mp_int z) { 1561 assert(z != NULL); 1562 1563 mp_size uz = MP_USED(z); 1564 if (uz == 1 && z->digits[0] == 0) return 1; 1565 1566 --uz; 1567 mp_size nbits = uz * MP_DIGIT_BIT; 1568 mp_digit d = z->digits[uz]; 1569 1570 while (d != 0) { 1571 d >>= 1; 1572 ++nbits; 1573 } 1574 1575 return nbits; 1576 } 1577 1578 mp_result mp_int_to_binary(mp_int z, unsigned char *buf, int limit) { 1579 static const int PAD_FOR_2C = 1; 1580 1581 assert(z != NULL && buf != NULL); 1582 1583 int limpos = limit; 1584 mp_result res = s_tobin(z, buf, &limpos, PAD_FOR_2C); 1585 1586 if (MP_SIGN(z) == MP_NEG) s_2comp(buf, limpos); 1587 1588 return res; 1589 } 1590 1591 mp_result mp_int_read_binary(mp_int z, unsigned char *buf, int len) { 1592 assert(z != NULL && buf != NULL && len > 0); 1593 1594 /* Figure out how many digits are needed to represent this value */ 1595 mp_size need = ((len * CHAR_BIT) + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT; 1596 if (!s_pad(z, need)) return MP_MEMORY; 1597 1598 mp_int_zero(z); 1599 1600 /* If the high-order bit is set, take the 2's complement before reading the 1601 value (it will be restored afterward) */ 1602 if (buf[0] >> (CHAR_BIT - 1)) { 1603 z->sign = MP_NEG; 1604 s_2comp(buf, len); 1605 } 1606 1607 mp_digit *dz = MP_DIGITS(z); 1608 unsigned char *tmp = buf; 1609 for (int i = len; i > 0; --i, ++tmp) { 1610 s_qmul(z, (mp_size)CHAR_BIT); 1611 *dz |= *tmp; 1612 } 1613 1614 /* Restore 2's complement if we took it before */ 1615 if (MP_SIGN(z) == MP_NEG) s_2comp(buf, len); 1616 1617 return MP_OK; 1618 } 1619 1620 mp_result mp_int_binary_len(mp_int z) { 1621 mp_result res = mp_int_count_bits(z); 1622 if (res <= 0) return res; 1623 1624 int bytes = mp_int_unsigned_len(z); 1625 1626 /* If the highest-order bit falls exactly on a byte boundary, we need to pad 1627 with an extra byte so that the sign will be read correctly when reading it 1628 back in. */ 1629 if (bytes * CHAR_BIT == res) ++bytes; 1630 1631 return bytes; 1632 } 1633 1634 mp_result mp_int_to_unsigned(mp_int z, unsigned char *buf, int limit) { 1635 static const int NO_PADDING = 0; 1636 1637 assert(z != NULL && buf != NULL); 1638 1639 return s_tobin(z, buf, &limit, NO_PADDING); 1640 } 1641 1642 mp_result mp_int_read_unsigned(mp_int z, unsigned char *buf, int len) { 1643 assert(z != NULL && buf != NULL && len > 0); 1644 1645 /* Figure out how many digits are needed to represent this value */ 1646 mp_size need = ((len * CHAR_BIT) + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT; 1647 if (!s_pad(z, need)) return MP_MEMORY; 1648 1649 mp_int_zero(z); 1650 1651 unsigned char *tmp = buf; 1652 for (int i = len; i > 0; --i, ++tmp) { 1653 (void)s_qmul(z, CHAR_BIT); 1654 *MP_DIGITS(z) |= *tmp; 1655 } 1656 1657 return MP_OK; 1658 } 1659 1660 mp_result mp_int_unsigned_len(mp_int z) { 1661 mp_result res = mp_int_count_bits(z); 1662 if (res <= 0) return res; 1663 1664 int bytes = (res + (CHAR_BIT - 1)) / CHAR_BIT; 1665 return bytes; 1666 } 1667 1668 const char *mp_error_string(mp_result res) { 1669 if (res > 0) return s_unknown_err; 1670 1671 res = -res; 1672 int ix; 1673 for (ix = 0; ix < res && s_error_msg[ix] != NULL; ++ix) 1674 ; 1675 1676 if (s_error_msg[ix] != NULL) { 1677 return s_error_msg[ix]; 1678 } else { 1679 return s_unknown_err; 1680 } 1681 } 1682 1683 /*------------------------------------------------------------------------*/ 1684 /* Private functions for internal use. These make assumptions. */ 1685 1686 #if DEBUG 1687 static const mp_digit fill = (mp_digit)0xdeadbeefabad1dea; 1688 #endif 1689 1690 static mp_digit *s_alloc(mp_size num) { 1691 mp_digit *out = malloc(num * sizeof(mp_digit)); 1692 assert(out != NULL); 1693 1694 #if DEBUG 1695 for (mp_size ix = 0; ix < num; ++ix) out[ix] = fill; 1696 #endif 1697 return out; 1698 } 1699 1700 static mp_digit *s_realloc(mp_digit *old, mp_size osize, mp_size nsize) { 1701 #if DEBUG 1702 mp_digit *new = s_alloc(nsize); 1703 assert(new != NULL); 1704 1705 for (mp_size ix = 0; ix < nsize; ++ix) new[ix] = fill; 1706 memcpy(new, old, osize * sizeof(mp_digit)); 1707 #else 1708 mp_digit *new = realloc(old, nsize * sizeof(mp_digit)); 1709 assert(new != NULL); 1710 #endif 1711 1712 return new; 1713 } 1714 1715 static void s_free(void *ptr) { free(ptr); } 1716 1717 static bool s_pad(mp_int z, mp_size min) { 1718 if (MP_ALLOC(z) < min) { 1719 mp_size nsize = s_round_prec(min); 1720 mp_digit *tmp; 1721 1722 if (z->digits == &(z->single)) { 1723 if ((tmp = s_alloc(nsize)) == NULL) return false; 1724 tmp[0] = z->single; 1725 } else if ((tmp = s_realloc(MP_DIGITS(z), MP_ALLOC(z), nsize)) == NULL) { 1726 return false; 1727 } 1728 1729 z->digits = tmp; 1730 z->alloc = nsize; 1731 } 1732 1733 return true; 1734 } 1735 1736 /* Note: This will not work correctly when value == MP_SMALL_MIN */ 1737 static void s_fake(mp_int z, mp_small value, mp_digit vbuf[]) { 1738 mp_usmall uv = (mp_usmall)(value < 0) ? -value : value; 1739 s_ufake(z, uv, vbuf); 1740 if (value < 0) z->sign = MP_NEG; 1741 } 1742 1743 static void s_ufake(mp_int z, mp_usmall value, mp_digit vbuf[]) { 1744 mp_size ndig = (mp_size)s_uvpack(value, vbuf); 1745 1746 z->used = ndig; 1747 z->alloc = MP_VALUE_DIGITS(value); 1748 z->sign = MP_ZPOS; 1749 z->digits = vbuf; 1750 } 1751 1752 static int s_cdig(mp_digit *da, mp_digit *db, mp_size len) { 1753 mp_digit *dat = da + len - 1, *dbt = db + len - 1; 1754 1755 for (/* */; len != 0; --len, --dat, --dbt) { 1756 if (*dat > *dbt) { 1757 return 1; 1758 } else if (*dat < *dbt) { 1759 return -1; 1760 } 1761 } 1762 1763 return 0; 1764 } 1765 1766 static int s_uvpack(mp_usmall uv, mp_digit t[]) { 1767 int ndig = 0; 1768 1769 if (uv == 0) 1770 t[ndig++] = 0; 1771 else { 1772 while (uv != 0) { 1773 t[ndig++] = (mp_digit)uv; 1774 uv >>= MP_DIGIT_BIT / 2; 1775 uv >>= MP_DIGIT_BIT / 2; 1776 } 1777 } 1778 1779 return ndig; 1780 } 1781 1782 static int s_ucmp(mp_int a, mp_int b) { 1783 mp_size ua = MP_USED(a), ub = MP_USED(b); 1784 1785 if (ua > ub) { 1786 return 1; 1787 } else if (ub > ua) { 1788 return -1; 1789 } else { 1790 return s_cdig(MP_DIGITS(a), MP_DIGITS(b), ua); 1791 } 1792 } 1793 1794 static int s_vcmp(mp_int a, mp_small v) { 1795 mp_usmall uv = (v < 0) ? -(mp_usmall)v : (mp_usmall)v; 1796 return s_uvcmp(a, uv); 1797 } 1798 1799 static int s_uvcmp(mp_int a, mp_usmall uv) { 1800 mpz_t vtmp; 1801 mp_digit vdig[MP_VALUE_DIGITS(uv)]; 1802 1803 s_ufake(&vtmp, uv, vdig); 1804 return s_ucmp(a, &vtmp); 1805 } 1806 1807 static mp_digit s_uadd(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a, 1808 mp_size size_b) { 1809 mp_size pos; 1810 mp_word w = 0; 1811 1812 /* Insure that da is the longer of the two to simplify later code */ 1813 if (size_b > size_a) { 1814 SWAP(mp_digit *, da, db); 1815 SWAP(mp_size, size_a, size_b); 1816 } 1817 1818 /* Add corresponding digits until the shorter number runs out */ 1819 for (pos = 0; pos < size_b; ++pos, ++da, ++db, ++dc) { 1820 w = w + (mp_word)*da + (mp_word)*db; 1821 *dc = LOWER_HALF(w); 1822 w = UPPER_HALF(w); 1823 } 1824 1825 /* Propagate carries as far as necessary */ 1826 for (/* */; pos < size_a; ++pos, ++da, ++dc) { 1827 w = w + *da; 1828 1829 *dc = LOWER_HALF(w); 1830 w = UPPER_HALF(w); 1831 } 1832 1833 /* Return carry out */ 1834 return (mp_digit)w; 1835 } 1836 1837 static void s_usub(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a, 1838 mp_size size_b) { 1839 mp_size pos; 1840 mp_word w = 0; 1841 1842 /* We assume that |a| >= |b| so this should definitely hold */ 1843 assert(size_a >= size_b); 1844 1845 /* Subtract corresponding digits and propagate borrow */ 1846 for (pos = 0; pos < size_b; ++pos, ++da, ++db, ++dc) { 1847 w = ((mp_word)MP_DIGIT_MAX + 1 + /* MP_RADIX */ 1848 (mp_word)*da) - 1849 w - (mp_word)*db; 1850 1851 *dc = LOWER_HALF(w); 1852 w = (UPPER_HALF(w) == 0); 1853 } 1854 1855 /* Finish the subtraction for remaining upper digits of da */ 1856 for (/* */; pos < size_a; ++pos, ++da, ++dc) { 1857 w = ((mp_word)MP_DIGIT_MAX + 1 + /* MP_RADIX */ 1858 (mp_word)*da) - 1859 w; 1860 1861 *dc = LOWER_HALF(w); 1862 w = (UPPER_HALF(w) == 0); 1863 } 1864 1865 /* If there is a borrow out at the end, it violates the precondition */ 1866 assert(w == 0); 1867 } 1868 1869 static int s_kmul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a, 1870 mp_size size_b) { 1871 mp_size bot_size; 1872 1873 /* Make sure b is the smaller of the two input values */ 1874 if (size_b > size_a) { 1875 SWAP(mp_digit *, da, db); 1876 SWAP(mp_size, size_a, size_b); 1877 } 1878 1879 /* Insure that the bottom is the larger half in an odd-length split; the code 1880 below relies on this being true. 1881 */ 1882 bot_size = (size_a + 1) / 2; 1883 1884 /* If the values are big enough to bother with recursion, use the Karatsuba 1885 algorithm to compute the product; otherwise use the normal multiplication 1886 algorithm 1887 */ 1888 if (multiply_threshold && size_a >= multiply_threshold && size_b > bot_size) { 1889 mp_digit *t1, *t2, *t3, carry; 1890 1891 mp_digit *a_top = da + bot_size; 1892 mp_digit *b_top = db + bot_size; 1893 1894 mp_size at_size = size_a - bot_size; 1895 mp_size bt_size = size_b - bot_size; 1896 mp_size buf_size = 2 * bot_size; 1897 1898 /* Do a single allocation for all three temporary buffers needed; each 1899 buffer must be big enough to hold the product of two bottom halves, and 1900 one buffer needs space for the completed product; twice the space is 1901 plenty. 1902 */ 1903 if ((t1 = s_alloc(4 * buf_size)) == NULL) return 0; 1904 t2 = t1 + buf_size; 1905 t3 = t2 + buf_size; 1906 ZERO(t1, 4 * buf_size); 1907 1908 /* t1 and t2 are initially used as temporaries to compute the inner product 1909 (a1 + a0)(b1 + b0) = a1b1 + a1b0 + a0b1 + a0b0 1910 */ 1911 carry = s_uadd(da, a_top, t1, bot_size, at_size); /* t1 = a1 + a0 */ 1912 t1[bot_size] = carry; 1913 1914 carry = s_uadd(db, b_top, t2, bot_size, bt_size); /* t2 = b1 + b0 */ 1915 t2[bot_size] = carry; 1916 1917 (void)s_kmul(t1, t2, t3, bot_size + 1, bot_size + 1); /* t3 = t1 * t2 */ 1918 1919 /* Now we'll get t1 = a0b0 and t2 = a1b1, and subtract them out so that 1920 we're left with only the pieces we want: t3 = a1b0 + a0b1 1921 */ 1922 ZERO(t1, buf_size); 1923 ZERO(t2, buf_size); 1924 (void)s_kmul(da, db, t1, bot_size, bot_size); /* t1 = a0 * b0 */ 1925 (void)s_kmul(a_top, b_top, t2, at_size, bt_size); /* t2 = a1 * b1 */ 1926 1927 /* Subtract out t1 and t2 to get the inner product */ 1928 s_usub(t3, t1, t3, buf_size + 2, buf_size); 1929 s_usub(t3, t2, t3, buf_size + 2, buf_size); 1930 1931 /* Assemble the output value */ 1932 COPY(t1, dc, buf_size); 1933 carry = s_uadd(t3, dc + bot_size, dc + bot_size, buf_size + 1, buf_size); 1934 assert(carry == 0); 1935 1936 carry = 1937 s_uadd(t2, dc + 2 * bot_size, dc + 2 * bot_size, buf_size, buf_size); 1938 assert(carry == 0); 1939 1940 s_free(t1); /* note t2 and t3 are just internal pointers to t1 */ 1941 } else { 1942 s_umul(da, db, dc, size_a, size_b); 1943 } 1944 1945 return 1; 1946 } 1947 1948 static void s_umul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a, 1949 mp_size size_b) { 1950 mp_size a, b; 1951 mp_word w; 1952 1953 for (a = 0; a < size_a; ++a, ++dc, ++da) { 1954 mp_digit *dct = dc; 1955 mp_digit *dbt = db; 1956 1957 if (*da == 0) continue; 1958 1959 w = 0; 1960 for (b = 0; b < size_b; ++b, ++dbt, ++dct) { 1961 w = (mp_word)*da * (mp_word)*dbt + w + (mp_word)*dct; 1962 1963 *dct = LOWER_HALF(w); 1964 w = UPPER_HALF(w); 1965 } 1966 1967 *dct = (mp_digit)w; 1968 } 1969 } 1970 1971 static int s_ksqr(mp_digit *da, mp_digit *dc, mp_size size_a) { 1972 if (multiply_threshold && size_a > multiply_threshold) { 1973 mp_size bot_size = (size_a + 1) / 2; 1974 mp_digit *a_top = da + bot_size; 1975 mp_digit *t1, *t2, *t3, carry; 1976 mp_size at_size = size_a - bot_size; 1977 mp_size buf_size = 2 * bot_size; 1978 1979 if ((t1 = s_alloc(4 * buf_size)) == NULL) return 0; 1980 t2 = t1 + buf_size; 1981 t3 = t2 + buf_size; 1982 ZERO(t1, 4 * buf_size); 1983 1984 (void)s_ksqr(da, t1, bot_size); /* t1 = a0 ^ 2 */ 1985 (void)s_ksqr(a_top, t2, at_size); /* t2 = a1 ^ 2 */ 1986 1987 (void)s_kmul(da, a_top, t3, bot_size, at_size); /* t3 = a0 * a1 */ 1988 1989 /* Quick multiply t3 by 2, shifting left (can't overflow) */ 1990 { 1991 int i, top = bot_size + at_size; 1992 mp_word w, save = 0; 1993 1994 for (i = 0; i < top; ++i) { 1995 w = t3[i]; 1996 w = (w << 1) | save; 1997 t3[i] = LOWER_HALF(w); 1998 save = UPPER_HALF(w); 1999 } 2000 t3[i] = LOWER_HALF(save); 2001 } 2002 2003 /* Assemble the output value */ 2004 COPY(t1, dc, 2 * bot_size); 2005 carry = s_uadd(t3, dc + bot_size, dc + bot_size, buf_size + 1, buf_size); 2006 assert(carry == 0); 2007 2008 carry = 2009 s_uadd(t2, dc + 2 * bot_size, dc + 2 * bot_size, buf_size, buf_size); 2010 assert(carry == 0); 2011 2012 s_free(t1); /* note that t2 and t2 are internal pointers only */ 2013 2014 } else { 2015 s_usqr(da, dc, size_a); 2016 } 2017 2018 return 1; 2019 } 2020 2021 static void s_usqr(mp_digit *da, mp_digit *dc, mp_size size_a) { 2022 mp_size i, j; 2023 mp_word w; 2024 2025 for (i = 0; i < size_a; ++i, dc += 2, ++da) { 2026 mp_digit *dct = dc, *dat = da; 2027 2028 if (*da == 0) continue; 2029 2030 /* Take care of the first digit, no rollover */ 2031 w = (mp_word)*dat * (mp_word)*dat + (mp_word)*dct; 2032 *dct = LOWER_HALF(w); 2033 w = UPPER_HALF(w); 2034 ++dat; 2035 ++dct; 2036 2037 for (j = i + 1; j < size_a; ++j, ++dat, ++dct) { 2038 mp_word t = (mp_word)*da * (mp_word)*dat; 2039 mp_word u = w + (mp_word)*dct, ov = 0; 2040 2041 /* Check if doubling t will overflow a word */ 2042 if (HIGH_BIT_SET(t)) ov = 1; 2043 2044 w = t + t; 2045 2046 /* Check if adding u to w will overflow a word */ 2047 if (ADD_WILL_OVERFLOW(w, u)) ov = 1; 2048 2049 w += u; 2050 2051 *dct = LOWER_HALF(w); 2052 w = UPPER_HALF(w); 2053 if (ov) { 2054 w += MP_DIGIT_MAX; /* MP_RADIX */ 2055 ++w; 2056 } 2057 } 2058 2059 w = w + *dct; 2060 *dct = (mp_digit)w; 2061 while ((w = UPPER_HALF(w)) != 0) { 2062 ++dct; 2063 w = w + *dct; 2064 *dct = LOWER_HALF(w); 2065 } 2066 2067 assert(w == 0); 2068 } 2069 } 2070 2071 static void s_dadd(mp_int a, mp_digit b) { 2072 mp_word w = 0; 2073 mp_digit *da = MP_DIGITS(a); 2074 mp_size ua = MP_USED(a); 2075 2076 w = (mp_word)*da + b; 2077 *da++ = LOWER_HALF(w); 2078 w = UPPER_HALF(w); 2079 2080 for (ua -= 1; ua > 0; --ua, ++da) { 2081 w = (mp_word)*da + w; 2082 2083 *da = LOWER_HALF(w); 2084 w = UPPER_HALF(w); 2085 } 2086 2087 if (w) { 2088 *da = (mp_digit)w; 2089 a->used += 1; 2090 } 2091 } 2092 2093 static void s_dmul(mp_int a, mp_digit b) { 2094 mp_word w = 0; 2095 mp_digit *da = MP_DIGITS(a); 2096 mp_size ua = MP_USED(a); 2097 2098 while (ua > 0) { 2099 w = (mp_word)*da * b + w; 2100 *da++ = LOWER_HALF(w); 2101 w = UPPER_HALF(w); 2102 --ua; 2103 } 2104 2105 if (w) { 2106 *da = (mp_digit)w; 2107 a->used += 1; 2108 } 2109 } 2110 2111 static void s_dbmul(mp_digit *da, mp_digit b, mp_digit *dc, mp_size size_a) { 2112 mp_word w = 0; 2113 2114 while (size_a > 0) { 2115 w = (mp_word)*da++ * (mp_word)b + w; 2116 2117 *dc++ = LOWER_HALF(w); 2118 w = UPPER_HALF(w); 2119 --size_a; 2120 } 2121 2122 if (w) *dc = LOWER_HALF(w); 2123 } 2124 2125 static mp_digit s_ddiv(mp_int a, mp_digit b) { 2126 mp_word w = 0, qdigit; 2127 mp_size ua = MP_USED(a); 2128 mp_digit *da = MP_DIGITS(a) + ua - 1; 2129 2130 for (/* */; ua > 0; --ua, --da) { 2131 w = (w << MP_DIGIT_BIT) | *da; 2132 2133 if (w >= b) { 2134 qdigit = w / b; 2135 w = w % b; 2136 } else { 2137 qdigit = 0; 2138 } 2139 2140 *da = (mp_digit)qdigit; 2141 } 2142 2143 CLAMP(a); 2144 return (mp_digit)w; 2145 } 2146 2147 static void s_qdiv(mp_int z, mp_size p2) { 2148 mp_size ndig = p2 / MP_DIGIT_BIT, nbits = p2 % MP_DIGIT_BIT; 2149 mp_size uz = MP_USED(z); 2150 2151 if (ndig) { 2152 mp_size mark; 2153 mp_digit *to, *from; 2154 2155 if (ndig >= uz) { 2156 mp_int_zero(z); 2157 return; 2158 } 2159 2160 to = MP_DIGITS(z); 2161 from = to + ndig; 2162 2163 for (mark = ndig; mark < uz; ++mark) { 2164 *to++ = *from++; 2165 } 2166 2167 z->used = uz - ndig; 2168 } 2169 2170 if (nbits) { 2171 mp_digit d = 0, *dz, save; 2172 mp_size up = MP_DIGIT_BIT - nbits; 2173 2174 uz = MP_USED(z); 2175 dz = MP_DIGITS(z) + uz - 1; 2176 2177 for (/* */; uz > 0; --uz, --dz) { 2178 save = *dz; 2179 2180 *dz = (*dz >> nbits) | (d << up); 2181 d = save; 2182 } 2183 2184 CLAMP(z); 2185 } 2186 2187 if (MP_USED(z) == 1 && z->digits[0] == 0) z->sign = MP_ZPOS; 2188 } 2189 2190 static void s_qmod(mp_int z, mp_size p2) { 2191 mp_size start = p2 / MP_DIGIT_BIT + 1, rest = p2 % MP_DIGIT_BIT; 2192 mp_size uz = MP_USED(z); 2193 mp_digit mask = (1u << rest) - 1; 2194 2195 if (start <= uz) { 2196 z->used = start; 2197 z->digits[start - 1] &= mask; 2198 CLAMP(z); 2199 } 2200 } 2201 2202 static int s_qmul(mp_int z, mp_size p2) { 2203 mp_size uz, need, rest, extra, i; 2204 mp_digit *from, *to, d; 2205 2206 if (p2 == 0) return 1; 2207 2208 uz = MP_USED(z); 2209 need = p2 / MP_DIGIT_BIT; 2210 rest = p2 % MP_DIGIT_BIT; 2211 2212 /* Figure out if we need an extra digit at the top end; this occurs if the 2213 topmost `rest' bits of the high-order digit of z are not zero, meaning 2214 they will be shifted off the end if not preserved */ 2215 extra = 0; 2216 if (rest != 0) { 2217 mp_digit *dz = MP_DIGITS(z) + uz - 1; 2218 2219 if ((*dz >> (MP_DIGIT_BIT - rest)) != 0) extra = 1; 2220 } 2221 2222 if (!s_pad(z, uz + need + extra)) return 0; 2223 2224 /* If we need to shift by whole digits, do that in one pass, then 2225 to back and shift by partial digits. 2226 */ 2227 if (need > 0) { 2228 from = MP_DIGITS(z) + uz - 1; 2229 to = from + need; 2230 2231 for (i = 0; i < uz; ++i) *to-- = *from--; 2232 2233 ZERO(MP_DIGITS(z), need); 2234 uz += need; 2235 } 2236 2237 if (rest) { 2238 d = 0; 2239 for (i = need, from = MP_DIGITS(z) + need; i < uz; ++i, ++from) { 2240 mp_digit save = *from; 2241 2242 *from = (*from << rest) | (d >> (MP_DIGIT_BIT - rest)); 2243 d = save; 2244 } 2245 2246 d >>= (MP_DIGIT_BIT - rest); 2247 if (d != 0) { 2248 *from = d; 2249 uz += extra; 2250 } 2251 } 2252 2253 z->used = uz; 2254 CLAMP(z); 2255 2256 return 1; 2257 } 2258 2259 /* Compute z = 2^p2 - |z|; requires that 2^p2 >= |z| 2260 The sign of the result is always zero/positive. 2261 */ 2262 static int s_qsub(mp_int z, mp_size p2) { 2263 mp_digit hi = (1u << (p2 % MP_DIGIT_BIT)), *zp; 2264 mp_size tdig = (p2 / MP_DIGIT_BIT), pos; 2265 mp_word w = 0; 2266 2267 if (!s_pad(z, tdig + 1)) return 0; 2268 2269 for (pos = 0, zp = MP_DIGITS(z); pos < tdig; ++pos, ++zp) { 2270 w = ((mp_word)MP_DIGIT_MAX + 1) - w - (mp_word)*zp; 2271 2272 *zp = LOWER_HALF(w); 2273 w = UPPER_HALF(w) ? 0 : 1; 2274 } 2275 2276 w = ((mp_word)MP_DIGIT_MAX + 1 + hi) - w - (mp_word)*zp; 2277 *zp = LOWER_HALF(w); 2278 2279 assert(UPPER_HALF(w) != 0); /* no borrow out should be possible */ 2280 2281 z->sign = MP_ZPOS; 2282 CLAMP(z); 2283 2284 return 1; 2285 } 2286 2287 static int s_dp2k(mp_int z) { 2288 int k = 0; 2289 mp_digit *dp = MP_DIGITS(z), d; 2290 2291 if (MP_USED(z) == 1 && *dp == 0) return 1; 2292 2293 while (*dp == 0) { 2294 k += MP_DIGIT_BIT; 2295 ++dp; 2296 } 2297 2298 d = *dp; 2299 while ((d & 1) == 0) { 2300 d >>= 1; 2301 ++k; 2302 } 2303 2304 return k; 2305 } 2306 2307 static int s_isp2(mp_int z) { 2308 mp_size uz = MP_USED(z), k = 0; 2309 mp_digit *dz = MP_DIGITS(z), d; 2310 2311 while (uz > 1) { 2312 if (*dz++ != 0) return -1; 2313 k += MP_DIGIT_BIT; 2314 --uz; 2315 } 2316 2317 d = *dz; 2318 while (d > 1) { 2319 if (d & 1) return -1; 2320 ++k; 2321 d >>= 1; 2322 } 2323 2324 return (int)k; 2325 } 2326 2327 static int s_2expt(mp_int z, mp_small k) { 2328 mp_size ndig, rest; 2329 mp_digit *dz; 2330 2331 ndig = (k + MP_DIGIT_BIT) / MP_DIGIT_BIT; 2332 rest = k % MP_DIGIT_BIT; 2333 2334 if (!s_pad(z, ndig)) return 0; 2335 2336 dz = MP_DIGITS(z); 2337 ZERO(dz, ndig); 2338 *(dz + ndig - 1) = (1u << rest); 2339 z->used = ndig; 2340 2341 return 1; 2342 } 2343 2344 static int s_norm(mp_int a, mp_int b) { 2345 mp_digit d = b->digits[MP_USED(b) - 1]; 2346 int k = 0; 2347 2348 while (d < (1u << (mp_digit)(MP_DIGIT_BIT - 1))) { /* d < (MP_RADIX / 2) */ 2349 d <<= 1; 2350 ++k; 2351 } 2352 2353 /* These multiplications can't fail */ 2354 if (k != 0) { 2355 (void)s_qmul(a, (mp_size)k); 2356 (void)s_qmul(b, (mp_size)k); 2357 } 2358 2359 return k; 2360 } 2361 2362 static mp_result s_brmu(mp_int z, mp_int m) { 2363 mp_size um = MP_USED(m) * 2; 2364 2365 if (!s_pad(z, um)) return MP_MEMORY; 2366 2367 s_2expt(z, MP_DIGIT_BIT * um); 2368 return mp_int_div(z, m, z, NULL); 2369 } 2370 2371 static int s_reduce(mp_int x, mp_int m, mp_int mu, mp_int q1, mp_int q2) { 2372 mp_size um = MP_USED(m), umb_p1, umb_m1; 2373 2374 umb_p1 = (um + 1) * MP_DIGIT_BIT; 2375 umb_m1 = (um - 1) * MP_DIGIT_BIT; 2376 2377 if (mp_int_copy(x, q1) != MP_OK) return 0; 2378 2379 /* Compute q2 = floor((floor(x / b^(k-1)) * mu) / b^(k+1)) */ 2380 s_qdiv(q1, umb_m1); 2381 UMUL(q1, mu, q2); 2382 s_qdiv(q2, umb_p1); 2383 2384 /* Set x = x mod b^(k+1) */ 2385 s_qmod(x, umb_p1); 2386 2387 /* Now, q is a guess for the quotient a / m. 2388 Compute x - q * m mod b^(k+1), replacing x. This may be off 2389 by a factor of 2m, but no more than that. 2390 */ 2391 UMUL(q2, m, q1); 2392 s_qmod(q1, umb_p1); 2393 (void)mp_int_sub(x, q1, x); /* can't fail */ 2394 2395 /* The result may be < 0; if it is, add b^(k+1) to pin it in the proper 2396 range. */ 2397 if ((CMPZ(x) < 0) && !s_qsub(x, umb_p1)) return 0; 2398 2399 /* If x > m, we need to back it off until it is in range. This will be 2400 required at most twice. */ 2401 if (mp_int_compare(x, m) >= 0) { 2402 (void)mp_int_sub(x, m, x); 2403 if (mp_int_compare(x, m) >= 0) { 2404 (void)mp_int_sub(x, m, x); 2405 } 2406 } 2407 2408 /* At this point, x has been properly reduced. */ 2409 return 1; 2410 } 2411 2412 /* Perform modular exponentiation using Barrett's method, where mu is the 2413 reduction constant for m. Assumes a < m, b > 0. */ 2414 static mp_result s_embar(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c) { 2415 mp_digit umu = MP_USED(mu); 2416 mp_digit *db = MP_DIGITS(b); 2417 mp_digit *dbt = db + MP_USED(b) - 1; 2418 2419 DECLARE_TEMP(3); 2420 REQUIRE(GROW(TEMP(0), 4 * umu)); 2421 REQUIRE(GROW(TEMP(1), 4 * umu)); 2422 REQUIRE(GROW(TEMP(2), 4 * umu)); 2423 ZERO(TEMP(0)->digits, TEMP(0)->alloc); 2424 ZERO(TEMP(1)->digits, TEMP(1)->alloc); 2425 ZERO(TEMP(2)->digits, TEMP(2)->alloc); 2426 2427 (void)mp_int_set_value(c, 1); 2428 2429 /* Take care of low-order digits */ 2430 while (db < dbt) { 2431 mp_digit d = *db; 2432 2433 for (int i = MP_DIGIT_BIT; i > 0; --i, d >>= 1) { 2434 if (d & 1) { 2435 /* The use of a second temporary avoids allocation */ 2436 UMUL(c, a, TEMP(0)); 2437 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) { 2438 REQUIRE(MP_MEMORY); 2439 } 2440 mp_int_copy(TEMP(0), c); 2441 } 2442 2443 USQR(a, TEMP(0)); 2444 assert(MP_SIGN(TEMP(0)) == MP_ZPOS); 2445 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) { 2446 REQUIRE(MP_MEMORY); 2447 } 2448 assert(MP_SIGN(TEMP(0)) == MP_ZPOS); 2449 mp_int_copy(TEMP(0), a); 2450 } 2451 2452 ++db; 2453 } 2454 2455 /* Take care of highest-order digit */ 2456 mp_digit d = *dbt; 2457 for (;;) { 2458 if (d & 1) { 2459 UMUL(c, a, TEMP(0)); 2460 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) { 2461 REQUIRE(MP_MEMORY); 2462 } 2463 mp_int_copy(TEMP(0), c); 2464 } 2465 2466 d >>= 1; 2467 if (!d) break; 2468 2469 USQR(a, TEMP(0)); 2470 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) { 2471 REQUIRE(MP_MEMORY); 2472 } 2473 (void)mp_int_copy(TEMP(0), a); 2474 } 2475 2476 CLEANUP_TEMP(); 2477 return MP_OK; 2478 } 2479 2480 /* Division of nonnegative integers 2481 2482 This function implements division algorithm for unsigned multi-precision 2483 integers. The algorithm is based on Algorithm D from Knuth's "The Art of 2484 Computer Programming", 3rd ed. 1998, pg 272-273. 2485 2486 We diverge from Knuth's algorithm in that we do not perform the subtraction 2487 from the remainder until we have determined that we have the correct 2488 quotient digit. This makes our algorithm less efficient that Knuth because 2489 we might have to perform multiple multiplication and comparison steps before 2490 the subtraction. The advantage is that it is easy to implement and ensure 2491 correctness without worrying about underflow from the subtraction. 2492 2493 inputs: u a n+m digit integer in base b (b is 2^MP_DIGIT_BIT) 2494 v a n digit integer in base b (b is 2^MP_DIGIT_BIT) 2495 n >= 1 2496 m >= 0 2497 outputs: u / v stored in u 2498 u % v stored in v 2499 */ 2500 static mp_result s_udiv_knuth(mp_int u, mp_int v) { 2501 /* Force signs to positive */ 2502 u->sign = MP_ZPOS; 2503 v->sign = MP_ZPOS; 2504 2505 /* Use simple division algorithm when v is only one digit long */ 2506 if (MP_USED(v) == 1) { 2507 mp_digit d, rem; 2508 d = v->digits[0]; 2509 rem = s_ddiv(u, d); 2510 mp_int_set_value(v, rem); 2511 return MP_OK; 2512 } 2513 2514 /* Algorithm D 2515 2516 The n and m variables are defined as used by Knuth. 2517 u is an n digit number with digits u_{n-1}..u_0. 2518 v is an n+m digit number with digits from v_{m+n-1}..v_0. 2519 We require that n > 1 and m >= 0 2520 */ 2521 mp_size n = MP_USED(v); 2522 mp_size m = MP_USED(u) - n; 2523 assert(n > 1); 2524 /* assert(m >= 0) follows because m is unsigned. */ 2525 2526 /* D1: Normalize. 2527 The normalization step provides the necessary condition for Theorem B, 2528 which states that the quotient estimate for q_j, call it qhat 2529 2530 qhat = u_{j+n}u_{j+n-1} / v_{n-1} 2531 2532 is bounded by 2533 2534 qhat - 2 <= q_j <= qhat. 2535 2536 That is, qhat is always greater than the actual quotient digit q, 2537 and it is never more than two larger than the actual quotient digit. 2538 */ 2539 int k = s_norm(u, v); 2540 2541 /* Extend size of u by one if needed. 2542 2543 The algorithm begins with a value of u that has one more digit of input. 2544 The normalization step sets u_{m+n}..u_0 = 2^k * u_{m+n-1}..u_0. If the 2545 multiplication did not increase the number of digits of u, we need to add 2546 a leading zero here. 2547 */ 2548 if (k == 0 || MP_USED(u) != m + n + 1) { 2549 if (!s_pad(u, m + n + 1)) return MP_MEMORY; 2550 u->digits[m + n] = 0; 2551 u->used = m + n + 1; 2552 } 2553 2554 /* Add a leading 0 to v. 2555 2556 The multiplication in step D4 multiplies qhat * 0v_{n-1}..v_0. We need to 2557 add the leading zero to v here to ensure that the multiplication will 2558 produce the full n+1 digit result. 2559 */ 2560 if (!s_pad(v, n + 1)) return MP_MEMORY; 2561 v->digits[n] = 0; 2562 2563 /* Initialize temporary variables q and t. 2564 q allocates space for m+1 digits to store the quotient digits 2565 t allocates space for n+1 digits to hold the result of q_j*v 2566 */ 2567 DECLARE_TEMP(2); 2568 REQUIRE(GROW(TEMP(0), m + 1)); 2569 REQUIRE(GROW(TEMP(1), n + 1)); 2570 2571 /* D2: Initialize j */ 2572 int j = m; 2573 mpz_t r; 2574 r.digits = MP_DIGITS(u) + j; /* The contents of r are shared with u */ 2575 r.used = n + 1; 2576 r.sign = MP_ZPOS; 2577 r.alloc = MP_ALLOC(u); 2578 ZERO(TEMP(1)->digits, TEMP(1)->alloc); 2579 2580 /* Calculate the m+1 digits of the quotient result */ 2581 for (; j >= 0; j--) { 2582 /* D3: Calculate q' */ 2583 /* r->digits is aligned to position j of the number u */ 2584 mp_word pfx, qhat; 2585 pfx = r.digits[n]; 2586 pfx <<= MP_DIGIT_BIT / 2; 2587 pfx <<= MP_DIGIT_BIT / 2; 2588 pfx |= r.digits[n - 1]; /* pfx = u_{j+n}{j+n-1} */ 2589 2590 qhat = pfx / v->digits[n - 1]; 2591 /* Check to see if qhat > b, and decrease qhat if so. 2592 Theorem B guarantess that qhat is at most 2 larger than the 2593 actual value, so it is possible that qhat is greater than 2594 the maximum value that will fit in a digit */ 2595 if (qhat > MP_DIGIT_MAX) qhat = MP_DIGIT_MAX; 2596 2597 /* D4,D5,D6: Multiply qhat * v and test for a correct value of q 2598 2599 We proceed a bit different than the way described by Knuth. This way is 2600 simpler but less efficent. Instead of doing the multiply and subtract 2601 then checking for underflow, we first do the multiply of qhat * v and 2602 see if it is larger than the current remainder r. If it is larger, we 2603 decrease qhat by one and try again. We may need to decrease qhat one 2604 more time before we get a value that is smaller than r. 2605 2606 This way is less efficent than Knuth because we do more multiplies, but 2607 we do not need to worry about underflow this way. 2608 */ 2609 /* t = qhat * v */ 2610 s_dbmul(MP_DIGITS(v), (mp_digit)qhat, TEMP(1)->digits, n + 1); 2611 TEMP(1)->used = n + 1; 2612 CLAMP(TEMP(1)); 2613 2614 /* Clamp r for the comparison. Comparisons do not like leading zeros. */ 2615 CLAMP(&r); 2616 if (s_ucmp(TEMP(1), &r) > 0) { /* would the remainder be negative? */ 2617 qhat -= 1; /* try a smaller q */ 2618 s_dbmul(MP_DIGITS(v), (mp_digit)qhat, TEMP(1)->digits, n + 1); 2619 TEMP(1)->used = n + 1; 2620 CLAMP(TEMP(1)); 2621 if (s_ucmp(TEMP(1), &r) > 0) { /* would the remainder be negative? */ 2622 assert(qhat > 0); 2623 qhat -= 1; /* try a smaller q */ 2624 s_dbmul(MP_DIGITS(v), (mp_digit)qhat, TEMP(1)->digits, n + 1); 2625 TEMP(1)->used = n + 1; 2626 CLAMP(TEMP(1)); 2627 } 2628 assert(s_ucmp(TEMP(1), &r) <= 0 && "The mathematics failed us."); 2629 } 2630 /* Unclamp r. The D algorithm expects r = u_{j+n}..u_j to always be n+1 2631 digits long. */ 2632 r.used = n + 1; 2633 2634 /* D4: Multiply and subtract 2635 2636 Note: The multiply was completed above so we only need to subtract here. 2637 */ 2638 s_usub(r.digits, TEMP(1)->digits, r.digits, r.used, TEMP(1)->used); 2639 2640 /* D5: Test remainder 2641 2642 Note: Not needed because we always check that qhat is the correct value 2643 before performing the subtract. Value cast to mp_digit to prevent 2644 warning, qhat has been clamped to MP_DIGIT_MAX 2645 */ 2646 TEMP(0)->digits[j] = (mp_digit)qhat; 2647 2648 /* D6: Add back 2649 Note: Not needed because we always check that qhat is the correct value 2650 before performing the subtract. 2651 */ 2652 2653 /* D7: Loop on j */ 2654 r.digits--; 2655 ZERO(TEMP(1)->digits, TEMP(1)->alloc); 2656 } 2657 2658 /* Get rid of leading zeros in q */ 2659 TEMP(0)->used = m + 1; 2660 CLAMP(TEMP(0)); 2661 2662 /* Denormalize the remainder */ 2663 CLAMP(u); /* use u here because the r.digits pointer is off-by-one */ 2664 if (k != 0) s_qdiv(u, k); 2665 2666 mp_int_copy(u, v); /* ok: 0 <= r < v */ 2667 mp_int_copy(TEMP(0), u); /* ok: q <= u */ 2668 2669 CLEANUP_TEMP(); 2670 return MP_OK; 2671 } 2672 2673 static int s_outlen(mp_int z, mp_size r) { 2674 assert(r >= MP_MIN_RADIX && r <= MP_MAX_RADIX); 2675 2676 mp_result bits = mp_int_count_bits(z); 2677 double raw = (double)bits * s_log2[r]; 2678 2679 return (int)(raw + 0.999999); 2680 } 2681 2682 static mp_size s_inlen(int len, mp_size r) { 2683 double raw = (double)len / s_log2[r]; 2684 mp_size bits = (mp_size)(raw + 0.5); 2685 2686 return (mp_size)((bits + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT) + 1; 2687 } 2688 2689 static int s_ch2val(char c, int r) { 2690 int out; 2691 2692 /* 2693 * In some locales, isalpha() accepts characters outside the range A-Z, 2694 * producing out<0 or out>=36. The "out >= r" check will always catch 2695 * out>=36. Though nothing explicitly catches out<0, our caller reacts the 2696 * same way to every negative return value. 2697 */ 2698 if (isdigit((unsigned char)c)) 2699 out = c - '0'; 2700 else if (r > 10 && isalpha((unsigned char)c)) 2701 out = toupper((unsigned char)c) - 'A' + 10; 2702 else 2703 return -1; 2704 2705 return (out >= r) ? -1 : out; 2706 } 2707 2708 static char s_val2ch(int v, int caps) { 2709 assert(v >= 0); 2710 2711 if (v < 10) { 2712 return v + '0'; 2713 } else { 2714 char out = (v - 10) + 'a'; 2715 2716 if (caps) { 2717 return toupper((unsigned char)out); 2718 } else { 2719 return out; 2720 } 2721 } 2722 } 2723 2724 static void s_2comp(unsigned char *buf, int len) { 2725 unsigned short s = 1; 2726 2727 for (int i = len - 1; i >= 0; --i) { 2728 unsigned char c = ~buf[i]; 2729 2730 s = c + s; 2731 c = s & UCHAR_MAX; 2732 s >>= CHAR_BIT; 2733 2734 buf[i] = c; 2735 } 2736 2737 /* last carry out is ignored */ 2738 } 2739 2740 static mp_result s_tobin(mp_int z, unsigned char *buf, int *limpos, int pad) { 2741 int pos = 0, limit = *limpos; 2742 mp_size uz = MP_USED(z); 2743 mp_digit *dz = MP_DIGITS(z); 2744 2745 while (uz > 0 && pos < limit) { 2746 mp_digit d = *dz++; 2747 int i; 2748 2749 for (i = sizeof(mp_digit); i > 0 && pos < limit; --i) { 2750 buf[pos++] = (unsigned char)d; 2751 d >>= CHAR_BIT; 2752 2753 /* Don't write leading zeroes */ 2754 if (d == 0 && uz == 1) i = 0; /* exit loop without signaling truncation */ 2755 } 2756 2757 /* Detect truncation (loop exited with pos >= limit) */ 2758 if (i > 0) break; 2759 2760 --uz; 2761 } 2762 2763 if (pad != 0 && (buf[pos - 1] >> (CHAR_BIT - 1))) { 2764 if (pos < limit) { 2765 buf[pos++] = 0; 2766 } else { 2767 uz = 1; 2768 } 2769 } 2770 2771 /* Digits are in reverse order, fix that */ 2772 REV(buf, pos); 2773 2774 /* Return the number of bytes actually written */ 2775 *limpos = pos; 2776 2777 return (uz == 0) ? MP_OK : MP_TRUNC; 2778 } 2779 2780 /* Here there be dragons */ 2781