diff --git a/lib/zstd_decompress.c b/lib/zstd_decompress.c index 6c5c293b..7de1f29d 100644 --- a/lib/zstd_decompress.c +++ b/lib/zstd_decompress.c @@ -128,7 +128,7 @@ struct ZSTD_DCtx_s ZSTD_frameParams fParams; blockType_t bType; /* used in ZSTD_decompressContinue(), to transfer blockType between header decoding and block decoding stages */ ZSTD_dStage stage; - U32 flagStaticTables; + U32 flagRepeatTable; const BYTE* litPtr; size_t litBufSize; size_t litSize; @@ -147,7 +147,7 @@ size_t ZSTD_decompressBegin(ZSTD_DCtx* dctx) dctx->vBase = NULL; dctx->dictEnd = NULL; dctx->hufTableX4[0] = HufLog; - dctx->flagStaticTables = 0; + dctx->flagRepeatTable = 0; return 0; } @@ -419,7 +419,7 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, U32 lhSize = ((istart[0]) >> 4) & 3; if (lhSize != 1) /* only case supported for now : small litSize, single stream */ return ERROR(corruption_detected); - if (!dctx->flagStaticTables) + if (!dctx->flagRepeatTable) return ERROR(dictionary_corrupted); /* 2 - 2 - 10 - 10 */ @@ -503,7 +503,7 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, */ FORCE_INLINE size_t ZSTD_buildSeqTable(FSE_DTable* DTable, U32 type, U32 max, U32 maxLog, const void* src, size_t srcSize, - const S16* defaultNorm, U32 defaultLog) + const S16* defaultNorm, U32 defaultLog, U32 flagRepeatTable) { switch(type) { @@ -516,6 +516,7 @@ FORCE_INLINE size_t ZSTD_buildSeqTable(FSE_DTable* DTable, U32 type, U32 max, U3 FSE_buildDTable(DTable, defaultNorm, max, defaultLog); return 0; case FSE_ENCODING_STATIC: + if (!flagRepeatTable) return ERROR(corruption_detected); return 0; default : /* impossible */ case FSE_ENCODING_DYNAMIC : @@ -531,7 +532,7 @@ FORCE_INLINE size_t ZSTD_buildSeqTable(FSE_DTable* DTable, U32 type, U32 max, U3 size_t ZSTD_decodeSeqHeaders(int* nbSeqPtr, - FSE_DTable* DTableLL, FSE_DTable* DTableML, FSE_DTable* DTableOffb, + FSE_DTable* DTableLL, FSE_DTable* DTableML, FSE_DTable* DTableOffb, U32 flagRepeatTable, const void* src, size_t srcSize) { const BYTE* const istart = (const BYTE* const)src; @@ -563,15 +564,15 @@ size_t ZSTD_decodeSeqHeaders(int* nbSeqPtr, if (ip > iend-3) return ERROR(srcSize_wrong); /* min : all 3 are "raw", hence no header, but at least xxLog bits per type */ /* Build DTables */ - { size_t const bhSize = ZSTD_buildSeqTable(DTableLL, LLtype, MaxLL, LLFSELog, ip, iend-ip, LL_defaultNorm, LL_defaultNormLog); + { size_t const bhSize = ZSTD_buildSeqTable(DTableLL, LLtype, MaxLL, LLFSELog, ip, iend-ip, LL_defaultNorm, LL_defaultNormLog, flagRepeatTable); if (ZSTD_isError(bhSize)) return ERROR(corruption_detected); ip += bhSize; } - { size_t const bhSize = ZSTD_buildSeqTable(DTableOffb, Offtype, MaxOff, OffFSELog, ip, iend-ip, OF_defaultNorm, OF_defaultNormLog); + { size_t const bhSize = ZSTD_buildSeqTable(DTableOffb, Offtype, MaxOff, OffFSELog, ip, iend-ip, OF_defaultNorm, OF_defaultNormLog, flagRepeatTable); if (ZSTD_isError(bhSize)) return ERROR(corruption_detected); ip += bhSize; } - { size_t const bhSize = ZSTD_buildSeqTable(DTableML, MLtype, MaxML, MLFSELog, ip, iend-ip, ML_defaultNorm, ML_defaultNormLog); + { size_t const bhSize = ZSTD_buildSeqTable(DTableML, MLtype, MaxML, MLFSELog, ip, iend-ip, ML_defaultNorm, ML_defaultNormLog, flagRepeatTable); if (ZSTD_isError(bhSize)) return ERROR(corruption_detected); ip += bhSize; } } @@ -765,9 +766,10 @@ static size_t ZSTD_decompressSequences( int nbSeq; /* Build Decoding Tables */ - { size_t const seqHSize = ZSTD_decodeSeqHeaders(&nbSeq, DTableLL, DTableML, DTableOffb, ip, seqSize); + { size_t const seqHSize = ZSTD_decodeSeqHeaders(&nbSeq, DTableLL, DTableML, DTableOffb, dctx->flagRepeatTable, ip, seqSize); if (ZSTD_isError(seqHSize)) return seqHSize; ip += seqHSize; + dctx->flagRepeatTable = 1; } /* Regen sequences */ @@ -1089,7 +1091,7 @@ static size_t ZSTD_loadEntropy(ZSTD_DCtx* dctx, const void* dict, size_t dictSiz errorCode = FSE_buildDTable(dctx->LLTable, litlengthNCount, litlengthMaxValue, litlengthLog); if (FSE_isError(errorCode)) return ERROR(dictionary_corrupted); - dctx->flagStaticTables = 1; + dctx->flagRepeatTable = 1; return hSize + offcodeHeaderSize + matchlengthHeaderSize + litlengthHeaderSize; }