From 9700f925837d819896ee02bd7362b718e6f13916 Mon Sep 17 00:00:00 2001 From: Sean Purcell Date: Mon, 30 Jan 2017 11:42:45 -0800 Subject: [PATCH] Add educational decoder to /contrib --- contrib/educational_decoder/README.md | 18 + contrib/educational_decoder/harness.c | 93 + contrib/educational_decoder/zstd_decompress.c | 2096 +++++++++++++++++ contrib/educational_decoder/zstd_decompress.h | 6 + 4 files changed, 2213 insertions(+) create mode 100644 contrib/educational_decoder/README.md create mode 100644 contrib/educational_decoder/harness.c create mode 100644 contrib/educational_decoder/zstd_decompress.c create mode 100644 contrib/educational_decoder/zstd_decompress.h diff --git a/contrib/educational_decoder/README.md b/contrib/educational_decoder/README.md new file mode 100644 index 00000000..a1f703f6 --- /dev/null +++ b/contrib/educational_decoder/README.md @@ -0,0 +1,18 @@ +Educational Decoder +=================== + +`zstd_decompress.c` is a self-contained implementation of a decoder according +to the Zstandard format specification written in C99. +While it does not implement as many features as the reference decoder, +such as the streaming API or content checksums, it is written to be easy to +follow and understand, to help understand how the Zstandard format works. +It's laid out to match the [format specification], +so it can be used to understand how confusing segments could be implemented. +It also contains implementations of Huffman and FSE table decoding. + +[format specification]: https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md + +`harness.c` provides a simple test harness around the decoder: + + harness [dictionary] + diff --git a/contrib/educational_decoder/harness.c b/contrib/educational_decoder/harness.c new file mode 100644 index 00000000..6f4765d9 --- /dev/null +++ b/contrib/educational_decoder/harness.c @@ -0,0 +1,93 @@ +#include +#include + +#include "zstd_decompress.h" + +typedef unsigned char u8; + +// There's no good way to determine output size without decompressing +// For this example assume we'll never decompress at a ratio larger than 16 +#define MAX_COMPRESSION_RATIO (16) + +u8 *input; +u8 *output; +u8 *dict; + +size_t read_file(const char *path, u8 **ptr) { + FILE *f = fopen(path, "rb"); + if (!f) { + fprintf(stderr, "failed to open file %s\n", path); + exit(1); + } + + fseek(f, 0L, SEEK_END); + size_t size = ftell(f); + rewind(f); + + *ptr = malloc(size); + if (!ptr) { + fprintf(stderr, "failed to allocate memory to hold %s\n", path); + exit(1); + } + + size_t pos = 0; + while (!feof(f)) { + size_t read = fread(&(*ptr)[pos], 1, size, f); + if (ferror(f)) { + fprintf(stderr, "error while reading file %s\n", path); + exit(1); + } + pos += read; + } + + fclose(f); + + return pos; +} + +void write_file(const char *path, const u8 *ptr, size_t size) { + FILE *f = fopen(path, "wb"); + + size_t written = 0; + while (written < size) { + written += fwrite(&ptr[written], 1, size, f); + if (ferror(f)) { + fprintf(stderr, "error while writing file %s\n", path); + exit(1); + } + } + + fclose(f); +} + +int main(int argc, char **argv) { + if (argc < 3) { + fprintf(stderr, "usage: %s [dictionary]\n", argv[0]); + + return 1; + } + + size_t input_size = read_file(argv[1], &input); + size_t dict_size = 0; + if (argc >= 4) { + dict_size = read_file(argv[3], &dict); + } + + output = malloc(MAX_COMPRESSION_RATIO * input_size); + if (!output) { + fprintf(stderr, "failed to allocate memory\n"); + return 1; + } + + size_t decompressed = + ZSTD_decompress_with_dict(output, input_size * MAX_COMPRESSION_RATIO, + input, input_size, dict, dict_size); + + write_file(argv[2], output, decompressed); + + free(input); + free(output); + free(dict); + input = output = dict = NULL; +} + diff --git a/contrib/educational_decoder/zstd_decompress.c b/contrib/educational_decoder/zstd_decompress.c new file mode 100644 index 00000000..8dc15900 --- /dev/null +++ b/contrib/educational_decoder/zstd_decompress.c @@ -0,0 +1,2096 @@ +/// Zstandard educational decoder implementation +/// See https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md + +#include +#include +#include +#include + +/// Zstandard decompression functions. +/// `dst` must point to a space at least as large as the reconstructed output. +size_t ZSTD_decompress(void *dst, size_t dst_len, const void *src, + size_t src_len); +/// If `dict != NULL` and `dict_len >= 8`, does the same thing as +/// `ZSTD_decompress` but uses the provided dict +size_t ZSTD_decompress_with_dict(void *dst, size_t dst_len, const void *src, + size_t src_len, const void *dict, + size_t dict_len); + +/******* UTILITY MACROS AND TYPES *********************************************/ +#define MAX_WINDOW_SIZE ((size_t)512 << 20) +// Max block size decompressed size is 128 KB and literal blocks must be smaller +// than that +#define MAX_LITERALS_SIZE ((size_t)(1024 * 128)) + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +#define ERROR(s) \ + do { \ + fprintf(stderr, "Error: %s\n", s); \ + exit(1); \ + } while (0) +#define INP_SIZE() \ + ERROR("Input buffer smaller than it should be or input is " \ + "corrupted") +#define OUT_SIZE() ERROR("Output buffer too small for output") +#define CORRUPTION() ERROR("Corruption detected while decompressing") +#define BAD_ALLOC() ERROR("Memory allocation error") + +typedef uint8_t u8; +typedef uint16_t u16; +typedef uint32_t u32; +typedef uint64_t u64; + +typedef int8_t i8; +typedef int16_t i16; +typedef int32_t i32; +typedef int64_t i64; +/******* END UTILITY MACROS AND TYPES *****************************************/ + +/******* IMPLEMENTATION PRIMITIVE PROTOTYPES **********************************/ +/// The implementations for these functions can be found at the bottom of this +/// file. They implement low-level functionality needed for the higher level +/// decompression functions. + +/*** CIRCULAR BUFFER ******************/ +/// A standard circular buffer, used to facilitate back reference commands +typedef struct { + u8 *ptr; + size_t idx, last_flush, size; +} cbuf_t; + +/// Initialize a circular buffer +static void cbuf_init(cbuf_t *buf, size_t size); +static void cbuf_free(cbuf_t *buf); + +/// Copies up to `src_len` bytes from `src` into the buffer, stopping if it +/// would need to flush. +/// Returns the total amount of data copied. +static size_t cbuf_write_data(cbuf_t *buf, const u8 *src, size_t src_len); +/// Copies `len` bytes from `offset` back in the buffer, stopping if it would +/// need to flush. +/// Returns the number of bytes copied. +static size_t cbuf_copy_offset(cbuf_t *buf, size_t offset, size_t len); +/// Writes up to `len` copies of `byte`, stopping if would need to flush. +/// Returns the number of bytes copied. +static size_t cbuf_repeat_byte(cbuf_t *buf, u8 byte, size_t len); + +/// The `full` versions of the above functions write the full amount requested, +/// flushing to `out` when necessary. +/// They return the number of bytes flushed to `out`, if any. +static size_t cbuf_write_data_full(cbuf_t *buf, const u8 *src, size_t src_len, + u8 *out, size_t out_len); +static size_t cbuf_copy_offset_full(cbuf_t *buf, size_t offset, size_t len, + u8 *out, size_t out_len); +static size_t cbuf_repeat_byte_full(cbuf_t *buf, u8 byte, size_t len, u8 *out, + size_t out_len); + +/// Flushes any unflushed data to `dst` +static size_t cbuf_flush(cbuf_t *buf, u8 *dst, size_t dst_len); +/*** END CIRCULAR BUFFER **************/ + +/*** BITSTREAM OPERATIONS *************/ +/// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits +static inline u64 read_bits_LE(const u8 *src, int num, size_t offset); + +/// Read bits from the end of a HUF or FSE bitstream. `offset` is in bits, so +/// it updates `offset` to `offset - bits`, and then reads `bits` bits from +/// `src + offset`. If the offset becomes negative, the extra bits at the +/// bottom are filled in with `0` bits instead of reading from before `src`. +static inline u64 STREAM_read_bits(const u8 *src, int bits, i64 *offset); +/*** END BITSTREAM OPERATIONS *********/ + +/*** BIT COUNTING OPERATIONS **********/ +/// Returns `x`, where `2^x` is the smallest power of 2 greater than or equal to +/// `num`, or `-1` if `num > 2^63` +static inline int log2sup(u64 num); + +/// Returns `x`, where `2^x` is the largest power of 2 less than or equal to +/// `num`, or `-1` if `num == 0`. +static inline int log2inf(u64 num); +/*** END BIT COUNTING OPERATIONS ******/ + +/*** HUFFMAN PRIMITIVES ***************/ +// Table decode method uses exponential memory, so we need to limit depth +#define HUF_MAX_BITS (16) + +// Limit the maximum number of symbols to 256 so we can store a symbol in a byte +#define HUF_MAX_SYMBS (256) + +/// Structure containing all tables necessary for efficient Huffman decoding +typedef struct { + u8 *symbols; + u8 *num_bits; + int max_bits; +} HUF_dtable; + +/// Decode a single symbol and read in enough bits to refresh the state +static inline u8 HUF_decode_symbol(HUF_dtable *dtable, u16 *state, + const u8 *src, i64 *offset); +/// Read in a full state's worth of bits to initialize it +static inline void HUF_init_state(HUF_dtable *dtable, u16 *state, const u8 *src, + i64 *offset); + +/// Initialize a Huffman decoding table using the table of bit counts provided +static void HUF_init_dtable(HUF_dtable *table, u8 *bits, int num_symbs); +/// Initialize a Huffman decoding table using the table of weights provided +/// Weights follow the definition provided in the Zstandard specification +static void HUF_init_dtable_usingweights(HUF_dtable *table, u8 *weights, + int num_symbs); + +/// Decompresses a single Huffman stream, returns the number of bytes decoded. +/// `src_len` must be the exact length of the Huffman-coded block. +static size_t HUF_decompress_1stream(HUF_dtable *table, u8 *dst, size_t dst_len, + const u8 *src, size_t src_len); +/// Same as previous but decodes 4 streams, formatted as in the Zstandard +/// specification. +/// `src_len` must be the exact length of the Huffman-coded block. +static size_t HUF_decompress_4stream(HUF_dtable *dtable, u8 *dst, + size_t dst_len, const u8 *src, + size_t src_len); + +/// Free the malloc'ed parts of a decoding table +static void HUF_free_dtable(HUF_dtable *dtable); + +/// Deep copy a decoding table, so that it can be used and free'd without +/// impacting the source table. +static void HUF_copy_dtable(HUF_dtable *dst, const HUF_dtable *src); +/*** END HUFFMAN PRIMITIVES ***********/ + +/*** FSE PRIMITIVES *******************/ +/// For more description of FSE see +/// https://github.com/Cyan4973/FiniteStateEntropy/ + +// FSE table decoding uses exponential memory, so limit the maximum accuracy +#define FSE_MAX_ACCURACY_LOG (15) +// Limit the maximum number of symbols so they can be stored in a single byte +#define FSE_MAX_SYMBS (256) + +/// The tables needed to decode FSE encoded streams +typedef struct { + u8 *symbols; + u8 *num_bits; + u16 *new_state_base; + int accuracy_log; +} FSE_dtable; + +/// Return the symbol for the current state +static inline u8 FSE_peek_symbol(FSE_dtable *dtable, u16 state); +/// Read the number of bits necessary to update state, update, and shift offset +/// back to reflect the bits read +static inline void FSE_update_state(FSE_dtable *dtable, u16 *state, + const u8 *src, i64 *offset); + +/// Combine peek and update: decode a symbol and update the state +static inline u8 FSE_decode_symbol(FSE_dtable *dtable, u16 *state, + const u8 *src, i64 *offset); + +/// Read bits from the stream to initialize the state and shift offset back +static inline void FSE_init_state(FSE_dtable *dtable, u16 *state, const u8 *src, + i64 *offset); + +/// Decompress two interleaved bitstreams (e.g. compressed Huffman weights) +/// using an FSE decoding table. `src_len` must be the exact length of the +/// block. +static size_t FSE_decompress_interleaved2(FSE_dtable *dtable, u8 *dst, + size_t dst_len, const u8 *src, + size_t src_len); + +/// Initialize a decoding table using normalized frequencies. +static void FSE_init_dtable(FSE_dtable *dtable, const i16 *norm_freqs, + int num_symbs, int accuracy_log); + +/// Decode an FSE header as defined in the Zstandard format specification and +/// use the decoded frequencies to initialize a decoding table. +static size_t FSE_decode_header(FSE_dtable *dtable, const u8 *src, + size_t src_len, int max_accuracy_log); + +/// Initialize an FSE table that will always return the same symbol and consume +/// 0 bits per symbol, to be used for RLE mode in sequence commands +static void FSE_init_dtable_rle(FSE_dtable *dtable, u8 symb); + +/// Free the malloc'ed parts of a decoding table +static void FSE_free_dtable(FSE_dtable *dtable); + +/// Deep copy a decoding table, so that it can be used and free'd without +/// impacting the source table. +static void FSE_copy_dtable(FSE_dtable *dst, const FSE_dtable *src); +/*** END FSE PRIMITIVES ***************/ + +/******* END IMPLEMENTATION PRIMITIVE PROTOTYPES ******************************/ + +/******* ZSTD HELPER STRUCTS AND PROTOTYPES ***********************************/ + +/// Input and output pointers to allow them to be advanced by +/// functions that consume input/produce output +typedef struct { + u8 *dst; + size_t dst_len; + + const u8 *src; + size_t src_len; +} io_streams_t; + +/// The context needed to decode blocks in a frame +typedef struct { + size_t window_size; + size_t frame_content_size; + + // The total amount of data available for backreferences, to determine if an + // offset too large to be correct + size_t current_total_output; + + // A sliding window of the past `window_size` bytes decoded + cbuf_t window; + + // Entropy encoding tables so they can be repeated by future blocks instead + // of + // retransmitting + HUF_dtable literals_dtable; + FSE_dtable ll_dtable; + FSE_dtable ml_dtable; + FSE_dtable of_dtable; + + // The last 3 offsets for the special "repeat offsets". Array size is 4 so + // that previous_offsets[1] corresponds to the most recent offset + u64 previous_offsets[4]; + + // The dictionary id for this frame if one exists + u32 dictionary_id; + + int single_segment_flag; + int content_checksum_flag; +} frame_context_t; + +/// The decoded contents of a dictionary so that it doesn't have to be repeated +/// for each frame that uses it +typedef struct { + // Entropy tables + HUF_dtable literals_dtable; + FSE_dtable ll_dtable; + FSE_dtable ml_dtable; + FSE_dtable of_dtable; + + // Raw content for backreferences + u8 *content; + size_t content_size; + + // Offset history to prepopulate the frame's history + u64 previous_offsets[4]; + + u32 dictionary_id; +} dictionary_t; + +/// A tuple containing the parts necessary to decode and execute a ZSTD sequence +/// command +typedef struct { + u32 literal_length; + u32 match_length; + u32 offset; +} sequence_command_t; + +/// The decoder works top-down, starting at the high level like Zstd frames, and +/// working down to lower more technical levels such as blocks, literals, and +/// sequences. The high-level functions roughly follow the outline of the +/// format specification: +/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md + +/// Before the implementation of each high-level function declared here, the +/// prototypes for their helper functions are defined and explained + +/// Decode a single Zstd frame, or error if the input is not a valid frame. +/// Accepts a dict argument, which may be NULL indicating no dictionary. +/// See +/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame-concatenation +static void decode_frame(io_streams_t *streams, dictionary_t *dict); + +// Decode data in a compressed block +static void decompress_block(io_streams_t *streams, frame_context_t *ctx, + size_t block_len); + +// Decode the literals section of a block +static size_t decode_literals(io_streams_t *streams, frame_context_t *ctx, + u8 **literals); + +// Decode the sequences part of a block +static size_t decode_sequences(frame_context_t *ctx, const u8 *src, + size_t src_len, sequence_command_t **sequences); + +// Execute the decoded sequences on the literals block +static size_t execute_sequences(io_streams_t *streams, frame_context_t *ctx, + sequence_command_t *sequences, + size_t num_sequences, const u8 *literals, + size_t literals_len); + +// Parse a provided dictionary blob for use in decompression +static void parse_dictionary(dictionary_t *dict, const u8 *src, size_t src_len); +static void free_dictionary(dictionary_t *dict); +/******* END ZSTD HELPER STRUCTS AND PROTOTYPES *******************************/ + +size_t ZSTD_decompress(void *dst, size_t dst_len, const void *src, + size_t src_len) { + return ZSTD_decompress_with_dict(dst, dst_len, src, src_len, NULL, 0); +} + +size_t ZSTD_decompress_usingDict(void *_ctx, void *dst, size_t dst_len, + const void *src, size_t src_len, + const void *dict, size_t dict_len) { + // _ctx needed to match ZSTD lib signature + return ZSTD_decompress_with_dict(dst, dst_len, src, src_len, dict, + dict_len); +} + +size_t ZSTD_decompress_with_dict(void *dst, size_t dst_len, const void *src, + size_t src_len, const void *dict, + size_t dict_len) { + dictionary_t parsed_dict; + memset(&parsed_dict, 0, sizeof(dictionary_t)); + // dict_len < 8 is not a valid dictionary + if (dict && dict_len > 8) { + parse_dictionary(&parsed_dict, (const u8 *)dict, dict_len); + } + + io_streams_t streams = {(u8 *)dst, dst_len, (const u8 *)src, src_len}; + while (streams.src_len > 0) { + decode_frame(&streams, &parsed_dict); + } + + free_dictionary(&parsed_dict); + + return streams.dst - (u8 *)dst; +} + +/******* FRAME DECODING ******************************************************/ + +static void decode_data_frame(io_streams_t *streams, dictionary_t *dict); +static void init_frame_context(frame_context_t *context); +static void free_frame_context(frame_context_t *context); +static void parse_frame_header(io_streams_t *streams, frame_context_t *ctx, + dictionary_t *dict); +static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict); + +static void decompress_data(io_streams_t *streams, frame_context_t *ctx); + +static void decode_frame(io_streams_t *streams, dictionary_t *dict) { + if (streams->src_len < 4) { + INP_SIZE(); + } + u32 magic_number = read_bits_LE(streams->src, 32, 0); + + streams->src += 4; + streams->src_len -= 4; + if (magic_number >= 0x184D2A50U && magic_number <= 0x184D2A5F) { + // skippable frame + if (streams->src_len < 4) { + INP_SIZE(); + } + size_t frame_size = read_bits_LE(streams->src, 32, 32); + + if (streams->src_len < 4 + frame_size) { + INP_SIZE(); + } + + // skip over frame + streams->src += 4 + frame_size; + streams->src_len -= 4 + frame_size; + } else if (magic_number == 0xFD2FB528U) { + // ZSTD frame + decode_data_frame(streams, dict); + } else { + // not a real frame + ERROR("Invalid magic number"); + } +} + +/// Decode a frame that contains compressed data. Not all frames do as there +/// are skippable frames. +/// See +/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#general-structure-of-zstandard-frame-format +static void decode_data_frame(io_streams_t *streams, dictionary_t *dict) { + frame_context_t ctx; + + // Initialize the context that needs to be carried from block to block + init_frame_context(&ctx); + parse_frame_header(streams, &ctx, dict); + frame_context_apply_dict(&ctx, dict); + + if (ctx.frame_content_size != 0 && + ctx.frame_content_size > streams->dst_len) { + OUT_SIZE(); + } + + decompress_data(streams, &ctx); + + free_frame_context(&ctx); +} + +static void init_frame_context(frame_context_t *context) { + memset(context, 0x00, sizeof(frame_context_t)); + + // Set up the offset history for the repeat offset commands + context->previous_offsets[1] = 1; + context->previous_offsets[2] = 4; + context->previous_offsets[3] = 8; +} + +static void free_frame_context(frame_context_t *context) { + HUF_free_dtable(&context->literals_dtable); + + FSE_free_dtable(&context->ll_dtable); + FSE_free_dtable(&context->ml_dtable); + FSE_free_dtable(&context->of_dtable); + + cbuf_free(&context->window); + + memset(context, 0, sizeof(frame_context_t)); +} + +static void parse_frame_header(io_streams_t *streams, frame_context_t *ctx, + dictionary_t *dict) { + if (streams->src_len < 1) { + INP_SIZE(); + } + + u8 descriptor = read_bits_LE(streams->src, 8, 0); + + // decode frame header descriptor into flags + u8 frame_content_size_flag = descriptor >> 6; + u8 single_segment_flag = (descriptor >> 5) & 1; + u8 reserved_bit = (descriptor >> 3) & 1; + u8 content_checksum_flag = (descriptor >> 2) & 1; + u8 dictionary_id_flag = descriptor & 3; + + if (reserved_bit != 0) { + CORRUPTION(); + } + + streams->src++; + streams->src_len--; + + ctx->single_segment_flag = single_segment_flag; + ctx->content_checksum_flag = content_checksum_flag; + + // decode window size + if (!single_segment_flag) { + if (streams->src_len < 1) { + INP_SIZE(); + } + + // Use the algorithm from the specification to compute window size + // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor + u8 window_descriptor = read_bits_LE(streams->src, 8, 0); + u8 exponent = window_descriptor >> 3; + u8 mantissa = window_descriptor & 7; + + size_t window_base = (size_t)1 << (10 + exponent); + size_t window_add = (window_base / 8) * mantissa; + ctx->window_size = window_base + window_add; + + streams->src++; + streams->src_len--; + } + + // decode dictionary id if it exists + if (dictionary_id_flag) { + const int bytes_array[] = {0, 1, 2, 4}; + const int bytes = bytes_array[dictionary_id_flag]; + + if (streams->src_len < bytes) { + INP_SIZE(); + } + + ctx->dictionary_id = read_bits_LE(streams->src, bytes * 8, 0); + streams->src += bytes; + streams->src_len -= bytes; + } else { + ctx->dictionary_id = 0; + } + + // decode frame content size if it exists + if (single_segment_flag || frame_content_size_flag) { + // if frame_content_size_flag == 0 but single_segment_flag is set, we + // still + // have a 1 byte field + const int bytes_array[] = {1, 2, 4, 8}; + const int bytes = bytes_array[frame_content_size_flag]; + + if (streams->src_len < bytes) { + INP_SIZE(); + } + + ctx->frame_content_size = read_bits_LE(streams->src, bytes * 8, 0); + if (bytes == 2) { + ctx->frame_content_size += 256; + } + + streams->src += bytes; + streams->src_len -= bytes; + } + + if (single_segment_flag) { + ctx->window_size = + ctx->frame_content_size + (dict ? dict->content_size : 0); + // We need to allocate a buffer to write to of size at least output + + // dict + // size + size_t size = ctx->frame_content_size + (dict ? dict->content_size : 0); + } + + // Allocate the window + if (ctx->window_size > MAX_WINDOW_SIZE) { + ERROR("Requested window size too large"); + } + cbuf_init(&ctx->window, ctx->window_size); +} + +/// A dictionary acts as initializing values for the frame context before +/// decompression, so we implement it by applying it's predetermined +/// tables and content to the context before beginning decompression +static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict) { + // If the content pointer is NULL then it must be an empty dict + if (!dict || !dict->content) + return; + + if (ctx->dictionary_id == 0 && dict->dictionary_id != 0) { + // The dictionary is unneeded, and shouldn't be used as it may interfere + // with the default offset history + return; + } + + // If the dictionary id is 0, it doesn't matter if we provide the wrong raw + // content dict, it won't change anything + if (ctx->dictionary_id != 0 && ctx->dictionary_id != dict->dictionary_id) { + ERROR("Wrong/no dictionary provided"); + } + + // Write the dict data in, and then flush to NULL so it's not sent to the + // output stream + cbuf_write_data_full(&ctx->window, dict->content, dict->content_size, NULL, + -1); + cbuf_flush(&ctx->window, NULL, -1); + ctx->current_total_output = dict->content_size; + + // If it's a formatted dict copy the precomputed tables in so they can + // be used in the table repeat modes + if (dict->dictionary_id != 0) { + // Deep copy the entropy tables so they can be freed independently of + // the + // dictionary struct + HUF_copy_dtable(&ctx->literals_dtable, &dict->literals_dtable); + FSE_copy_dtable(&ctx->ll_dtable, &dict->ll_dtable); + FSE_copy_dtable(&ctx->of_dtable, &dict->of_dtable); + FSE_copy_dtable(&ctx->ml_dtable, &dict->ml_dtable); + + memcpy(ctx->previous_offsets, dict->previous_offsets, + sizeof(ctx->previous_offsets)); + } +} + +/// Decompress the data from a frame block by block +static void decompress_data(io_streams_t *streams, frame_context_t *ctx) { + + u8 last_block = 0; + do { + if (streams->src_len < 3) { + INP_SIZE(); + } + // Parse the block header + last_block = streams->src[0] & 1; + u8 block_type = (streams->src[0] >> 1) & 3; + size_t block_len = read_bits_LE(streams->src, 21, 3); + + streams->src += 3; + streams->src_len -= 3; + + switch (block_type) { + case 0: { + // Raw, uncompressed block + if (streams->src_len < block_len) { + INP_SIZE(); + } + if (streams->dst_len < block_len) { + OUT_SIZE(); + } + + // Write the raw data into the window buffer + size_t written = + cbuf_write_data_full(&ctx->window, streams->src, block_len, + streams->dst, streams->dst_len); + streams->src += block_len; + streams->src_len -= block_len; + + streams->dst += written; + streams->dst_len -= written; + break; + } + case 1: { + // RLE block, repeat the first byte N times + if (streams->src_len < 1) { + INP_SIZE(); + } + if (streams->dst_len < block_len) { + OUT_SIZE(); + } + + // Write streams->src[0] into the buffer block_len times + size_t written = + cbuf_repeat_byte_full(&ctx->window, streams->src[0], block_len, + streams->dst, streams->dst_len); + streams->dst += written; + streams->dst_len -= written; + + streams->src += 1; + streams->src_len -= 1; + break; + } + case 2: + // Compressed block, this is mode complex + decompress_block(streams, ctx, block_len); + break; + } + } while (!last_block); + + // Flush out anything left in the window buffer to the destination stream + size_t written = cbuf_flush(&ctx->window, streams->dst, streams->dst_len); + streams->dst += written; + streams->dst_len -= written; + + if (ctx->content_checksum_flag) { + // This program does not support checking the checksum, so skip over it + // if + // it's present + if (streams->src_len < 4) { + INP_SIZE(); + } + streams->src += 4; + streams->src_len -= 4; + } +} +/******* END FRAME DECODING ***************************************************/ + +/******* BLOCK DECOMPRESSION **************************************************/ +static void decompress_block(io_streams_t *streams, frame_context_t *ctx, + size_t block_len) { + if (streams->src_len < block_len) { + INP_SIZE(); + } + // We need this to determine how long the compressed literals block was + const u8 *const end_of_block = streams->src + block_len; + + // Part 1: decode the literals block + u8 *literals = NULL; + size_t literals_size = decode_literals(streams, ctx, &literals); + + // Part 2: decode the sequences block + if (streams->src > end_of_block) { + INP_SIZE(); + } + size_t sequences_size = end_of_block - streams->src; + sequence_command_t *sequences = NULL; + size_t num_sequences = + decode_sequences(ctx, streams->src, sequences_size, &sequences); + + streams->src += sequences_size; + streams->src_len -= sequences_size; + + // Part 3: combine literals and sequence commands to generate output + execute_sequences(streams, ctx, sequences, num_sequences, literals, + literals_size); + free(literals); + free(sequences); +} +/******* END BLOCK DECOMPRESSION **********************************************/ + +/******* LITERALS DECODING ****************************************************/ +static size_t decode_literals_simple(io_streams_t *streams, u8 **literals, + int block_type, int size_format); +static size_t decode_literals_compressed(io_streams_t *streams, + frame_context_t *ctx, u8 **literals, + int block_type, int size_format); +static size_t decode_huf_table(const u8 *src, size_t src_len, + HUF_dtable *dtable); +static size_t fse_decode_hufweights(const u8 *src, size_t src_len, u8 *weights, + int *num_symbs, size_t compressed_size); + +static size_t decode_literals(io_streams_t *streams, frame_context_t *ctx, + u8 **literals) { + if (streams->src_len < 1) { + INP_SIZE(); + } + // Decode literals header + int block_type = streams->src[0] & 3; + int size_format = (streams->src[0] >> 2) & 3; + + if (block_type <= 1) { + // Raw or RLE literals block + return decode_literals_simple(streams, literals, block_type, + size_format); + } else { + // Huffman compressed literals + return decode_literals_compressed(streams, ctx, literals, block_type, + size_format); + } +} + +/// Decodes literals blocks in raw or RLE form +static size_t decode_literals_simple(io_streams_t *streams, u8 **literals, + int block_type, int size_format) { + size_t size; + switch (size_format) { + // These cases are in the form X0 + // In this case, the X bit is actually part of the size field + case 0: + case 2: + size = read_bits_LE(streams->src, 5, 3); + streams->src += 1; + streams->src_len -= 1; + break; + case 1: + if (streams->src_len < 2) { + INP_SIZE(); + } + size = read_bits_LE(streams->src, 12, 4); + streams->src += 2; + streams->src_len -= 2; + break; + case 3: + if (streams->src_len < 2) { + INP_SIZE(); + } + size = read_bits_LE(streams->src, 20, 4); + streams->src += 3; + streams->src_len -= 3; + break; + default: + // Impossible + size = -1; + } + + if (size > MAX_LITERALS_SIZE) { + CORRUPTION(); + } + + *literals = malloc(size); + if (!*literals) { + BAD_ALLOC(); + } + + switch (block_type) { + case 0: + // Raw data + if (size > streams->src_len) { + INP_SIZE(); + } + memcpy(*literals, streams->src, size); + streams->src += size; + streams->src_len -= size; + break; + case 1: + // Single repeated byte + if (1 > streams->src_len) { + INP_SIZE(); + } + memset(*literals, streams->src[0], size); + streams->src += 1; + streams->src_len -= 1; + break; + } + + return size; +} + +/// Decodes Huffman compressed literals +static size_t decode_literals_compressed(io_streams_t *streams, + frame_context_t *ctx, u8 **literals, + int block_type, int size_format) { + size_t regenerated_size, compressed_size; + // Only size_format=0 has 1 stream, so default to 4 + int num_streams = 4; + switch (size_format) { + case 0: + num_streams = 1; + // Fall through as it has the same size format + case 1: + if (streams->src_len < 3) { + INP_SIZE(); + } + regenerated_size = read_bits_LE(streams->src, 10, 4); + compressed_size = read_bits_LE(streams->src, 10, 14); + streams->src += 3; + streams->src_len -= 3; + break; + case 2: + if (streams->src_len < 4) { + INP_SIZE(); + } + regenerated_size = read_bits_LE(streams->src, 14, 4); + compressed_size = read_bits_LE(streams->src, 14, 18); + streams->src += 4; + streams->src_len -= 4; + break; + case 3: + if (streams->src_len < 5) { + INP_SIZE(); + } + regenerated_size = read_bits_LE(streams->src, 18, 4); + compressed_size = read_bits_LE(streams->src, 18, 22); + streams->src += 5; + streams->src_len -= 5; + break; + default: + // Impossible + compressed_size = regenerated_size = -1; + } + if (regenerated_size > MAX_LITERALS_SIZE || + compressed_size > regenerated_size) { + CORRUPTION(); + } + + if (compressed_size > streams->src_len) { + INP_SIZE(); + } + + *literals = malloc(regenerated_size); + if (!*literals) { + BAD_ALLOC(); + } + + if (block_type == 2) { + // Decode provided Huffman table + + HUF_free_dtable(&ctx->literals_dtable); + size_t size = decode_huf_table(streams->src, compressed_size, + &ctx->literals_dtable); + streams->src += size; + streams->src_len -= size; + compressed_size -= size; + } else { + // If we're to repeat the previous Huffman table, make sure it exists + if (!ctx->literals_dtable.symbols) { + CORRUPTION(); + } + } + + if (num_streams == 1) { + HUF_decompress_1stream(&ctx->literals_dtable, *literals, + regenerated_size, streams->src, compressed_size); + } else { + HUF_decompress_4stream(&ctx->literals_dtable, *literals, + regenerated_size, streams->src, compressed_size); + } + streams->src += compressed_size; + streams->src_len -= compressed_size; + + return regenerated_size; +} + +// Decode the Huffman table description +static size_t decode_huf_table(const u8 *src, size_t src_len, + HUF_dtable *dtable) { + if (src_len < 1) { + INP_SIZE(); + } + + const u8 *const osrc = src; + + u8 header = src[0]; + u8 weights[HUF_MAX_SYMBS]; + memset(weights, 0, sizeof(weights)); + + src++; + src_len--; + + int num_symbs; + + if (header >= 128) { + // Direct representation, read the weights out + num_symbs = header - 127; + size_t bytes = (num_symbs + 1) / 2; + + if (bytes > src_len) { + INP_SIZE(); + } + + for (int i = 0; i < num_symbs; i++) { + if (i % 2 == 0) { + weights[i] = src[i / 2] >> 4; + } else { + weights[i] = src[i / 2] & 0xf; + } + } + + src += bytes; + src_len -= bytes; + } else { + // The weights are FSE encoded, decode them before we can construct the + // table + size_t size = + fse_decode_hufweights(src, src_len, weights, &num_symbs, header); + src += size; + src_len -= size; + } + + // Construct the table using the decoded weights + HUF_init_dtable_usingweights(dtable, weights, num_symbs); + return src - osrc; +} + +static size_t fse_decode_hufweights(const u8 *src, size_t src_len, u8 *weights, + int *num_symbs, size_t compressed_size) { + const int MAX_ACCURACY_LOG = 7; + + FSE_dtable dtable; + + // Construct the FSE table + size_t read = FSE_decode_header(&dtable, src, src_len, MAX_ACCURACY_LOG); + + if (src_len < compressed_size) { + INP_SIZE(); + } + + // Decode the weights + *num_symbs = FSE_decompress_interleaved2( + &dtable, weights, HUF_MAX_SYMBS, src + read, compressed_size - read); + + FSE_free_dtable(&dtable); + + return compressed_size; +} +/******* END LITERALS DECODING ************************************************/ + +/******* SEQUENCE DECODING ****************************************************/ +/// The combination of FSE states needed to decode sequences +typedef struct { + u16 ll_state, of_state, ml_state; + FSE_dtable ll_table, of_table, ml_table; +} sequence_state_t; + +/// Different modes to signal to decode_seq_tables what to do +typedef enum { + seq_literal_length = 0, + seq_offset = 1, + seq_match_length = 2, +} seq_part_t; + +typedef enum { + seq_predefined = 0, + seq_rle = 1, + seq_fse = 2, + seq_repeat = 3, +} seq_mode_t; + +/// The predefined FSE distribution tables for `seq_predefined` mode +static const i16 SEQ_LITERAL_LENGTH_DEFAULT_DIST[36] = { + 4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, -1, -1, -1, -1}; +static const i16 SEQ_OFFSET_DEFAULT_DIST[29] = { + 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1}; +static const i16 SEQ_MATCH_LENGTH_DEFAULT_DIST[53] = { + 1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1}; + +/// The sequence decoding baseline and number of additional bits to read/add +/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#the-codes-for-literals-lengths-match-lengths-and-offsets +static const u32 SEQ_LITERAL_LENGTH_BASELINES[36] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 18, 20, 22, 24, 28, 32, 40, + 48, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65538}; +static const u8 SEQ_LITERAL_LENGTH_EXTRA_BITS[36] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + +static const u32 SEQ_MATCH_LENGTH_BASELINES[53] = { + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 37, 39, 41, 43, 47, 51, 59, 67, 83, + 99, 131, 259, 515, 1027, 2051, 4099, 8195, 16387, 32771, 65539}; +static const u8 SEQ_MATCH_LENGTH_EXTRA_BITS[53] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, + 2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + +/// Offset decoding is simpler so we just need a maximum code value +static const u8 SEQ_MAX_CODES[3] = {35, -1, 52}; + +static void decompress_sequences(frame_context_t *ctx, const u8 *src, + size_t src_len, sequence_command_t *sequences, + size_t num_sequences); +static sequence_command_t decode_sequence(sequence_state_t *state, + const u8 *src, i64 *offset); +static size_t decode_seq_table(const u8 *src, size_t src_len, FSE_dtable *table, + seq_part_t type, seq_mode_t mode); + +static size_t decode_sequences(frame_context_t *ctx, const u8 *src, + size_t src_len, sequence_command_t **sequences) { + size_t num_sequences; + + // Decode the sequence header and allocate space for the output + if (src_len < 1) { + INP_SIZE(); + } + if (src[0] == 0) { + *sequences = NULL; + return 0; + } else if (src[0] < 128) { + num_sequences = src[0]; + src++; + src_len--; + } else if (src[0] < 255) { + if (src_len < 2) { + INP_SIZE(); + } + num_sequences = ((src[0] - 128) << 8) + src[1]; + src += 2; + src_len -= 2; + } else { + if (src_len < 3) { + INP_SIZE(); + } + num_sequences = src[1] + ((u64)src[2] << 8) + 0x7F00; + src += 3; + src_len -= 3; + } + + *sequences = malloc(num_sequences * sizeof(sequence_command_t)); + if (!*sequences) { + BAD_ALLOC(); + } + + decompress_sequences(ctx, src, src_len, *sequences, num_sequences); + return num_sequences; +} + +/// Decompress the FSE encoded sequence commands +static void decompress_sequences(frame_context_t *ctx, const u8 *src, + size_t src_len, sequence_command_t *sequences, + size_t num_sequences) { + if (src_len < 1) { + INP_SIZE(); + } + u8 compression_modes = src[0]; + src++; + src_len--; + + if ((compression_modes & 3) != 0) { + CORRUPTION(); + } + + sequence_state_t state; + size_t read; + // Update the tables we have stored in the context + read = decode_seq_table(src, src_len, &ctx->ll_dtable, seq_literal_length, + (compression_modes >> 6) & 3); + src += read; + src_len -= read; + read = decode_seq_table(src, src_len, &ctx->of_dtable, seq_offset, + (compression_modes >> 4) & 3); + src += read; + src_len -= read; + read = decode_seq_table(src, src_len, &ctx->ml_dtable, seq_match_length, + (compression_modes >> 2) & 3); + src += read; + src_len -= read; + + // Check to make sure none of the tables are uninitialized + if (!ctx->ll_dtable.symbols || !ctx->of_dtable.symbols || + !ctx->ml_dtable.symbols) { + CORRUPTION(); + } + + // Now use the context's tables + memcpy(&state.ll_table, &ctx->ll_dtable, sizeof(FSE_dtable)); + memcpy(&state.of_table, &ctx->of_dtable, sizeof(FSE_dtable)); + memcpy(&state.ml_table, &ctx->ml_dtable, sizeof(FSE_dtable)); + + int padding = 8 - log2inf(src[src_len - 1]); + i64 offset = src_len * 8 - padding; + + FSE_init_state(&state.ll_table, &state.ll_state, src, &offset); + FSE_init_state(&state.of_table, &state.of_state, src, &offset); + FSE_init_state(&state.ml_table, &state.ml_state, src, &offset); + + for (size_t i = 0; i < num_sequences; i++) { + // Decode sequences one by one + sequences[i] = decode_sequence(&state, src, &offset); + } + + if (offset != 0) { + CORRUPTION(); + } + + // Don't free our tables so they can be used in the next block +} + +// Decode a single sequence and update the state +static sequence_command_t decode_sequence(sequence_state_t *state, + const u8 *src, i64 *offset) { + // Decode symbols, but don't update states + u8 of_code = FSE_peek_symbol(&state->of_table, state->of_state); + u8 ll_code = FSE_peek_symbol(&state->ll_table, state->ll_state); + u8 ml_code = FSE_peek_symbol(&state->ml_table, state->ml_state); + + // Offset doesn't need a max value as it's not decoded using a table + if (ll_code > SEQ_MAX_CODES[seq_literal_length] || + ml_code > SEQ_MAX_CODES[seq_match_length]) { + CORRUPTION(); + } + + // Read the interleaved bits + sequence_command_t seq; + // Offset computation works differently + seq.offset = ((u32)1 << of_code) + STREAM_read_bits(src, of_code, offset); + seq.match_length = + SEQ_MATCH_LENGTH_BASELINES[ml_code] + + STREAM_read_bits(src, SEQ_MATCH_LENGTH_EXTRA_BITS[ml_code], offset); + seq.literal_length = + SEQ_LITERAL_LENGTH_BASELINES[ll_code] + + STREAM_read_bits(src, SEQ_LITERAL_LENGTH_EXTRA_BITS[ll_code], offset); + + // If the stream is complete don't read bits to update state + if (*offset != 0) { + // Update state in the order specified in the specification + FSE_update_state(&state->ll_table, &state->ll_state, src, offset); + FSE_update_state(&state->ml_table, &state->ml_state, src, offset); + FSE_update_state(&state->of_table, &state->of_state, src, offset); + } + + return seq; +} + +/// Given a sequence part and table mode, decode the FSE distribution +static size_t decode_seq_table(const u8 *src, size_t src_len, FSE_dtable *table, + seq_part_t type, seq_mode_t mode) { + + // Constant arrays indexed by seq_part_t + const i16 *const default_distributions[] = {SEQ_LITERAL_LENGTH_DEFAULT_DIST, + SEQ_OFFSET_DEFAULT_DIST, + SEQ_MATCH_LENGTH_DEFAULT_DIST}; + const size_t default_distribution_lengths[] = {36, 29, 53}; + const size_t default_distribution_accuracies[] = {6, 5, 6}; + + const size_t max_accuracies[] = {9, 8, 9}; + + if (mode != seq_repeat) { + // ree old one before overwriting + FSE_free_dtable(table); + } + + switch (mode) { + case seq_predefined: { + const i16 *distribution = default_distributions[type]; + const size_t symbs = default_distribution_lengths[type]; + const size_t accuracy_log = default_distribution_accuracies[type]; + + FSE_init_dtable(table, distribution, symbs, accuracy_log); + + return 0; + } + case seq_rle: { + if (src_len < 1) { + INP_SIZE(); + } + u8 symb = src[0]; + src++; + src_len--; + FSE_init_dtable_rle(table, symb); + + return 1; + } + case seq_fse: { + size_t read = + FSE_decode_header(table, src, src_len, max_accuracies[type]); + src += read; + src_len -= read; + + return read; + } + case seq_repeat: + // Don't have to do anything here as we're not changing the table + return 0; + default: + // Impossible, as mode is from 0-3 + return -1; + } +} +/******* END SEQUENCE DECODING ************************************************/ + +/******* SEQUENCE EXECUTION ***************************************************/ +static size_t execute_sequences(io_streams_t *streams, frame_context_t *ctx, + sequence_command_t *sequences, + size_t num_sequences, const u8 *literals, + size_t literals_len) { + u64 *offset_hist = ctx->previous_offsets; + size_t total_output = ctx->current_total_output; + + for (size_t i = 0; i < num_sequences; i++) { + sequence_command_t seq = sequences[i]; + + if (seq.literal_length > literals_len) { + CORRUPTION(); + } + + { + // Copy literals to the buffer + size_t written = + cbuf_write_data_full(&ctx->window, literals, seq.literal_length, + streams->dst, streams->dst_len); + + literals += seq.literal_length; + literals_len -= seq.literal_length; + + streams->dst += written; + streams->dst_len -= written; + + total_output += seq.literal_length; + } + + size_t offset; + + // Offsets are special, we need to handle the repeat offsets + if (seq.offset <= 3) { + u32 idx = seq.offset; + if (seq.literal_length == 0) { + // Special case when literal length is 0 + idx++; + } + + if (idx == 1) { + offset = offset_hist[1]; + } else { + // If idx == 4 then literal length was 0 and the offset was 3 + offset = idx < 4 ? offset_hist[idx] : offset_hist[1] - 1; + + // If idx == 2 we don't need to modify offset_hist[3] + if (idx > 2) { + offset_hist[3] = offset_hist[2]; + } + offset_hist[2] = offset_hist[1]; + offset_hist[1] = offset; + } + } else { + offset = seq.offset - 3; + + // Shift back history + offset_hist[3] = offset_hist[2]; + offset_hist[2] = offset_hist[1]; + offset_hist[1] = offset; + } + + if (offset > total_output) { + CORRUPTION(); + } + + { + // Do the offset copy operation + size_t written = + cbuf_copy_offset_full(&ctx->window, offset, seq.match_length, + streams->dst, streams->dst_len); + + streams->dst += written; + streams->dst_len -= written; + total_output += seq.match_length; + } + } + + { + // Copy any leftover literal bytes + size_t written = + cbuf_write_data_full(&ctx->window, literals, literals_len, + streams->dst, streams->dst_len); + streams->dst += written; + streams->dst_len -= written; + + total_output += literals_len; + } + + ctx->current_total_output = total_output; + + return total_output; +} +/******* END SEQUENCE EXECUTION ***********************************************/ + +/******* DICTIONARY PARSING ***************************************************/ +static void init_raw_content_dict(dictionary_t *dict, const u8 *src, + size_t src_len); + +static void parse_dictionary(dictionary_t *dict, const u8 *src, + size_t src_len) { + memset(dict, 0, sizeof(dictionary_t)); + if (src_len < 8) { + INP_SIZE(); + } + u32 magic_number = read_bits_LE(src, 32, 0); + if (magic_number != 0xEC30A437) { + // raw content dict + init_raw_content_dict(dict, src, src_len); + return; + } + dict->dictionary_id = read_bits_LE(src, 32, 32); + + src += 8; + src_len -= 8; + + // Parse the provided entropy tables in order + { + size_t read = decode_huf_table(src, src_len, &dict->literals_dtable); + src += read; + src_len -= read; + } + { + size_t read = decode_seq_table(src, src_len, &dict->of_dtable, + seq_offset, seq_fse); + src += read; + src_len -= read; + } + { + size_t read = decode_seq_table(src, src_len, &dict->ml_dtable, + seq_match_length, seq_fse); + src += read; + src_len -= read; + } + { + size_t read = decode_seq_table(src, src_len, &dict->ll_dtable, + seq_literal_length, seq_fse); + src += read; + src_len -= read; + } + + if (src_len < 12) { + INP_SIZE(); + } + // Read in the previous offset history + dict->previous_offsets[1] = read_bits_LE(src, 32, 0); + dict->previous_offsets[2] = read_bits_LE(src, 32, 32); + dict->previous_offsets[3] = read_bits_LE(src, 32, 64); + + src += 12; + src_len -= 12; + + // Ensure the provided offsets aren't too large + for (int i = 1; i <= 3; i++) { + if (dict->previous_offsets[i] > src_len) { + ERROR("Dictionary corrupted"); + } + } + // The rest is the content + dict->content = malloc(src_len); + if (!dict->content) { + BAD_ALLOC(); + } + + dict->content_size = src_len; + memcpy(dict->content, src, src_len); +} + +/// If parse_dictionary is given a raw content dictionary, it delegates here +static void init_raw_content_dict(dictionary_t *dict, const u8 *src, + size_t src_len) { + dict->dictionary_id = 0; + // Copy in the content + dict->content = malloc(src_len); + if (!dict->content) { + BAD_ALLOC(); + } + + dict->content_size = src_len; + memcpy(dict->content, src, src_len); +} + +/// Free an allocated dictionary +static void free_dictionary(dictionary_t *dict) { + HUF_free_dtable(&dict->literals_dtable); + FSE_free_dtable(&dict->ll_dtable); + FSE_free_dtable(&dict->of_dtable); + FSE_free_dtable(&dict->ml_dtable); + + free(dict->content); + + memset(dict, 0, sizeof(dictionary_t)); +} +/******* END DICTIONARY PARSING ***********************************************/ + +/******* CIRCULAR BUFFER ******************************************************/ +static void cbuf_init(cbuf_t *buf, size_t size) { + buf->ptr = malloc(size); + + if (!buf->ptr) { + BAD_ALLOC(); + } + + memset(buf->ptr, 0x3f, size); + + buf->size = size; + buf->idx = 0; + buf->last_flush = 0; +} + +static size_t cbuf_write_data(cbuf_t *buf, const u8 *src, size_t src_len) { + if (buf->size == 0 && src_len > 0) { + CORRUPTION(); + } + size_t max_len = buf->size - buf->idx; + size_t len = MIN(src_len, max_len); + + memcpy(buf->ptr + buf->idx, src, len); + + buf->idx += len; + + return len; +} + +static size_t cbuf_write_data_full(cbuf_t *buf, const u8 *src, size_t src_len, + u8 *out, size_t out_len) { + size_t written = 0; + size_t flushed = 0; + while (1) { + written += cbuf_write_data(buf, src + written, src_len - written); + if (written == src_len) { + break; + } else { + flushed += cbuf_flush(buf, out + flushed, out_len - flushed); + } + } + + return flushed; +} + +static size_t cbuf_copy_offset(cbuf_t *buf, size_t offset, size_t len) { + if (buf->size == 0 && len > 0) { + CORRUPTION(); + } + if (offset > buf->size) { + CORRUPTION(); + } + size_t max_len = buf->size - buf->idx; + len = MIN(len, max_len); + + size_t read_off = (buf->idx + buf->size - offset) % buf->size; + + for (size_t i = 0; i < len; i++) { + buf->ptr[buf->idx++] = buf->ptr[read_off++]; + if (read_off == buf->size) { + read_off = 0; + } + } + + return len; +} + +static size_t cbuf_copy_offset_full(cbuf_t *buf, size_t offset, size_t len, + u8 *out, size_t out_len) { + size_t written = 0; + size_t flushed = 0; + while (1) { + written += cbuf_copy_offset(buf, offset, len - written); + if (written == len) { + break; + } else { + flushed += cbuf_flush(buf, out + flushed, out_len - flushed); + } + } + + return flushed; +} + +static size_t cbuf_repeat_byte(cbuf_t *buf, u8 byte, size_t len) { + if (buf->size == 0 && len > 0) { + CORRUPTION(); + } + size_t max_len = buf->size - buf->idx; + len = MIN(len, max_len); + + memset(buf->ptr + buf->idx, byte, len); + + return len; +} + +static size_t cbuf_repeat_byte_full(cbuf_t *buf, u8 byte, size_t len, u8 *out, + size_t out_len) { + size_t written = 0; + size_t flushed = 0; + while (1) { + written += cbuf_repeat_byte(buf, byte, len - written); + if (written == len) { + break; + } else { + flushed += cbuf_flush(buf, out + flushed, out_len - flushed); + } + } + + return flushed; +} + +static size_t cbuf_flush(cbuf_t *buf, u8 *dst, size_t dst_len) { + if (buf->idx < buf->last_flush) { + CORRUPTION(); + } + + size_t len = buf->idx - buf->last_flush; + + if (dst && len > dst_len) { + OUT_SIZE(); + } + + // allow for NULL buffers to indicate flushing to nowhere + if (dst) { + memcpy(dst, buf->ptr + buf->last_flush, len); + } + + // we could have a 0 size buffer + if (buf->size) { + buf->idx = buf->idx % buf->size; + } + buf->last_flush = buf->idx; + + return len; +} + +static void cbuf_free(cbuf_t *buf) { + free(buf->ptr); + memset(buf, 0, sizeof(cbuf_t)); +} +/******* END CIRCULAR BUFFER **************************************************/ + +/******* BITSTREAM OPERATIONS *************************************************/ +static inline u64 read_bits_LE(const u8 *src, int num, size_t offset) { + if (num > 64) { + return -1; + } + + src += offset / 8; + offset %= 8; + u64 res = 0; + + int shift = 0; + int left = num; + while (left > 0) { + u64 mask = left >= 8 ? 0xff : (((u64)1 << left) - 1); + res += (((u64)*src++ >> offset) & mask) << shift; + shift += 8 - offset; + left -= 8 - offset; + offset = 0; + } + + return res; +} + +static inline u64 STREAM_read_bits(const u8 *src, int bits, i64 *offset) { + *offset = *offset - bits; + size_t actual_off = *offset; + if (*offset < 0) { + bits += *offset; + actual_off = 0; + } + u64 res = read_bits_LE(src, bits, actual_off); + + if (*offset < 0) { + // Fill in the bottom "overflowed" bits with 0's + res = -*offset >= 64 ? 0 : (res << -*offset); + } + return res; +} +/******* END BITSTREAM OPERATIONS *********************************************/ + +/******* BIT COUNTING OPERATIONS **********************************************/ +static inline int log2sup(u64 num) { + for (int i = 0; i < 64; i++) { + if (((u64)1 << i) >= num) { + return i; + } + } + return -1; +} + +static inline int log2inf(u64 num) { + for (int i = 63; i >= 0; i--) { + if (((u64)1 << i) <= num) { + return i; + } + } + return -1; +} +/******* END BIT COUNTING OPERATIONS ******************************************/ + +/******* HUFFMAN PRIMITIVES ***************************************************/ +static inline u8 HUF_decode_symbol(HUF_dtable *dtable, u16 *state, + const u8 *src, i64 *offset) { + // Look up the symbol and number of bits to read + const u8 symb = dtable->symbols[*state]; + const u8 bits = dtable->num_bits[*state]; + const u16 rest = STREAM_read_bits(src, bits, offset); + *state = ((*state << bits) + rest) & (((u16)1 << dtable->max_bits) - 1); + + return symb; +} + +static inline void HUF_init_state(HUF_dtable *dtable, u16 *state, const u8 *src, + i64 *offset) { + // Read in a full dtable->max_bits to initialize the state + const u8 bits = dtable->max_bits; + *state = STREAM_read_bits(src, bits, offset); +} + +static size_t HUF_decompress_1stream(HUF_dtable *dtable, u8 *dst, + size_t dst_len, const u8 *src, + size_t src_len) { + u8 *const dst_max = dst + dst_len; + u8 *const odst = dst; + + // To maintain similarity with FSE, start from the end + // Find the last 1 bit + int padding = 8 - log2inf(src[src_len - 1]); + + i64 offset = src_len * 8 - padding; + u16 state; + + HUF_init_state(dtable, &state, src, &offset); + + while (dst < dst_max && offset > -dtable->max_bits) { + *dst++ = HUF_decode_symbol(dtable, &state, src, &offset); + } + // If we stopped before consuming all the input, we didn't have enough space + if (dst == dst_max && offset > -dtable->max_bits) { + OUT_SIZE(); + } + + // The current state should be the `max_bits` preceding the start as + // everything from `src` onward should be consumed + if (offset != -dtable->max_bits) { + CORRUPTION(); + } + + return dst - odst; +} + +static size_t HUF_decompress_4stream(HUF_dtable *dtable, u8 *dst, + size_t dst_len, const u8 *src, + size_t src_len) { + // Decode each stream independently for simplicity + // If we wanted to we could decode all 4 at the same time for speed, + // utilizing + // more execution units + + const u8 *src1, *src2, *src3, *src4, *src_end; + u8 *dst1, *dst2, *dst3, *dst4, *dst_end; + + size_t total_out = 0; + + if (src_len < 6) { + INP_SIZE(); + } + + src1 = src + 6; + src2 = src1 + read_bits_LE(src, 16, 0); + src3 = src2 + read_bits_LE(src, 16, 16); + src4 = src3 + read_bits_LE(src, 16, 32); + src_end = src + src_len; + + // We can't test with all 4 sizes because the 4th size is a function of the + // other 3 and the provided length + if (src4 - src >= src_len) { + INP_SIZE(); + } + + size_t segment_size = (dst_len + 3) / 4; + dst1 = dst; + dst2 = dst1 + segment_size; + dst3 = dst2 + segment_size; + dst4 = dst3 + segment_size; + dst_end = dst + dst_len; + + total_out += + HUF_decompress_1stream(dtable, dst1, segment_size, src1, src2 - src1); + total_out += + HUF_decompress_1stream(dtable, dst2, segment_size, src2, src3 - src2); + total_out += + HUF_decompress_1stream(dtable, dst3, segment_size, src3, src4 - src3); + total_out += HUF_decompress_1stream(dtable, dst4, dst_end - dst4, src4, + src_end - src4); + + return total_out; +} + +static void HUF_init_dtable(HUF_dtable *table, u8 *bits, int num_symbs) { + memset(table, 0, sizeof(HUF_dtable)); + if (num_symbs > HUF_MAX_SYMBS) { + ERROR("Too many symbols for Huffman"); + } + + u8 max_bits = 0; + u16 rank_count[HUF_MAX_BITS + 1]; + memset(rank_count, 0, sizeof(rank_count)); + + // Count the number of symbols for each number of bits, and determine the + // depth of the tree + for (int i = 0; i < num_symbs; i++) { + if (bits[i] > HUF_MAX_BITS) { + ERROR("Huffman table depth too large"); + } + max_bits = MAX(max_bits, bits[i]); + rank_count[bits[i]]++; + } + + size_t table_size = 1 << max_bits; + table->max_bits = max_bits; + table->symbols = malloc(table_size); + table->num_bits = malloc(table_size); + + if (!table->symbols || !table->num_bits) { + free(table->symbols); + free(table->num_bits); + BAD_ALLOC(); + } + + u32 rank_idx[HUF_MAX_BITS + 1]; + // Initialize the starting codes for each rank (number of bits) + rank_idx[max_bits] = 0; + for (int i = max_bits; i >= 1; i--) { + rank_idx[i - 1] = rank_idx[i] + rank_count[i] * (1 << (max_bits - i)); + // The entire range takes the same number of bits so we can memset it + memset(&table->num_bits[rank_idx[i]], i, rank_idx[i - 1] - rank_idx[i]); + } + + if (rank_idx[0] != table_size) { + CORRUPTION(); + } + + // Allocate codes and fill in the table + for (int i = 0; i < num_symbs; i++) { + if (bits[i] != 0) { + // Allocate a code for this symbol and set its range in the table + const u16 code = rank_idx[bits[i]]; + const u16 len = 1 << (max_bits - bits[i]); + memset(&table->symbols[code], i, len); + rank_idx[bits[i]] += len; + } + } +} + +static void HUF_init_dtable_usingweights(HUF_dtable *table, u8 *weights, + int num_symbs) { + // +1 because the last weight is not transmitted in the header + if (num_symbs + 1 > HUF_MAX_SYMBS) { + ERROR("Too many symbols for Huffman"); + } + + u8 bits[HUF_MAX_SYMBS]; + + u64 weight_sum = 0; + for (int i = 0; i < num_symbs; i++) { + weight_sum += weights[i] > 0 ? (u64)1 << (weights[i] - 1) : 0; + } + + // Find the first power of 2 larger than the sum + int max_bits = log2inf(weight_sum) + 1; + u64 left_over = ((u64)1 << max_bits) - weight_sum; + // If the left over isn't a power of 2, the weights are invalid + if (left_over & (left_over - 1)) { + CORRUPTION(); + } + + int last_weight = log2inf(left_over) + 1; + + for (int i = 0; i < num_symbs; i++) { + bits[i] = weights[i] > 0 ? (max_bits + 1 - weights[i]) : 0; + } + bits[num_symbs] = + max_bits + 1 - last_weight; // last weight is always non-zero + + HUF_init_dtable(table, bits, num_symbs + 1); +} + +static void HUF_free_dtable(HUF_dtable *dtable) { + free(dtable->symbols); + free(dtable->num_bits); + memset(dtable, 0, sizeof(HUF_dtable)); +} + +static void HUF_copy_dtable(HUF_dtable *dst, const HUF_dtable *src) { + if (src->max_bits == 0) { + memset(dst, 0, sizeof(HUF_dtable)); + return; + } + + size_t size = (size_t)1 << src->max_bits; + dst->max_bits = src->max_bits; + + dst->symbols = malloc(size); + dst->num_bits = malloc(size); + if (!dst->symbols || !dst->num_bits) { + BAD_ALLOC(); + } + + memcpy(dst->symbols, src->symbols, size); + memcpy(dst->num_bits, src->num_bits, size); +} +/******* END HUFFMAN PRIMITIVES ***********************************************/ + +/******* FSE PRIMITIVES *******************************************************/ +static inline u8 FSE_peek_symbol(FSE_dtable *dtable, u16 state) { + return dtable->symbols[state]; +} + +static inline void FSE_update_state(FSE_dtable *dtable, u16 *state, + const u8 *src, i64 *offset) { + const u8 bits = dtable->num_bits[*state]; + const u16 rest = STREAM_read_bits(src, bits, offset); + *state = dtable->new_state_base[*state] + rest; +} + +// Decodes a single FSE symbol and updates the offset +static inline u8 FSE_decode_symbol(FSE_dtable *dtable, u16 *state, + const u8 *src, i64 *offset) { + const u8 symb = FSE_peek_symbol(dtable, *state); + FSE_update_state(dtable, state, src, offset); + return symb; +} + +static inline void FSE_init_state(FSE_dtable *dtable, u16 *state, const u8 *src, + i64 *offset) { + const u8 bits = dtable->accuracy_log; + *state = STREAM_read_bits(src, bits, offset); +} + +static size_t FSE_decompress_interleaved2(FSE_dtable *dtable, u8 *dst, + size_t dst_len, const u8 *src, + size_t src_len) { + if (src_len == 0) { + INP_SIZE(); + } + + u8 *dst_max = dst + dst_len; + u8 *const odst = dst; + + // Find the last 1 bit + int padding = 8 - log2inf(src[src_len - 1]); + + i64 offset = src_len * 8 - padding; + + u16 state1, state2; + FSE_init_state(dtable, &state1, src, &offset); + FSE_init_state(dtable, &state2, src, &offset); + + // Decode until we overflow the stream + // Since we decode in reverse order, overflowing the stream is offset going + // negative + while (1) { + if (dst > dst_max - 2) { + OUT_SIZE(); + } + *dst++ = FSE_decode_symbol(dtable, &state1, src, &offset); + if (offset < 0) { + // There's still a symbol to decode in state2 + *dst++ = FSE_decode_symbol(dtable, &state2, src, &offset); + break; + } + + if (dst > dst_max - 2) { + OUT_SIZE(); + } + *dst++ = FSE_decode_symbol(dtable, &state2, src, &offset); + if (offset < 0) { + // There's still a symbol to decode in state1 + *dst++ = FSE_decode_symbol(dtable, &state1, src, &offset); + break; + } + } + + // number of symbols read + return dst - odst; +} + +static void FSE_init_dtable(FSE_dtable *dtable, const i16 *norm_freqs, + int num_symbs, int accuracy_log) { + if (accuracy_log > FSE_MAX_ACCURACY_LOG) { + ERROR("FSE accuracy too large"); + } + if (num_symbs > FSE_MAX_SYMBS) { + ERROR("Too many symbols for FSE"); + } + + dtable->accuracy_log = accuracy_log; + + size_t size = (size_t)1 << accuracy_log; + dtable->symbols = malloc(size * sizeof(u8)); + dtable->num_bits = malloc(size * sizeof(u8)); + dtable->new_state_base = malloc(size * sizeof(u16)); + + // Used to determine how many bits need to be read for each state, + // and where the destination range should start + // Needs to be u16 because max value is 2 * max number of symbols, + // which can be larger than a byte can store + u16 state_desc[FSE_MAX_SYMBS]; + + int high_threshold = size; + for (int s = 0; s < num_symbs; s++) { + // Scan for low probability symbols to put at the top + if (norm_freqs[s] == -1) { + dtable->symbols[--high_threshold] = s; + state_desc[s] = 1; + } + } + + // Place the rest in the table + u16 step = (size >> 1) + (size >> 3) + 3; + u16 mask = size - 1; + u16 pos = 0; + for (int s = 0; s < num_symbs; s++) { + if (norm_freqs[s] <= 0) { + continue; + } + + state_desc[s] = norm_freqs[s]; + + for (int i = 0; i < norm_freqs[s]; i++) { + dtable->symbols[pos] = s; + do { + pos = (pos + step) & mask; + } while (pos >= + high_threshold); // Make sure we don't occupy a spot taken + // by the low prob symbols + // Note: no other collision checking is necessary as `step` is + // coprime to + // `size`, so the cycle will visit each position exactly once + } + } + if (pos != 0) { + CORRUPTION(); + } + + // Now we can fill baseline and num bits + for (int i = 0; i < size; i++) { + u8 symbol = dtable->symbols[i]; + u16 next_state_desc = state_desc[symbol]++; + // Fills in the table appropriately + // next_state_desc increases by symbol over time, decreasing number of + // bits + dtable->num_bits[i] = (u8)(accuracy_log - log2inf(next_state_desc)); + // baseline increases until the bit threshold is passed, at which point + // it + // resets to 0 + dtable->new_state_base[i] = + ((u16)next_state_desc << dtable->num_bits[i]) - size; + } +} + +static size_t FSE_decode_header(FSE_dtable *dtable, const u8 *src, + size_t src_len, int max_accuracy_log) { + if (max_accuracy_log > FSE_MAX_ACCURACY_LOG) { + ERROR("FSE accuracy too large"); + } + if (src_len < 1) { + INP_SIZE(); + } + + int accuracy_log = 5 + read_bits_LE(src, 4, 0); + if (accuracy_log > max_accuracy_log) { + ERROR("FSE accuracy too large"); + } + + // The +1 facilitates the `-1` probabilities + i32 remaining = (1 << accuracy_log) + 1; + i16 frequencies[FSE_MAX_SYMBS]; + + int symb = 0; + size_t offset = 4; + while (remaining > 1 && symb < FSE_MAX_SYMBS) { + int bits = log2sup(remaining + + 1); // the number of possible values we could read + u16 val = read_bits_LE(src, bits, offset); + offset += bits; + + // try to mask out the lower bits to see if it qualifies for the "small + // value" threshold + u16 lower_mask = ((u16)1 << (bits - 1)) - 1; + u16 threshold = ((u16)1 << bits) - 1 - remaining; + + if ((val & lower_mask) < threshold) { + offset--; + val = val & lower_mask; + } else if (val > lower_mask) { + val = val - threshold; + } + + i16 proba = (i16)val - 1; + // a value of -1 is possible, and has special meaning + remaining -= proba < 0 ? -proba : proba; + + frequencies[symb] = proba; + symb++; + + // Handle the special probability = 0 case + if (proba == 0) { + // read the next two bits to see how many more 0s + int repeat = read_bits_LE(src, 2, offset); + offset += 2; + + while (1) { + for (int i = 0; i < repeat && symb < FSE_MAX_SYMBS; i++) { + frequencies[symb++] = 0; + } + if (repeat == 3) { + repeat = read_bits_LE(src, 2, offset); + offset += 2; + } else { + break; + } + } + } + } + + if (remaining != 1 || symb >= FSE_MAX_SYMBS) { + CORRUPTION(); + } + + // Initialize the decoding table using the determined weights + FSE_init_dtable(dtable, frequencies, symb, accuracy_log); + + return (offset + 7) / 8; +} + +static void FSE_init_dtable_rle(FSE_dtable *dtable, u8 symb) { + dtable->symbols = malloc(sizeof(u8)); + dtable->num_bits = malloc(sizeof(u8)); + dtable->new_state_base = malloc(sizeof(u16)); + + // This setup will always have a state of 0, always return symbol `symb`, + // and + // never consume any bits + dtable->symbols[0] = symb; + dtable->num_bits[0] = 0; + dtable->new_state_base[0] = 0; + dtable->accuracy_log = 0; +} + +static void FSE_free_dtable(FSE_dtable *dtable) { + free(dtable->symbols); + free(dtable->num_bits); + free(dtable->new_state_base); + memset(dtable, 0, sizeof(FSE_dtable)); +} + +static void FSE_copy_dtable(FSE_dtable *dst, const FSE_dtable *src) { + if (src->accuracy_log == 0) { + memset(dst, 0, sizeof(FSE_dtable)); + return; + } + + size_t size = (size_t)1 << src->accuracy_log; + dst->accuracy_log = src->accuracy_log; + + dst->symbols = malloc(size); + dst->num_bits = malloc(size); + dst->new_state_base = malloc(size * sizeof(u16)); + if (!dst->symbols || !dst->num_bits || !dst->new_state_base) { + BAD_ALLOC(); + } + + memcpy(dst->symbols, src->symbols, size); + memcpy(dst->num_bits, src->num_bits, size); + memcpy(dst->new_state_base, src->new_state_base, size * sizeof(u16)); +} +/******* END FSE PRIMITIVES ***************************************************/ + diff --git a/contrib/educational_decoder/zstd_decompress.h b/contrib/educational_decoder/zstd_decompress.h new file mode 100644 index 00000000..3671678b --- /dev/null +++ b/contrib/educational_decoder/zstd_decompress.h @@ -0,0 +1,6 @@ +size_t ZSTD_decompress(void *dst, size_t dst_len, const void *src, + size_t src_len); +size_t ZSTD_decompress_with_dict(void *dst, size_t dst_len, const void *src, + size_t src_len, const void *dict, + size_t dict_len); +