summaryrefslogtreecommitdiff
path: root/powm.c
blob: 1f0b0c708cd0bf67efc429cbe368a89b0526f004 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
#include "powm.h"
#include <stdlib.h>
#include <stdint.h>
#include <stddef.h>
#include <string.h>
#include <stdio.h>
#include <assert.h>

/*
This code implements result =  base ^ exp (mod m)

Internally bignums are represented by limbs of type limb_t ordered in little
endian (LLF).

Currently, with gcc -O3, we measure at around 2,3x the time libgmp needs for
it's powm which is not so bad, considering we don't use any assembler
optimization.

Hot spot function is mp_mul_uint_add.

We should set PREVENT_TIMING_ATTACKS so that powm() does not leak information
about the number of bits set, but on average doubling the time needed.
*/

#define WITH_PREVENT_TIMING_ATTACKS
//#define WITHOUT_UNROLLING
#define WITH_ROUNDS 128
#define KARATSUBA_THRESHOLD 16

/* Returns 0 if a and b are equal, -1 if a < b, 1 if a > b */
static int mp_cmp( limb_t const *a, limb_t const *b, int limbs )
{
  while( limbs-- )
  {
    if( a[limbs] < b[limbs] ) return -1;
    if( a[limbs] > b[limbs] ) return  1;
  }
  return 0;
}

/* Subtract b from a, store in a. Expects enough words prepended
   to borrow from */
static void mp_sub( limb_t *a, limb_t const *b, int limbs )
{
  int borrow = 0, borrow_temp;
  while( limbs-- )
  {
    limb_t temp = *a - *(b++);
    borrow_temp = temp > *a;
    *a = temp - borrow;
    borrow = borrow_temp | ( *(a++) > temp );
  }
  while( borrow )
  {
    limb_t temp = *a - borrow;
    borrow = temp > *a;
    *(a++) = temp;
  }
}

/* Add b to a, store in a. Operates on limbs + flimbs words, with
   flimbs the amount of limbs to propagate the carry to */
static void mp_add( limb_t *a, limb_t const *b, int limbs, int flimbs )
{
  dlimb_t acc = 0;
  while( limbs-- )
  {
    acc += (dlimb_t)*a + (dlimb_t)*(b++);
    *(a++) = (limb_t)acc;
    acc >>= 8*sizeof(limb_t);
  }
  while( acc && flimbs-- )
  {
    acc += (dlimb_t)*a;
    *(a++) = (limb_t)acc;
    acc >>= 8*sizeof(limb_t);
  }
}

/* Subtract b from a and store in result. Expects nothing to borrow.*/
static void mp_sub_mod( limb_t * result, limb_t const *a, limb_t const * b, int limbs )
{
  int borrow = 0, borrow_temp;
  while( limbs-- )
  {
    limb_t temp = *a - *(b++);
    borrow_temp = temp > *(a++);
    *result = temp - borrow;
    borrow = borrow_temp | (*(result++) > temp );
  }
}

/* Fast negate */
static void mp_negate( limb_t * p, int limbs )
{
  /* Only as long as we find 0 along the way, we need to
     propagate the carry of -x == !x+1 */
  int carry = 1;
  while( limbs-- )
  {
    limb_t v = carry + ( *p ^ (limb_t)-1 );
    if(v) carry = 0;
    *(p++)=v;
  }
}

/* Multiplies a with fac, adds to result.
   result is guaranteed to be initialized with enough limbs prepended to take the carry */
static void mp_mul_uint_add( limb_t *result, limb_t const *a, limb_t fac, int limbs )
{
  dlimb_t acc15 = 0;
#ifndef WITHOUT_UNROLLING
  dlimb_t acc0 = 0, acc1 = 0, acc2 = 0, acc3 = 0;
  dlimb_t acc4 = 0, acc5 = 0, acc6 = 0, acc7 = 0;
  dlimb_t acc8 = 0, acc9 = 0, acc10= 0, acc11= 0;
  dlimb_t acc12= 0, acc13= 0, acc14= 0;

  while( limbs >= 16 )
  {
#define MUL_UINT_ADD_ROUND(ACC,CARRY) \
    acc##ACC = ( acc##CARRY >> 8*sizeof(limb_t) ) + (dlimb_t)result[ACC] + (dlimb_t)a[ACC] * (dlimb_t)fac; \
    result[ACC] = (limb_t)acc##ACC;

    MUL_UINT_ADD_ROUND(0,15)
    MUL_UINT_ADD_ROUND(1,0)
    MUL_UINT_ADD_ROUND(2,1)
    MUL_UINT_ADD_ROUND(3,2)
    MUL_UINT_ADD_ROUND(4,3)
    MUL_UINT_ADD_ROUND(5,4)
    MUL_UINT_ADD_ROUND(6,5)
    MUL_UINT_ADD_ROUND(7,6)
    MUL_UINT_ADD_ROUND(8,7)
    MUL_UINT_ADD_ROUND(9,8)
    MUL_UINT_ADD_ROUND(10,9)
    MUL_UINT_ADD_ROUND(11,10)
    MUL_UINT_ADD_ROUND(12,11)
    MUL_UINT_ADD_ROUND(13,12)
    MUL_UINT_ADD_ROUND(14,13)
    MUL_UINT_ADD_ROUND(15,14)

    result += 16;
    a += 16;
    limbs -= 16;
  }
  acc15 >>= 8*sizeof(limb_t);
#endif
  while( limbs-- )
  {
    acc15 += (dlimb_t)*result + (dlimb_t)*(a++) * (dlimb_t)fac;
    *(result++) = (limb_t)acc15;
    acc15 >>= 8*sizeof(limb_t);
  }
  while( acc15 )
  {
    acc15 += (dlimb_t)*result;
    *(result++) = (limb_t)acc15;
    acc15 >>= 8*sizeof(limb_t);
  }
}

