From 38e1e5272fc9c66a00d702246813135452819ffe Mon Sep 17 00:00:00 2001 From: cinap_lenrek Date: Sat, 21 Nov 2015 09:39:59 +0100 Subject: [PATCH] libmp: initial attempt at constant time code, faster reductions for special primes (for ecc) introduce MPtimesafe flag to request time invariant computation disables normalization so significant digits are not leaked. --- sys/include/mp.h | 21 +++++- sys/man/2/mp | 80 ++++++++++++++++++++-- sys/src/libmp/port/betomp.c | 18 ++--- sys/src/libmp/port/letomp.c | 7 +- sys/src/libmp/port/mkfile | 5 ++ sys/src/libmp/port/mpadd.c | 3 + sys/src/libmp/port/mpaux.c | 75 ++++++++++++--------- sys/src/libmp/port/mpcmp.c | 22 ++++--- sys/src/libmp/port/mpdiv.c | 21 ++++++ sys/src/libmp/port/mpeuclid.c | 3 + sys/src/libmp/port/mpexp.c | 4 ++ sys/src/libmp/port/mpextendedgcd.c | 5 +- sys/src/libmp/port/mpfmt.c | 10 ++- sys/src/libmp/port/mpleft.c | 9 ++- sys/src/libmp/port/mpmod.c | 102 ++++++++++++++++++++++++++--- sys/src/libmp/port/mpmodop.c | 96 +++++++++++++++++++++++++++ sys/src/libmp/port/mpmul.c | 36 +++++++--- sys/src/libmp/port/mpnrand.c | 4 +- sys/src/libmp/port/mprand.c | 29 ++++---- sys/src/libmp/port/mpright.c | 11 ++-- sys/src/libmp/port/mpsel.c | 42 ++++++++++++ sys/src/libmp/port/mpsub.c | 6 +- sys/src/libmp/port/mptobe.c | 56 +++++----------- sys/src/libmp/port/mptober.c | 34 ++++++++++ sys/src/libmp/port/mptoi.c | 18 +++-- sys/src/libmp/port/mptole.c | 50 ++++---------- sys/src/libmp/port/mptolel.c | 33 ++++++++++ sys/src/libmp/port/mptoui.c | 11 ++-- sys/src/libmp/port/mptouv.c | 14 ++-- sys/src/libmp/port/mptov.c | 20 ++---- sys/src/libmp/port/mpvectscmp.c | 34 ++++++++++ sys/src/libmp/port/strtomp.c | 10 +-- 32 files changed, 660 insertions(+), 229 deletions(-) create mode 100644 sys/src/libmp/port/mpmodop.c create mode 100644 sys/src/libmp/port/mpsel.c create mode 100644 sys/src/libmp/port/mptober.c create mode 100644 sys/src/libmp/port/mptolel.c create mode 100644 sys/src/libmp/port/mpvectscmp.c diff --git a/sys/include/mp.h b/sys/include/mp.h index 14061adc7..b17df619c 100644 --- a/sys/include/mp.h +++ b/sys/include/mp.h @@ -22,7 +22,10 @@ struct mpint enum { - MPstatic= 0x01, + MPstatic= 0x01, /* static constant */ + MPnorm= 0x02, /* normalization status */ + MPtimesafe= 0x04, /* request time invariant computation */ + Dbytes= sizeof(mpdigit), /* bytes per digit */ Dbits= Dbytes*8 /* bits per digit */ }; @@ -32,7 +35,7 @@ void mpsetminbits(int n); /* newly created mpint's get at least n bits */ mpint* mpnew(int n); /* create a new mpint with at least n bits */ void mpfree(mpint *b); void mpbits(mpint *b, int n); /* ensure that b has at least n bits */ -void mpnorm(mpint *b); /* dump leading zeros */ +mpint* mpnorm(mpint *b); /* dump leading zeros */ mpint* mpcopy(mpint *b); void mpassign(mpint *old, mpint *new); @@ -47,8 +50,10 @@ int mpfmt(Fmt*); char* mptoa(mpint*, int, char*, int); mpint* letomp(uchar*, uint, mpint*); /* byte array, little-endian */ int mptole(mpint*, uchar*, uint, uchar**); +void mptolel(mpint *b, uchar *p, int n); mpint* betomp(uchar*, uint, mpint*); /* byte array, big-endian */ int mptobe(mpint*, uchar*, uint, uchar**); +void mptober(mpint *b, uchar *p, int n); uint mptoui(mpint*); /* unsigned int */ mpint* uitomp(uint, mpint*); int mptoi(mpint*); /* int */ @@ -71,12 +76,20 @@ void mpmul(mpint *b1, mpint *b2, mpint *prod); /* prod = b1*b2 */ void mpexp(mpint *b, mpint *e, mpint *m, mpint *res); /* res = b**e mod m */ void mpmod(mpint *b, mpint *m, mpint *remainder); /* remainder = b mod m */ +/* modular arithmetic, time invariant when 0≤b1≤m-1 and 0≤b2≤m-1 */ +void mpmodadd(mpint *b1, mpint *b2, mpint *m, mpint *sum); /* sum = b1+b2 % m */ +void mpmodsub(mpint *b1, mpint *b2, mpint *m, mpint *diff); /* diff = b1-b2 % m */ +void mpmodmul(mpint *b1, mpint *b2, mpint *m, mpint *prod); /* prod = b1*b2 % m */ + /* quotient = dividend/divisor, remainder = dividend % divisor */ void mpdiv(mpint *dividend, mpint *divisor, mpint *quotient, mpint *remainder); /* return neg, 0, pos as b1-b2 is neg, 0, pos */ int mpcmp(mpint *b1, mpint *b2); +/* res = s != 0 ? b1 : b2 */ +void mpsel(int s, mpint *b1, mpint *b2, mpint *res); + /* extended gcd return d, x, and y, s.t. d = gcd(a,b) and ax+by = d */ void mpextendedgcd(mpint *a, mpint *b, mpint *d, mpint *x, mpint *y); @@ -106,12 +119,14 @@ void mpvecdigmuladd(mpdigit *b, int n, mpdigit m, mpdigit *p); /* prereq: p has room for n+1 digits */ int mpvecdigmulsub(mpdigit *b, int n, mpdigit m, mpdigit *p); -/* p[0:alen*blen-1] = a[0:alen-1] * b[0:blen-1] */ +/* p[0:alen+blen-1] = a[0:alen-1] * b[0:blen-1] */ /* prereq: alen >= blen, p has room for m*n digits */ void mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p); +void mpvectsmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p); /* sign of a - b or zero if the same */ int mpveccmp(mpdigit *a, int alen, mpdigit *b, int blen); +int mpvectscmp(mpdigit *a, int alen, mpdigit *b, int blen); /* divide the 2 digit dividend by the one digit divisor and stick in quotient */ /* we assume that the result is one digit - overflow is all 1's */ diff --git a/sys/man/2/mp b/sys/man/2/mp index 5be4246c8..c562ccab4 100644 --- a/sys/man/2/mp +++ b/sys/man/2/mp @@ -1,6 +1,6 @@ .TH MP 2 .SH NAME -mpsetminbits, mpnew, mpfree, mpbits, mpnorm, mpcopy, mpassign, mprand, mpnrand, strtomp, mpfmt,mptoa, betomp, mptobe, letomp, mptole, mptoui, uitomp, mptoi, itomp, uvtomp, mptouv, vtomp, mptov, mpdigdiv, mpadd, mpsub, mpleft, mpright, mpmul, mpexp, mpmod, mpdiv, mpcmp, mpextendedgcd, mpinvert, mpsignif, mplowbits0, mpvecdigmuladd, mpvecdigmulsub, mpvecadd, mpvecsub, mpveccmp, mpvecmul, mpmagcmp, mpmagadd, mpmagsub, crtpre, crtin, crtout, crtprefree, crtresfree \- extended precision arithmetic +mpsetminbits, mpnew, mpfree, mpbits, mpnorm, mpcopy, mpassign, mprand, mpnrand, strtomp, mpfmt,mptoa, betomp, mptobe, mptober, letomp, mptole, mptolel, mptoui, uitomp, mptoi, itomp, uvtomp, mptouv, vtomp, mptov, mpdigdiv, mpadd, mpsub, mpleft, mpright, mpmul, mpexp, mpmod, mpmodadd, mpmodsub, mpmodmul, mpdiv, mpcmp, mpsel, mpextendedgcd, mpinvert, mpsignif, mplowbits0, mpvecdigmuladd, mpvecdigmulsub, mpvecadd, mpvecsub, mpveccmp, mpvecmul, mpmagcmp, mpmagadd, mpmagsub, crtpre, crtin, crtout, crtprefree, crtresfree \- extended precision arithmetic .SH SYNOPSIS .B #include .br @@ -22,7 +22,7 @@ void mpsetminbits(int n) void mpbits(mpint *b, int n) .PP .B -void mpnorm(mpint *b) +mpint* mpnorm(mpint *b) .PP .B mpint* mpcopy(mpint *b) @@ -52,12 +52,18 @@ mpint* betomp(uchar *buf, uint blen, mpint *b) int mptobe(mpint *b, uchar *buf, uint blen, uchar **bufp) .PP .B +void mptober(mpint *b, uchar *buf, int blen) +.PP +.B mpint* letomp(uchar *buf, uint blen, mpint *b) .PP .B int mptole(mpint *b, uchar *buf, uint blen, uchar **bufp) .PP .B +void mptolel(mpint *b, uchar *buf, int blen) +.PP +.B uint mptoui(mpint*) .PP .B @@ -115,12 +121,24 @@ void mpdiv(mpint *dividend, mpint *divisor, mpint *quotient, mpint *remainder) .PP .B +void mpmodadd(mpint *b1, mpint *b2, mpint *m, mpint *sum) +.PP +.B +void mpmodsub(mpint *b1, mpint *b2, mpint *m, mpint *diff) +.PP +.B +void mpmodmul(mpint *b1, mpint *b2, mpint *m, mpint *prod) +.PP +.B int mpcmp(mpint *b1, mpint *b2) .PP .B int mpmagcmp(mpint *b1, mpint *b2) .PP .B +void mpsel(int s, mpint *b1, mpint *b2, mpint *res) +.PP +.B void mpextendedgcd(mpint *a, mpint *b, mpint *d, mpint *x, .br .B @@ -383,6 +401,24 @@ deposited in the location pointed to by Sign is ignored in these conversions, i.e., the byte array version is always positive. .PP +.I Mptober +and +.I mptolel +fill +.I blen +lower bytes of an +.I mpint +into a fixed length byte array. +.I Mptober +fills the bytes right adjusted in big endian order so that the least +significant byte is at +.I buf[blen-1] +while +.I mptolel +fills in little endian order; left adjusted; so that the least +significat byte is filled into +.IR buf[0] . +.PP .IR Betomp , and .I letomp @@ -486,6 +522,31 @@ is less than, equal to, or greater than the same as .I mpcmp but ignores the sign and just compares magnitudes. +.TP +.I mpsel +assigns +.I b1 +to +.I res +when +.I s +is not zero, otherwise +.I b2 +is assigned to +.IR res . +.PD +.PP +Modular arithmetic: +.TF mpmodmul_ +.TP +.I mpmodadd +.BR "sum = b1+b2 mod m" . +.TP +.I mpmodsub +.BR "diff = b1-b2 mod m" . +.TP +.I mpmodmul +.BR "prod = b1*b2 mod m" . .PD .PP .I Mpextendedgcd @@ -564,8 +625,8 @@ We assume p has room for n+1 digits. It returns +1 is the result is positive an -1 if negative. .TP .I mpvecmul -.BR "p[0:alen*blen] = a[0:alen-1] * b[0:blen-1]" . -We assume that p has room for alen*blen+1 digits. +.BR "p[0:alen+blen] = a[0:alen-1] * b[0:blen-1]" . +We assume that p has room for alen+blen+1 digits. .TP .I mpveccmp This returns -1, 0, or +1 as a - b is negative, 0, or positive. @@ -576,6 +637,17 @@ This returns -1, 0, or +1 as a - b is negative, 0, or positive. and .I mpzero are the constants 2, 1 and 0. These cannot be freed. +.SS "Time invariant computation" +.PP +In the field of cryptography, it is sometimes neccesary to implement +algorithms such that the runtime of the algorithm is not depdenent on +the input data. This library provides partial support for time +invariant computation with the +.I MPtimesafe +flag that can be set on input or destination operands to request timing +safe operation. The result of a timing safe operation will also have the +.I MPtimesafe +flag set and is not normalized. .SS "Chinese remainder theorem .PP When computing in a non-prime modulus, diff --git a/sys/src/libmp/port/betomp.c b/sys/src/libmp/port/betomp.c index 9197f3a14..0830704ef 100644 --- a/sys/src/libmp/port/betomp.c +++ b/sys/src/libmp/port/betomp.c @@ -13,19 +13,12 @@ betomp(uchar *p, uint n, mpint *b) b = mpnew(0); setmalloctag(b, getcallerpc(&p)); } - - // dump leading zeros - while(*p == 0 && n > 1){ - p++; - n--; - } - - // get the space mpbits(b, n*8); - b->top = DIGITS(n*8); - m = b->top-1; - // first digit might not be Dbytes long + m = DIGITS(n*8); + b->top = m--; + b->sign = 1; + s = ((n-1)*8)%Dbits; x = 0; for(; n > 0; n--){ @@ -37,6 +30,5 @@ betomp(uchar *p, uint n, mpint *b) x = 0; } } - - return b; + return mpnorm(b); } diff --git a/sys/src/libmp/port/letomp.c b/sys/src/libmp/port/letomp.c index e23fed21e..d5cca241b 100644 --- a/sys/src/libmp/port/letomp.c +++ b/sys/src/libmp/port/letomp.c @@ -9,8 +9,10 @@ letomp(uchar *s, uint n, mpint *b) int i=0, m = 0; mpdigit x=0; - if(b == nil) + if(b == nil){ b = mpnew(0); + setmalloctag(b, getcallerpc(&s)); + } mpbits(b, 8*n); for(; n > 0; n--){ x |= ((mpdigit)(*s++)) << i; @@ -24,5 +26,6 @@ letomp(uchar *s, uint n, mpint *b) if(i > 0) b->p[m++] = x; b->top = m; - return b; + b->sign = 1; + return mpnorm(b); } diff --git a/sys/src/libmp/port/mkfile b/sys/src/libmp/port/mkfile index 76fa25dd7..b0bdf67d5 100644 --- a/sys/src/libmp/port/mkfile +++ b/sys/src/libmp/port/mkfile @@ -6,12 +6,15 @@ FILES=\ mpfmt\ strtomp\ mptobe\ + mptober\ mptole\ + mptolel\ betomp\ letomp\ mpadd\ mpsub\ mpcmp\ + mpsel\ mpfactorial\ mpmul\ mpleft\ @@ -20,10 +23,12 @@ FILES=\ mpvecsub\ mpvecdigmuladd\ mpveccmp\ + mpvectscmp\ mpdigdiv\ mpdiv\ mpexp\ mpmod\ + mpmodop\ mpextendedgcd\ mpinvert\ mprand\ diff --git a/sys/src/libmp/port/mpadd.c b/sys/src/libmp/port/mpadd.c index 6022a64ef..9a1ccde66 100644 --- a/sys/src/libmp/port/mpadd.c +++ b/sys/src/libmp/port/mpadd.c @@ -9,6 +9,8 @@ mpmagadd(mpint *b1, mpint *b2, mpint *sum) int m, n; mpint *t; + sum->flags |= (b1->flags | b2->flags) & MPtimesafe; + // get the sizes right if(b2->top > b1->top){ t = b1; @@ -41,6 +43,7 @@ mpadd(mpint *b1, mpint *b2, mpint *sum) int sign; if(b1->sign != b2->sign){ + assert(((b1->flags | b2->flags | sum->flags) & MPtimesafe) == 0); if(b1->sign < 0) mpmagsub(b2, b1, sum); else diff --git a/sys/src/libmp/port/mpaux.c b/sys/src/libmp/port/mpaux.c index 66f1524f0..eb70a9364 100644 --- a/sys/src/libmp/port/mpaux.c +++ b/sys/src/libmp/port/mpaux.c @@ -5,33 +5,27 @@ static mpdigit _mptwodata[1] = { 2 }; static mpint _mptwo = { - 1, - 1, - 1, + 1, 1, 1, _mptwodata, - MPstatic + MPstatic|MPnorm }; mpint *mptwo = &_mptwo; static mpdigit _mponedata[1] = { 1 }; static mpint _mpone = { - 1, - 1, - 1, + 1, 1, 1, _mponedata, - MPstatic + MPstatic|MPnorm }; mpint *mpone = &_mpone; static mpdigit _mpzerodata[1] = { 0 }; static mpint _mpzero = { - 1, - 1, - 0, + 1, 1, 0, _mpzerodata, - MPstatic + MPstatic|MPnorm }; mpint *mpzero = &_mpzero; @@ -57,18 +51,17 @@ mpnew(int n) if(n < 0) sysfatal("mpsetminbits: n < 0"); - b = mallocz(sizeof(mpint), 1); - setmalloctag(b, getcallerpc(&n)); - if(b == nil) - sysfatal("mpnew: %r"); n = DIGITS(n); if(n < mpmindigits) n = mpmindigits; - b->p = (mpdigit*)mallocz(n*Dbytes, 1); - if(b->p == nil) + b = mallocz(sizeof(mpint) + n*Dbytes, 1); + if(b == nil) sysfatal("mpnew: %r"); + setmalloctag(b, getcallerpc(&n)); + b->p = (mpdigit*)&b[1]; b->size = n; b->sign = 1; + b->flags = MPnorm; return b; } @@ -83,16 +76,23 @@ mpbits(mpint *b, int m) if(b->size >= n){ if(b->top >= n) return; - memset(&b->p[b->top], 0, Dbytes*(n - b->top)); - b->top = n; - return; + } else { + if(b->p == (mpdigit*)&b[1]){ + b->p = (mpdigit*)mallocz(n*Dbytes, 0); + if(b->p == nil) + sysfatal("mpbits: %r"); + memmove(b->p, &b[1], Dbytes*b->top); + memset(&b[1], 0, Dbytes*b->size); + } else { + b->p = (mpdigit*)realloc(b->p, n*Dbytes); + if(b->p == nil) + sysfatal("mpbits: %r"); + } + b->size = n; } - b->p = (mpdigit*)realloc(b->p, n*Dbytes); - if(b->p == nil) - sysfatal("mpbits: %r"); memset(&b->p[b->top], 0, Dbytes*(n - b->top)); - b->size = n; b->top = n; + b->flags &= ~MPnorm; } void @@ -102,22 +102,30 @@ mpfree(mpint *b) return; if(b->flags & MPstatic) sysfatal("freeing mp constant"); - memset(b->p, 0, b->size*Dbytes); // information hiding - free(b->p); + memset(b->p, 0, b->size*Dbytes); + if(b->p != (mpdigit*)&b[1]) + free(b->p); free(b); } -void +mpint* mpnorm(mpint *b) { int i; + if(b->flags & MPtimesafe){ + assert(b->sign == 1); + b->flags &= ~MPnorm; + return b; + } for(i = b->top-1; i >= 0; i--) if(b->p[i] != 0) break; b->top = i+1; if(b->top == 0) b->sign = 1; + b->flags |= MPnorm; + return b; } mpint* @@ -126,8 +134,10 @@ mpcopy(mpint *old) mpint *new; new = mpnew(Dbits*old->size); - new->top = old->top; + setmalloctag(new, getcallerpc(&old)); new->sign = old->sign; + new->top = old->top; + new->flags = old->flags & ~MPstatic; memmove(new->p, old->p, Dbytes*old->top); return new; } @@ -135,9 +145,14 @@ mpcopy(mpint *old) void mpassign(mpint *old, mpint *new) { + if(new == nil || old == new) + return; + new->top = 0; mpbits(new, Dbits*old->top); new->sign = old->sign; new->top = old->top; + new->flags &= ~MPnorm; + new->flags |= old->flags & ~MPstatic; memmove(new->p, old->p, Dbytes*old->top); } @@ -167,6 +182,7 @@ mplowbits0(mpint *n) int k, bit, digit; mpdigit d; + assert(n->flags & MPnorm); if(n->top==0) return 0; k = 0; @@ -187,4 +203,3 @@ mplowbits0(mpint *n) } return k; } - diff --git a/sys/src/libmp/port/mpcmp.c b/sys/src/libmp/port/mpcmp.c index a2e3cf724..7ab5a16b6 100644 --- a/sys/src/libmp/port/mpcmp.c +++ b/sys/src/libmp/port/mpcmp.c @@ -8,10 +8,14 @@ mpmagcmp(mpint *b1, mpint *b2) { int i; - i = b1->top - b2->top; - if(i) - return i; - + i = b1->flags | b2->flags; + if(i & MPtimesafe) + return mpvectscmp(b1->p, b1->top, b2->p, b2->top); + if(i & MPnorm){ + i = b1->top - b2->top; + if(i) + return i; + } return mpveccmp(b1->p, b1->top, b2->p, b2->top); } @@ -19,10 +23,8 @@ mpmagcmp(mpint *b1, mpint *b2) int mpcmp(mpint *b1, mpint *b2) { - if(b1->sign != b2->sign) - return b1->sign - b2->sign; - if(b1->sign < 0) - return mpmagcmp(b2, b1); - else - return mpmagcmp(b1, b2); + int sign; + + sign = (b1->sign - b2->sign) >> 1; // -1, 0, 1 + return sign | (sign&1)-1 & mpmagcmp(b1, b2)*b1->sign; } diff --git a/sys/src/libmp/port/mpdiv.c b/sys/src/libmp/port/mpdiv.c index 92aee03f4..54b943862 100644 --- a/sys/src/libmp/port/mpdiv.c +++ b/sys/src/libmp/port/mpdiv.c @@ -13,10 +13,29 @@ mpdiv(mpint *dividend, mpint *divisor, mpint *quotient, mpint *remainder) mpdigit qd, *up, *vp, *qp; mpint *u, *v, *t; + assert(quotient != remainder); + assert(divisor->flags & MPnorm); + // divide bv zero if(divisor->top == 0) abort(); + // division by one or small powers of two + if(divisor->top == 1 && (divisor->p[0] & divisor->p[0]-1) == 0){ + vlong r = (vlong)dividend->sign * (dividend->p[0] & divisor->p[0]-1); + if(quotient != nil){ + for(s = 0; ((divisor->p[0] >> s) & 1) == 0; s++) + ; + mpright(dividend, s, quotient); + } + if(remainder != nil){ + remainder->flags |= dividend->flags & MPtimesafe; + vtomp(r, remainder); + } + return; + } + assert((dividend->flags & MPtimesafe) == 0); + // quick check if(mpmagcmp(dividend, divisor) < 0){ if(remainder != nil) @@ -95,12 +114,14 @@ mpdiv(mpint *dividend, mpint *divisor, mpint *quotient, mpint *remainder) *up-- = 0; } if(qp != nil){ + assert((quotient->flags & MPtimesafe) == 0); mpnorm(quotient); if(dividend->sign != divisor->sign) quotient->sign = -1; } if(remainder != nil){ + assert((remainder->flags & MPtimesafe) == 0); mpright(u, s, remainder); // u is the remainder shifted remainder->sign = dividend->sign; } diff --git a/sys/src/libmp/port/mpeuclid.c b/sys/src/libmp/port/mpeuclid.c index 80b5983bf..586b9cc22 100644 --- a/sys/src/libmp/port/mpeuclid.c +++ b/sys/src/libmp/port/mpeuclid.c @@ -13,6 +13,9 @@ mpeuclid(mpint *a, mpint *b, mpint *d, mpint *x, mpint *y) { mpint *tmp, *x0, *x1, *x2, *y0, *y1, *y2, *q, *r; + assert((a->flags&b->flags) & MPnorm); + assert(((a->flags|b->flags|d->flags|x->flags|y->flags) & MPtimesafe) == 0); + if(a->sign<0 || b->sign<0) sysfatal("mpeuclid: negative arg"); diff --git a/sys/src/libmp/port/mpexp.c b/sys/src/libmp/port/mpexp.c index 9ec067cb9..1ebabba93 100644 --- a/sys/src/libmp/port/mpexp.c +++ b/sys/src/libmp/port/mpexp.c @@ -22,6 +22,10 @@ mpexp(mpint *b, mpint *e, mpint *m, mpint *res) mpdigit d, bit; int i, j; + assert(m->flags & MPnorm); + assert((e->flags & MPtimesafe) == 0); + res->flags |= b->flags & MPtimesafe; + i = mpcmp(e,mpzero); if(i==0){ mpassign(mpone, res); diff --git a/sys/src/libmp/port/mpextendedgcd.c b/sys/src/libmp/port/mpextendedgcd.c index 413a05c2a..72e49bce1 100644 --- a/sys/src/libmp/port/mpextendedgcd.c +++ b/sys/src/libmp/port/mpextendedgcd.c @@ -5,7 +5,7 @@ // extended binary gcd // -// For a anv b it solves, v = gcd(a,b) and finds x and y s.t. +// For a and b it solves, v = gcd(a,b) and finds x and y s.t. // ax + by = v // // Handbook of Applied Cryptography, Menezes et al, 1997, pg 608. @@ -15,6 +15,9 @@ mpextendedgcd(mpint *a, mpint *b, mpint *v, mpint *x, mpint *y) mpint *u, *A, *B, *C, *D; int g; + assert((a->flags&b->flags) & MPnorm); + assert(((a->flags|b->flags|v->flags|x->flags|y->flags) & MPtimesafe) == 0); + if(a->sign < 0 || b->sign < 0){ mpassign(mpzero, v); mpassign(mpzero, y); diff --git a/sys/src/libmp/port/mpfmt.c b/sys/src/libmp/port/mpfmt.c index f7c42a7bc..676b64be0 100644 --- a/sys/src/libmp/port/mpfmt.c +++ b/sys/src/libmp/port/mpfmt.c @@ -102,6 +102,7 @@ to10(mpint *b, char *buf, int len) return -1; d = mpcopy(b); + mpnorm(d); r = mpnew(0); billion = uitomp(1000000000, nil); out = buf+len; @@ -128,15 +129,20 @@ int mpfmt(Fmt *fmt) { mpint *b; - char *p; + char *p, f; b = va_arg(fmt->args, mpint*); if(b == nil) return fmtstrcpy(fmt, "*"); - + + f = b->flags; + b->flags &= ~MPtimesafe; + p = mptoa(b, fmt->prec, nil, 0); fmt->flags &= ~FmtPrec; + b->flags = f; + if(p == nil) return fmtstrcpy(fmt, "*"); else{ diff --git a/sys/src/libmp/port/mpleft.c b/sys/src/libmp/port/mpleft.c index cdcdff740..38929b82e 100644 --- a/sys/src/libmp/port/mpleft.c +++ b/sys/src/libmp/port/mpleft.c @@ -15,8 +15,8 @@ mpleft(mpint *b, int shift, mpint *res) return; } - // a negative left shift is a right shift - if(shift < 0){ + // a zero or negative left shift is a right shift + if(shift <= 0){ mpright(b, -shift, res); return; } @@ -46,7 +46,6 @@ mpleft(mpint *b, int shift, mpint *res) for(i = 0; i < d; i++) res->p[i] = 0; - // normalize - while(res->top > 0 && res->p[res->top-1] == 0) - res->top--; + res->flags |= b->flags & MPtimesafe; + mpnorm(res); } diff --git a/sys/src/libmp/port/mpmod.c b/sys/src/libmp/port/mpmod.c index 91bebfa27..c053f5b7f 100644 --- a/sys/src/libmp/port/mpmod.c +++ b/sys/src/libmp/port/mpmod.c @@ -2,14 +2,100 @@ #include #include "dat.h" -// remainder = b mod m -// -// knuth, vol 2, pp 398-400 - void -mpmod(mpint *b, mpint *m, mpint *remainder) +mpmod(mpint *x, mpint *n, mpint *r) { - mpdiv(b, m, nil, remainder); - if(remainder->sign < 0) - mpadd(m, remainder, remainder); + static int busy; + static mpint *p, *m, *c, *v; + mpdigit q[32], t[64], d; + int sign, k, s, qn, tn; + + sign = x->sign; + + assert(n->flags & MPnorm); + if(n->top < 2 || n->top > nelem(q) || (x->top-n->top) > nelem(q)) + goto hard; + + /* + * check if n = 2**k - c where c has few power of two factors + * above the lowest digit. + */ + for(k = n->top-1; k > 0; k--){ + d = n->p[k] >> 1; + if((d+1 & d) != 0) + goto hard; + } + + d = n->p[n->top-1]; + for(s = 0; (d & (mpdigit)1<top; + + while(_tas(&busy)) + ; + + if(p == nil || mpmagcmp(n, p) != 0){ + if(m == nil){ + m = mpnew(0); + c = mpnew(0); + p = mpnew(0); + } + mpassign(n, p); + + mpleft(n, s, m); + mpleft(mpone, k*Dbits, c); + mpsub(c, m, c); + } + + mpleft(x, s, r); + if(r->top <= k){ + mpbits(r, (k+1)*Dbits); + r->top = k+1; + } + + /* q = hi(r) */ + qn = r->top - k; + memmove(q, r->p+k, qn*Dbytes); + + /* r = lo(r) */ + r->top = k; + + do { + /* t = q*c */ + tn = qn + c->top; + memset(t, 0, tn*Dbytes); + mpvecmul(q, qn, c->p, c->top, t); + + /* q = hi(t) */ + qn = tn - k; + if(qn <= 0) qn = 0; + else memmove(q, t+k, qn*Dbytes); + + /* r += lo(t) */ + if(tn > k) + tn = k; + mpvecadd(r->p, k, t, tn, r->p); + + /* if(r >= m) r -= m */ + mpvecsub(r->p, k+1, m->p, k, t), d = t[k]; + for(tn = 0; tn < k; tn++) + r->p[tn] = (r->p[tn] & d) | (t[tn] & ~d); + } while(qn > 0); + + busy = 0; + + if(s != 0) + mpright(r, s, r); + else + mpnorm(r); + goto done; + +hard: + mpdiv(x, n, nil, r); + +done: + if(sign < 0) + mpmagsub(n, r, r); } diff --git a/sys/src/libmp/port/mpmodop.c b/sys/src/libmp/port/mpmodop.c new file mode 100644 index 000000000..8bc7cbb5a --- /dev/null +++ b/sys/src/libmp/port/mpmodop.c @@ -0,0 +1,96 @@ +#include +#include +#include + +/* operands need to have m->top+1 digits of space and satisfy 0 ≤ a ≤ m-1 */ +static mpint* +modarg(mpint *a, mpint *m) +{ + if(a->size <= m->top || a->sign < 0 || mpmagcmp(a, m) >= 0){ + a = mpcopy(a); + mpmod(a, m, a); + mpbits(a, Dbits*(m->top+1)); + a->top = m->top; + } else if(a->top < m->top){ + memset(&a->p[a->top], 0, (m->top - a->top)*Dbytes); + } + return a; +} + +void +mpmodadd(mpint *b1, mpint *b2, mpint *m, mpint *sum) +{ + mpint *a, *b; + mpdigit d; + int i, j; + + a = modarg(b1, m); + b = modarg(b2, m); + + sum->flags |= (a->flags | b->flags) & MPtimesafe; + mpbits(sum, Dbits*2*(m->top+1)); + + mpvecadd(a->p, m->top, b->p, m->top, sum->p); + mpvecsub(sum->p, m->top+1, m->p, m->top, sum->p+m->top+1); + + d = sum->p[2*m->top+1]; + for(i = 0, j = m->top+1; i < m->top; i++, j++) + sum->p[i] = (sum->p[i] & d) | (sum->p[j] & ~d); + + sum->top = m->top; + sum->sign = 1; + mpnorm(sum); + + if(a != b1) + mpfree(a); + if(b != b2) + mpfree(b); +} + +void +mpmodsub(mpint *b1, mpint *b2, mpint *m, mpint *diff) +{ + mpint *a, *b; + mpdigit d; + int i, j; + + a = modarg(b1, m); + b = modarg(b2, m); + + diff->flags |= (a->flags | b->flags) & MPtimesafe; + mpbits(diff, Dbits*2*(m->top+1)); + + a->p[m->top] = 0; + mpvecsub(a->p, m->top+1, b->p, m->top, diff->p); + mpvecadd(diff->p, m->top, m->p, m->top, diff->p+m->top+1); + + d = ~diff->p[m->top]; + for(i = 0, j = m->top+1; i < m->top; i++, j++) + diff->p[i] = (diff->p[i] & d) | (diff->p[j] & ~d); + + diff->top = m->top; + diff->sign = 1; + mpnorm(diff); + + if(a != b1) + mpfree(a); + if(b != b2) + mpfree(b); +} + +void +mpmodmul(mpint *b1, mpint *b2, mpint *m, mpint *prod) +{ + mpint *a, *b; + + a = modarg(b1, m); + b = modarg(b2, m); + + mpmul(a, b, prod); + mpmod(prod, m, prod); + + if(a != b1) + mpfree(a); + if(b != b2) + mpfree(b); +} diff --git a/sys/src/libmp/port/mpmul.c b/sys/src/libmp/port/mpmul.c index dedd474a7..777adf307 100644 --- a/sys/src/libmp/port/mpmul.c +++ b/sys/src/libmp/port/mpmul.c @@ -113,10 +113,6 @@ mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p) a = b; b = t; } - if(blen == 0){ - memset(p, 0, Dbytes*(alen+blen)); - return; - } if(alen >= KARATSUBAMIN && blen > 1){ // O(n^1.585) @@ -131,25 +127,49 @@ mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p) } } +void +mpvectsmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p) +{ + int i; + mpdigit *t; + + if(alen < blen){ + i = alen; + alen = blen; + blen = i; + t = a; + a = b; + b = t; + } + if(blen == 0) + return; + for(i = 0; i < blen; i++) + mpvecdigmuladd(a, alen, b[i], &p[i]); +} + void mpmul(mpint *b1, mpint *b2, mpint *prod) { mpint *oprod; - oprod = nil; + oprod = prod; if(prod == b1 || prod == b2){ - oprod = prod; prod = mpnew(0); + prod->flags = oprod->flags; } + prod->flags |= (b1->flags | b2->flags) & MPtimesafe; prod->top = 0; mpbits(prod, (b1->top+b2->top+1)*Dbits); - mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p); + if(prod->flags & MPtimesafe) + mpvectsmul(b1->p, b1->top, b2->p, b2->top, prod->p); + else + mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p); prod->top = b1->top+b2->top+1; prod->sign = b1->sign*b2->sign; mpnorm(prod); - if(oprod != nil){ + if(oprod != prod){ mpassign(prod, oprod); mpfree(prod); } diff --git a/sys/src/libmp/port/mpnrand.c b/sys/src/libmp/port/mpnrand.c index 600283d9d..ebbed5097 100644 --- a/sys/src/libmp/port/mpnrand.c +++ b/sys/src/libmp/port/mpnrand.c @@ -16,8 +16,10 @@ mpnrand(mpint *n, void (*gen)(uchar*, int), mpint *b) mpleft(mpone, bits, m); mpsub(m, mpone, m); - if(b == nil) + if(b == nil){ b = mpnew(bits); + setmalloctag(b, getcallerpc(&n)); + } /* m = m - (m % n) */ mpmod(m, n, b); diff --git a/sys/src/libmp/port/mprand.c b/sys/src/libmp/port/mprand.c index fd288f24e..29433b669 100644 --- a/sys/src/libmp/port/mprand.c +++ b/sys/src/libmp/port/mprand.c @@ -6,37 +6,32 @@ mpint* mprand(int bits, void (*gen)(uchar*, int), mpint *b) { - int n, m; mpdigit mask; + int n, m; uchar *p; n = DIGITS(bits); - if(b == nil) + if(b == nil){ b = mpnew(bits); - else + setmalloctag(b, getcallerpc(&bits)); + }else mpbits(b, bits); p = malloc(n*Dbytes); if(p == nil) - return nil; + sysfatal("mprand: %r"); (*gen)(p, n*Dbytes); betomp(p, n*Dbytes, b); free(p); // make sure we don't give too many bits m = bits%Dbits; - n--; - if(m > 0){ - mask = 1; - mask <<= m; - mask--; - b->p[n] &= mask; - } + if(m == 0) + return b; - for(; n >= 0; n--) - if(b->p[n] != 0) - break; - b->top = n+1; - b->sign = 1; - return b; + mask = 1; + mask <<= m; + mask--; + b->p[n-1] &= mask; + return mpnorm(b); } diff --git a/sys/src/libmp/port/mpright.c b/sys/src/libmp/port/mpright.c index 03039177b..dde7aeace 100644 --- a/sys/src/libmp/port/mpright.c +++ b/sys/src/libmp/port/mpright.c @@ -23,12 +23,16 @@ mpright(mpint *b, int shift, mpint *res) if(res != b) mpbits(res, b->top*Dbits - shift); + else if(shift == 0) + return; + d = shift/Dbits; r = shift - d*Dbits; l = Dbits - r; // shift all the bits out == zero if(d>=b->top){ + res->sign = 1; res->top = 0; return; } @@ -46,9 +50,8 @@ mpright(mpint *b, int shift, mpint *res) } res->p[i++] = last>>r; } - while(i > 0 && res->p[i-1] == 0) - i--; + res->top = i; - if(i==0) - res->sign = 1; + res->flags |= b->flags & MPtimesafe; + mpnorm(res); } diff --git a/sys/src/libmp/port/mpsel.c b/sys/src/libmp/port/mpsel.c new file mode 100644 index 000000000..a145b9d06 --- /dev/null +++ b/sys/src/libmp/port/mpsel.c @@ -0,0 +1,42 @@ +#include "os.h" +#include +#include "dat.h" + +// res = s != 0 ? b1 : b2 +void +mpsel(int s, mpint *b1, mpint *b2, mpint *res) +{ + mpdigit d; + int n, m, i; + + res->flags |= (b1->flags | b2->flags) & MPtimesafe; + if((res->flags & MPtimesafe) == 0){ + mpassign(s ? b1 : b2, res); + return; + } + res->flags &= ~MPnorm; + + n = b1->top; + m = b2->top; + mpbits(res, Dbits*(n >= m ? n : m)); + res->top = n >= m ? n : m; + + s = (-s^s|s)>>(sizeof(s)*8-1); + res->sign = (b1->sign & s) | (b2->sign & ~s); + + d = -((mpdigit)s & 1); + + i = 0; + while(i < n && i < m){ + res->p[i] = (b1->p[i] & d) | (b2->p[i] & ~d); + i++; + } + while(i < n){ + res->p[i] = b1->p[i] & d; + i++; + } + while(i < m){ + res->p[i] = b2->p[i] & ~d; + i++; + } +} diff --git a/sys/src/libmp/port/mpsub.c b/sys/src/libmp/port/mpsub.c index 3fe6ca095..292648f23 100644 --- a/sys/src/libmp/port/mpsub.c +++ b/sys/src/libmp/port/mpsub.c @@ -11,12 +11,15 @@ mpmagsub(mpint *b1, mpint *b2, mpint *diff) // get the sizes right if(mpmagcmp(b1, b2) < 0){ + assert(((b1->flags | b2->flags | diff->flags) & MPtimesafe) == 0); sign = -1; t = b1; b1 = b2; b2 = t; - } else + } else { + diff->flags |= (b1->flags | b2->flags) & MPtimesafe; sign = 1; + } n = b1->top; m = b2->top; if(m == 0){ @@ -39,6 +42,7 @@ mpsub(mpint *b1, mpint *b2, mpint *diff) int sign; if(b1->sign != b2->sign){ + assert(((b1->flags | b2->flags | diff->flags) & MPtimesafe) == 0); sign = b1->sign; mpmagadd(b1, b2, diff); diff->sign = sign; diff --git a/sys/src/libmp/port/mptobe.c b/sys/src/libmp/port/mptobe.c index ed527cc76..9ddea35ed 100644 --- a/sys/src/libmp/port/mptobe.c +++ b/sys/src/libmp/port/mptobe.c @@ -2,57 +2,31 @@ #include #include "dat.h" -// convert an mpint into a big endian byte array (most significant byte first) +// convert an mpint into a big endian byte array (most significant byte first; left adjusted) // return number of bytes converted // if p == nil, allocate and result array int mptobe(mpint *b, uchar *p, uint n, uchar **pp) { - int i, j, suppress; - mpdigit x; - uchar *e, *s, c; + int m; + m = (mpsignif(b)+7)/8; + if(m == 0) + m++; if(p == nil){ - n = (b->top+1)*Dbytes; + n = m; p = malloc(n); + if(p == nil) + sysfatal("mptobe: %r"); setmalloctag(p, getcallerpc(&b)); + } else { + if(n < m) + return -1; + if(n > m) + memset(p+m, 0, n-m); } - if(p == nil) - return -1; if(pp != nil) *pp = p; - memset(p, 0, n); - - // special case 0 - if(b->top == 0){ - if(n < 1) - return -1; - else - return 1; - } - - s = p; - e = s+n; - suppress = 1; - for(i = b->top-1; i >= 0; i--){ - x = b->p[i]; - for(j = Dbits-8; j >= 0; j -= 8){ - c = x>>j; - if(c == 0 && suppress) - continue; - if(p >= e) - return -1; - *p++ = c; - suppress = 0; - } - } - - // guarantee at least one byte - if(s == p){ - if(p >= e) - return -1; - *p++ = 0; - } - - return p - s; + mptober(b, p, m); + return m; } diff --git a/sys/src/libmp/port/mptober.c b/sys/src/libmp/port/mptober.c new file mode 100644 index 000000000..ce63d338d --- /dev/null +++ b/sys/src/libmp/port/mptober.c @@ -0,0 +1,34 @@ +#include "os.h" +#include +#include "dat.h" + +void +mptober(mpint *b, uchar *p, int n) +{ + int i, j, m; + mpdigit x; + + memset(p, 0, n); + + p += n; + m = b->top*Dbytes; + if(m < n) + n = m; + + i = 0; + while(n >= Dbytes){ + n -= Dbytes; + x = b->p[i++]; + for(j = 0; j < Dbytes; j++){ + *--p = x; + x >>= 8; + } + } + if(n > 0){ + x = b->p[i]; + for(j = 0; j < n; j++){ + *--p = x; + x >>= 8; + } + } +} diff --git a/sys/src/libmp/port/mptoi.c b/sys/src/libmp/port/mptoi.c index b3f22b424..6183fa7e5 100644 --- a/sys/src/libmp/port/mptoi.c +++ b/sys/src/libmp/port/mptoi.c @@ -10,17 +10,15 @@ mpint* itomp(int i, mpint *b) { - if(b == nil) + if(b == nil){ b = mpnew(0); - mpassign(mpzero, b); - if(i != 0) - b->top = 1; - if(i < 0){ - b->sign = -1; - *b->p = -i; - } else - *b->p = i; - return b; + setmalloctag(b, getcallerpc(&i)); + } + b->sign = (i >> (sizeof(i)*8 - 1)) | 1; + i *= b->sign; + *b->p = i; + b->top = 1; + return mpnorm(b); } int diff --git a/sys/src/libmp/port/mptole.c b/sys/src/libmp/port/mptole.c index 9421d5f66..3dd892401 100644 --- a/sys/src/libmp/port/mptole.c +++ b/sys/src/libmp/port/mptole.c @@ -3,52 +3,26 @@ #include "dat.h" // convert an mpint into a little endian byte array (least significant byte first) - // return number of bytes converted // if p == nil, allocate and result array int mptole(mpint *b, uchar *p, uint n, uchar **pp) { - int i, j; - mpdigit x; - uchar *e, *s; + int m; + m = (mpsignif(b)+7)/8; + if(m == 0) + m++; if(p == nil){ - n = (b->top+1)*Dbytes; + n = m; p = malloc(n); - } + if(p == nil) + sysfatal("mptole: %r"); + setmalloctag(p, getcallerpc(&b)); + } else if(n < m) + return -1; if(pp != nil) *pp = p; - if(p == nil) - return -1; - memset(p, 0, n); - - // special case 0 - if(b->top == 0){ - if(n < 1) - return -1; - else - return 0; - } - - s = p; - e = s+n; - for(i = 0; i < b->top-1; i++){ - x = b->p[i]; - for(j = 0; j < Dbytes; j++){ - if(p >= e) - return -1; - *p++ = x; - x >>= 8; - } - } - x = b->p[i]; - while(x > 0){ - if(p >= e) - return -1; - *p++ = x; - x >>= 8; - } - - return p - s; + mptolel(b, p, n); + return m; } diff --git a/sys/src/libmp/port/mptolel.c b/sys/src/libmp/port/mptolel.c new file mode 100644 index 000000000..4ee41971f --- /dev/null +++ b/sys/src/libmp/port/mptolel.c @@ -0,0 +1,33 @@ +#include "os.h" +#include +#include "dat.h" + +void +mptolel(mpint *b, uchar *p, int n) +{ + int i, j, m; + mpdigit x; + + memset(p, 0, n); + + m = b->top*Dbytes; + if(m < n) + n = m; + + i = 0; + while(n >= Dbytes){ + n -= Dbytes; + x = b->p[i++]; + for(j = 0; j < Dbytes; j++){ + *p++ = x; + x >>= 8; + } + } + if(n > 0){ + x = b->p[i]; + for(j = 0; j < n; j++){ + *p++ = x; + x >>= 8; + } + } +} diff --git a/sys/src/libmp/port/mptoui.c b/sys/src/libmp/port/mptoui.c index 41c0b0b67..2a963de0c 100644 --- a/sys/src/libmp/port/mptoui.c +++ b/sys/src/libmp/port/mptoui.c @@ -10,13 +10,14 @@ mpint* uitomp(uint i, mpint *b) { - if(b == nil) + if(b == nil){ b = mpnew(0); - mpassign(mpzero, b); - if(i != 0) - b->top = 1; + setmalloctag(b, getcallerpc(&i)); + } *b->p = i; - return b; + b->top = 1; + b->sign = 1; + return mpnorm(b); } uint diff --git a/sys/src/libmp/port/mptouv.c b/sys/src/libmp/port/mptouv.c index b2a7632d1..9e52a357f 100644 --- a/sys/src/libmp/port/mptouv.c +++ b/sys/src/libmp/port/mptouv.c @@ -13,19 +13,18 @@ uvtomp(uvlong v, mpint *b) { int s; - if(b == nil) + if(b == nil){ b = mpnew(VLDIGITS*sizeof(mpdigit)); - else + setmalloctag(b, getcallerpc(&v)); + }else mpbits(b, VLDIGITS*sizeof(mpdigit)); - mpassign(mpzero, b); - if(v == 0) - return b; - for(s = 0; s < VLDIGITS && v != 0; s++){ + b->sign = 1; + for(s = 0; s < VLDIGITS; s++){ b->p[s] = v; v >>= sizeof(mpdigit)*8; } b->top = s; - return b; + return mpnorm(b); } uvlong @@ -37,7 +36,6 @@ mptouv(mpint *b) if(b->top == 0) return 0LL; - mpnorm(b); if(b->top > VLDIGITS) return MAXVLONG; diff --git a/sys/src/libmp/port/mptov.c b/sys/src/libmp/port/mptov.c index b09718ef0..b1b3e93f7 100644 --- a/sys/src/libmp/port/mptov.c +++ b/sys/src/libmp/port/mptov.c @@ -14,24 +14,19 @@ vtomp(vlong v, mpint *b) int s; uvlong uv; - if(b == nil) + if(b == nil){ b = mpnew(VLDIGITS*sizeof(mpdigit)); - else + setmalloctag(b, getcallerpc(&v)); + }else mpbits(b, VLDIGITS*sizeof(mpdigit)); - mpassign(mpzero, b); - if(v == 0) - return b; - if(v < 0){ - b->sign = -1; - uv = -v; - } else - uv = v; - for(s = 0; s < VLDIGITS && uv != 0; s++){ + b->sign = (v >> (sizeof(v)*8 - 1)) | 1; + uv = v * b->sign; + for(s = 0; s < VLDIGITS; s++){ b->p[s] = uv; uv >>= sizeof(mpdigit)*8; } b->top = s; - return b; + return mpnorm(b); } vlong @@ -43,7 +38,6 @@ mptov(mpint *b) if(b->top == 0) return 0LL; - mpnorm(b); if(b->top > VLDIGITS){ if(b->sign > 0) return (vlong)MAXVLONG; diff --git a/sys/src/libmp/port/mpvectscmp.c b/sys/src/libmp/port/mpvectscmp.c new file mode 100644 index 000000000..ccad79b16 --- /dev/null +++ b/sys/src/libmp/port/mpvectscmp.c @@ -0,0 +1,34 @@ +#include "os.h" +#include +#include "dat.h" + +int +mpvectscmp(mpdigit *a, int alen, mpdigit *b, int blen) +{ + mpdigit x, y, z, v; + int m, p; + + if(alen > blen){ + v = 0; + while(alen > blen) + v |= a[--alen]; + m = p = (-v^v|v)>>Dbits-1; + } else if(blen > alen){ + v = 0; + while(blen > alen) + v |= b[--blen]; + m = (-v^v|v)>>Dbits-1; + p = m^1; + } else + m = p = 0; + while(alen-- > 0){ + x = a[alen]; + y = b[alen]; + z = x - y; + x = ~x; + v = ((-z^z|z)>>Dbits-1) & ~m; + p = ((~(x&y|x&z|y&z)>>Dbits-1) & v) | (p & ~v); + m |= v; + } + return (p-m) | m; +} diff --git a/sys/src/libmp/port/strtomp.c b/sys/src/libmp/port/strtomp.c index 2ef8c2109..0a9959692 100644 --- a/sys/src/libmp/port/strtomp.c +++ b/sys/src/libmp/port/strtomp.c @@ -50,7 +50,6 @@ from16(char *a, mpint *b) int i; mpdigit x; - b->top = 0; for(p = a; *p; p++) if(tab.t16[*(uchar*)p] == INVAL) break; @@ -157,8 +156,10 @@ strtomp(char *a, char **pp, int base, mpint *b) int sign; char *e; - if(b == nil) + if(b == nil){ b = mpnew(0); + setmalloctag(b, getcallerpc(&a)); + } if(tab.inited == 0) init(); @@ -196,10 +197,9 @@ strtomp(char *a, char **pp, int base, mpint *b) if(e == a) return nil; - mpnorm(b); - b->sign = sign; if(pp != nil) *pp = e; - return b; + b->sign = sign; + return mpnorm(b); }