libmp: initial attempt at constant time code, faster reductions for special primes (for ecc)

introduce MPtimesafe flag to request time invariant computation
disables normalization so significant digits are not leaked.
front
cinap_lenrek 2015-11-21 09:39:59 +01:00
parent b677ab0c59
commit 38e1e5272f
32 changed files with 660 additions and 229 deletions

View File

@ -22,7 +22,10 @@ struct mpint
enum
{
MPstatic= 0x01,
MPstatic= 0x01, /* static constant */
MPnorm= 0x02, /* normalization status */
MPtimesafe= 0x04, /* request time invariant computation */
Dbytes= sizeof(mpdigit), /* bytes per digit */
Dbits= Dbytes*8 /* bits per digit */
};
@ -32,7 +35,7 @@ void mpsetminbits(int n); /* newly created mpint's get at least n bits */
mpint* mpnew(int n); /* create a new mpint with at least n bits */
void mpfree(mpint *b);
void mpbits(mpint *b, int n); /* ensure that b has at least n bits */
void mpnorm(mpint *b); /* dump leading zeros */
mpint* mpnorm(mpint *b); /* dump leading zeros */
mpint* mpcopy(mpint *b);
void mpassign(mpint *old, mpint *new);
@ -47,8 +50,10 @@ int mpfmt(Fmt*);
char* mptoa(mpint*, int, char*, int);
mpint* letomp(uchar*, uint, mpint*); /* byte array, little-endian */
int mptole(mpint*, uchar*, uint, uchar**);
void mptolel(mpint *b, uchar *p, int n);
mpint* betomp(uchar*, uint, mpint*); /* byte array, big-endian */
int mptobe(mpint*, uchar*, uint, uchar**);
void mptober(mpint *b, uchar *p, int n);
uint mptoui(mpint*); /* unsigned int */
mpint* uitomp(uint, mpint*);
int mptoi(mpint*); /* int */
@ -71,12 +76,20 @@ void mpmul(mpint *b1, mpint *b2, mpint *prod); /* prod = b1*b2 */
void mpexp(mpint *b, mpint *e, mpint *m, mpint *res); /* res = b**e mod m */
void mpmod(mpint *b, mpint *m, mpint *remainder); /* remainder = b mod m */
/* modular arithmetic, time invariant when 0≤b1≤m-1 and 0≤b2≤m-1 */
void mpmodadd(mpint *b1, mpint *b2, mpint *m, mpint *sum); /* sum = b1+b2 % m */
void mpmodsub(mpint *b1, mpint *b2, mpint *m, mpint *diff); /* diff = b1-b2 % m */
void mpmodmul(mpint *b1, mpint *b2, mpint *m, mpint *prod); /* prod = b1*b2 % m */
/* quotient = dividend/divisor, remainder = dividend % divisor */
void mpdiv(mpint *dividend, mpint *divisor, mpint *quotient, mpint *remainder);
/* return neg, 0, pos as b1-b2 is neg, 0, pos */
int mpcmp(mpint *b1, mpint *b2);
/* res = s != 0 ? b1 : b2 */
void mpsel(int s, mpint *b1, mpint *b2, mpint *res);
/* extended gcd return d, x, and y, s.t. d = gcd(a,b) and ax+by = d */
void mpextendedgcd(mpint *a, mpint *b, mpint *d, mpint *x, mpint *y);
@ -106,12 +119,14 @@ void mpvecdigmuladd(mpdigit *b, int n, mpdigit m, mpdigit *p);
/* prereq: p has room for n+1 digits */
int mpvecdigmulsub(mpdigit *b, int n, mpdigit m, mpdigit *p);
/* p[0:alen*blen-1] = a[0:alen-1] * b[0:blen-1] */
/* p[0:alen+blen-1] = a[0:alen-1] * b[0:blen-1] */
/* prereq: alen >= blen, p has room for m*n digits */
void mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p);
void mpvectsmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p);
/* sign of a - b or zero if the same */
int mpveccmp(mpdigit *a, int alen, mpdigit *b, int blen);
int mpvectscmp(mpdigit *a, int alen, mpdigit *b, int blen);
/* divide the 2 digit dividend by the one digit divisor and stick in quotient */
/* we assume that the result is one digit - overflow is all 1's */

View File

