From 18ce8b54ddeb7cd80de8978d7fb0b66b966089d7 Mon Sep 17 00:00:00 2001 From: Sean Purcell Date: Wed, 1 Feb 2017 17:05:45 -0800 Subject: [PATCH] Switch IO to go through streams --- contrib/educational_decoder/harness.c | 15 +- contrib/educational_decoder/zstd_decompress.c | 1263 ++++++++--------- 2 files changed, 604 insertions(+), 674 deletions(-) diff --git a/contrib/educational_decoder/harness.c b/contrib/educational_decoder/harness.c index cff8239d..683278df 100644 --- a/contrib/educational_decoder/harness.c +++ b/contrib/educational_decoder/harness.c @@ -18,6 +18,9 @@ typedef unsigned char u8; // compression ratio is at most 16 #define MAX_COMPRESSION_RATIO (16) +// Protect against allocating too much memory for output +#define MAX_OUTPUT_SIZE ((size_t)1024 * 1024 * 1024) + u8 *input; u8 *output; u8 *dict; @@ -86,11 +89,17 @@ int main(int argc, char **argv) { size_t decompressed_size = ZSTD_get_decompressed_size(input, input_size); if (decompressed_size == -1) { decompressed_size = MAX_COMPRESSION_RATIO * input_size; - fprintf(stderr, "WARNING: Compressed data does contain decompressed " - "size, going to assume the compression ratio is at " - "most %d (decompressed size of at most %zu)\n", + fprintf(stderr, "WARNING: Compressed data does not contain " + "decompressed size, going to assume the compression " + "ratio is at most %d (decompressed size of at most " + "%zu)\n", MAX_COMPRESSION_RATIO, decompressed_size); } + if (decompressed_size > MAX_OUTPUT_SIZE) { + fprintf(stderr, + "Required output size too large for this implementation\n"); + return 1; + } output = malloc(decompressed_size); if (!output) { fprintf(stderr, "failed to allocate memory\n"); diff --git a/contrib/educational_decoder/zstd_decompress.c b/contrib/educational_decoder/zstd_decompress.c index e2fbcf2c..8f28313e 100644 --- a/contrib/educational_decoder/zstd_decompress.c +++ b/contrib/educational_decoder/zstd_decompress.c @@ -48,6 +48,7 @@ size_t ZSTD_get_decompressed_size(const void *const src, const size_t src_len); #define OUT_SIZE() ERROR("Output buffer too small for output") #define CORRUPTION() ERROR("Corruption detected while decompressing") #define BAD_ALLOC() ERROR("Memory allocation error") +#define IMPOSSIBLE() ERROR("An impossibility has occurred") typedef uint8_t u8; typedef uint16_t u16; @@ -65,6 +66,62 @@ typedef int64_t i64; /// file. They implement low-level functionality needed for the higher level /// decompression functions. +/*** IO STREAM OPERATIONS *************/ +/// These structs are the interface for IO, and do bounds checking on all +/// operations. They should be used opaquely to ensure safety. + +/// Output is always done byte-by-byte +typedef struct { + u8 *ptr; + size_t len; +} ostream_t; + +/// Input often reads a few bits at a time, so maintain an internal offset +typedef struct { + const u8 *ptr; + int bit_offset; + size_t len; +} istream_t; + +/// The following two functions are the only ones that allow the istream to be +/// non-byte aligned + +/// Reads `num` bits from a bitstream, and updates the internal offset +static inline u64 IO_read_bits(istream_t *const in, const int num); +/// Rewinds the stream by `num` bits +static inline void IO_rewind_bits(istream_t *const in, const int num); +/// If the remaining bits in a byte will be unused, advance to the end of the +/// byte +static inline void IO_align_stream(istream_t *const in); + +/// Write the given byte into the output stream +static inline void IO_write_byte(ostream_t *const out, u8 symb); + +/// Returns the number of bytes left to be read in this stream. The stream must +/// be byte aligned. +static inline size_t IO_istream_len(const istream_t *const in); + +/// Returns a pointer where `len` bytes can be read, and advances the internal +/// state. The stream must be byte aligned. +static inline const u8 *IO_read_bytes(istream_t *const in, size_t len); +/// Returns a pointer where `len` bytes can be written, and advances the internal +/// state. The stream must be byte aligned. +static inline u8 *IO_write_bytes(ostream_t *const out, size_t len); + +/// Advance the inner state by `len` bytes. The stream must be byte aligned. +static inline void IO_advance_input(istream_t *const in, size_t len); + +/// Returns an `ostream_t` constructed from the given pointer and length +static inline ostream_t IO_make_ostream(u8 *out, size_t len); +/// Returns an `istream_t` constructed from the given pointer and length +static inline istream_t IO_make_istream(const u8 *in, size_t len); + +/// Returns an `istream_t` with the same base as `in`, and length `len` +/// Then, advance `in` to account for the consumed bytes +/// `in` must be byte aligned +static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len); +/*** END IO STREAM OPERATIONS *********/ + /*** BITSTREAM OPERATIONS *************/ /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits static inline u64 read_bits_LE(const u8 *src, const int num, @@ -109,15 +166,13 @@ static inline void HUF_init_state(const HUF_dtable *const dtable, /// Decompresses a single Huffman stream, returns the number of bytes decoded. /// `src_len` must be the exact length of the Huffman-coded block. -static size_t HUF_decompress_1stream(const HUF_dtable *const dtable, u8 *dst, - const size_t dst_len, const u8 *src, - size_t src_len); +static size_t HUF_decompress_1stream(const HUF_dtable *const dtable, + ostream_t *const out, istream_t *const in); /// Same as previous but decodes 4 streams, formatted as in the Zstandard /// specification. /// `src_len` must be the exact length of the Huffman-coded block. -static size_t HUF_decompress_4stream(const HUF_dtable *const dtable, u8 *dst, - const size_t dst_len, const u8 *const src, - const size_t src_len); +static size_t HUF_decompress_4stream(const HUF_dtable *const dtable, + ostream_t *const out, istream_t *const in); /// Initialize a Huffman decoding table using the table of bit counts provided static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits, @@ -176,9 +231,8 @@ static inline void FSE_init_state(const FSE_dtable *const dtable, /// using an FSE decoding table. `src_len` must be the exact length of the /// block. static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable, - u8 *dst, const size_t dst_len, - const u8 *const src, - const size_t src_len); + ostream_t *const out, + istream_t *const in); /// Initialize a decoding table using normalized frequencies. static void FSE_init_dtable(FSE_dtable *const dtable, @@ -187,8 +241,7 @@ static void FSE_init_dtable(FSE_dtable *const dtable, /// Decode an FSE header as defined in the Zstandard format specification and /// use the decoded frequencies to initialize a decoding table. -static size_t FSE_decode_header(FSE_dtable *const dtable, const u8 *const src, - const size_t src_len, +static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in, const int max_accuracy_log); /// Initialize an FSE table that will always return the same symbol and consume @@ -207,16 +260,6 @@ static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src); /******* ZSTD HELPER STRUCTS AND PROTOTYPES ***********************************/ -/// Input and output pointers to allow them to be advanced by -/// functions that consume input/produce output -typedef struct { - u8 *dst; - size_t dst_len; - - const u8 *src; - size_t src_len; -} io_streams_t; - /// A small structure that can be reused in various places that need to access /// frame header information typedef struct { @@ -233,9 +276,6 @@ typedef struct { int content_checksum_flag; // Whether or not the output for this frame is in a single segment int single_segment_flag; - - // The size in bytes of this header - int header_size; } frame_header_t; /// The context needed to decode blocks in a frame @@ -256,9 +296,8 @@ typedef struct { FSE_dtable ml_dtable; FSE_dtable of_dtable; - // The last 3 offsets for the special "repeat offsets". Array size is 4 so - // that previous_offsets[1] corresponds to the most recent offset - u64 previous_offsets[4]; + // The last 3 offsets for the special "repeat offsets". + u64 previous_offsets[3]; } frame_context_t; /// The decoded contents of a dictionary so that it doesn't have to be repeated @@ -275,7 +314,7 @@ typedef struct { size_t content_size; // Offset history to prepopulate the frame's history - u64 previous_offsets[4]; + u64 previous_offsets[3]; u32 dictionary_id; } dictionary_t; @@ -301,34 +340,31 @@ typedef struct { /// Accepts a dict argument, which may be NULL indicating no dictionary. /// See /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame-concatenation -static void decode_frame(io_streams_t *const streams, +static void decode_frame(ostream_t *const out, istream_t *const in, const dictionary_t *const dict); // Decode data in a compressed block -static void decompress_block(io_streams_t *const streams, - frame_context_t *const ctx, - const size_t block_len); +static void decompress_block(frame_context_t *const ctx, ostream_t *const out, + istream_t *const in); // Decode the literals section of a block -static size_t decode_literals(io_streams_t *const streams, - frame_context_t *const ctx, u8 **const literals); +static size_t decode_literals(frame_context_t *const ctx, istream_t *const in, + u8 **const literals); // Decode the sequences part of a block -static size_t decode_sequences(frame_context_t *const ctx, const u8 *const src, - const size_t src_len, +static size_t decode_sequences(frame_context_t *const ctx, istream_t *const in, sequence_command_t **const sequences); // Execute the decoded sequences on the literals block -static void execute_sequences(io_streams_t *const streams, - frame_context_t *const ctx, +static void execute_sequences(frame_context_t *const ctx, ostream_t *const out, + const u8 *const literals, + const size_t literals_len, const sequence_command_t *const sequences, - const size_t num_sequences, - const u8 *literals, - size_t literals_len); + const size_t num_sequences); // Parse a provided dictionary blob for use in decompression -static void parse_dictionary(dictionary_t *const dict, const u8 *const src, - const size_t src_len); +static void parse_dictionary(dictionary_t *const dict, const u8 *src, + size_t src_len); static void free_dictionary(dictionary_t *const dict); /******* END ZSTD HELPER STRUCTS AND PROTOTYPES *******************************/ @@ -348,58 +384,46 @@ size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len, parse_dictionary(&parsed_dict, (const u8 *)dict, dict_len); } - io_streams_t streams = {(u8 *)dst, dst_len, (const u8 *)src, src_len}; - while (streams.src_len > 0) { - decode_frame(&streams, &parsed_dict); + istream_t in = {(const u8 *)src, 0, src_len}; + ostream_t out = {(u8 *)dst, dst_len}; + while (IO_istream_len(&in) > 0) { + decode_frame(&out, &in, &parsed_dict); } free_dictionary(&parsed_dict); - return streams.dst - (u8 *)dst; + return out.ptr - (u8 *)dst; } /******* FRAME DECODING ******************************************************/ -static void decode_data_frame(io_streams_t *const streams, +static void decode_data_frame(ostream_t *const out, istream_t *const in, const dictionary_t *const dict); -static void init_frame_context(io_streams_t *const streams, - frame_context_t *const context, +static void init_frame_context(frame_context_t *const context, + istream_t *const in, const dictionary_t *const dict); static void free_frame_context(frame_context_t *const context); static void parse_frame_header(frame_header_t *const header, - const u8 *const src, const size_t src_len); + istream_t *const in); static void frame_context_apply_dict(frame_context_t *const ctx, const dictionary_t *const dict); -static void decompress_data(io_streams_t *const streams, - frame_context_t *const ctx); +static void decompress_data(frame_context_t *const ctx, ostream_t *const out, + istream_t *const in); -static void decode_frame(io_streams_t *const streams, +static void decode_frame(ostream_t *const out, istream_t *const in, const dictionary_t *const dict) { - if (streams->src_len < 4) { - INP_SIZE(); - } - const u32 magic_number = read_bits_LE(streams->src, 32, 0); + const u32 magic_number = IO_read_bits(in, 32); - streams->src += 4; - streams->src_len -= 4; - if (magic_number >= 0x184D2A50U && magic_number <= 0x184D2A5F) { - // skippable frame - if (streams->src_len < 4) { - INP_SIZE(); - } - const size_t frame_size = read_bits_LE(streams->src, 32, 32); - - if (streams->src_len < 4 + frame_size) { - INP_SIZE(); - } + if ((magic_number & ~0xFU) == 0x184D2A50U) { + // Skippable frame + const size_t frame_size = IO_read_bits(in, 32); // skip over frame - streams->src += 4 + frame_size; - streams->src_len -= 4 + frame_size; + IO_advance_input(in, frame_size); } else if (magic_number == 0xFD2FB528U) { // ZSTD frame - decode_data_frame(streams, dict); + decode_data_frame(out, in, dict); } else { // not a real frame ERROR("Invalid magic number"); @@ -410,40 +434,38 @@ static void decode_frame(io_streams_t *const streams, /// are skippable frames. /// See /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#general-structure-of-zstandard-frame-format -static void decode_data_frame(io_streams_t *const streams, +static void decode_data_frame(ostream_t *const out, istream_t *const in, const dictionary_t *const dict) { frame_context_t ctx; // Initialize the context that needs to be carried from block to block - init_frame_context(streams, &ctx, dict); + init_frame_context(&ctx, in, dict); if (ctx.header.frame_content_size != 0 && - ctx.header.frame_content_size > streams->dst_len) { + ctx.header.frame_content_size > out->len) { OUT_SIZE(); } - decompress_data(streams, &ctx); + decompress_data(&ctx, out, in); free_frame_context(&ctx); } /// Takes the information provided in the header and dictionary, and initializes /// the context for this frame -static void init_frame_context(io_streams_t *const streams, - frame_context_t *const context, +static void init_frame_context(frame_context_t *const context, + istream_t *const in, const dictionary_t *const dict) { // Most fields in context are correct when initialized to 0 - memset(context, 0x00, sizeof(frame_context_t)); + memset(context, 0, sizeof(frame_context_t)); // Parse data from the frame header - parse_frame_header(&context->header, streams->src, streams->src_len); - streams->src += context->header.header_size; - streams->src_len -= context->header.header_size; + parse_frame_header(&context->header, in); // Set up the offset history for the repeat offset commands - context->previous_offsets[1] = 1; - context->previous_offsets[2] = 4; - context->previous_offsets[3] = 8; + context->previous_offsets[0] = 1; + context->previous_offsets[1] = 4; + context->previous_offsets[2] = 8; // Apply details from the dict if it exists frame_context_apply_dict(context, dict); @@ -460,12 +482,8 @@ static void free_frame_context(frame_context_t *const context) { } static void parse_frame_header(frame_header_t *const header, - const u8 *const src, const size_t src_len) { - if (src_len < 1) { - INP_SIZE(); - } - - const u8 descriptor = read_bits_LE(src, 8, 0); + istream_t *const in) { + const u8 descriptor = IO_read_bits(in, 8); // decode frame header descriptor into flags const u8 frame_content_size_flag = descriptor >> 6; @@ -478,28 +496,20 @@ static void parse_frame_header(frame_header_t *const header, CORRUPTION(); } - int header_size = 1; - header->single_segment_flag = single_segment_flag; header->content_checksum_flag = content_checksum_flag; // decode window size if (!single_segment_flag) { - if (src_len < header_size + 1) { - INP_SIZE(); - } - // Use the algorithm from the specification to compute window size // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor - u8 window_descriptor = src[header_size]; + u8 window_descriptor = IO_read_bits(in, 8); u8 exponent = window_descriptor >> 3; u8 mantissa = window_descriptor & 7; size_t window_base = (size_t)1 << (10 + exponent); size_t window_add = (window_base / 8) * mantissa; header->window_size = window_base + window_add; - - header_size++; } // decode dictionary id if it exists @@ -507,13 +517,7 @@ static void parse_frame_header(frame_header_t *const header, const int bytes_array[] = {0, 1, 2, 4}; const int bytes = bytes_array[dictionary_id_flag]; - if (src_len < header_size + bytes) { - INP_SIZE(); - } - - header->dictionary_id = read_bits_LE(src + header_size, bytes * 8, 0); - - header_size += bytes; + header->dictionary_id = IO_read_bits(in, bytes * 8); } else { header->dictionary_id = 0; } @@ -525,17 +529,10 @@ static void parse_frame_header(frame_header_t *const header, const int bytes_array[] = {1, 2, 4, 8}; const int bytes = bytes_array[frame_content_size_flag]; - if (src_len < header_size + bytes) { - INP_SIZE(); - } - - header->frame_content_size = - read_bits_LE(src + header_size, bytes * 8, 0); + header->frame_content_size = IO_read_bits(in, bytes * 8); if (bytes == 2) { header->frame_content_size += 256; } - - header_size += bytes; } else { header->frame_content_size = 0; } @@ -546,8 +543,6 @@ static void parse_frame_header(frame_header_t *const header, // back to the dictionary or not on large offsets header->window_size = header->frame_content_size; } - - header->header_size = header_size; } /// A dictionary acts as initializing values for the frame context before @@ -559,20 +554,15 @@ static void frame_context_apply_dict(frame_context_t *const ctx, if (!dict || !dict->content) return; - if (ctx->header.dictionary_id == 0 && dict->dictionary_id != 0) { - // The dictionary is unneeded, and shouldn't be used as it may interfere - // with the default offset history - return; - } - - // If the dictionary id is 0, it doesn't matter if we provide the wrong raw - // content dict, it won't change anything + // If the requested dictionary_id is non-zero, the correct dictionary must + // be present if (ctx->header.dictionary_id != 0 && ctx->header.dictionary_id != dict->dictionary_id) { - ERROR("Wrong/no dictionary provided"); + ERROR("Wrong dictionary provided"); } - // Copy the pointer in so we can reference it in sequence execution + // Copy the dict content to the context for references during sequence + // execution ctx->dict_content = dict->content; ctx->dict_content_len = dict->content_size; @@ -592,188 +582,137 @@ static void frame_context_apply_dict(frame_context_t *const ctx, } /// Decompress the data from a frame block by block -static void decompress_data(io_streams_t *const streams, - frame_context_t *const ctx) { +static void decompress_data(frame_context_t *const ctx, ostream_t *const out, + istream_t *const in) { int last_block = 0; do { - if (streams->src_len < 3) { - INP_SIZE(); - } // Parse the block header - last_block = streams->src[0] & 1; - const int block_type = (streams->src[0] >> 1) & 3; - const size_t block_len = read_bits_LE(streams->src, 21, 3); - - streams->src += 3; - streams->src_len -= 3; + last_block = IO_read_bits(in, 1); + const int block_type = IO_read_bits(in, 2); + const size_t block_len = IO_read_bits(in, 21); switch (block_type) { case 0: { // Raw, uncompressed block - if (streams->src_len < block_len) { - INP_SIZE(); - } - if (streams->dst_len < block_len) { - OUT_SIZE(); - } - + const u8 *const read_ptr = IO_read_bytes(in, block_len); + u8 *const write_ptr = IO_write_bytes(out, block_len); + // // Copy the raw data into the output - memcpy(streams->dst, streams->src, block_len); - - streams->src += block_len; - streams->src_len -= block_len; - - streams->dst += block_len; - streams->dst_len -= block_len; + memcpy(write_ptr, read_ptr, block_len); ctx->current_total_output += block_len; break; } case 1: { // RLE block, repeat the first byte N times - if (streams->src_len < 1) { - INP_SIZE(); - } - if (streams->dst_len < block_len) { - OUT_SIZE(); - } + const u8 *const read_ptr = IO_read_bytes(in, 1); + u8 *const write_ptr = IO_write_bytes(out, block_len); // Copy `block_len` copies of `streams->src[0]` to the output - memset(streams->dst, streams->src[0], block_len); - - streams->dst += block_len; - streams->dst_len -= block_len; - - streams->src += 1; - streams->src_len -= 1; + memset(write_ptr, read_ptr[0], block_len); ctx->current_total_output += block_len; break; } - case 2: - // Compressed block, this is mode complex - decompress_block(streams, ctx, block_len); + case 2: { + // Compressed block + // Create a sub-stream for the block + istream_t block_stream = IO_make_sub_istream(in, block_len); + decompress_block(ctx, out, &block_stream); break; + } case 3: // Reserved block type CORRUPTION(); break; + default: + IMPOSSIBLE(); } } while (!last_block); if (ctx->header.content_checksum_flag) { // This program does not support checking the checksum, so skip over it // if it's present - if (streams->src_len < 4) { - INP_SIZE(); - } - streams->src += 4; - streams->src_len -= 4; + IO_advance_input(in, 4); } } /******* END FRAME DECODING ***************************************************/ /******* BLOCK DECOMPRESSION **************************************************/ -static void decompress_block(io_streams_t *const streams, frame_context_t *const ctx, - const size_t block_len) { - if (streams->src_len < block_len) { - INP_SIZE(); - } - // We need this to determine how long the compressed literals block was - const u8 *const end_of_block = streams->src + block_len; - +static void decompress_block(frame_context_t *const ctx, ostream_t *const out, + istream_t *const in) { // Part 1: decode the literals block u8 *literals = NULL; - const size_t literals_size = decode_literals(streams, ctx, &literals); + const size_t literals_size = decode_literals(ctx, in, &literals); // Part 2: decode the sequences block - if (streams->src > end_of_block) { - INP_SIZE(); - } - const size_t sequences_size = end_of_block - streams->src; sequence_command_t *sequences = NULL; const size_t num_sequences = - decode_sequences(ctx, streams->src, sequences_size, &sequences); - - streams->src += sequences_size; - streams->src_len -= sequences_size; + decode_sequences(ctx, in, &sequences); // Part 3: combine literals and sequence commands to generate output - execute_sequences(streams, ctx, sequences, num_sequences, literals, - literals_size); + execute_sequences(ctx, out, literals, literals_size, sequences, + num_sequences); free(literals); free(sequences); } /******* END BLOCK DECOMPRESSION **********************************************/ /******* LITERALS DECODING ****************************************************/ -static size_t decode_literals_simple(io_streams_t *const streams, - u8 **const literals, const int block_type, +static size_t decode_literals_simple(istream_t *const in, u8 **const literals, + const int block_type, const int size_format); -static size_t decode_literals_compressed(io_streams_t *const streams, - frame_context_t *const ctx, +static size_t decode_literals_compressed(frame_context_t *const ctx, + istream_t *const in, u8 **const literals, const int block_type, const int size_format); -static size_t decode_huf_table(const u8 *src, size_t src_len, - HUF_dtable *const dtable); -static size_t fse_decode_hufweights(const u8 *const src, const size_t src_len, - u8 *const weights, int *const num_symbs, - const size_t compressed_size); +static void decode_huf_table(istream_t *const in, HUF_dtable *const dtable); +static void fse_decode_hufweights(ostream_t *weights, istream_t *const in, + int *const num_symbs); -static size_t decode_literals(io_streams_t *const streams, - frame_context_t *const ctx, u8 **const literals) { - if (streams->src_len < 1) { - INP_SIZE(); - } +static size_t decode_literals(frame_context_t *const ctx, istream_t *const in, + u8 **const literals) { // Decode literals header - int block_type = streams->src[0] & 3; - int size_format = (streams->src[0] >> 2) & 3; + int block_type = IO_read_bits(in, 2); + int size_format = IO_read_bits(in, 2); if (block_type <= 1) { // Raw or RLE literals block - return decode_literals_simple(streams, literals, block_type, + return decode_literals_simple(in, literals, block_type, size_format); } else { // Huffman compressed literals - return decode_literals_compressed(streams, ctx, literals, block_type, + return decode_literals_compressed(ctx, in, literals, block_type, size_format); } } /// Decodes literals blocks in raw or RLE form -static size_t decode_literals_simple(io_streams_t *const streams, - u8 **const literals, const int block_type, +static size_t decode_literals_simple(istream_t *const in, u8 **const literals, + const int block_type, const int size_format) { size_t size; switch (size_format) { - // These cases are in the form X0 - // In this case, the X bit is actually part of the size field + // These cases are in the form ?0 + // In this case, the ? bit is actually part of the size field case 0: case 2: - size = read_bits_LE(streams->src, 5, 3); - streams->src += 1; - streams->src_len -= 1; + // "Size_Format uses 1 bit. Regenerated_Size uses 5 bits (0-31)." + IO_rewind_bits(in, 1); + size = IO_read_bits(in, 2); break; case 1: - if (streams->src_len < 2) { - INP_SIZE(); - } - size = read_bits_LE(streams->src, 12, 4); - streams->src += 2; - streams->src_len -= 2; + // "Size_Format uses 2 bits. Regenerated_Size uses 12 bits (0-4095)." + size = IO_read_bits(in, 12); break; case 3: - if (streams->src_len < 2) { - INP_SIZE(); - } - size = read_bits_LE(streams->src, 20, 4); - streams->src += 3; - streams->src_len -= 3; + // "Size_Format uses 2 bits. Regenerated_Size uses 20 bits (0-1048575)." + size = IO_read_bits(in, 20); break; default: - // Impossible - size = -1; + // Size format is in range 0-3 + IMPOSSIBLE(); } if (size > MAX_LITERALS_SIZE) { @@ -786,32 +725,28 @@ static size_t decode_literals_simple(io_streams_t *const streams, } switch (block_type) { - case 0: + case 0: { // Raw data - if (size > streams->src_len) { - INP_SIZE(); - } - memcpy(*literals, streams->src, size); - streams->src += size; - streams->src_len -= size; + const u8 *const read_ptr = IO_read_bytes(in, size); + memcpy(*literals, read_ptr, size); break; - case 1: + } + case 1: { // Single repeated byte - if (1 > streams->src_len) { - INP_SIZE(); - } - memset(*literals, streams->src[0], size); - streams->src += 1; - streams->src_len -= 1; + const u8 *const read_ptr = IO_read_bytes(in, 1); + memset(*literals, read_ptr[0], size); break; } + default: + IMPOSSIBLE(); + } return size; } /// Decodes Huffman compressed literals -static size_t decode_literals_compressed(io_streams_t *const streams, - frame_context_t *const ctx, +static size_t decode_literals_compressed(frame_context_t *const ctx, + istream_t *const in, u8 **const literals, const int block_type, const int size_format) { @@ -820,98 +755,78 @@ static size_t decode_literals_compressed(io_streams_t *const streams, int num_streams = 4; switch (size_format) { case 0: + // "A single stream. Both Compressed_Size and Regenerated_Size use 10 + // bits (0-1023)." num_streams = 1; // Fall through as it has the same size format case 1: - if (streams->src_len < 3) { - INP_SIZE(); - } - regenerated_size = read_bits_LE(streams->src, 10, 4); - compressed_size = read_bits_LE(streams->src, 10, 14); - streams->src += 3; - streams->src_len -= 3; + // "4 streams. Both Compressed_Size and Regenerated_Size use 10 bits + // (0-1023)." + regenerated_size = IO_read_bits(in, 10); + compressed_size = IO_read_bits(in, 10); break; case 2: - if (streams->src_len < 4) { - INP_SIZE(); - } - regenerated_size = read_bits_LE(streams->src, 14, 4); - compressed_size = read_bits_LE(streams->src, 14, 18); - streams->src += 4; - streams->src_len -= 4; + // "4 streams. Both Compressed_Size and Regenerated_Size use 14 bits + // (0-16383)." + regenerated_size = IO_read_bits(in, 14); + compressed_size = IO_read_bits(in, 14); break; case 3: - if (streams->src_len < 5) { - INP_SIZE(); - } - regenerated_size = read_bits_LE(streams->src, 18, 4); - compressed_size = read_bits_LE(streams->src, 18, 22); - streams->src += 5; - streams->src_len -= 5; + // "4 streams. Both Compressed_Size and Regenerated_Size use 18 bits + // (0-262143)." + regenerated_size = IO_read_bits(in, 18); + compressed_size = IO_read_bits(in, 18); break; default: // Impossible - compressed_size = regenerated_size = -1; + IMPOSSIBLE(); } if (regenerated_size > MAX_LITERALS_SIZE || compressed_size > regenerated_size) { CORRUPTION(); } - if (compressed_size > streams->src_len) { - INP_SIZE(); - } - *literals = malloc(regenerated_size); if (!*literals) { BAD_ALLOC(); } + ostream_t lit_stream = IO_make_ostream(*literals, regenerated_size); + istream_t huf_stream = IO_make_sub_istream(in, compressed_size); + if (block_type == 2) { // Decode provided Huffman table HUF_free_dtable(&ctx->literals_dtable); - const size_t size = decode_huf_table(streams->src, compressed_size, - &ctx->literals_dtable); - streams->src += size; - streams->src_len -= size; - compressed_size -= size; + decode_huf_table(&huf_stream, &ctx->literals_dtable); } else { - // If we're to repeat the previous Huffman table, make sure it exists + // If the previous Huffman table is being repeated, ensure it exists if (!ctx->literals_dtable.symbols) { CORRUPTION(); } } + size_t symbols_decoded; if (num_streams == 1) { - HUF_decompress_1stream(&ctx->literals_dtable, *literals, - regenerated_size, streams->src, compressed_size); + symbols_decoded = HUF_decompress_1stream(&ctx->literals_dtable, &lit_stream, &huf_stream); } else { - HUF_decompress_4stream(&ctx->literals_dtable, *literals, - regenerated_size, streams->src, compressed_size); + symbols_decoded = HUF_decompress_4stream(&ctx->literals_dtable, &lit_stream, &huf_stream); + } + + if (symbols_decoded != regenerated_size) { + CORRUPTION(); } - streams->src += compressed_size; - streams->src_len -= compressed_size; return regenerated_size; } // Decode the Huffman table description -static size_t decode_huf_table(const u8 *src, size_t src_len, - HUF_dtable *const dtable) { - if (src_len < 1) { - INP_SIZE(); - } +static void decode_huf_table(istream_t *const in, HUF_dtable *const dtable) { + const u8 header = IO_read_bits(in, 8); - const u8 *const osrc = src; - - const u8 header = src[0]; u8 weights[HUF_MAX_SYMBS]; memset(weights, 0, sizeof(weights)); - src++; - src_len--; - int num_symbs; if (header >= 128) { @@ -919,67 +834,56 @@ static size_t decode_huf_table(const u8 *src, size_t src_len, num_symbs = header - 127; const size_t bytes = (num_symbs + 1) / 2; - if (bytes > src_len) { - INP_SIZE(); - } + const u8 *const weight_src = IO_read_bytes(in, bytes); for (int i = 0; i < num_symbs; i++) { // read_bits_LE isn't applicable here because the weights are order // reversed within each byte // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#huffman-tree-header if (i % 2 == 0) { - weights[i] = src[i / 2] >> 4; + weights[i] = weight_src[i / 2] >> 4; } else { - weights[i] = src[i / 2] & 0xf; + weights[i] = weight_src[i / 2] & 0xf; } } - - src += bytes; - src_len -= bytes; } else { // The weights are FSE encoded, decode them before we can construct the // table - const size_t size = - fse_decode_hufweights(src, src_len, weights, &num_symbs, header); - src += size; - src_len -= size; + istream_t fse_stream = IO_make_sub_istream(in, header); + ostream_t weight_stream = IO_make_ostream(weights, HUF_MAX_SYMBS); + fse_decode_hufweights(&weight_stream, &fse_stream, &num_symbs); } // Construct the table using the decoded weights HUF_init_dtable_usingweights(dtable, weights, num_symbs); - return src - osrc; } -static size_t fse_decode_hufweights(const u8 *const src, const size_t src_len, - u8 *const weights, int *const num_symbs, - const size_t compressed_size) { +static void fse_decode_hufweights(ostream_t *weights, istream_t *const in, + int *const num_symbs) { const int MAX_ACCURACY_LOG = 7; FSE_dtable dtable; // Construct the FSE table - const size_t read = - FSE_decode_header(&dtable, src, src_len, MAX_ACCURACY_LOG); - - if (src_len < compressed_size) { - INP_SIZE(); - } + FSE_decode_header(&dtable, in, MAX_ACCURACY_LOG); // Decode the weights - *num_symbs = FSE_decompress_interleaved2( - &dtable, weights, HUF_MAX_SYMBS, src + read, compressed_size - read); + *num_symbs = FSE_decompress_interleaved2(&dtable, weights, in); FSE_free_dtable(&dtable); - - return compressed_size; } /******* END LITERALS DECODING ************************************************/ /******* SEQUENCE DECODING ****************************************************/ /// The combination of FSE states needed to decode sequences typedef struct { - u16 ll_state, of_state, ml_state; - FSE_dtable ll_table, of_table, ml_table; + FSE_dtable ll_table; + FSE_dtable of_table; + FSE_dtable ml_table; + + u16 ll_state; + u16 of_state; + u16 ml_state; } sequence_state_t; /// Different modes to signal to decode_seq_tables what to do @@ -1031,47 +935,36 @@ static const u8 SEQ_MATCH_LENGTH_EXTRA_BITS[53] = { /// Offset decoding is simpler so we just need a maximum code value static const u8 SEQ_MAX_CODES[3] = {35, -1, 52}; -static void decompress_sequences(frame_context_t *const ctx, const u8 *src, - size_t src_len, +static void decompress_sequences(frame_context_t *const ctx, + istream_t *const in, sequence_command_t *const sequences, const size_t num_sequences); static sequence_command_t decode_sequence(sequence_state_t *const state, const u8 *const src, i64 *const offset); -static size_t decode_seq_table(const u8 *src, size_t src_len, - FSE_dtable *const table, const seq_part_t type, - const seq_mode_t mode); +static void decode_seq_table(istream_t *const in, FSE_dtable *const table, + const seq_part_t type, const seq_mode_t mode); -static size_t decode_sequences(frame_context_t *const ctx, const u8 *src, - size_t src_len, +static size_t decode_sequences(frame_context_t *const ctx, istream_t *in, sequence_command_t **const sequences) { size_t num_sequences; // Decode the sequence header and allocate space for the output - if (src_len < 1) { - INP_SIZE(); - } - if (src[0] == 0) { + u8 header = IO_read_bits(in, 8); + if (header == 0) { + // "There are no sequences. The sequence section stops there. + // Regenerated content is defined entirely by literals section." *sequences = NULL; return 0; - } else if (src[0] < 128) { - num_sequences = src[0]; - src++; - src_len--; - } else if (src[0] < 255) { - if (src_len < 2) { - INP_SIZE(); - } - num_sequences = ((src[0] - 128) << 8) + src[1]; - src += 2; - src_len -= 2; + } else if (header < 128) { + // "Number_of_Sequences = byte0 . Uses 1 byte." + num_sequences = header; + } else if (header < 255) { + // "Number_of_Sequences = ((byte0-128) << 8) + byte1 . Uses 2 bytes." + num_sequences = ((header - 128) << 8) + IO_read_bits(in, 8); } else { - if (src_len < 3) { - INP_SIZE(); - } - num_sequences = src[1] + ((u64)src[2] << 8) + 0x7F00; - src += 3; - src_len -= 3; + // "Number_of_Sequences = byte1 + (byte2<<8) + 0x7F00 . Uses 3 bytes." + num_sequences = IO_read_bits(in, 16) + 0x7F00; } *sequences = malloc(num_sequences * sizeof(sequence_command_t)); @@ -1079,51 +972,29 @@ static size_t decode_sequences(frame_context_t *const ctx, const u8 *src, BAD_ALLOC(); } - decompress_sequences(ctx, src, src_len, *sequences, num_sequences); + decompress_sequences(ctx, in, *sequences, num_sequences); return num_sequences; } /// Decompress the FSE encoded sequence commands -static void decompress_sequences(frame_context_t *const ctx, const u8 *src, - size_t src_len, +static void decompress_sequences(frame_context_t *const ctx, istream_t *in, sequence_command_t *const sequences, const size_t num_sequences) { - if (src_len < 1) { - INP_SIZE(); - } - u8 compression_modes = src[0]; - src++; - src_len--; + u8 compression_modes = IO_read_bits(in, 8); if ((compression_modes & 3) != 0) { CORRUPTION(); } - { - size_t read; - // Update the tables we have stored in the context - read = decode_seq_table(src, src_len, &ctx->ll_dtable, - seq_literal_length, - (compression_modes >> 6) & 3); - src += read; - src_len -= read; - } + // Update the tables we have stored in the context + decode_seq_table(in, &ctx->ll_dtable, seq_literal_length, + (compression_modes >> 6) & 3); - { - const size_t read = - decode_seq_table(src, src_len, &ctx->of_dtable, seq_offset, - (compression_modes >> 4) & 3); - src += read; - src_len -= read; - } + decode_seq_table(in, &ctx->of_dtable, seq_offset, + (compression_modes >> 4) & 3); - { - const size_t read = decode_seq_table(src, src_len, &ctx->ml_dtable, - seq_match_length, - (compression_modes >> 2) & 3); - src += read; - src_len -= read; - } + decode_seq_table(in, &ctx->ml_dtable, seq_match_length, + (compression_modes >> 2) & 3); // Check to make sure none of the tables are uninitialized if (!ctx->ll_dtable.symbols || !ctx->of_dtable.symbols || @@ -1137,8 +1008,13 @@ static void decompress_sequences(frame_context_t *const ctx, const u8 *src, memcpy(&state.of_table, &ctx->of_dtable, sizeof(FSE_dtable)); memcpy(&state.ml_table, &ctx->ml_dtable, sizeof(FSE_dtable)); - const int padding = 8 - log2inf(src[src_len - 1]); - i64 offset = src_len * 8 - padding; + size_t len = IO_istream_len(in); + const u8 *const src = IO_read_bytes(in, len); + + // "After writing the last bit containing information, the compressor writes + // a single 1-bit and then fills the byte with 0-7 0 bits of padding." + const int padding = 8 - log2inf(src[len - 1]); + i64 offset = len * 8 - padding; FSE_init_state(&state.ll_table, &state.ll_state, src, &offset); FSE_init_state(&state.of_table, &state.of_state, src, &offset); @@ -1153,7 +1029,7 @@ static void decompress_sequences(frame_context_t *const ctx, const u8 *src, CORRUPTION(); } - // Don't free our tables so they can be used in the next block + // Don't free tables so they can be used in the next block } // Decode a single sequence and update the state @@ -1194,9 +1070,8 @@ static sequence_command_t decode_sequence(sequence_state_t *const state, } /// Given a sequence part and table mode, decode the FSE distribution -static size_t decode_seq_table(const u8 *src, size_t src_len, - FSE_dtable *const table, const seq_part_t type, - const seq_mode_t mode) { +static void decode_seq_table(istream_t *const in, FSE_dtable *const table, + const seq_part_t type, const seq_mode_t mode) { // Constant arrays indexed by seq_part_t const i16 *const default_distributions[] = {SEQ_LITERAL_LENGTH_DEFAULT_DIST, SEQ_OFFSET_DEFAULT_DIST, @@ -1207,7 +1082,7 @@ static size_t decode_seq_table(const u8 *src, size_t src_len, const size_t max_accuracies[] = {9, 8, 9}; if (mode != seq_repeat) { - // ree old one before overwriting + // Free old one before overwriting FSE_free_dtable(table); } @@ -1218,102 +1093,102 @@ static size_t decode_seq_table(const u8 *src, size_t src_len, const size_t accuracy_log = default_distribution_accuracies[type]; FSE_init_dtable(table, distribution, symbs, accuracy_log); - - return 0; + break; } case seq_rle: { - if (src_len < 1) { - INP_SIZE(); - } - const u8 symb = src[0]; - src++; - src_len--; + const u8 symb = IO_read_bits(in, 8); FSE_init_dtable_rle(table, symb); - - return 1; + break; } case seq_fse: { - size_t read = - FSE_decode_header(table, src, src_len, max_accuracies[type]); - src += read; - src_len -= read; - - return read; + FSE_decode_header(table, in, max_accuracies[type]); + break; } case seq_repeat: - // Don't have to do anything here as we're not changing the table - return 0; + // Nothing to do here, table will be unchanged + break; default: // Impossible, as mode is from 0-3 - return -1; + IMPOSSIBLE(); + break; } } /******* END SEQUENCE DECODING ************************************************/ /******* SEQUENCE EXECUTION ***************************************************/ -static void execute_sequences(io_streams_t *const streams, - frame_context_t *const ctx, +static void execute_sequences(frame_context_t *const ctx, ostream_t *const out, + const u8 *const literals, + const size_t literals_len, const sequence_command_t *const sequences, - const size_t num_sequences, - const u8 *literals, - size_t literals_len) { + const size_t num_sequences) { + istream_t litstream = IO_make_istream(literals, literals_len); + u64 *const offset_hist = ctx->previous_offsets; size_t total_output = ctx->current_total_output; for (size_t i = 0; i < num_sequences; i++) { const sequence_command_t seq = sequences[i]; - if (seq.literal_length > literals_len) { - CORRUPTION(); + { + if (seq.literal_length > IO_istream_len(&litstream)) { + CORRUPTION(); + } + + u8 *const write_ptr = IO_write_bytes(out, seq.literal_length); + const u8 *const read_ptr = + IO_read_bytes(&litstream, seq.literal_length); + // Copy literals to output + memcpy(write_ptr, read_ptr, seq.literal_length); + + total_output += seq.literal_length; } - if (streams->dst_len < seq.literal_length + seq.match_length) { - OUT_SIZE(); - } - // Copy literals to output - memcpy(streams->dst, literals, seq.literal_length); - - literals += seq.literal_length; - literals_len -= seq.literal_length; - - streams->dst += seq.literal_length; - streams->dst_len -= seq.literal_length; - - total_output += seq.literal_length; - size_t offset; // Offsets are special, we need to handle the repeat offsets if (seq.offset <= 3) { - u32 idx = seq.offset; + // "The first 3 values define a repeated offset and we will call + // them Repeated_Offset1, Repeated_Offset2, and Repeated_Offset3. + // They are sorted in recency order, with Repeated_Offset1 meaning + // 'most recent one'". + + // Use 0 indexing for the array + u32 idx = seq.offset - 1; if (seq.literal_length == 0) { - // Special case when literal length is 0 + // "There is an exception though, when current sequence's + // literals length is 0. In this case, repeated offsets are + // shifted by one, so Repeated_Offset1 becomes Repeated_Offset2, + // Repeated_Offset2 becomes Repeated_Offset3, and + // Repeated_Offset3 becomes Repeated_Offset1 - 1_byte." idx++; } - if (idx == 1) { - offset = offset_hist[1]; + if (idx == 0) { + offset = offset_hist[0]; } else { - // If idx == 4 then literal length was 0 and the offset was 3 - offset = idx < 4 ? offset_hist[idx] : offset_hist[1] - 1; + // If idx == 3 then literal length was 0 and the offset was 3, + // as per the exception listed above + offset = idx < 3 ? offset_hist[idx] : offset_hist[0] - 1; - // If idx == 2 we don't need to modify offset_hist[3] - if (idx > 2) { - offset_hist[3] = offset_hist[2]; + // If idx == 1 we don't need to modify offset_hist[2] + if (idx > 1) { + offset_hist[2] = offset_hist[1]; } - offset_hist[2] = offset_hist[1]; - offset_hist[1] = offset; + offset_hist[1] = offset_hist[0]; + offset_hist[0] = offset; } } else { offset = seq.offset - 3; // Shift back history - offset_hist[3] = offset_hist[2]; offset_hist[2] = offset_hist[1]; - offset_hist[1] = offset; + offset_hist[1] = offset_hist[0]; + offset_hist[0] = offset; } size_t match_length = seq.match_length; + + u8 *write_ptr = IO_write_bytes(out, match_length); if (total_output <= ctx->header.window_size) { // In this case offset might go back into the dictionary if (offset > total_output + ctx->dict_content_len) { @@ -1322,13 +1197,16 @@ static void execute_sequences(io_streams_t *const streams, } if (offset > total_output) { + // "The rest of the dictionary is its content. The content act + // as a "past" in front of data to compress or decompress, so it + // can be referenced in sequence commands." const size_t dict_copy = MIN(offset - total_output, match_length); const size_t dict_offset = ctx->dict_content_len - (offset - total_output); - for (size_t i = 0; i < dict_copy; i++) { - *streams->dst++ = ctx->dict_content[dict_offset + i]; - } + + memcpy(write_ptr, ctx->dict_content + dict_offset, dict_copy); + write_ptr += dict_copy; match_length -= dict_copy; } } else if (offset > ctx->header.window_size) { @@ -1340,31 +1218,29 @@ static void execute_sequences(io_streams_t *const streams, // ex: if the output so far was "abc", a command with offset=3 and // match_length=6 would produce "abcabcabc" as the new output for (size_t i = 0; i < match_length; i++) { - *streams->dst = *(streams->dst - offset); - streams->dst++; + *write_ptr = *(write_ptr - offset); + write_ptr++; } - streams->dst_len -= seq.match_length; total_output += seq.match_length; } - if (streams->dst_len < literals_len) { - OUT_SIZE(); - } - // Copy any leftover literals - memcpy(streams->dst, literals, literals_len); - streams->dst += literals_len; - streams->dst_len -= literals_len; + { + size_t len = IO_istream_len(&litstream); + u8 *const write_ptr = IO_write_bytes(out, len); + const u8 *const read_ptr = IO_read_bytes(&litstream, len); + // Copy any leftover literals + memcpy(write_ptr, read_ptr, len); - total_output += literals_len; + total_output += len; + } ctx->current_total_output = total_output; } /******* END SEQUENCE EXECUTION ***********************************************/ /******* OUTPUT SIZE COUNTING *************************************************/ -size_t traverse_frame(const frame_header_t *const header, const u8 *src, - size_t src_len); +static void traverse_frame(const frame_header_t *const header, istream_t *const in); /// Get the decompressed size of an input stream so memory can be allocated in /// advance. @@ -1372,115 +1248,75 @@ size_t traverse_frame(const frame_header_t *const header, const u8 *src, /// implementation, as this API allows for the decompression of multiple /// concatenated frames. size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) { - const u8 *ip = (const u8 *) src; - size_t ip_len = src_len; - size_t dst_size = 0; + istream_t in = IO_make_istream(src, src_len); + size_t dst_size = 0; - // Each frame header only gives us the size of its frame, so iterate over all - // frames - while (ip_len > 0) { - if (ip_len < 4) { - INP_SIZE(); + // Each frame header only gives us the size of its frame, so iterate over + // all + // frames + while (IO_istream_len(&in) > 0) { + const u32 magic_number = IO_read_bits(&in, 32); + + if ((magic_number & ~0xFU) == 0x184D2A50U) { + // skippable frame, this has no impact on output size + const size_t frame_size = IO_read_bits(&in, 32); + IO_advance_input(&in, frame_size); + } else if (magic_number == 0xFD2FB528U) { + // ZSTD frame + frame_header_t header; + parse_frame_header(&header, &in); + + if (header.frame_content_size == 0 && !header.single_segment_flag) { + // Content size not provided, we can't tell + return -1; + } + + dst_size += header.frame_content_size; + + // Consume the input from the frame to reach the start of the next + traverse_frame(&header, &in); + } else { + // not a real frame + ERROR("Invalid magic number"); + } } - const u32 magic_number = read_bits_LE(ip, 32, 0); - - ip += 4; - ip_len -= 4; - if (magic_number >= 0x184D2A50U && magic_number <= 0x184D2A5F) { - // skippable frame, this has no impact on output size - if (ip_len < 4) { - INP_SIZE(); - } - const size_t frame_size = read_bits_LE(ip, 32, 32); - - if (ip_len < 4 + frame_size) { - INP_SIZE(); - } - - // skip over frame - ip += 4 + frame_size; - ip_len -= 4 + frame_size; - } else if (magic_number == 0xFD2FB528U) { - // ZSTD frame - frame_header_t header; - parse_frame_header(&header, ip, ip_len); - - if (header.frame_content_size == 0 && !header.single_segment_flag) { - // Content size not provided, we can't tell - return -1; - } - - dst_size += header.frame_content_size; - - // we need to traverse the frame to find when the next one starts - const size_t traversed = traverse_frame(&header, ip, ip_len); - ip += traversed; - ip_len -= traversed; - } else { - // not a real frame - ERROR("Invalid magic number"); - } - } - - return dst_size; + return dst_size; } /// Iterate over each block in a frame to find the end of it, to get to the /// start of the next frame -size_t traverse_frame(const frame_header_t *const header, const u8 *src, - size_t src_len) { - const u8 *const src_beg = src; - const u8 *const src_end = src + src_len; - src += header->header_size; - src_len += header->header_size; - +static void traverse_frame(const frame_header_t *const header, istream_t *const in) { int last_block = 0; do { - if (src + 3 > src_end) { - INP_SIZE(); - } // Parse the block header - last_block = src[0] & 1; - const int block_type = (src[0] >> 1) & 3; - const size_t block_len = read_bits_LE(src, 21, 3); + last_block = IO_read_bits(in, 1); + const int block_type = IO_read_bits(in, 2); + const size_t block_len = IO_read_bits(in, 21); - src += 3; switch (block_type) { case 0: // Raw block, block_len bytes - if (src + block_len > src_end) { - INP_SIZE(); - } - src += block_len; + IO_advance_input(in, block_len); break; case 1: // RLE block, 1 byte - if (src + 1 > src_end) { - INP_SIZE(); - } - src++; + IO_advance_input(in, 1); break; case 2: // Compressed block, compressed size is block_len - if (src + block_len > src_end) { - INP_SIZE(); - } - src += block_len; + IO_advance_input(in, block_len); break; case 3: // Reserved block type CORRUPTION(); break; + default: + IMPOSSIBLE(); } } while (!last_block); if (header->content_checksum_flag) { - if (src + 4 > src_end) { - INP_SIZE(); - } - src += 4; + IO_advance_input(in, 4); } - - return src - src_beg; } /******* END OUTPUT SIZE COUNTING *********************************************/ @@ -1495,68 +1331,46 @@ static void parse_dictionary(dictionary_t *const dict, const u8 *src, if (src_len < 8) { INP_SIZE(); } - const u32 magic_number = read_bits_LE(src, 32, 0); + + istream_t in = IO_make_istream(src, src_len); + + const u32 magic_number = IO_read_bits(&in, 32); if (magic_number != 0xEC30A437) { // raw content dict init_raw_content_dict(dict, src, src_len); return; } - dict->dictionary_id = read_bits_LE(src, 32, 32); - src += 8; - src_len -= 8; + dict->dictionary_id = IO_read_bits(&in, 32); // Parse the provided entropy tables in order - { - const size_t read = - decode_huf_table(src, src_len, &dict->literals_dtable); - src += read; - src_len -= read; - } - { - const size_t read = decode_seq_table(src, src_len, &dict->of_dtable, - seq_offset, seq_fse); - src += read; - src_len -= read; - } - { - const size_t read = decode_seq_table(src, src_len, &dict->ml_dtable, - seq_match_length, seq_fse); - src += read; - src_len -= read; - } - { - const size_t read = decode_seq_table(src, src_len, &dict->ll_dtable, - seq_literal_length, seq_fse); - src += read; - src_len -= read; - } + decode_huf_table(&in, &dict->literals_dtable); + decode_seq_table(&in, &dict->of_dtable, seq_offset, seq_fse); + decode_seq_table(&in, &dict->ml_dtable, seq_match_length, seq_fse); + decode_seq_table(&in, &dict->ll_dtable, seq_literal_length, seq_fse); - if (src_len < 12) { - INP_SIZE(); - } // Read in the previous offset history - dict->previous_offsets[1] = read_bits_LE(src, 32, 0); - dict->previous_offsets[2] = read_bits_LE(src, 32, 32); - dict->previous_offsets[3] = read_bits_LE(src, 32, 64); - - src += 12; - src_len -= 12; + dict->previous_offsets[0] = IO_read_bits(&in, 32); + dict->previous_offsets[1] = IO_read_bits(&in, 32); + dict->previous_offsets[2] = IO_read_bits(&in, 32); // Ensure the provided offsets aren't too large - for (int i = 1; i <= 3; i++) { + for (int i = 0; i < 3; i++) { if (dict->previous_offsets[i] > src_len) { ERROR("Dictionary corrupted"); } } + // The rest is the content - dict->content = malloc(src_len); + dict->content_size = IO_istream_len(&in); + dict->content = malloc(dict->content_size); if (!dict->content) { BAD_ALLOC(); } - dict->content_size = src_len; - memcpy(dict->content, src, src_len); + const u8 *const content = IO_read_bytes(&in, dict->content_size); + + memcpy(dict->content, content, dict->content_size); } /// If parse_dictionary is given a raw content dictionary, it delegates here @@ -1586,6 +1400,143 @@ static void free_dictionary(dictionary_t *const dict) { } /******* END DICTIONARY PARSING ***********************************************/ +/******* IO STREAM OPERATIONS *************************************************/ +#define UNALIGNED() ERROR("Attempting to operate on a non-byte aligned stream") +/// Reads `num` bits from a bitstream, and updates the internal offset +static inline u64 IO_read_bits(istream_t *const in, const int num) { + if (num > 64) { + return -1; + } + + const size_t bytes = (num + in->bit_offset + 7) / 8; + const size_t full_bytes = (num + in->bit_offset) / 8; + if (bytes > in->len) { + INP_SIZE(); + } + + const u64 result = read_bits_LE(in->ptr, num, in->bit_offset); + + in->bit_offset = (num + in->bit_offset) % 8; + in->ptr += full_bytes; + in->len -= full_bytes; + + return result; +} + +/// If a non-zero number of bits have been read from the current byte, advance +/// the offset to the next byte +static inline void IO_rewind_bits(istream_t *const in, int num) { + if (num < 0) { + ERROR("Attempting to rewind stream by a negative number of bits"); + } + + const int new_offset = in->bit_offset - num; + const i64 bytes = (new_offset - 7) / 8; + + in->ptr += bytes; + in->len -= bytes; + in->bit_offset = ((new_offset % 8) + 8) % 8; +} + +/// If the remaining bits in a byte will be unused, advance to the end of the +/// byte +static inline void IO_align_stream(istream_t *const in) { + if (in->bit_offset != 0) { + if (in->len == 0) { + INP_SIZE(); + } + in->ptr++; + in->len--; + in->bit_offset = 0; + } +} + +/// Write the given byte into the output stream +static inline void IO_write_byte(ostream_t *const out, u8 symb) { + if (out->len == 0) { + OUT_SIZE(); + } + + out->ptr[0] = symb; + out->ptr++; + out->len--; +} + +/// Returns the number of bytes left to be read in this stream. The stream must +/// be byte aligned. +static inline size_t IO_istream_len(const istream_t *const in) { + return in->len; +} + +/// Returns a pointer where `len` bytes can be read, and advances the internal +/// state. The stream must be byte aligned. +static inline const u8 *IO_read_bytes(istream_t *const in, size_t len) { + if (len > in->len) { + INP_SIZE(); + } + if (in->bit_offset != 0) { + UNALIGNED(); + } + const u8 *const ptr = in->ptr; + in->ptr += len; + in->len -= len; + + return ptr; +} +/// Returns a pointer to write `len` bytes to, and advances the internal state +static inline u8 *IO_write_bytes(ostream_t *const out, size_t len) { + if (len > out->len) { + INP_SIZE(); + } + u8 *const ptr = out->ptr; + out->ptr += len; + out->len -= len; + + return ptr; +} + +/// Advance the inner state by `len` bytes +static inline void IO_advance_input(istream_t *const in, size_t len) { + if (len > in->len) { + INP_SIZE(); + } + if (in->bit_offset != 0) { + UNALIGNED(); + } + + in->ptr += len; + in->len -= len; +} + +/// Returns an `ostream_t` constructed from the given pointer and length +static inline ostream_t IO_make_ostream(u8 *out, size_t len) { + return (ostream_t) { out, len }; +} + +/// Returns an `istream_t` constructed from the given pointer and length +static inline istream_t IO_make_istream(const u8 *in, size_t len) { + return (istream_t) { in, 0, len }; +} + +/// Returns an `istream_t` with the same base as `in`, and length `len` +/// Then, advance `in` to account for the consumed bytes +/// `in` must be byte aligned +static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len) { + if (len > in->len) { + INP_SIZE(); + } + if (in->bit_offset != 0) { + UNALIGNED(); + } + const istream_t sub = { in->ptr, in->bit_offset, len }; + + in->ptr += len; + in->len -= len; + + return sub; +} +/******* END IO STREAM OPERATIONS *********************************************/ + /******* BITSTREAM OPERATIONS *************************************************/ /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits static inline u64 read_bits_LE(const u8 *src, const int num, @@ -1676,28 +1627,29 @@ static inline void HUF_init_state(const HUF_dtable *const dtable, *state = STREAM_read_bits(src, bits, offset); } -static size_t HUF_decompress_1stream(const HUF_dtable *const dtable, u8 *dst, - const size_t dst_len, const u8 *src, - size_t src_len) { - const u8 *const dst_max = dst + dst_len; - const u8 *const odst = dst; +static size_t HUF_decompress_1stream(const HUF_dtable *const dtable, + ostream_t *const out, + istream_t *const in) { + const size_t len = IO_istream_len(in); + if (len == 0) { + INP_SIZE(); + } + const u8 *const src = IO_read_bytes(in, len); // To maintain similarity with FSE, start from the end // Find the last 1 bit - const int padding = 8 - log2inf(src[src_len - 1]); + const int padding = 8 - log2inf(src[len - 1]); - i64 offset = src_len * 8 - padding; + i64 offset = len * 8 - padding; u16 state; HUF_init_state(dtable, &state, src, &offset); - while (dst < dst_max && offset > -dtable->max_bits) { + size_t symbols_written = 0; + while (offset > -dtable->max_bits) { // Iterate over the stream, decoding one symbol at a time - *dst++ = HUF_decode_symbol(dtable, &state, src, &offset); - } - // If we stopped before consuming all the input, we didn't have enough space - if (dst == dst_max && offset > -dtable->max_bits) { - OUT_SIZE(); + IO_write_byte(out, HUF_decode_symbol(dtable, &state, src, &offset)); + symbols_written++; } // When all symbols have been decoded, the final state value shouldn't have @@ -1709,50 +1661,30 @@ static size_t HUF_decompress_1stream(const HUF_dtable *const dtable, u8 *dst, CORRUPTION(); } - return dst - odst; + return symbols_written; } -static size_t HUF_decompress_4stream(const HUF_dtable *const dtable, u8 *dst, - const size_t dst_len, const u8 *const src, - const size_t src_len) { - if (src_len < 6) { - INP_SIZE(); - } +static size_t HUF_decompress_4stream(const HUF_dtable *const dtable, + ostream_t *const out, istream_t *const in) { + const size_t csize1 = IO_read_bits(in, 16); + const size_t csize2 = IO_read_bits(in, 16); + const size_t csize3 = IO_read_bits(in, 16); - const u8 *const src1 = src + 6; - const u8 *const src2 = src1 + read_bits_LE(src, 16, 0); - const u8 *const src3 = src2 + read_bits_LE(src, 16, 16); - const u8 *const src4 = src3 + read_bits_LE(src, 16, 32); - const u8 *const src_end = src + src_len; - - // We can't test with all 4 sizes because the 4th size is a function of the - // other 3 and the provided length - if (src4 - src >= src_len) { - INP_SIZE(); - } - - const size_t segment_size = (dst_len + 3) / 4; - u8 *const dst1 = dst; - u8 *const dst2 = dst1 + segment_size; - u8 *const dst3 = dst2 + segment_size; - u8 *const dst4 = dst3 + segment_size; - u8 *const dst_end = dst + dst_len; - - size_t total_out = 0; + istream_t in1 = IO_make_sub_istream(in, csize1); + istream_t in2 = IO_make_sub_istream(in, csize2); + istream_t in3 = IO_make_sub_istream(in, csize3); + istream_t in4 = IO_make_sub_istream(in, IO_istream_len(in)); + size_t total_output = 0; // Decode each stream independently for simplicity // If we wanted to we could decode all 4 at the same time for speed, // utilizing more execution units - total_out += HUF_decompress_1stream(dtable, dst1, segment_size, src1, - src2 - src1); - total_out += HUF_decompress_1stream(dtable, dst2, segment_size, src2, - src3 - src2); - total_out += HUF_decompress_1stream(dtable, dst3, segment_size, src3, - src4 - src3); - total_out += HUF_decompress_1stream(dtable, dst4, dst_end - dst4, src4, - src_end - src4); + total_output += HUF_decompress_1stream(dtable, out, &in1); + total_output += HUF_decompress_1stream(dtable, out, &in2); + total_output += HUF_decompress_1stream(dtable, out, &in3); + total_output += HUF_decompress_1stream(dtable, out, &in4); - return total_out; + return total_output; } static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits, @@ -1827,6 +1759,10 @@ static void HUF_init_dtable_usingweights(HUF_dtable *const table, u64 weight_sum = 0; for (int i = 0; i < num_symbs; i++) { + // Weights are in the same range as bit count + if (weights[i] > HUF_MAX_BITS) { + CORRUPTION(); + } weight_sum += weights[i] > 0 ? (u64)1 << (weights[i] - 1) : 0; } @@ -1913,20 +1849,17 @@ static inline void FSE_init_state(const FSE_dtable *const dtable, } static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable, - u8 *dst, const size_t dst_len, - const u8 *const src, - const size_t src_len) { - if (src_len == 0) { + ostream_t *const out, + istream_t *const in) { + const size_t len = IO_istream_len(in); + if (len == 0) { INP_SIZE(); } - - const u8 *const dst_max = dst + dst_len; - const u8 *const odst = dst; + const u8 *const src = IO_read_bytes(in, len); // Find the last 1 bit - const int padding = 8 - log2inf(src[src_len - 1]); - - i64 offset = src_len * 8 - padding; + const int padding = 8 - log2inf(src[len - 1]); + i64 offset = len * 8 - padding; // The end of the stream contains the 2 states, in this order u16 state1, state2; @@ -1936,30 +1869,28 @@ static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable, // Decode until we overflow the stream // Since we decode in reverse order, overflowing the stream is offset going // negative + size_t symbols_written = 0; while (1) { - if (dst > dst_max - 2) { - OUT_SIZE(); - } - *dst++ = FSE_decode_symbol(dtable, &state1, src, &offset); + IO_write_byte(out, FSE_decode_symbol(dtable, &state1, src, &offset)); + symbols_written++; if (offset < 0) { // There's still a symbol to decode in state2 - *dst++ = FSE_peek_symbol(dtable, state2); + IO_write_byte(out, FSE_peek_symbol(dtable, state2)); + symbols_written++; break; } - if (dst > dst_max - 2) { - OUT_SIZE(); - } - *dst++ = FSE_decode_symbol(dtable, &state2, src, &offset); + IO_write_byte(out, FSE_decode_symbol(dtable, &state2, src, &offset)); + symbols_written++; if (offset < 0) { // There's still a symbol to decode in state1 - *dst++ = FSE_peek_symbol(dtable, state1); + IO_write_byte(out, FSE_peek_symbol(dtable, state1)); + symbols_written++; break; } } - // Number of symbols read - return dst - odst; + return symbols_written; } static void FSE_init_dtable(FSE_dtable *const dtable, @@ -2042,17 +1973,13 @@ static void FSE_init_dtable(FSE_dtable *const dtable, /// Decode an FSE header as defined in the Zstandard format specification and /// use the decoded frequencies to initialize a decoding table. -static size_t FSE_decode_header(FSE_dtable *const dtable, const u8 *const src, - const size_t src_len, +static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in, const int max_accuracy_log) { if (max_accuracy_log > FSE_MAX_ACCURACY_LOG) { ERROR("FSE accuracy too large"); } - if (src_len < 1) { - INP_SIZE(); - } - const int accuracy_log = 5 + read_bits_LE(src, 4, 0); + const int accuracy_log = 5 + IO_read_bits(in, 4); if (accuracy_log > max_accuracy_log) { ERROR("FSE accuracy too large"); } @@ -2062,14 +1989,11 @@ static size_t FSE_decode_header(FSE_dtable *const dtable, const u8 *const src, i16 frequencies[FSE_MAX_SYMBS]; int symb = 0; - // Offset of 4 because 4 bits were already read in for accuracy - size_t offset = 4; while (remaining > 1 && symb < FSE_MAX_SYMBS) { // Log of the number of possible values we could read int bits = log2inf(remaining) + 1; - u16 val = read_bits_LE(src, bits, offset); - offset += bits; + u16 val = IO_read_bits(in, bits); // Try to mask out the lower bits to see if it qualifies for the "small // value" threshold @@ -2077,7 +2001,7 @@ static size_t FSE_decode_header(FSE_dtable *const dtable, const u8 *const src, const u16 threshold = ((u16)1 << bits) - 1 - remaining; if ((val & lower_mask) < threshold) { - offset--; + IO_rewind_bits(in, 1); val = val & lower_mask; } else if (val > lower_mask) { val = val - threshold; @@ -2093,22 +2017,21 @@ static size_t FSE_decode_header(FSE_dtable *const dtable, const u8 *const src, // Handle the special probability = 0 case if (proba == 0) { // Read the next two bits to see how many more 0s - int repeat = read_bits_LE(src, 2, offset); - offset += 2; + int repeat = IO_read_bits(in, 2); while (1) { for (int i = 0; i < repeat && symb < FSE_MAX_SYMBS; i++) { frequencies[symb++] = 0; } if (repeat == 3) { - repeat = read_bits_LE(src, 2, offset); - offset += 2; + repeat = IO_read_bits(in, 2); } else { break; } } } } + IO_align_stream(in); if (remaining != 1 || symb >= FSE_MAX_SYMBS) { CORRUPTION(); @@ -2116,8 +2039,6 @@ static size_t FSE_decode_header(FSE_dtable *const dtable, const u8 *const src, // Initialize the decoding table using the determined weights FSE_init_dtable(dtable, frequencies, symb, accuracy_log); - - return (offset + 7) / 8; } static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb) {