#include "powm.h" #include #include #include #include #include #include /* This code implements result = base ^ exp (mod m) Internally bignums are represented by limbs of type limb_t ordered in little endian (LLF). Currently, with gcc -O3, we measure at around 2,3x the time libgmp needs for it's powm which is not so bad, considering we don't use any assembler optimization. Hot spot function is mp_mul_uint_add. We should set PREVENT_TIMING_ATTACKS so that powm() does not leak information about the number of bits set, but on average doubling the time needed. */ #define WITH_PREVENT_TIMING_ATTACKS //#define WITHOUT_UNROLLING #define WITH_ROUNDS 128 #define KARATSUBA_THRESHOLD 16 /* Returns 0 if a and b are equal, -1 if a < b, 1 if a > b */ static int mp_cmp( limb_t const *a, limb_t const *b, int limbs ) { while( limbs-- ) { if( a[limbs] < b[limbs] ) return -1; if( a[limbs] > b[limbs] ) return 1; } return 0; } /* Subtract b from a, store in a. Expects enough words prepended to borrow from */ static void mp_sub( limb_t *a, limb_t const *b, int limbs ) { int borrow = 0, borrow_temp; while( limbs-- ) { limb_t temp = *a - *(b++); borrow_temp = temp > *a; *a = temp - borrow; borrow = borrow_temp | ( *(a++) > temp ); } while( borrow ) { limb_t temp = *a - borrow; borrow = temp > *a; *(a++) = temp; } } /* Add b to a, store in a. Operates on limbs + flimbs words, with flimbs the amount of limbs to propagate the carry to */ static void mp_add( limb_t *a, limb_t const *b, int limbs, int flimbs ) { dlimb_t acc = 0; while( limbs-- ) { acc += (dlimb_t)*a + (dlimb_t)*(b++); *(a++) = (limb_t)acc; acc >>= 8*sizeof(limb_t); } while( acc && flimbs-- ) { acc += (dlimb_t)*a; *(a++) = (limb_t)acc; acc >>= 8*sizeof(limb_t); } } /* Subtract b from a and store in result. Expects nothing to borrow.*/ static void mp_sub_mod( limb_t * result, limb_t const *a, limb_t const * b, int limbs ) { int borrow = 0, borrow_temp; while( limbs-- ) { limb_t temp = *a - *(b++); borrow_temp = temp > *(a++); *result = temp - borrow; borrow = borrow_temp | (*(result++) > temp ); } } /* Fast negate */ static void mp_negate( limb_t * p, int limbs ) { /* Only as long as we find 0 along the way, we need to propagate the carry of -x == !x+1 */ int carry = 1; while( limbs-- ) { limb_t v = carry + ( *p ^ (limb_t)-1 ); if(v) carry = 0; *(p++)=v; } } /* Multiplies a with fac, adds to result. result is guaranteed to be initialized with enough limbs prepended to take the carry */ static void mp_mul_uint_add( limb_t *result, limb_t const *a, limb_t fac, int limbs ) { dlimb_t acc15 = 0; #ifndef WITHOUT_UNROLLING dlimb_t acc0 = 0, acc1 = 0, acc2 = 0, acc3 = 0; dlimb_t acc4 = 0, acc5 = 0, acc6 = 0, acc7 = 0; dlimb_t acc8 = 0, acc9 = 0, acc10= 0, acc11= 0; dlimb_t acc12= 0, acc13= 0, acc14= 0; while( limbs >= 16 ) { #define MUL_UINT_ADD_ROUND(ACC,CARRY) \ acc##ACC = ( acc##CARRY >> 8*sizeof(limb_t) ) + (dlimb_t)result[ACC] + (dlimb_t)a[ACC] * (dlimb_t)fac; \ result[ACC] = (limb_t)acc##ACC; MUL_UINT_ADD_ROUND(0,15) MUL_UINT_ADD_ROUND(1,0) MUL_UINT_ADD_ROUND(2,1) MUL_UINT_ADD_ROUND(3,2) MUL_UINT_ADD_ROUND(4,3) MUL_UINT_ADD_ROUND(5,4) MUL_UINT_ADD_ROUND(6,5) MUL_UINT_ADD_ROUND(7,6) MUL_UINT_ADD_ROUND(8,7) MUL_UINT_ADD_ROUND(9,8) MUL_UINT_ADD_ROUND(10,9) MUL_UINT_ADD_ROUND(11,10) MUL_UINT_ADD_ROUND(12,11) MUL_UINT_ADD_ROUND(13,12) MUL_UINT_ADD_ROUND(14,13) MUL_UINT_ADD_ROUND(15,14) result += 16; a += 16; limbs -= 16; } acc15 >>= 8*sizeof(limb_t); #endif while( limbs-- ) { acc15 += (dlimb_t)*result + (dlimb_t)*(a++) * (dlimb_t)fac; *(result++) = (limb_t)acc15; acc15 >>= 8*sizeof(limb_t); } while( acc15 ) { acc15 += (dlimb_t)*result; *(result++) = (limb_t)acc15; acc15 >>= 8*sizeof(limb_t); } } /* Multiplies a and b, adds to result, base case version. */ static void mp_mul_oper_add( limb_t * result, limb_t const *a, limb_t const *b, int limbs ) { int limb = limbs; while( limb-- ) mp_mul_uint_add( result++, a, *(b++), limbs ); } /* Optimized mp_mul_oper_add for a == b, i.e. squaring */ static void mp_sqr( limb_t *result, limb_t const * a, int limbs ) { while( limbs-- ) { limb_t fac = *(a++), *dest = result; limb_t const *src = a; int limb = limbs; dlimb_t acc = (dlimb_t)*dest + (dlimb_t)fac * (dlimb_t)fac; *(dest++) = (limb_t)acc; acc >>= 8*sizeof(limb_t); while( limb-- ) { dlimb_t subresult = (dlimb_t)fac * (dlimb_t)*(src++); int carry = !!( subresult >> (16*sizeof(limb_t)-1)); acc += 2 * subresult + (dlimb_t)*dest; *(dest++) = (limb_t)acc; acc >>= 8*sizeof(limb_t); acc += (dlimb_t)carry << 8*sizeof(limb_t); } while( acc ) { acc += (dlimb_t)*dest; *(dest++) = (limb_t)acc; acc >>= 8*sizeof(limb_t); } result += 2; } } /* Optimized karatsuba (toom2.2) for a == b, i.e. squaring */ static void mp_mul_kara_square( limb_t* p, limb_t const *a, int len, limb_t *scratch ) { memset( p, 0, 2 * len * sizeof( limb_t )); if( len <= KARATSUBA_THRESHOLD ) mp_sqr( p, a, len ); else { int n = len / 2; if( mp_cmp( a, a + n, n ) > 0 ) mp_sub_mod( scratch, a, a + n, n ); else mp_sub_mod( scratch, a + n, a, n ); mp_mul_kara_square( p + n, scratch, n, scratch + len ); mp_negate( p + n, len + n ); mp_mul_kara_square( scratch, a + n, n, scratch + len ); mp_add( p + len, scratch, len, 0 ); mp_add( p + n, scratch, len, n ); mp_mul_kara_square( scratch, a, n, scratch + len ); mp_add( p + n, scratch, len, n ); mp_add( p, scratch, len, len ); } } /* karatsuba (toom2.2), generic */ static void mp_mul_kara( limb_t* p, limb_t const *a, limb_t const *b, int len, limb_t *scratch ) { memset( p, 0, 2 * len * sizeof( limb_t )); if( len <= KARATSUBA_THRESHOLD ) mp_mul_oper_add( p, a, b, len ); else { int sign = 0, n = len / 2; if( mp_cmp( a, a + n, n ) > 0 ) mp_sub_mod( scratch, a, a + n, n ); else { mp_sub_mod( scratch, a + n, a, n ); sign = 1; } if( mp_cmp( b, b + n, n ) > 0 ) mp_sub_mod( scratch + n, b, b + n, n ); else { mp_sub_mod( scratch + n, b + n, b, n ); sign ^= 1; } mp_mul_kara( p + n, scratch, scratch + n, n, scratch + len ); if( !sign ) mp_negate( p + n, len + n ); mp_mul_kara( scratch, a + n, b + n, n, scratch + len ); mp_add( p + len, scratch, len, 0 ); mp_add( p + n, scratch, len, n ); mp_mul_kara( scratch, a, b, n, scratch + len ); mp_add( p + n, scratch, len, n ); mp_add( p, scratch, len, len ); } } /* Multiply a and b, store in a, work in montgomery domain to achieve multiply, a needs to be reduced by k * modulus (with k being the unknown integer) until all lower limbs are 0, allowing exact division (via implicit shift) by 2^r if !do_mul, we convert from montgomery domain back to modulus domain and thus only multiply by 1 before reducing */ static void redc( limb_t * a, const limb_t * b, int limbs, const limb_t * modulus, limb_t r_inverse, int do_mul ) { limb_t scratch[ 2 * ( limbs - KARATSUBA_THRESHOLD ) ]; limb_t temp[ 1 + 2 * limbs ]; int limb; /* Not necessary for transforming back */ if( do_mul ) { temp[2*limbs] = 0; if( a == b ) mp_mul_kara_square( temp, a, limbs, scratch ); else mp_mul_kara( temp, a, b, limbs, scratch ); } else { memset( temp + limbs, 0, (limbs + 1) * sizeof(limb_t)); memcpy( temp, a, limbs * sizeof(limb_t)); } /* m = p * ( m * R_1 ) % R */ for( limb = 0; limb < limbs; ++limb ) { limb_t k = temp[limb] * r_inverse; mp_mul_uint_add( temp+limb, modulus, k, limbs ); } /* the lower limbs of temp are now zero if necessary, reduce temp to fit in limbs */ if( temp[2*limbs] || mp_cmp( temp + limbs, modulus, limbs ) > 0 ) mp_sub( temp + limbs, modulus, limbs ); memcpy( a, temp + limbs, limbs * sizeof(limb_t) ); memset( scratch, 0, sizeof(scratch)); memset( temp, 0, sizeof(temp)); } /* calculate base ^ exponent modulo modulus */ void powm_internal( limb_t * result, const limb_t * base, uint32_t base_limbs, const limb_t * exponent, uint32_t exp_limbs, const limb_t * modulus, uint32_t mod_limbs, const limb_t * r_square, const limb_t r_inverse ) { limb_t acc[mod_limbs]; #ifdef WITH_PREVENT_TIMING_ATTACKS limb_t dummy[mod_limbs]; #endif int first = 0, bit; int expbits = 8 * sizeof(limb_t) * exp_limbs; memset( result, 0, sizeof(limb_t) * mod_limbs ); memset( acc, 0, sizeof(limb_t) * mod_limbs ); memcpy( acc, base, sizeof(limb_t) * base_limbs ); //*acc = base; /* Transform base into montgomery domain */ redc( acc, r_square, mod_limbs, modulus, r_inverse, 1 ); /* mul in temp and if bit set in exponent, multiply into accumulator */ for( bit = 0; bit < expbits ; bit++ ) { int this_bit = bit % ( sizeof(limb_t) * 8 ); if( ( exponent[ bit / ( sizeof(limb_t) * 8 ) ] >> this_bit ) & 1 ) { if( first++ ) redc( result, acc, mod_limbs, modulus, r_inverse, 1 ); else memcpy( result, acc, sizeof(limb_t) * mod_limbs ); } #ifdef WITH_PREVENT_TIMING_ATTACKS else redc( dummy, acc, mod_limbs, modulus, r_inverse, 1 ); #endif redc( acc, acc, mod_limbs, modulus, r_inverse, 1 ); } /* Transform result in acc back into mod m domain */ if( first ) redc( result, 0, mod_limbs, modulus, r_inverse, 0 ); else /* base ^ 0 mod m = 1 */ { memset( result+1, 0, mod_limbs * sizeof(limb_t) ); *result = 1; } } uint32_t mp_import( limb_t *dest, uint8_t const * src, uint32_t len ) { int byte = 0; limb_t *d = dest; limb_t acc = 0; while( len-- ) { acc |= ((limb_t)src[len]) << ( 8 * byte ); if( ++byte == sizeof(limb_t)) { *(d++) = acc; acc = 0; byte = 0; } } if( byte ) *(d++) = acc; return ( d - dest ); } void mp_export( uint8_t * dest, const limb_t * src, uint32_t limb_size ) { while( limb_size ) { limb_t outlimb = src[--limb_size]; int cnt = sizeof( limb_t ); while( cnt ) *(dest++) = (uint8_t)( outlimb >> ( 8 * --cnt ) ); } }