@ -1,6 +1,6 @@
.TH MP 2
.SH NAME
mpsetminbits, mpnew, mpfree, mpbits, mpnorm, mpcopy, mpassign, mprand, mpnrand, strtomp, mpfmt,mptoa, betomp, mptobe, letomp, mptole, mptoui, uitomp, mptoi, itomp, uvtomp, mptouv, vtomp, mptov, mpdigdiv, mpadd, mpsub, mpleft, mpright, mpmul, mpexp, mpmod, mpdiv, mpcmp, mpextendedgcd, mpinvert, mpsignif, mplowbits0, mpvecdigmuladd, mpvecdigmulsub, mpvecadd, mpvecsub, mpveccmp, mpvecmul, mpmagcmp, mpmagadd, mpmagsub, crtpre, crtin, crtout, crtprefree, crtresfree \- extended precision arithmetic
mpsetminbits, mpnew, mpfree, mpbits, mpnorm, mpcopy, mpassign, mprand, mpnrand, strtomp, mpfmt,mptoa, betomp, mptobe, mptober, letomp, mptole, mptolel, mptoui, uitomp, mptoi, itomp, uvtomp, mptouv, vtomp, mptov, mpdigdiv, mpadd, mpsub, mpleft, mpright, mpmul, mpexp, mpmod, mpmodadd, mpmodsub, mpmodmul, mpdiv, mpcmp, mpsel, mpextendedgcd, mpinvert, mpsignif, mplowbits0, mpvecdigmuladd, mpvecdigmulsub, mpvecadd, mpvecsub, mpveccmp, mpvecmul, mpmagcmp, mpmagadd, mpmagsub, crtpre, crtin, crtout, crtprefree, crtresfree \- extended precision arithmetic
.SH SYNOPSIS
.B #include <u.h>
.br
@ -22,7 +22,7 @@ void mpsetminbits(int n)
void mpbits(mpint *b, int n)
.PP
.B
void mpnorm(mpint *b)
mpint* mpnorm(mpint *b)
.PP
.B
mpint* mpcopy(mpint *b)
@ -52,12 +52,18 @@ mpint* betomp(uchar *buf, uint blen, mpint *b)
int mptobe(mpint *b, uchar *buf, uint blen, uchar **bufp)
.PP
.B
void mptober(mpint *b, uchar *buf, int blen)
.PP
.B
mpint* letomp(uchar *buf, uint blen, mpint *b)
.PP
.B
int mptole(mpint *b, uchar *buf, uint blen, uchar **bufp)
.PP
.B
void mptolel(mpint *b, uchar *buf, int blen)
.PP
.B
uint mptoui(mpint*)
.PP
.B
@ -115,12 +121,24 @@ void mpdiv(mpint *dividend, mpint *divisor, mpint *quotient,
mpint *remainder)
.PP
.B
void mpmodadd(mpint *b1, mpint *b2, mpint *m, mpint *sum)
.PP
.B
void mpmodsub(mpint *b1, mpint *b2, mpint *m, mpint *diff)
.PP
.B
void mpmodmul(mpint *b1, mpint *b2, mpint *m, mpint *prod)
.PP
.B
int mpcmp(mpint *b1, mpint *b2)
.PP
.B
int mpmagcmp(mpint *b1, mpint *b2)
.PP
.B
void mpsel(int s, mpint *b1, mpint *b2, mpint *res)
.PP
.B
void mpextendedgcd(mpint *a, mpint *b, mpint *d, mpint *x,
.br
.B
@ -383,6 +401,24 @@ deposited in the location pointed to by
Sign is ignored in these conversions, i.e., the byte
array version is always positive.
.PP
.I Mptober
and
.I mptolel
fill
.I blen
lower bytes of an
.I mpint
into a fixed length byte array.
.I Mptober
fills the bytes right adjusted in big endian order so that the least
significant byte is at
.I buf[blen-1]
while
.I mptolel
fills in little endian order; left adjusted; so that the least
significat byte is filled into
.IR buf[0] .
.PP
.IR Betomp ,
and
.I letomp
@ -486,6 +522,31 @@ is less than, equal to, or greater than
the same as
.I mpcmp
but ignores the sign and just compares magnitudes.
.TP
.I mpsel
assigns
.I b1
to
.I res
when
.I s
is not zero, otherwise
.I b2
is assigned to
.IR res .
.PD
.PP
Modular arithmetic:
.TF mpmodmul_
.TP
.I mpmodadd
.BR "sum = b1+b2 mod m" .
.TP
.I mpmodsub
.BR "diff = b1-b2 mod m" .
.TP
.I mpmodmul
.BR "prod = b1*b2 mod m" .
.PD
.PP
.I Mpextendedgcd
@ -564,8 +625,8 @@ We assume p has room for n+1 digits. It returns +1 is the result is positive an
-1 if negative.
.TP
.I mpvecmul
.BR "p[0:alen*blen] = a[0:alen-1] * b[0:blen-1]" .
We assume that p has room for alen*blen+1 digits.
.BR "p[0:alen+blen] = a[0:alen-1] * b[0:blen-1]" .
We assume that p has room for alen+blen+1 digits.
.TP
.I mpveccmp
This returns -1, 0, or +1 as a - b is negative, 0, or positive.
@ -576,6 +637,17 @@ This returns -1, 0, or +1 as a - b is negative, 0, or positive.
and
.I mpzero
are the constants 2, 1 and 0. These cannot be freed.
.SS "Time invariant computation"
.PP
In the field of cryptography, it is sometimes neccesary to implement
algorithms such that the runtime of the algorithm is not depdenent on
the input data. This library provides partial support for time
invariant computation with the
.I MPtimesafe
flag that can be set on input or destination operands to request timing
safe operation. The result of a timing safe operation will also have the
.I MPtimesafe
flag set and is not normalized.
.SS "Chinese remainder theorem
.PP
When computing in a non-prime modulus,

View File

@ -13,19 +13,12 @@ betomp(uchar *p, uint n, mpint *b)
b = mpnew(0);
setmalloctag(b, getcallerpc(&p));
}
// dump leading zeros
while(*p == 0 && n > 1){
p++;
n--;
}
// get the space
mpbits(b, n*8);
b->top = DIGITS(n*8);
m = b->top-1;
// first digit might not be Dbytes long
m = DIGITS(n*8);
b->top = m--;
b->sign = 1;
s = ((n-1)*8)%Dbits;
x = 0;
for(; n > 0; n--){
@ -37,6 +30,5 @@ betomp(uchar *p, uint n, mpint *b)
x = 0;
}
}
return b;
return mpnorm(b);
}