/* Multiplies a and b, adds to result, base case version. */
static void mp_mul_oper_add( limb_t * result, limb_t const *a, limb_t const *b, int limbs )
{
  int limb = limbs;
  while( limb-- )
    mp_mul_uint_add( result++, a, *(b++), limbs );
}

/* Optimized mp_mul_oper_add for a == b, i.e. squaring */
static void mp_sqr( limb_t *result, limb_t const * a, int limbs )
{
  while( limbs-- ) {
    limb_t fac = *(a++), *dest = result;
    limb_t const *src = a;
    int limb = limbs;

    dlimb_t acc = (dlimb_t)*dest + (dlimb_t)fac * (dlimb_t)fac;

    *(dest++) = (limb_t)acc;
    acc >>= 8*sizeof(limb_t);

    while( limb-- )
    {
      dlimb_t subresult = (dlimb_t)fac * (dlimb_t)*(src++);
      int carry = !!( subresult >> (16*sizeof(limb_t)-1));

      acc += 2 * subresult + (dlimb_t)*dest;
      *(dest++) = (limb_t)acc;

      acc >>= 8*sizeof(limb_t);
      acc += (dlimb_t)carry << 8*sizeof(limb_t);
    }

    while( acc )
    {
      acc += (dlimb_t)*dest;
      *(dest++) = (limb_t)acc;
      acc >>= 8*sizeof(limb_t);
    }
    result += 2;
  }
}

/* Optimized karatsuba (toom2.2) for a == b, i.e. squaring */
static void mp_mul_kara_square( limb_t* p, limb_t const *a, int len, limb_t *scratch )
{
  memset( p, 0, 2 * len * sizeof( limb_t ));
  if( len <= KARATSUBA_THRESHOLD )
     mp_sqr( p, a, len );
  else
  {
    int n = len / 2;

    if( mp_cmp( a, a + n, n ) > 0 )
      mp_sub_mod( scratch, a, a + n, n );
    else
      mp_sub_mod( scratch, a + n, a, n );

    mp_mul_kara_square( p + n, scratch, n, scratch + len );
    mp_negate( p + n, len + n );

    mp_mul_kara_square( scratch, a + n, n, scratch + len );
    mp_add( p + len, scratch, len, 0 );
    mp_add( p + n, scratch, len, n );

    mp_mul_kara_square( scratch, a, n, scratch + len );
    mp_add( p + n, scratch, len, n );
    mp_add( p, scratch, len, len );
  }
}

/* karatsuba (toom2.2), generic */
static void mp_mul_kara( limb_t* p, limb_t const *a, limb_t const *b, int len, limb_t *scratch )
{
  memset( p, 0, 2 * len * sizeof( limb_t ));
  if( len <= KARATSUBA_THRESHOLD )
    mp_mul_oper_add( p, a, b, len );
  else
  {
    int sign = 0, n = len / 2;

    if( mp_cmp( a, a + n, n ) > 0 )
      mp_sub_mod( scratch, a, a + n, n );
    else
    {
      mp_sub_mod( scratch, a + n, a, n );
      sign = 1;
    }

    if( mp_cmp( b, b + n, n ) > 0 )
      mp_sub_mod( scratch + n, b, b + n, n );
    else
    {
      mp_sub_mod( scratch + n, b + n, b, n );
      sign ^= 1;
    }

    mp_mul_kara( p + n, scratch, scratch + n, n, scratch + len );
    if( !sign )
      mp_negate( p + n, len + n );

    mp_mul_kara( scratch, a + n, b + n, n, scratch + len );
    mp_add( p + len, scratch, len, 0 );
    mp_add( p + n, scratch, len, n );

    mp_mul_kara( scratch, a, b, n, scratch + len );
    mp_add( p + n, scratch, len, n );
    mp_add( p, scratch, len, len );
  }
}

