Home | History | Annotate | Line # | Download | only in generic
toom33_mul.c revision 1.1
      1 /* mpn_toom33_mul -- Multiply {ap,an} and {p,bn} where an and bn are close in
      2    size.  Or more accurately, bn <= an < (3/2)bn.
      3 
      4    Contributed to the GNU project by Torbjorn Granlund.
      5    Additional improvements by Marco Bodrato.
      6 
      7    THE FUNCTION IN THIS FILE IS INTERNAL WITH A MUTABLE INTERFACE.  IT IS ONLY
      8    SAFE TO REACH IT THROUGH DOCUMENTED INTERFACES.  IN FACT, IT IS ALMOST
      9    GUARANTEED THAT IT WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.
     10 
     11 Copyright 2006, 2007, 2008, 2010 Free Software Foundation, Inc.
     12 
     13 This file is part of the GNU MP Library.
     14 
     15 The GNU MP Library is free software; you can redistribute it and/or modify
     16 it under the terms of the GNU Lesser General Public License as published by
     17 the Free Software Foundation; either version 3 of the License, or (at your
     18 option) any later version.
     19 
     20 The GNU MP Library is distributed in the hope that it will be useful, but
     21 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
     22 or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
     23 License for more details.
     24 
     25 You should have received a copy of the GNU Lesser General Public License
     26 along with the GNU MP Library.  If not, see http://www.gnu.org/licenses/.  */
     27 
     28 
     29 #include "gmp.h"
     30 #include "gmp-impl.h"
     31 
     32 /* Evaluate in: -1, 0, +1, +2, +inf
     33 
     34   <-s--><--n--><--n--><--n-->
     35    ____ ______ ______ ______
     36   |_a3_|___a2_|___a1_|___a0_|
     37    |b3_|___b2_|___b1_|___b0_|
     38    <-t-><--n--><--n--><--n-->
     39 
     40   v0  =  a0         * b0          #   A(0)*B(0)
     41   v1  = (a0+ a1+ a2)*(b0+ b1+ b2) #   A(1)*B(1)      ah  <= 2  bh <= 2
     42   vm1 = (a0- a1+ a2)*(b0- b1+ b2) #  A(-1)*B(-1)    |ah| <= 1  bh <= 1
     43   v2  = (a0+2a1+4a2)*(b0+2b1+4b2) #   A(2)*B(2)      ah  <= 6  bh <= 6
     44   vinf=          a2 *         b2  # A(inf)*B(inf)
     45 */
     46 
     47 #if TUNE_PROGRAM_BUILD
     48 #define MAYBE_mul_basecase 1
     49 #define MAYBE_mul_toom33   1
     50 #else
     51 #define MAYBE_mul_basecase						\
     52   (MUL_TOOM33_THRESHOLD < 3 * MUL_TOOM22_THRESHOLD)
     53 #define MAYBE_mul_toom33						\
     54   (MUL_TOOM44_THRESHOLD >= 3 * MUL_TOOM33_THRESHOLD)
     55 #endif
     56 
     57 /* FIXME: TOOM33_MUL_N_REC is not quite right for a balanced
     58    multiplication at the infinity point. We may have
     59    MAYBE_mul_basecase == 0, and still get s just below
     60    MUL_TOOM22_THRESHOLD. If MUL_TOOM33_THRESHOLD == 7, we can even get
     61    s == 1 and mpn_toom22_mul will crash.
     62 */
     63 
     64 #define TOOM33_MUL_N_REC(p, a, b, n, ws)				\
     65   do {									\
     66     if (MAYBE_mul_basecase						\
     67 	&& BELOW_THRESHOLD (n, MUL_TOOM22_THRESHOLD))			\
     68       mpn_mul_basecase (p, a, n, b, n);					\
     69     else if (! MAYBE_mul_toom33						\
     70 	     || BELOW_THRESHOLD (n, MUL_TOOM33_THRESHOLD))		\
     71       mpn_toom22_mul (p, a, n, b, n, ws);				\
     72     else								\
     73       mpn_toom33_mul (p, a, n, b, n, ws);				\
     74   } while (0)
     75 
     76 void
     77 mpn_toom33_mul (mp_ptr pp,
     78 		mp_srcptr ap, mp_size_t an,
     79 		mp_srcptr bp, mp_size_t bn,
     80 		mp_ptr scratch)
     81 {
     82   mp_size_t n, s, t;
     83   int vm1_neg;
     84   mp_limb_t cy, vinf0;
     85   mp_ptr gp;
     86   mp_ptr as1, asm1, as2;
     87   mp_ptr bs1, bsm1, bs2;
     88 
     89 #define a0  ap
     90 #define a1  (ap + n)
     91 #define a2  (ap + 2*n)
     92 #define b0  bp
     93 #define b1  (bp + n)
     94 #define b2  (bp + 2*n)
     95 
     96   n = (an + 2) / (size_t) 3;
     97 
     98   s = an - 2 * n;
     99   t = bn - 2 * n;
    100 
    101   ASSERT (an >= bn);
    102 
    103   ASSERT (0 < s && s <= n);
    104   ASSERT (0 < t && t <= n);
    105 
    106   as1  = scratch + 4 * n + 4;
    107   asm1 = scratch + 2 * n + 2;
    108   as2 = pp + n + 1;
    109 
    110   bs1 = pp;
    111   bsm1 = scratch + 3 * n + 3; /* we need 4n+4 <= 4n+s+t */
    112   bs2 = pp + 2 * n + 2;
    113 
    114   gp = scratch;
    115 
    116   vm1_neg = 0;
    117 
    118   /* Compute as1 and asm1.  */
    119   cy = mpn_add (gp, a0, n, a2, s);
    120 #if HAVE_NATIVE_mpn_add_n_sub_n
    121   if (cy == 0 && mpn_cmp (gp, a1, n) < 0)
    122     {
    123       cy = mpn_add_n_sub_n (as1, asm1, a1, gp, n);
    124       as1[n] = cy >> 1;
    125       asm1[n] = 0;
    126       vm1_neg = 1;
    127     }
    128   else
    129     {
    130       mp_limb_t cy2;
    131       cy2 = mpn_add_n_sub_n (as1, asm1, gp, a1, n);
    132       as1[n] = cy + (cy2 >> 1);
    133       asm1[n] = cy - (cy2 & 1);
    134     }
    135 #else
    136   as1[n] = cy + mpn_add_n (as1, gp, a1, n);
    137   if (cy == 0 && mpn_cmp (gp, a1, n) < 0)
    138     {
    139       mpn_sub_n (asm1, a1, gp, n);
    140       asm1[n] = 0;
    141       vm1_neg = 1;
    142     }
    143   else
    144     {
    145       cy -= mpn_sub_n (asm1, gp, a1, n);
    146       asm1[n] = cy;
    147     }
    148 #endif
    149 
    150   /* Compute as2.  */
    151 #if HAVE_NATIVE_mpn_rsblsh1_n
    152   cy = mpn_add_n (as2, a2, as1, s);
    153   if (s != n)
    154     cy = mpn_add_1 (as2 + s, as1 + s, n - s, cy);
    155   cy += as1[n];
    156   cy = 2 * cy + mpn_rsblsh1_n (as2, a0, as2, n);
    157 #else
    158 #if HAVE_NATIVE_mpn_addlsh1_n
    159   cy  = mpn_addlsh1_n (as2, a1, a2, s);
    160   if (s != n)
    161     cy = mpn_add_1 (as2 + s, a1 + s, n - s, cy);
    162   cy = 2 * cy + mpn_addlsh1_n (as2, a0, as2, n);
    163 #else
    164   cy = mpn_add_n (as2, a2, as1, s);
    165   if (s != n)
    166     cy = mpn_add_1 (as2 + s, as1 + s, n - s, cy);
    167   cy += as1[n];
    168   cy = 2 * cy + mpn_lshift (as2, as2, n, 1);
    169   cy -= mpn_sub_n (as2, as2, a0, n);
    170 #endif
    171 #endif
    172   as2[n] = cy;
    173 
    174   /* Compute bs1 and bsm1.  */
    175   cy = mpn_add (gp, b0, n, b2, t);
    176 #if HAVE_NATIVE_mpn_add_n_sub_n
    177   if (cy == 0 && mpn_cmp (gp, b1, n) < 0)
    178     {
    179       cy = mpn_add_n_sub_n (bs1, bsm1, b1, gp, n);
    180       bs1[n] = cy >> 1;
    181       bsm1[n] = 0;
    182       vm1_neg ^= 1;
    183     }
    184   else
    185     {
    186       mp_limb_t cy2;
    187       cy2 = mpn_add_n_sub_n (bs1, bsm1, gp, b1, n);
    188       bs1[n] = cy + (cy2 >> 1);
    189       bsm1[n] = cy - (cy2 & 1);
    190     }
    191 #else
    192   bs1[n] = cy + mpn_add_n (bs1, gp, b1, n);
    193   if (cy == 0 && mpn_cmp (gp, b1, n) < 0)
    194     {
    195       mpn_sub_n (bsm1, b1, gp, n);
    196       bsm1[n] = 0;
    197       vm1_neg ^= 1;
    198     }
    199   else
    200     {
    201       cy -= mpn_sub_n (bsm1, gp, b1, n);
    202       bsm1[n] = cy;
    203     }
    204 #endif
    205 
    206   /* Compute bs2.  */
    207 #if HAVE_NATIVE_mpn_rsblsh1_n
    208   cy = mpn_add_n (bs2, b2, bs1, t);
    209   if (t != n)
    210     cy = mpn_add_1 (bs2 + t, bs1 + t, n - t, cy);
    211   cy += bs1[n];
    212   cy = 2 * cy + mpn_rsblsh1_n (bs2, b0, bs2, n);
    213 #else
    214 #if HAVE_NATIVE_mpn_addlsh1_n
    215   cy  = mpn_addlsh1_n (bs2, b1, b2, t);
    216   if (t != n)
    217     cy = mpn_add_1 (bs2 + t, b1 + t, n - t, cy);
    218   cy = 2 * cy + mpn_addlsh1_n (bs2, b0, bs2, n);
    219 #else
    220   cy  = mpn_add_n (bs2, bs1, b2, t);
    221   if (t != n)
    222     cy = mpn_add_1 (bs2 + t, bs1 + t, n - t, cy);
    223   cy += bs1[n];
    224   cy = 2 * cy + mpn_lshift (bs2, bs2, n, 1);
    225   cy -= mpn_sub_n (bs2, bs2, b0, n);
    226 #endif
    227 #endif
    228   bs2[n] = cy;
    229 
    230   ASSERT (as1[n] <= 2);
    231   ASSERT (bs1[n] <= 2);
    232   ASSERT (asm1[n] <= 1);
    233   ASSERT (bsm1[n] <= 1);
    234   ASSERT (as2[n] <= 6);
    235   ASSERT (bs2[n] <= 6);
    236 
    237 #define v0    pp				/* 2n */
    238 #define v1    (pp + 2 * n)			/* 2n+1 */
    239 #define vinf  (pp + 4 * n)			/* s+t */
    240 #define vm1   scratch				/* 2n+1 */
    241 #define v2    (scratch + 2 * n + 1)		/* 2n+2 */
    242 #define scratch_out  (scratch + 5 * n + 5)
    243 
    244   /* vm1, 2n+1 limbs */
    245 #ifdef SMALLER_RECURSION
    246   TOOM33_MUL_N_REC (vm1, asm1, bsm1, n, scratch_out);
    247   cy = 0;
    248   if (asm1[n] != 0)
    249     cy = bsm1[n] + mpn_add_n (vm1 + n, vm1 + n, bsm1, n);
    250   if (bsm1[n] != 0)
    251     cy += mpn_add_n (vm1 + n, vm1 + n, asm1, n);
    252   vm1[2 * n] = cy;
    253 #else
    254   TOOM33_MUL_N_REC (vm1, asm1, bsm1, n + 1, scratch_out);
    255 #endif
    256 
    257   TOOM33_MUL_N_REC (v2, as2, bs2, n + 1, scratch_out);	/* v2, 2n+1 limbs */
    258 
    259   /* vinf, s+t limbs */
    260   if (s > t)  mpn_mul (vinf, a2, s, b2, t);
    261   else        TOOM33_MUL_N_REC (vinf, a2, b2, s, scratch_out);
    262 
    263   vinf0 = vinf[0];				/* v1 overlaps with this */
    264 
    265 #ifdef SMALLER_RECURSION
    266   /* v1, 2n+1 limbs */
    267   TOOM33_MUL_N_REC (v1, as1, bs1, n, scratch_out);
    268   if (as1[n] == 1)
    269     {
    270       cy = bs1[n] + mpn_add_n (v1 + n, v1 + n, bs1, n);
    271     }
    272   else if (as1[n] != 0)
    273     {
    274 #if HAVE_NATIVE_mpn_addlsh1_n
    275       cy = 2 * bs1[n] + mpn_addlsh1_n (v1 + n, v1 + n, bs1, n);
    276 #else
    277       cy = 2 * bs1[n] + mpn_addmul_1 (v1 + n, bs1, n, CNST_LIMB(2));
    278 #endif
    279     }
    280   else
    281     cy = 0;
    282   if (bs1[n] == 1)
    283     {
    284       cy += mpn_add_n (v1 + n, v1 + n, as1, n);
    285     }
    286   else if (bs1[n] != 0)
    287     {
    288 #if HAVE_NATIVE_mpn_addlsh1_n
    289       cy += mpn_addlsh1_n (v1 + n, v1 + n, as1, n);
    290 #else
    291       cy += mpn_addmul_1 (v1 + n, as1, n, CNST_LIMB(2));
    292 #endif
    293     }
    294   v1[2 * n] = cy;
    295 #else
    296   cy = vinf[1];
    297   TOOM33_MUL_N_REC (v1, as1, bs1, n + 1, scratch_out);
    298   vinf[1] = cy;
    299 #endif
    300 
    301   TOOM33_MUL_N_REC (v0, ap, bp, n, scratch_out);	/* v0, 2n limbs */
    302 
    303   mpn_toom_interpolate_5pts (pp, v2, vm1, n, s + t, vm1_neg, vinf0);
    304 }
    305