View File

@ -9,8 +9,10 @@ letomp(uchar *s, uint n, mpint *b)
int i=0, m = 0;
mpdigit x=0;
if(b == nil)
if(b == nil){
b = mpnew(0);
setmalloctag(b, getcallerpc(&s));
}
mpbits(b, 8*n);
for(; n > 0; n--){
x |= ((mpdigit)(*s++)) << i;
@ -24,5 +26,6 @@ letomp(uchar *s, uint n, mpint *b)
if(i > 0)
b->p[m++] = x;
b->top = m;
return b;
b->sign = 1;
return mpnorm(b);
}

View File

@ -6,12 +6,15 @@ FILES=\
mpfmt\
strtomp\
mptobe\
mptober\
mptole\
mptolel\
betomp\
letomp\
mpadd\
mpsub\
mpcmp\
mpsel\
mpfactorial\
mpmul\
mpleft\
@ -20,10 +23,12 @@ FILES=\
mpvecsub\
mpvecdigmuladd\
mpveccmp\
mpvectscmp\
mpdigdiv\
mpdiv\
mpexp\
mpmod\
mpmodop\
mpextendedgcd\
mpinvert\
mprand\

View File

@ -9,6 +9,8 @@ mpmagadd(mpint *b1, mpint *b2, mpint *sum)
int m, n;
mpint *t;
sum->flags |= (b1->flags | b2->flags) & MPtimesafe;
// get the sizes right
if(b2->top > b1->top){
t = b1;
@ -41,6 +43,7 @@ mpadd(mpint *b1, mpint *b2, mpint *sum)
int sign;
if(b1->sign != b2->sign){
assert(((b1->flags | b2->flags | sum->flags) & MPtimesafe) == 0);
if(b1->sign < 0)
mpmagsub(b2, b1, sum);
else

View File

@ -5,33 +5,27 @@
static mpdigit _mptwodata[1] = { 2 };
static mpint _mptwo =
{
1,
1,
1,
1, 1, 1,
_mptwodata,
MPstatic
MPstatic|MPnorm
};
mpint *mptwo = &_mptwo;
static mpdigit _mponedata[1] = { 1 };
static mpint _mpone =
{
1,
1,
1,
1, 1, 1,
_mponedata,
MPstatic
MPstatic|MPnorm
};
mpint *mpone = &_mpone;
static mpdigit _mpzerodata[1] = { 0 };
static mpint _mpzero =
{
1,
1,
0,
1, 1, 0,
_mpzerodata,
MPstatic
MPstatic|MPnorm
};
mpint *mpzero = &_mpzero;
@ -57,18 +51,17 @@ mpnew(int n)
if(n < 0)
sysfatal("mpsetminbits: n < 0");
b = mallocz(sizeof(mpint), 1);
setmalloctag(b, getcallerpc(&n));
if(b == nil)
sysfatal("mpnew: %r");
n = DIGITS(n);
if(n < mpmindigits)
n = mpmindigits;
b->p = (mpdigit*)mallocz(n*Dbytes, 1);
if(b->p == nil)
b = mallocz(sizeof(mpint) + n*Dbytes, 1);
if(b == nil)
sysfatal("mpnew: %r");
setmalloctag(b, getcallerpc(&n));
b->p = (mpdigit*)&b[1];
b->size = n;
b->sign = 1;
b->flags = MPnorm;
return b;
}
@ -83,16 +76,23 @@ mpbits(mpint *b, int m)
if(b->size >= n){
if(b->top >= n)
return;
memset(&b->p[b->top], 0, Dbytes*(n - b->top));
b->top = n;
return;
} else {
if(b->p == (mpdigit*)&b[1]){
b->p = (mpdigit*)mallocz(n*Dbytes, 0);
if(b->p == nil)
sysfatal("mpbits: %r");
memmove(b->p, &b[1], Dbytes*b->top);
memset(&b[1], 0, Dbytes*b->size);
} else {
b->p = (mpdigit*)realloc(b->p, n*Dbytes);
if(b->p == nil)
sysfatal("mpbits: %r");
}
b->size = n;
}
b->p = (mpdigit*)realloc(b->p, n*Dbytes);
if(b->p == nil)
sysfatal("mpbits: %r");
memset(&b->p[b->top], 0, Dbytes*(n - b->top));
b->size = n;
b->top = n;
b->flags &= ~MPnorm;
}
void
@ -102,22 +102,30 @@ mpfree(mpint *b)
return;
if(b->flags & MPstatic)
sysfatal("freeing mp constant");
memset(b->p, 0, b->size*Dbytes); // information hiding
free(b->p);
memset(b->p, 0, b->size*Dbytes);
if(b->p != (mpdigit*)&b[1])
free(b->p);
free(b);
}
void
mpint*
mpnorm(mpint *b)
{
int i;
if(b->flags & MPtimesafe){
assert(b->sign == 1);
b->flags &= ~MPnorm;
return b;
}
for(i = b->top-1; i >= 0; i--)
if(b->p[i] != 0)
break;
b->top = i+1;
if(b->top == 0)
b->sign = 1;
b->flags |= MPnorm;
return b;
}
mpint*
@ -126,8 +134,10 @@ mpcopy(mpint *old)
mpint *new;
new = mpnew(Dbits*old->size);
new->top = old->top;
setmalloctag(new, getcallerpc(&old));
new->sign = old->sign;
new->top = old->top;
new->flags = old->flags & ~MPstatic;
memmove(new->p, old->p, Dbytes*old->top);
return new;
}
@ -135,9 +145,14 @@ mpcopy(mpint *old)
void
mpassign(mpint *old, mpint *new)
{
if(new == nil || old == new)
return;
new->top = 0;
mpbits(new, Dbits*old->top);
new->sign = old->sign;
new->top = old->top;
new->flags &= ~MPnorm;
new->flags |= old->flags & ~MPstatic;
memmove(new->p, old->p, Dbytes*old->top);
}
@ -167,6 +182,7 @@ mplowbits0(mpint *n)
int k, bit, digit;
mpdigit d;
assert(n->flags & MPnorm);
if(n->top==0)
return 0;
k = 0;
@ -187,4 +203,3 @@ mplowbits0(mpint *n)
}
return k;
}

View File

@ -8,10 +8,14 @@ mpmagcmp(mpint *b1, mpint *b2)
{
int i;
i = b1->top - b2->top;
if(i)
return i;
i = b1->flags | b2->flags;
if(i & MPtimesafe)
return mpvectscmp(b1->p, b1->top, b2->p, b2->top);
if(i & MPnorm){
i = b1->top - b2->top;
if(i)
return i;
}
return mpveccmp(b1->p, b1->top, b2->p, b2->top);
}
@ -19,10 +23,8 @@ mpmagcmp(mpint *b1, mpint *b2)
int
mpcmp(mpint *b1, mpint *b2)
{
if(b1->sign != b2->sign)
return b1->sign - b2->sign;
if(b1->sign < 0)
return mpmagcmp(b2, b1);
else
return mpmagcmp(b1, b2);
int sign;
sign = (b1->sign - b2->sign) >> 1; // -1, 0, 1
return sign | (sign&1)-1 & mpmagcmp(b1, b2)*b1->sign;
}

View File

@ -13,10 +13,29 @@ mpdiv(mpint *dividend, mpint *divisor, mpint *quotient, mpint *remainder)
mpdigit qd, *up, *vp, *qp;
mpint *u, *v, *t;
assert(quotient != remainder);
assert(divisor->flags & MPnorm);
// divide bv zero
if(divisor->top == 0)
abort();
// division by one or small powers of two
if(divisor->top == 1 && (divisor->p[0] & divisor->p[0]-1) == 0){
vlong r = (vlong)dividend->sign * (dividend->p[0] & divisor->p[0]-1);
if(quotient != nil){
for(s = 0; ((divisor->p[0] >> s) & 1) == 0; s++)
;
mpright(dividend, s, quotient);
}
if(remainder != nil){
remainder->flags |= dividend->flags & MPtimesafe;
vtomp(r, remainder);
}
return;
}
assert((dividend->flags & MPtimesafe) == 0);
// quick check
if(mpmagcmp(dividend, divisor) < 0){
if(remainder != nil)
@ -95,12 +114,14 @@ mpdiv(mpint *dividend, mpint *divisor, mpint *quotient, mpint *remainder)
*up-- = 0;
}
if(qp != nil){
assert((quotient->flags & MPtimesafe) == 0);
mpnorm(quotient);
if(dividend->sign != divisor->sign)
quotient->sign = -1;
}
if(remainder != nil){
assert((remainder->flags & MPtimesafe) == 0);
mpright(u, s, remainder); // u is the remainder shifted
remainder->sign = dividend->sign;
}

View File

