std/deflate: Avoid reading past end of stream
Use a conservative (and slower) approach in the Huffman decoder fast path. Closes #6847master
parent
88eb3ae8e5
commit
20fba0933f
|
@ -27,6 +27,8 @@ const FIXLCODES = 288;
|
|||
const PREFIX_LUT_BITS = 9;
|
||||
|
||||
const Huffman = struct {
|
||||
const LUTEntry = packed struct { symbol: u16 align(4), len: u16 };
|
||||
|
||||
// Number of codes for each possible length
|
||||
count: [MAXBITS + 1]u16,
|
||||
// Mapping between codes and symbols
|
||||
|
@ -40,19 +42,23 @@ const Huffman = struct {
|
|||
// canonical Huffman code and we have to decode it using a slower method.
|
||||
//
|
||||
// [1] https://github.com/madler/zlib/blob/v1.2.11/doc/algorithm.txt#L58
|
||||
prefix_lut: [1 << PREFIX_LUT_BITS]u16,
|
||||
prefix_lut_len: [1 << PREFIX_LUT_BITS]u16,
|
||||
prefix_lut: [1 << PREFIX_LUT_BITS]LUTEntry,
|
||||
// The following info refer to the codes of length PREFIX_LUT_BITS+1 and are
|
||||
// used to bootstrap the bit-by-bit reading method if the fast-path fails.
|
||||
last_code: u16,
|
||||
last_index: u16,
|
||||
|
||||
min_code_len: u16,
|
||||
|
||||
fn construct(self: *Huffman, code_length: []const u16) !void {
|
||||
for (self.count) |*val| {
|
||||
val.* = 0;
|
||||
}
|
||||
|
||||
self.min_code_len = math.maxInt(u16);
|
||||
for (code_length) |len| {
|
||||
if (len != 0 and len < self.min_code_len)
|
||||
self.min_code_len = len;
|
||||
self.count[len] += 1;
|
||||
}
|
||||
|
||||
|
@ -85,39 +91,38 @@ const Huffman = struct {
|
|||
}
|
||||
}
|
||||
|
||||
self.prefix_lut_len = mem.zeroes(@TypeOf(self.prefix_lut_len));
|
||||
self.prefix_lut = mem.zeroes(@TypeOf(self.prefix_lut));
|
||||
|
||||
for (code_length) |len, symbol| {
|
||||
if (len != 0) {
|
||||
// Fill the symbol table.
|
||||
// The symbols are assigned sequentially for each length.
|
||||
self.symbol[offset[len]] = @truncate(u16, symbol);
|
||||
// Track the last assigned offset
|
||||
// Track the last assigned offset.
|
||||
offset[len] += 1;
|
||||
}
|
||||
|
||||
if (len == 0 or len > PREFIX_LUT_BITS)
|
||||
continue;
|
||||
|
||||
// Given a Huffman code of length N we have to massage it so
|
||||
// that it becomes an index in the lookup table.
|
||||
// The bit order is reversed as the fast path reads the bit
|
||||
// sequence MSB to LSB using an &, the order is flipped wrt the
|
||||
// one obtained by reading bit-by-bit.
|
||||
// The codes are prefix-free, if the prefix matches we can
|
||||
// safely ignore the trail bits. We do so by replicating the
|
||||
// symbol info for each combination of the trailing bits.
|
||||
// Given a Huffman code of length N we transform it into an index
|
||||
// into the lookup table by reversing its bits and filling the
|
||||
// remaining bits (PREFIX_LUT_BITS - N) with every possible
|
||||
// combination of bits to act as a wildcard.
|
||||
const bits_to_fill = @intCast(u5, PREFIX_LUT_BITS - len);
|
||||
const rev_code = bitReverse(codes[len], len);
|
||||
// Track the last used code, but only for lengths < PREFIX_LUT_BITS
|
||||
const rev_code = bitReverse(u16, codes[len], len);
|
||||
|
||||
// Track the last used code, but only for lengths < PREFIX_LUT_BITS.
|
||||
codes[len] += 1;
|
||||
|
||||
var j: usize = 0;
|
||||
while (j < @as(usize, 1) << bits_to_fill) : (j += 1) {
|
||||
const index = rev_code | (j << @intCast(u5, len));
|
||||
assert(self.prefix_lut_len[index] == 0);
|
||||
self.prefix_lut[index] = @truncate(u16, symbol);
|
||||
self.prefix_lut_len[index] = @truncate(u16, len);
|
||||
assert(self.prefix_lut[index].len == 0);
|
||||
self.prefix_lut[index] = .{
|
||||
.symbol = @truncate(u16, symbol),
|
||||
.len = @truncate(u16, len),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -126,14 +131,10 @@ const Huffman = struct {
|
|||
}
|
||||
};
|
||||
|
||||
// Reverse bit-by-bit a N-bit value
|
||||
fn bitReverse(x: usize, N: usize) usize {
|
||||
var tmp: usize = 0;
|
||||
var i: usize = 0;
|
||||
while (i < N) : (i += 1) {
|
||||
tmp |= ((x >> @intCast(u5, i)) & 1) << @intCast(u5, N - i - 1);
|
||||
}
|
||||
return tmp;
|
||||
// Reverse bit-by-bit a N-bit code.
|
||||
fn bitReverse(comptime T: type, value: T, N: usize) T {
|
||||
const r = @bitReverse(T, value);
|
||||
return r >> @intCast(math.Log2Int(T), @typeInfo(T).Int.bits - N);
|
||||
}
|
||||
|
||||
pub fn InflateStream(comptime ReaderType: type) type {
|
||||
|
@ -269,8 +270,8 @@ pub fn InflateStream(comptime ReaderType: type) type {
|
|||
hdist: *Huffman,
|
||||
hlen: *Huffman,
|
||||
|
||||
// Temporary buffer for the bitstream, only bits 0..`bits_left` are
|
||||
// considered valid.
|
||||
// Temporary buffer for the bitstream.
|
||||
// Bits 0..`bits_left` are filled with data, the remaining ones are zeros.
|
||||
bits: u32,
|
||||
bits_left: usize,
|
||||
|
||||
|
@ -280,7 +281,8 @@ pub fn InflateStream(comptime ReaderType: type) type {
|
|||
self.bits |= @as(u32, byte) << @intCast(u5, self.bits_left);
|
||||
self.bits_left += 8;
|
||||
}
|
||||
return self.bits & ((@as(u32, 1) << @intCast(u5, bits)) - 1);
|
||||
const mask = (@as(u32, 1) << @intCast(u5, bits)) - 1;
|
||||
return self.bits & mask;
|
||||
}
|
||||
fn readBits(self: *Self, bits: usize) !u32 {
|
||||
const val = self.peekBits(bits);
|
||||
|
@ -293,8 +295,8 @@ pub fn InflateStream(comptime ReaderType: type) type {
|
|||
}
|
||||
|
||||
fn stored(self: *Self) !void {
|
||||
// Discard the remaining bits, the lenght field is always
|
||||
// byte-aligned (and so is the data)
|
||||
// Discard the remaining bits, the length field is always
|
||||
// byte-aligned (and so is the data).
|
||||
self.discardBits(self.bits_left);
|
||||
|
||||
const length = try self.inner_reader.readIntLittle(u16);
|
||||
|
@ -481,32 +483,52 @@ pub fn InflateStream(comptime ReaderType: type) type {
|
|||
}
|
||||
|
||||
fn decode(self: *Self, h: *Huffman) !u16 {
|
||||
// Fast path, read some bits and hope they're prefixes of some code
|
||||
const prefix = try self.peekBits(PREFIX_LUT_BITS);
|
||||
if (h.prefix_lut_len[prefix] != 0) {
|
||||
self.discardBits(h.prefix_lut_len[prefix]);
|
||||
return h.prefix_lut[prefix];
|
||||
// Using u32 instead of u16 to reduce the number of casts needed.
|
||||
var prefix: u32 = 0;
|
||||
|
||||
// Fast path, read some bits and hope they're the prefix of some code.
|
||||
// We can't read PREFIX_LUT_BITS as we don't want to read past the
|
||||
// deflate stream end, use an incremental approach instead.
|
||||
var code_len = h.min_code_len;
|
||||
while (true) {
|
||||
_ = try self.peekBits(code_len);
|
||||
// Small optimization win, use as many bits as possible in the
|
||||
// table lookup.
|
||||
prefix = self.bits & ((1 << PREFIX_LUT_BITS) - 1);
|
||||
|
||||
const lut_entry = &h.prefix_lut[prefix];
|
||||
// The code is longer than PREFIX_LUT_BITS!
|
||||
if (lut_entry.len == 0)
|
||||
break;
|
||||
// If the code lenght doesn't increase we found a match.
|
||||
if (lut_entry.len <= code_len) {
|
||||
self.discardBits(code_len);
|
||||
return lut_entry.symbol;
|
||||
}
|
||||
|
||||
code_len = lut_entry.len;
|
||||
}
|
||||
|
||||
// The sequence we've read is not a prefix of any code of length <=
|
||||
// PREFIX_LUT_BITS, keep decoding it using a slower method
|
||||
self.discardBits(PREFIX_LUT_BITS);
|
||||
// PREFIX_LUT_BITS, keep decoding it using a slower method.
|
||||
prefix = try self.readBits(PREFIX_LUT_BITS);
|
||||
|
||||
// Speed up the decoding by starting from the first code length
|
||||
// that's not covered by the table
|
||||
// that's not covered by the table.
|
||||
var len: usize = PREFIX_LUT_BITS + 1;
|
||||
var first: usize = h.last_code;
|
||||
var index: usize = h.last_index;
|
||||
|
||||
// Reverse the prefix so that the LSB becomes the MSB and make space
|
||||
// for the next bit
|
||||
var code = bitReverse(prefix, PREFIX_LUT_BITS + 1);
|
||||
// for the next bit.
|
||||
var code = bitReverse(u32, prefix, PREFIX_LUT_BITS + 1);
|
||||
|
||||
while (len <= MAXBITS) : (len += 1) {
|
||||
code |= try self.readBits(1);
|
||||
const count = h.count[len];
|
||||
if (code < first + count)
|
||||
if (code < first + count) {
|
||||
return h.symbol[index + (code - first)];
|
||||
}
|
||||
index += count;
|
||||
first += count;
|
||||
first <<= 1;
|
||||
|
@ -520,7 +542,7 @@ pub fn InflateStream(comptime ReaderType: type) type {
|
|||
while (true) {
|
||||
switch (self.state) {
|
||||
.DecodeBlockHeader => {
|
||||
// The compressed stream is done
|
||||
// The compressed stream is done.
|
||||
if (self.seen_eos) return;
|
||||
|
||||
const last = @intCast(u1, try self.readBits(1));
|
||||
|
@ -528,7 +550,7 @@ pub fn InflateStream(comptime ReaderType: type) type {
|
|||
|
||||
self.seen_eos = last != 0;
|
||||
|
||||
// The next state depends on the block type
|
||||
// The next state depends on the block type.
|
||||
switch (kind) {
|
||||
0 => try self.stored(),
|
||||
1 => try self.fixed(),
|
||||
|
@ -553,7 +575,7 @@ pub fn InflateStream(comptime ReaderType: type) type {
|
|||
var tmp: [1]u8 = undefined;
|
||||
if ((try self.inner_reader.read(&tmp)) != 1) {
|
||||
// Unexpected end of stream, keep this error
|
||||
// consistent with the use of readBitsNoEof
|
||||
// consistent with the use of readBitsNoEof.
|
||||
return error.EndOfStream;
|
||||
}
|
||||
self.window.appendUnsafe(tmp[0]);
|
||||
|
|
|
@ -144,6 +144,19 @@ test "compressed data" {
|
|||
);
|
||||
}
|
||||
|
||||
test "don't read past deflate stream's end" {
|
||||
try testReader(
|
||||
&[_]u8{
|
||||
0x08, 0xd7, 0x63, 0xf8, 0xcf, 0xc0, 0xc0, 0x00, 0xc1, 0xff,
|
||||
0xff, 0x43, 0x30, 0x03, 0x03, 0xc3, 0xff, 0xff, 0xff, 0x01,
|
||||
0x83, 0x95, 0x0b, 0xf5,
|
||||
},
|
||||
// SHA256 of
|
||||
// 00ff 0000 00ff 0000 00ff 00ff ffff 00ff ffff 0000 0000 ffff ff
|
||||
"3bbba1cc65408445c81abb61f3d2b86b1b60ee0d70b4c05b96d1499091a08c93",
|
||||
);
|
||||
}
|
||||
|
||||
test "sanity checks" {
|
||||
// Truncated header
|
||||
testing.expectError(
|
||||
|
|
|
@ -1141,4 +1141,3 @@ test "math.comptime" {
|
|||
comptime const v = sin(@as(f32, 1)) + ln(@as(f32, 5));
|
||||
testing.expect(v == sin(@as(f32, 1)) + ln(@as(f32, 5)));
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue