diff --git a/lib/std/compress/deflate.zig b/lib/std/compress/deflate.zig index 9fe96cacb..2697fd5b8 100644 --- a/lib/std/compress/deflate.zig +++ b/lib/std/compress/deflate.zig @@ -27,6 +27,8 @@ const FIXLCODES = 288; const PREFIX_LUT_BITS = 9; const Huffman = struct { + const LUTEntry = packed struct { symbol: u16 align(4), len: u16 }; + // Number of codes for each possible length count: [MAXBITS + 1]u16, // Mapping between codes and symbols @@ -40,19 +42,23 @@ const Huffman = struct { // canonical Huffman code and we have to decode it using a slower method. // // [1] https://github.com/madler/zlib/blob/v1.2.11/doc/algorithm.txt#L58 - prefix_lut: [1 << PREFIX_LUT_BITS]u16, - prefix_lut_len: [1 << PREFIX_LUT_BITS]u16, + prefix_lut: [1 << PREFIX_LUT_BITS]LUTEntry, // The following info refer to the codes of length PREFIX_LUT_BITS+1 and are // used to bootstrap the bit-by-bit reading method if the fast-path fails. last_code: u16, last_index: u16, + min_code_len: u16, + fn construct(self: *Huffman, code_length: []const u16) !void { for (self.count) |*val| { val.* = 0; } + self.min_code_len = math.maxInt(u16); for (code_length) |len| { + if (len != 0 and len < self.min_code_len) + self.min_code_len = len; self.count[len] += 1; } @@ -85,39 +91,38 @@ const Huffman = struct { } } - self.prefix_lut_len = mem.zeroes(@TypeOf(self.prefix_lut_len)); + self.prefix_lut = mem.zeroes(@TypeOf(self.prefix_lut)); for (code_length) |len, symbol| { if (len != 0) { // Fill the symbol table. // The symbols are assigned sequentially for each length. self.symbol[offset[len]] = @truncate(u16, symbol); - // Track the last assigned offset + // Track the last assigned offset. offset[len] += 1; } if (len == 0 or len > PREFIX_LUT_BITS) continue; - // Given a Huffman code of length N we have to massage it so - // that it becomes an index in the lookup table. - // The bit order is reversed as the fast path reads the bit - // sequence MSB to LSB using an &, the order is flipped wrt the - // one obtained by reading bit-by-bit. - // The codes are prefix-free, if the prefix matches we can - // safely ignore the trail bits. We do so by replicating the - // symbol info for each combination of the trailing bits. + // Given a Huffman code of length N we transform it into an index + // into the lookup table by reversing its bits and filling the + // remaining bits (PREFIX_LUT_BITS - N) with every possible + // combination of bits to act as a wildcard. const bits_to_fill = @intCast(u5, PREFIX_LUT_BITS - len); - const rev_code = bitReverse(codes[len], len); - // Track the last used code, but only for lengths < PREFIX_LUT_BITS + const rev_code = bitReverse(u16, codes[len], len); + + // Track the last used code, but only for lengths < PREFIX_LUT_BITS. codes[len] += 1; var j: usize = 0; while (j < @as(usize, 1) << bits_to_fill) : (j += 1) { const index = rev_code | (j << @intCast(u5, len)); - assert(self.prefix_lut_len[index] == 0); - self.prefix_lut[index] = @truncate(u16, symbol); - self.prefix_lut_len[index] = @truncate(u16, len); + assert(self.prefix_lut[index].len == 0); + self.prefix_lut[index] = .{ + .symbol = @truncate(u16, symbol), + .len = @truncate(u16, len), + }; } } @@ -126,14 +131,10 @@ const Huffman = struct { } }; -// Reverse bit-by-bit a N-bit value -fn bitReverse(x: usize, N: usize) usize { - var tmp: usize = 0; - var i: usize = 0; - while (i < N) : (i += 1) { - tmp |= ((x >> @intCast(u5, i)) & 1) << @intCast(u5, N - i - 1); - } - return tmp; +// Reverse bit-by-bit a N-bit code. +fn bitReverse(comptime T: type, value: T, N: usize) T { + const r = @bitReverse(T, value); + return r >> @intCast(math.Log2Int(T), @typeInfo(T).Int.bits - N); } pub fn InflateStream(comptime ReaderType: type) type { @@ -269,8 +270,8 @@ pub fn InflateStream(comptime ReaderType: type) type { hdist: *Huffman, hlen: *Huffman, - // Temporary buffer for the bitstream, only bits 0..`bits_left` are - // considered valid. + // Temporary buffer for the bitstream. + // Bits 0..`bits_left` are filled with data, the remaining ones are zeros. bits: u32, bits_left: usize, @@ -280,7 +281,8 @@ pub fn InflateStream(comptime ReaderType: type) type { self.bits |= @as(u32, byte) << @intCast(u5, self.bits_left); self.bits_left += 8; } - return self.bits & ((@as(u32, 1) << @intCast(u5, bits)) - 1); + const mask = (@as(u32, 1) << @intCast(u5, bits)) - 1; + return self.bits & mask; } fn readBits(self: *Self, bits: usize) !u32 { const val = self.peekBits(bits); @@ -293,8 +295,8 @@ pub fn InflateStream(comptime ReaderType: type) type { } fn stored(self: *Self) !void { - // Discard the remaining bits, the lenght field is always - // byte-aligned (and so is the data) + // Discard the remaining bits, the length field is always + // byte-aligned (and so is the data). self.discardBits(self.bits_left); const length = try self.inner_reader.readIntLittle(u16); @@ -481,32 +483,52 @@ pub fn InflateStream(comptime ReaderType: type) type { } fn decode(self: *Self, h: *Huffman) !u16 { - // Fast path, read some bits and hope they're prefixes of some code - const prefix = try self.peekBits(PREFIX_LUT_BITS); - if (h.prefix_lut_len[prefix] != 0) { - self.discardBits(h.prefix_lut_len[prefix]); - return h.prefix_lut[prefix]; + // Using u32 instead of u16 to reduce the number of casts needed. + var prefix: u32 = 0; + + // Fast path, read some bits and hope they're the prefix of some code. + // We can't read PREFIX_LUT_BITS as we don't want to read past the + // deflate stream end, use an incremental approach instead. + var code_len = h.min_code_len; + while (true) { + _ = try self.peekBits(code_len); + // Small optimization win, use as many bits as possible in the + // table lookup. + prefix = self.bits & ((1 << PREFIX_LUT_BITS) - 1); + + const lut_entry = &h.prefix_lut[prefix]; + // The code is longer than PREFIX_LUT_BITS! + if (lut_entry.len == 0) + break; + // If the code lenght doesn't increase we found a match. + if (lut_entry.len <= code_len) { + self.discardBits(code_len); + return lut_entry.symbol; + } + + code_len = lut_entry.len; } // The sequence we've read is not a prefix of any code of length <= - // PREFIX_LUT_BITS, keep decoding it using a slower method - self.discardBits(PREFIX_LUT_BITS); + // PREFIX_LUT_BITS, keep decoding it using a slower method. + prefix = try self.readBits(PREFIX_LUT_BITS); // Speed up the decoding by starting from the first code length - // that's not covered by the table + // that's not covered by the table. var len: usize = PREFIX_LUT_BITS + 1; var first: usize = h.last_code; var index: usize = h.last_index; // Reverse the prefix so that the LSB becomes the MSB and make space - // for the next bit - var code = bitReverse(prefix, PREFIX_LUT_BITS + 1); + // for the next bit. + var code = bitReverse(u32, prefix, PREFIX_LUT_BITS + 1); while (len <= MAXBITS) : (len += 1) { code |= try self.readBits(1); const count = h.count[len]; - if (code < first + count) + if (code < first + count) { return h.symbol[index + (code - first)]; + } index += count; first += count; first <<= 1; @@ -520,7 +542,7 @@ pub fn InflateStream(comptime ReaderType: type) type { while (true) { switch (self.state) { .DecodeBlockHeader => { - // The compressed stream is done + // The compressed stream is done. if (self.seen_eos) return; const last = @intCast(u1, try self.readBits(1)); @@ -528,7 +550,7 @@ pub fn InflateStream(comptime ReaderType: type) type { self.seen_eos = last != 0; - // The next state depends on the block type + // The next state depends on the block type. switch (kind) { 0 => try self.stored(), 1 => try self.fixed(), @@ -553,7 +575,7 @@ pub fn InflateStream(comptime ReaderType: type) type { var tmp: [1]u8 = undefined; if ((try self.inner_reader.read(&tmp)) != 1) { // Unexpected end of stream, keep this error - // consistent with the use of readBitsNoEof + // consistent with the use of readBitsNoEof. return error.EndOfStream; } self.window.appendUnsafe(tmp[0]); diff --git a/lib/std/compress/zlib.zig b/lib/std/compress/zlib.zig index d4bac4a8a..63ef6c2ae 100644 --- a/lib/std/compress/zlib.zig +++ b/lib/std/compress/zlib.zig @@ -144,6 +144,19 @@ test "compressed data" { ); } +test "don't read past deflate stream's end" { + try testReader( + &[_]u8{ + 0x08, 0xd7, 0x63, 0xf8, 0xcf, 0xc0, 0xc0, 0x00, 0xc1, 0xff, + 0xff, 0x43, 0x30, 0x03, 0x03, 0xc3, 0xff, 0xff, 0xff, 0x01, + 0x83, 0x95, 0x0b, 0xf5, + }, + // SHA256 of + // 00ff 0000 00ff 0000 00ff 00ff ffff 00ff ffff 0000 0000 ffff ff + "3bbba1cc65408445c81abb61f3d2b86b1b60ee0d70b4c05b96d1499091a08c93", + ); +} + test "sanity checks" { // Truncated header testing.expectError( diff --git a/lib/std/math.zig b/lib/std/math.zig index f0c4f74d7..ffc0aa168 100644 --- a/lib/std/math.zig +++ b/lib/std/math.zig @@ -1141,4 +1141,3 @@ test "math.comptime" { comptime const v = sin(@as(f32, 1)) + ln(@as(f32, 5)); testing.expect(v == sin(@as(f32, 1)) + ln(@as(f32, 5))); } -