diff --git a/lib/std/math/big/int.zig b/lib/std/math/big/int.zig index 207d76af6..54ad2f55d 100644 --- a/lib/std/math/big/int.zig +++ b/lib/std/math/big/int.zig @@ -59,8 +59,8 @@ pub fn calcSetStringLimbCount(base: u8, string_len: usize) usize { } pub fn calcPowLimbsBufferLen(a_bit_count: usize, y: usize) usize { - // The 1 accounts for the multiplication carry - return 1 + (a_bit_count * y + (limb_bits - 1)) / limb_bits; + // The 2 accounts for the minimum space requirement for llmulacc + return 2 + (a_bit_count * y + (limb_bits - 1)) / limb_bits; } /// a + b * c + *carry, sets carry to the overflow bits @@ -2205,47 +2205,51 @@ fn llxor(r: []Limb, a: []const Limb, b: []const Limb) void { /// Knuth 4.6.3 fn llpow(r: []Limb, a: []const Limb, b: u32, tmp_limbs: []Limb) void { - mem.copy(Limb, r, a); - mem.set(Limb, r[a.len..], 0); + var tmp1: []Limb = undefined; + var tmp2: []Limb = undefined; // Multiplication requires no aliasing between the operand and the result // variable, use the output limbs and another temporary set to overcome this - // limit. - // Note that the order is important in the code below. - var list = [_][]Limb{ r, tmp_limbs }; - var index: usize = 0; - - // Scan the exponent as a binary number, from left to right, dropping the - // most significant bit set - var exp = @bitReverse(u32, b) >> (1 + @intCast(u5, @clz(u32, b))); - while (exp != 0) : (exp >>= 1) { - // Square - { - const cur_buf = list[index]; - const cur_buf_len = llnormalize(cur_buf); - const cur_buf_out = list[index ^ 1]; - - mem.set(Limb, cur_buf_out, 0); - llmulacc(null, cur_buf_out, cur_buf[0..cur_buf_len], cur_buf[0..cur_buf_len]); - - index ^= 1; - } - - if ((exp & 1) != 0) { - // Multiply - const cur_buf = list[index]; - const cur_buf_len = llnormalize(cur_buf); - const cur_buf_out = list[index ^ 1]; - - mem.set(Limb, cur_buf_out, 0); - llmulacc(null, cur_buf_out, cur_buf, a); - - index ^= 1; - } + // limitation. + // The initial assignment makes the result end in `r` so an extra memory + // copy is saved, each 1 flips the index twice so it's a no-op so count the + // 0. + const b_leading_zeros = @intCast(u5, @clz(u32, b)); + const exp_zeros = @popCount(u32, ~b) - b_leading_zeros; + if (exp_zeros & 1 != 0) { + tmp1 = tmp_limbs; + tmp2 = r; + } else { + tmp1 = r; + tmp2 = tmp_limbs; } - if (index != 0) { - mem.copy(Limb, r, tmp_limbs); + const a_norm = a[0..llnormalize(a)]; + + mem.copy(Limb, tmp1, a_norm); + mem.set(Limb, tmp1[a_norm.len..], 0); + + // Scan the exponent as a binary number, from left to right, dropping the + // most significant bit set. + const exp_bits = @intCast(u5, 31 - b_leading_zeros); + var exp = @bitReverse(u32, b) >> 1 + b_leading_zeros; + + var i: u5 = 0; + while (i < exp_bits) : (i += 1) { + // Square + { + mem.set(Limb, tmp2, 0); + const op = tmp1[0..llnormalize(tmp1)]; + llmulacc(null, tmp2, op, op); + mem.swap([]Limb, &tmp1, &tmp2); + } + // Multiply by a + if (exp & 1 != 0) { + mem.set(Limb, tmp2, 0); + llmulacc(null, tmp2, tmp1[0..llnormalize(tmp1)], a_norm); + mem.swap([]Limb, &tmp1, &tmp2); + } + exp >>= 1; } } diff --git a/lib/std/math/big/int_test.zig b/lib/std/math/big/int_test.zig index 85c0ff387..5d07bee9b 100644 --- a/lib/std/math/big/int_test.zig +++ b/lib/std/math/big/int_test.zig @@ -1482,6 +1482,13 @@ test "big.int const to managed" { } test "big.int pow" { + { + var a = try Managed.initSet(testing.allocator, 10); + defer a.deinit(); + + try a.pow(a, 8); + testing.expectEqual(@as(u32, 100000000), try a.to(u32)); + } { var a = try Managed.initSet(testing.allocator, 10); defer a.deinit();