@ -13,6 +13,9 @@ mpeuclid(mpint *a, mpint *b, mpint *d, mpint *x, mpint *y)
{
mpint *tmp, *x0, *x1, *x2, *y0, *y1, *y2, *q, *r;
assert((a->flags&b->flags) & MPnorm);
assert(((a->flags|b->flags|d->flags|x->flags|y->flags) & MPtimesafe) == 0);
if(a->sign<0 || b->sign<0)
sysfatal("mpeuclid: negative arg");

View File

@ -22,6 +22,10 @@ mpexp(mpint *b, mpint *e, mpint *m, mpint *res)
mpdigit d, bit;
int i, j;
assert(m->flags & MPnorm);
assert((e->flags & MPtimesafe) == 0);
res->flags |= b->flags & MPtimesafe;
i = mpcmp(e,mpzero);
if(i==0){
mpassign(mpone, res);

View File

@ -5,7 +5,7 @@
// extended binary gcd
//
// For a anv b it solves, v = gcd(a,b) and finds x and y s.t.
// For a and b it solves, v = gcd(a,b) and finds x and y s.t.
// ax + by = v
//
// Handbook of Applied Cryptography, Menezes et al, 1997, pg 608.
@ -15,6 +15,9 @@ mpextendedgcd(mpint *a, mpint *b, mpint *v, mpint *x, mpint *y)
mpint *u, *A, *B, *C, *D;
int g;
assert((a->flags&b->flags) & MPnorm);
assert(((a->flags|b->flags|v->flags|x->flags|y->flags) & MPtimesafe) == 0);
if(a->sign < 0 || b->sign < 0){
mpassign(mpzero, v);
mpassign(mpzero, y);

View File

@ -102,6 +102,7 @@ to10(mpint *b, char *buf, int len)
return -1;
d = mpcopy(b);
mpnorm(d);
r = mpnew(0);
billion = uitomp(1000000000, nil);
out = buf+len;
@ -128,15 +129,20 @@ int
mpfmt(Fmt *fmt)
{
mpint *b;
char *p;
char *p, f;
b = va_arg(fmt->args, mpint*);
if(b == nil)
return fmtstrcpy(fmt, "*");
f = b->flags;
b->flags &= ~MPtimesafe;
p = mptoa(b, fmt->prec, nil, 0);
fmt->flags &= ~FmtPrec;
b->flags = f;
if(p == nil)
return fmtstrcpy(fmt, "*");
else{

View File

@ -15,8 +15,8 @@ mpleft(mpint *b, int shift, mpint *res)
return;
}
// a negative left shift is a right shift
if(shift < 0){
// a zero or negative left shift is a right shift
if(shift <= 0){
mpright(b, -shift, res);
return;
}
@ -46,7 +46,6 @@ mpleft(mpint *b, int shift, mpint *res)
for(i = 0; i < d; i++)
res->p[i] = 0;
// normalize
while(res->top > 0 && res->p[res->top-1] == 0)
res->top--;
res->flags |= b->flags & MPtimesafe;
mpnorm(res);
}

View File

@ -2,14 +2,100 @@
#include <mp.h>
#include "dat.h"
// remainder = b mod m
//
// knuth, vol 2, pp 398-400
void
mpmod(mpint *b, mpint *m, mpint *remainder)
mpmod(mpint *x, mpint *n, mpint *r)
{
mpdiv(b, m, nil, remainder);
if(remainder->sign < 0)
mpadd(m, remainder, remainder);
static int busy;
static mpint *p, *m, *c, *v;
mpdigit q[32], t[64], d;
int sign, k, s, qn, tn;
sign = x->sign;
assert(n->flags & MPnorm);
if(n->top < 2 || n->top > nelem(q) || (x->top-n->top) > nelem(q))
goto hard;
/*
* check if n = 2**k - c where c has few power of two factors
* above the lowest digit.
*/
for(k = n->top-1; k > 0; k--){
d = n->p[k] >> 1;
if((d+1 & d) != 0)
goto hard;
}
d = n->p[n->top-1];
for(s = 0; (d & (mpdigit)1<<Dbits-1) == 0; s++)
d <<= 1;
/* lo(x) = x[0:k-1], hi(x) = x[k:xn-1] */
k = n->top;
while(_tas(&busy))
;
if(p == nil || mpmagcmp(n, p) != 0){
if(m == nil){
m = mpnew(0);
c = mpnew(0);
p = mpnew(0);
}
mpassign(n, p);
mpleft(n, s, m);
mpleft(mpone, k*Dbits, c);
mpsub(c, m, c);
}
mpleft(x, s, r);
if(r->top <= k){
mpbits(r, (k+1)*Dbits);
r->top = k+1;
}
/* q = hi(r) */
qn = r->top - k;
memmove(q, r->p+k, qn*Dbytes);
/* r = lo(r) */
r->top = k;
do {
/* t = q*c */
tn = qn + c->top;
memset(t, 0, tn*Dbytes);
mpvecmul(q, qn, c->p, c->top, t);
/* q = hi(t) */
qn = tn - k;
if(qn <= 0) qn = 0;
else memmove(q, t+k, qn*Dbytes);
/* r += lo(t) */
if(tn > k)
tn = k;
mpvecadd(r->p, k, t, tn, r->p);
/* if(r >= m) r -= m */
mpvecsub(r->p, k+1, m->p, k, t), d = t[k];
for(tn = 0; tn < k; tn++)
r->p[tn] = (r->p[tn] & d) | (t[tn] & ~d);
} while(qn > 0);
busy = 0;
if(s != 0)
mpright(r, s, r);
else
mpnorm(r);
goto done;
hard:
mpdiv(x, n, nil, r);
done:
if(sign < 0)
mpmagsub(n, r, r);
}

View File

@ -0,0 +1,96 @@
#include <u.h>
#include <libc.h>
#include <mp.h>
/* operands need to have m->top+1 digits of space and satisfy 0 ≤ a ≤ m-1 */
static mpint*
modarg(mpint *a, mpint *m)
{
if(a->size <= m->top || a->sign < 0 || mpmagcmp(a, m) >= 0){
a = mpcopy(a);
mpmod(a, m, a);
mpbits(a, Dbits*(m->top+1));
a->top = m->top;
} else if(a->top < m->top){
memset(&a->p[a->top], 0, (m->top - a->top)*Dbytes);
}
return a;
}
void
mpmodadd(mpint *b1, mpint *b2, mpint *m, mpint *sum)
{
mpint *a, *b;
mpdigit d;
int i, j;
a = modarg(b1, m);
b = modarg(b2, m);
sum->flags |= (a->flags | b->flags) & MPtimesafe;
mpbits(sum, Dbits*2*(m->top+1));
mpvecadd(a->p, m->top, b->p, m->top, sum->p);
mpvecsub(sum->p, m->top+1, m->p, m->top, sum->p+m->top+1);
d = sum->p[2*m->top+1];
for(i = 0, j = m->top+1; i < m->top; i++, j++)
sum->p[i] = (sum->p[i] & d) | (sum->p[j] & ~d);
sum->top = m->top;
sum->sign = 1;
mpnorm(sum);
if(a != b1)
mpfree(a);
if(b != b2)
mpfree(b);
}
void
mpmodsub(mpint *b1, mpint *b2, mpint *m, mpint *diff)
{
mpint *a, *b;
mpdigit d;
int i, j;
a = modarg(b1, m);
b = modarg(b2, m);
diff->flags |= (a->flags | b->flags) & MPtimesafe;
mpbits(diff, Dbits*2*(m->top+1));
a->p[m->top] = 0;
mpvecsub(a->p, m->top+1, b->p, m->top, diff->p);
mpvecadd(diff->p, m->top, m->p, m->top, diff->p+m->top+1);
d = ~diff->p[m->top];
for(i = 0, j = m->top+1; i < m->top; i++, j++)
diff->p[i] = (diff->p[i] & d) | (diff->p[j] & ~d);
diff->top = m->top;
diff->sign = 1;
mpnorm(diff);
if(a != b1)
mpfree(a);
if(b != b2)
mpfree(b);
}
void
mpmodmul(mpint *b1, mpint *b2, mpint *m, mpint *prod)
{
mpint *a, *b;
a = modarg(b1, m);
b = modarg(b2, m);
mpmul(a, b, prod);
mpmod(prod, m, prod);
if(a != b1)
mpfree(a);
if(b != b2)
mpfree(b);
}

View File

@ -113,10 +113,6 @@ mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
a = b;
b = t;
}
if(blen == 0){
memset(p, 0, Dbytes*(alen+blen));
return;
}
if(alen >= KARATSUBAMIN && blen > 1){
// O(n^1.585)
@ -131,25 +127,49 @@ mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
}
}
void
mpvectsmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
{
int i;
mpdigit *t;
if(alen < blen){
i = alen;
alen = blen;
blen = i;
t = a;
a = b;
b = t;
}
if(blen == 0)
return;
for(i = 0; i < blen; i++)
mpvecdigmuladd(a, alen, b[i], &p[i]);
}
void
mpmul(mpint *b1, mpint *b2, mpint *prod)
{
mpint *oprod;
oprod = nil;
oprod = prod;
if(prod == b1 || prod == b2){
oprod = prod;
prod = mpnew(0);
prod->flags = oprod->flags;
}
prod->flags |= (b1->flags | b2->flags) & MPtimesafe;
prod->top = 0;
mpbits(prod, (b1->top+b2->top+1)*Dbits);
mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p);
if(prod->flags & MPtimesafe)
mpvectsmul(b1->p, b1->top, b2->p, b2->top, prod->p);
else
mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p);
prod->top = b1->top+b2->top+1;
prod->sign = b1->sign*b2->sign;
mpnorm(prod);
if(oprod != nil){
if(oprod != prod){
mpassign(prod, oprod);
mpfree(prod);
}

View File

@ -16,8 +16,10 @@ mpnrand(mpint *n, void (*gen)(uchar*, int), mpint *b)
mpleft(mpone, bits, m);
mpsub(m, mpone, m);
if(b == nil)
if(b == nil){
b = mpnew(bits);
setmalloctag(b, getcallerpc(&n));
}
/* m = m - (m % n) */
mpmod(m, n, b);

View File

@ -6,37 +6,32 @@
mpint*
mprand(int bits, void (*gen)(uchar*, int), mpint *b)
{
int n, m;
mpdigit mask;
int n, m;
uchar *p;
n = DIGITS(bits);
if(b == nil)
if(b == nil){
b = mpnew(bits);
else
setmalloctag(b, getcallerpc(&bits));
}else
mpbits(b, bits);
p = malloc(n*Dbytes);
if(p == nil)
return nil;
sysfatal("mprand: %r");
(*gen)(p, n*Dbytes);
betomp(p, n*Dbytes, b);
free(p);
// make sure we don't give too many bits
m = bits%Dbits;
n--;
if(m > 0){
mask = 1;
mask <<= m;
mask--;
b->p[n] &= mask;
}
if(m == 0)
return b;
for(; n >= 0; n--)
if(b->p[n] != 0)
break;
b->top = n+1;
b->sign = 1;
return b;
mask = 1;
mask <<= m;
mask--;
b->p[n-1] &= mask;
return mpnorm(b);
}

View File

@ -23,12 +23,16 @@ mpright(mpint *b, int shift, mpint *res)
if(res != b)
mpbits(res, b->top*Dbits - shift);
else if(shift == 0)
return;
d = shift/Dbits;
r = shift - d*Dbits;
l = Dbits - r;
// shift all the bits out == zero
if(d>=b->top){
res->sign = 1;
res->top = 0;
return;
}
@ -46,9 +50,8 @@ mpright(mpint *b, int shift, mpint *res)
}
res->p[i++] = last>>r;
}
while(i > 0 && res->p[i-1] == 0)
i--;
res->top = i;
if(i==0)
res->sign = 1;
res->flags |= b->flags & MPtimesafe;
mpnorm(res);
}

