std/deflate: Avoid reading past end of stream

Use a conservative (and slower) approach in the Huffman decoder fast
path.

Closes #6847
master
LemonBoy 2020-10-29 17:16:03 +01:00
parent 88eb3ae8e5
commit 20fba0933f
3 changed files with 79 additions and 45 deletions

View File

@ -27,6 +27,8 @@ const FIXLCODES = 288;
const PREFIX_LUT_BITS = 9;
const Huffman = struct {
const LUTEntry = packed struct { symbol: u16 align(4), len: u16 };
// Number of codes for each possible length
count: [MAXBITS + 1]u16,
// Mapping between codes and symbols
@ -40,19 +42,23 @@ const Huffman = struct {
// canonical Huffman code and we have to decode it using a slower method.
//
// [1] https://github.com/madler/zlib/blob/v1.2.11/doc/algorithm.txt#L58
prefix_lut: [1 << PREFIX_LUT_BITS]u16,
prefix_lut_len: [1 << PREFIX_LUT_BITS]u16,
prefix_lut: [1 << PREFIX_LUT_BITS]LUTEntry,
// The following info refer to the codes of length PREFIX_LUT_BITS+1 and are
// used to bootstrap the bit-by-bit reading method if the fast-path fails.
last_code: u16,
last_index: u16,
min_code_len: u16,
fn construct(self: *Huffman, code_length: []const u16) !void {
for (self.count) |*val| {
val.* = 0;
}
self.min_code_len = math.maxInt(u16);
for (code_length) |len| {
if (len != 0 and len < self.min_code_len)
self.min_code_len = len;
self.count[len] += 1;
}
@ -85,39 +91,38 @@ const Huffman = struct {
}
}
self.prefix_lut_len = mem.zeroes(@TypeOf(self.prefix_lut_len));
self.prefix_lut = mem.zeroes(@TypeOf(self.prefix_lut));
for (code_length) |len, symbol| {
if (len != 0) {
// Fill the symbol table.
// The symbols are assigned sequentially for each length.
self.symbol[offset[len]] = @truncate(u16, symbol);
// Track the last assigned offset
// Track the last assigned offset.
offset[len] += 1;
}
if (len == 0 or len > PREFIX_LUT_BITS)
continue;
// Given a Huffman code of length N we have to massage it so
// that it becomes an index in the lookup table.
// The bit order is reversed as the fast path reads the bit
// sequence MSB to LSB using an &, the order is flipped wrt the
// one obtained by reading bit-by-bit.
// The codes are prefix-free, if the prefix matches we can
// safely ignore the trail bits. We do so by replicating the
// symbol info for each combination of the trailing bits.
// Given a Huffman code of length N we transform it into an index
// into the lookup table by reversing its bits and filling the
// remaining bits (PREFIX_LUT_BITS - N) with every possible
// combination of bits to act as a wildcard.
const bits_to_fill = @intCast(u5, PREFIX_LUT_BITS - len);
const rev_code = bitReverse(codes[len], len);
// Track the last used code, but only for lengths < PREFIX_LUT_BITS
const rev_code = bitReverse(u16, codes[len], len);
// Track the last used code, but only for lengths < PREFIX_LUT_BITS.
codes[len] += 1;
var j: usize = 0;
while (j < @as(usize, 1) << bits_to_fill) : (j += 1) {
const index = rev_code | (j << @intCast(u5, len));
assert(self.prefix_lut_len[index] == 0);
self.prefix_lut[index] = @truncate(u16, symbol);
self.prefix_lut_len[index] = @truncate(u16, len);
assert(self.prefix_lut[index].len == 0);
self.prefix_lut[index] = .{
.symbol = @truncate(u16, symbol),
.len = @truncate(u16, len),
};
}
}
@ -126,14 +131,10 @@ const Huffman = struct {
}
};
// Reverse bit-by-bit a N-bit value
fn bitReverse(x: usize, N: usize) usize {
var tmp: usize = 0;
var i: usize = 0;
while (i < N) : (i += 1) {
tmp |= ((x >> @intCast(u5, i)) & 1) << @intCast(u5, N - i - 1);
}
return tmp;
// Reverse bit-by-bit a N-bit code.
fn bitReverse(comptime T: type, value: T, N: usize) T {
const r = @bitReverse(T, value);
return r >> @intCast(math.Log2Int(T), @typeInfo(T).Int.bits - N);
}
pub fn InflateStream(comptime ReaderType: type) type {
@ -269,8 +270,8 @@ pub fn InflateStream(comptime ReaderType: type) type {
hdist: *Huffman,
hlen: *Huffman,
// Temporary buffer for the bitstream, only bits 0..`bits_left` are
// considered valid.
// Temporary buffer for the bitstream.
// Bits 0..`bits_left` are filled with data, the remaining ones are zeros.
bits: u32,
bits_left: usize,
@ -280,7 +281,8 @@ pub fn InflateStream(comptime ReaderType: type) type {
self.bits |= @as(u32, byte) << @intCast(u5, self.bits_left);
self.bits_left += 8;
}
return self.bits & ((@as(u32, 1) << @intCast(u5, bits)) - 1);
const mask = (@as(u32, 1) << @intCast(u5, bits)) - 1;
return self.bits & mask;
}
fn readBits(self: *Self, bits: usize) !u32 {
const val = self.peekBits(bits);
@ -293,8 +295,8 @@ pub fn InflateStream(comptime ReaderType: type) type {
}
fn stored(self: *Self) !void {
// Discard the remaining bits, the lenght field is always
// byte-aligned (and so is the data)
// Discard the remaining bits, the length field is always
// byte-aligned (and so is the data).
self.discardBits(self.bits_left);
const length = try self.inner_reader.readIntLittle(u16);
@ -481,32 +483,52 @@ pub fn InflateStream(comptime ReaderType: type) type {
}
fn decode(self: *Self, h: *Huffman) !u16 {
// Fast path, read some bits and hope they're prefixes of some code
const prefix = try self.peekBits(PREFIX_LUT_BITS);
if (h.prefix_lut_len[prefix] != 0) {
self.discardBits(h.prefix_lut_len[prefix]);
return h.prefix_lut[prefix];
// Using u32 instead of u16 to reduce the number of casts needed.
var prefix: u32 = 0;
// Fast path, read some bits and hope they're the prefix of some code.
// We can't read PREFIX_LUT_BITS as we don't want to read past the
// deflate stream end, use an incremental approach instead.
var code_len = h.min_code_len;
while (true) {
_ = try self.peekBits(code_len);
// Small optimization win, use as many bits as possible in the
// table lookup.
prefix = self.bits & ((1 << PREFIX_LUT_BITS) - 1);
const lut_entry = &h.prefix_lut[prefix];
// The code is longer than PREFIX_LUT_BITS!
if (lut_entry.len == 0)
break;
// If the code lenght doesn't increase we found a match.
if (lut_entry.len <= code_len) {
self.discardBits(code_len);
return lut_entry.symbol;
}
code_len = lut_entry.len;
}
// The sequence we've read is not a prefix of any code of length <=
// PREFIX_LUT_BITS, keep decoding it using a slower method
self.discardBits(PREFIX_LUT_BITS);
// PREFIX_LUT_BITS, keep decoding it using a slower method.
prefix = try self.readBits(PREFIX_LUT_BITS);
// Speed up the decoding by starting from the first code length
// that's not covered by the table
// that's not covered by the table.
var len: usize = PREFIX_LUT_BITS + 1;
var first: usize = h.last_code;
var index: usize = h.last_index;
// Reverse the prefix so that the LSB becomes the MSB and make space
// for the next bit
var code = bitReverse(prefix, PREFIX_LUT_BITS + 1);
// for the next bit.
var code = bitReverse(u32, prefix, PREFIX_LUT_BITS + 1);
while (len <= MAXBITS) : (len += 1) {
code |= try self.readBits(1);
const count = h.count[len];
if (code < first + count)
if (code < first + count) {
return h.symbol[index + (code - first)];
}
index += count;
first += count;
first <<= 1;
@ -520,7 +542,7 @@ pub fn InflateStream(comptime ReaderType: type) type {
while (true) {
switch (self.state) {
.DecodeBlockHeader => {
// The compressed stream is done
// The compressed stream is done.
if (self.seen_eos) return;
const last = @intCast(u1, try self.readBits(1));
@ -528,7 +550,7 @@ pub fn InflateStream(comptime ReaderType: type) type {
self.seen_eos = last != 0;
// The next state depends on the block type
// The next state depends on the block type.
switch (kind) {
0 => try self.stored(),
1 => try self.fixed(),
@ -553,7 +575,7 @@ pub fn InflateStream(comptime ReaderType: type) type {
var tmp: [1]u8 = undefined;
if ((try self.inner_reader.read(&tmp)) != 1) {
// Unexpected end of stream, keep this error
// consistent with the use of readBitsNoEof
// consistent with the use of readBitsNoEof.
return error.EndOfStream;
}
self.window.appendUnsafe(tmp[0]);

View File

@ -144,6 +144,19 @@ test "compressed data" {
);
}
test "don't read past deflate stream's end" {
try testReader(
&[_]u8{
0x08, 0xd7, 0x63, 0xf8, 0xcf, 0xc0, 0xc0, 0x00, 0xc1, 0xff,
0xff, 0x43, 0x30, 0x03, 0x03, 0xc3, 0xff, 0xff, 0xff, 0x01,
0x83, 0x95, 0x0b, 0xf5,
},
// SHA256 of
// 00ff 0000 00ff 0000 00ff 00ff ffff 00ff ffff 0000 0000 ffff ff
"3bbba1cc65408445c81abb61f3d2b86b1b60ee0d70b4c05b96d1499091a08c93",
);
}
test "sanity checks" {
// Truncated header
testing.expectError(

View File

@ -1141,4 +1141,3 @@ test "math.comptime" {
comptime const v = sin(@as(f32, 1)) + ln(@as(f32, 5));
testing.expect(v == sin(@as(f32, 1)) + ln(@as(f32, 5)));
}