/* Multiply a and b, store in a, work in montgomery domain to achieve multiply,
   a needs to be reduced by k * modulus (with k being the unknown integer)
   until all lower limbs are 0, allowing exact division (via implicit shift)
   by 2^r

   if !do_mul, we convert from montgomery domain back to modulus domain and
   thus only multiply by 1 before reducing
 */
static void redc( limb_t * a,
            const limb_t * b, int limbs,
            const limb_t * modulus, limb_t r_inverse, int do_mul )
{
  limb_t scratch[ 2 * ( limbs - KARATSUBA_THRESHOLD ) ];
  limb_t temp[ 1 + 2 * limbs ];
  int limb;

  /* Not necessary for transforming back */
  if( do_mul )
  {
    temp[2*limbs] = 0;
    if( a == b )
        mp_mul_kara_square( temp, a, limbs, scratch );
    else
        mp_mul_kara( temp, a, b, limbs, scratch );
  }
  else
  {
    memset( temp + limbs, 0, (limbs + 1) * sizeof(limb_t));
    memcpy( temp, a, limbs * sizeof(limb_t));
  }

  /* m = p * ( m * R_1 ) % R */
  for( limb = 0; limb < limbs; ++limb )
  {
    limb_t k = temp[limb] * r_inverse;
    mp_mul_uint_add( temp+limb, modulus, k, limbs );
  }

  /* the lower limbs of temp are now zero
     if necessary, reduce temp to fit in limbs */
  if( temp[2*limbs] || mp_cmp( temp + limbs, modulus, limbs ) > 0 )
    mp_sub( temp + limbs, modulus, limbs );

  memcpy( a, temp + limbs, limbs * sizeof(limb_t) );
  memset( scratch, 0, sizeof(scratch));
  memset( temp, 0, sizeof(temp));
}

/* calculate base ^ exponent modulo modulus */
void powm_internal( limb_t * result,
              const limb_t * base,     uint32_t base_limbs,
              const limb_t * exponent, uint32_t exp_limbs,
              const limb_t * modulus,  uint32_t mod_limbs,
              const limb_t * r_square, const limb_t r_inverse )
{
  limb_t acc[mod_limbs];
#ifdef WITH_PREVENT_TIMING_ATTACKS
  limb_t dummy[mod_limbs];
#endif
  int first = 0, bit;
  int expbits = 8 * sizeof(limb_t) * exp_limbs;

  memset( result, 0, sizeof(limb_t) * mod_limbs );
  memset( acc,    0, sizeof(limb_t) * mod_limbs );

  memcpy( acc, base, sizeof(limb_t) * base_limbs );
  //*acc = base;

  /* Transform base into montgomery domain */
  redc( acc, r_square, mod_limbs, modulus, r_inverse, 1 );

  /* mul in temp and if bit set in exponent, multiply into accumulator */
  for( bit = 0; bit < expbits ; bit++ )
  {
    int this_bit = bit % ( sizeof(limb_t) * 8 );
    if( ( exponent[ bit / ( sizeof(limb_t) * 8 ) ] >> this_bit ) & 1 ) {
      if( first++ )
        redc( result, acc, mod_limbs, modulus, r_inverse, 1 );
      else
        memcpy( result, acc, sizeof(limb_t) * mod_limbs );
    }
#ifdef WITH_PREVENT_TIMING_ATTACKS
    else
      redc( dummy, acc, mod_limbs, modulus, r_inverse, 1 );
#endif
    redc( acc, acc, mod_limbs, modulus, r_inverse, 1 );
  }

  /* Transform result in acc back into mod m domain */
  if( first )
    redc( result, 0, mod_limbs, modulus, r_inverse, 0 );
  else /* base ^ 0 mod m = 1 */
  {
    memset( result+1, 0, mod_limbs * sizeof(limb_t) );
    *result = 1;
  }
}

uint32_t mp_import( limb_t *dest, uint8_t const * src, uint32_t len )
{
  int byte = 0;
  limb_t *d = dest;
  limb_t acc = 0;

  while( len-- )
  {
    acc |= ((limb_t)src[len]) << ( 8 * byte );
    if( ++byte == sizeof(limb_t))
    {
      *(d++) = acc;
      acc = 0;
      byte = 0;
    }
  }
  if( byte )
    *(d++) = acc;
  return ( d - dest );
}

void mp_export( uint8_t * dest, const limb_t * src, uint32_t limb_size )
{
  while( limb_size )
  {
    limb_t outlimb = src[--limb_size];
    int cnt = sizeof( limb_t );
    while( cnt )
      *(dest++) = (uint8_t)( outlimb >> ( 8 * --cnt ) );
  }
}