View File

@ -0,0 +1,42 @@
#include "os.h"
#include <mp.h>
#include "dat.h"
// res = s != 0 ? b1 : b2
void
mpsel(int s, mpint *b1, mpint *b2, mpint *res)
{
mpdigit d;
int n, m, i;
res->flags |= (b1->flags | b2->flags) & MPtimesafe;
if((res->flags & MPtimesafe) == 0){
mpassign(s ? b1 : b2, res);
return;
}
res->flags &= ~MPnorm;
n = b1->top;
m = b2->top;
mpbits(res, Dbits*(n >= m ? n : m));
res->top = n >= m ? n : m;
s = (-s^s|s)>>(sizeof(s)*8-1);
res->sign = (b1->sign & s) | (b2->sign & ~s);
d = -((mpdigit)s & 1);
i = 0;
while(i < n && i < m){
res->p[i] = (b1->p[i] & d) | (b2->p[i] & ~d);
i++;
}
while(i < n){
res->p[i] = b1->p[i] & d;
i++;
}
while(i < m){
res->p[i] = b2->p[i] & ~d;
i++;
}
}

View File

@ -11,12 +11,15 @@ mpmagsub(mpint *b1, mpint *b2, mpint *diff)
// get the sizes right
if(mpmagcmp(b1, b2) < 0){
assert(((b1->flags | b2->flags | diff->flags) & MPtimesafe) == 0);
sign = -1;
t = b1;
b1 = b2;
b2 = t;
} else
} else {
diff->flags |= (b1->flags | b2->flags) & MPtimesafe;
sign = 1;
}
n = b1->top;
m = b2->top;
if(m == 0){
@ -39,6 +42,7 @@ mpsub(mpint *b1, mpint *b2, mpint *diff)
int sign;
if(b1->sign != b2->sign){
assert(((b1->flags | b2->flags | diff->flags) & MPtimesafe) == 0);
sign = b1->sign;
mpmagadd(b1, b2, diff);
diff->sign = sign;

View File

@ -2,57 +2,31 @@
#include <mp.h>
#include "dat.h"
// convert an mpint into a big endian byte array (most significant byte first)
// convert an mpint into a big endian byte array (most significant byte first; left adjusted)
// return number of bytes converted
// if p == nil, allocate and result array
int
mptobe(mpint *b, uchar *p, uint n, uchar **pp)
{
int i, j, suppress;
mpdigit x;
uchar *e, *s, c;
int m;
m = (mpsignif(b)+7)/8;
if(m == 0)
m++;
if(p == nil){
n = (b->top+1)*Dbytes;
n = m;
p = malloc(n);
if(p == nil)
sysfatal("mptobe: %r");
setmalloctag(p, getcallerpc(&b));
} else {
if(n < m)
return -1;
if(n > m)
memset(p+m, 0, n-m);
}
if(p == nil)
return -1;
if(pp != nil)
*pp = p;
memset(p, 0, n);
// special case 0
if(b->top == 0){
if(n < 1)
return -1;
else
return 1;
}
s = p;
e = s+n;
suppress = 1;
for(i = b->top-1; i >= 0; i--){
x = b->p[i];
for(j = Dbits-8; j >= 0; j -= 8){
c = x>>j;
if(c == 0 && suppress)
continue;
if(p >= e)
return -1;
*p++ = c;
suppress = 0;
}
}
// guarantee at least one byte
if(s == p){
if(p >= e)
return -1;
*p++ = 0;
}
return p - s;
mptober(b, p, m);
return m;
}

View File

@ -0,0 +1,34 @@
#include "os.h"
#include <mp.h>
#include "dat.h"
void
mptober(mpint *b, uchar *p, int n)
{
int i, j, m;
mpdigit x;
memset(p, 0, n);
p += n;
m = b->top*Dbytes;
if(m < n)
n = m;
i = 0;
while(n >= Dbytes){
n -= Dbytes;
x = b->p[i++];
for(j = 0; j < Dbytes; j++){
*--p = x;
x >>= 8;
}
}
if(n > 0){
x = b->p[i];
for(j = 0; j < n; j++){
*--p = x;
x >>= 8;
}
}
}

View File

@ -10,17 +10,15 @@
mpint*
itomp(int i, mpint *b)
{
if(b == nil)
if(b == nil){
b = mpnew(0);
mpassign(mpzero, b);
if(i != 0)
b->top = 1;
if(i < 0){
b->sign = -1;
*b->p = -i;
} else
*b->p = i;
return b;
setmalloctag(b, getcallerpc(&i));
}
b->sign = (i >> (sizeof(i)*8 - 1)) | 1;
i *= b->sign;
*b->p = i;
b->top = 1;
return mpnorm(b);
}
int

View File

@ -3,52 +3,26 @@
#include "dat.h"
// convert an mpint into a little endian byte array (least significant byte first)
// return number of bytes converted
// if p == nil, allocate and result array
int
mptole(mpint *b, uchar *p, uint n, uchar **pp)
{
int i, j;
mpdigit x;
uchar *e, *s;
int m;
m = (mpsignif(b)+7)/8;
if(m == 0)
m++;
if(p == nil){
n = (b->top+1)*Dbytes;
n = m;
p = malloc(n);
}
if(p == nil)
sysfatal("mptole: %r");
setmalloctag(p, getcallerpc(&b));
} else if(n < m)
return -1;
if(pp != nil)
*pp = p;
if(p == nil)
return -1;
memset(p, 0, n);
// special case 0
if(b->top == 0){
if(n < 1)
return -1;
else
return 0;
}
s = p;
e = s+n;
for(i = 0; i < b->top-1; i++){
x = b->p[i];
for(j = 0; j < Dbytes; j++){
if(p >= e)
return -1;
*p++ = x;
x >>= 8;
}
}
x = b->p[i];
while(x > 0){
if(p >= e)
return -1;
*p++ = x;
x >>= 8;
}
return p - s;
mptolel(b, p, n);
return m;
}

View File

@ -0,0 +1,33 @@
#include "os.h"
#include <mp.h>
#include "dat.h"
void
mptolel(mpint *b, uchar *p, int n)
{
int i, j, m;
mpdigit x;
memset(p, 0, n);
m = b->top*Dbytes;
if(m < n)
n = m;
i = 0;
while(n >= Dbytes){
n -= Dbytes;
x = b->p[i++];
for(j = 0; j < Dbytes; j++){
*p++ = x;
x >>= 8;
}
}
if(n > 0){
x = b->p[i];
for(j = 0; j < n; j++){
*p++ = x;
x >>= 8;
}
}
}

View File

@ -10,13 +10,14 @@
mpint*
uitomp(uint i, mpint *b)
{
if(b == nil)
if(b == nil){
b = mpnew(0);
mpassign(mpzero, b);
if(i != 0)
b->top = 1;
setmalloctag(b, getcallerpc(&i));
}
*b->p = i;
return b;
b->top = 1;
b->sign = 1;
return mpnorm(b);
}
uint

View File

@ -13,19 +13,18 @@ uvtomp(uvlong v, mpint *b)
{
int s;
if(b == nil)
if(b == nil){
b = mpnew(VLDIGITS*sizeof(mpdigit));
else
setmalloctag(b, getcallerpc(&v));
}else
mpbits(b, VLDIGITS*sizeof(mpdigit));
mpassign(mpzero, b);
if(v == 0)
return b;
for(s = 0; s < VLDIGITS && v != 0; s++){
b->sign = 1;
for(s = 0; s < VLDIGITS; s++){
b->p[s] = v;
v >>= sizeof(mpdigit)*8;
}
b->top = s;
return b;
return mpnorm(b);
}
uvlong
@ -37,7 +36,6 @@ mptouv(mpint *b)
if(b->top == 0)
return 0LL;
mpnorm(b);
if(b->top > VLDIGITS)
return MAXVLONG;

View File

@ -14,24 +14,19 @@ vtomp(vlong v, mpint *b)
int s;
uvlong uv;
if(b == nil)
if(b == nil){
b = mpnew(VLDIGITS*sizeof(mpdigit));
else
setmalloctag(b, getcallerpc(&v));
}else
mpbits(b, VLDIGITS*sizeof(mpdigit));
mpassign(mpzero, b);
if(v == 0)
return b;
if(v < 0){
b->sign = -1;
uv = -v;
} else
uv = v;
for(s = 0; s < VLDIGITS && uv != 0; s++){
b->sign = (v >> (sizeof(v)*8 - 1)) | 1;
uv = v * b->sign;
for(s = 0; s < VLDIGITS; s++){
b->p[s] = uv;
uv >>= sizeof(mpdigit)*8;
}
b->top = s;
return b;
return mpnorm(b);
}
vlong
@ -43,7 +38,6 @@ mptov(mpint *b)
if(b->top == 0)
return 0LL;
mpnorm(b);
if(b->top > VLDIGITS){
if(b->sign > 0)
return (vlong)MAXVLONG;

View File

@ -0,0 +1,34 @@
#include "os.h"
#include <mp.h>
#include "dat.h"
int
mpvectscmp(mpdigit *a, int alen, mpdigit *b, int blen)
{
mpdigit x, y, z, v;
int m, p;
if(alen > blen){
v = 0;
while(alen > blen)
v |= a[--alen];
m = p = (-v^v|v)>>Dbits-1;
} else if(blen > alen){
v = 0;
while(blen > alen)
v |= b[--blen];
m = (-v^v|v)>>Dbits-1;
p = m^1;
} else
m = p = 0;
while(alen-- > 0){
x = a[alen];
y = b[alen];
z = x - y;
x = ~x;
v = ((-z^z|z)>>Dbits-1) & ~m;
p = ((~(x&y|x&z|y&z)>>Dbits-1) & v) | (p & ~v);
m |= v;
}
return (p-m) | m;
}

View File

@ -50,7 +50,6 @@ from16(char *a, mpint *b)
int i;
mpdigit x;
b->top = 0;
for(p = a; *p; p++)
if(tab.t16[*(uchar*)p] == INVAL)
break;
@ -157,8 +156,10 @@ strtomp(char *a, char **pp, int base, mpint *b)
int sign;
char *e;
if(b == nil)
if(b == nil){
b = mpnew(0);
setmalloctag(b, getcallerpc(&a));
}
if(tab.inited == 0)
init();
@ -196,10 +197,9 @@ strtomp(char *a, char **pp, int base, mpint *b)
if(e == a)
return nil;
mpnorm(b);
b->sign = sign;
if(pp != nil)
*pp = e;
return b;
b->sign = sign;
return mpnorm(b);
}