Add error forwarding to loadCEntropy(), make check for dictSize >= 8 from bad merge

This commit is contained in:
Sen Huang 2019-11-07 13:58:35 -05:00
parent 4a61aaf368
commit 6ce335371b
3 changed files with 23 additions and 16 deletions

View File

@ -2775,6 +2775,10 @@ size_t ZSTD_loadCEntropy(ZSTD_compressedBlockState_t* bs, void* workspace,
const BYTE* dictPtr = (const BYTE*)dict + 8; /* skip magic num and dict ID */ const BYTE* dictPtr = (const BYTE*)dict + 8; /* skip magic num and dict ID */
const BYTE* const dictEnd = dictPtr + dictSize; const BYTE* const dictEnd = dictPtr + dictSize;
ZSTD_STATIC_ASSERT(HUF_WORKSPACE_SIZE >= (1<<MAX(MLFSELog,LLFSELog)));
assert(dictSize >= 8);
assert(MEM_readLE32(dictPtr) == ZSTD_MAGIC_DICTIONARY);
{ unsigned maxSymbolValue = 255; { unsigned maxSymbolValue = 255;
size_t const hufHeaderSize = HUF_readCTable((HUF_CElt*)bs->entropy.huf.CTable, &maxSymbolValue, dictPtr, dictEnd-dictPtr); size_t const hufHeaderSize = HUF_readCTable((HUF_CElt*)bs->entropy.huf.CTable, &maxSymbolValue, dictPtr, dictEnd-dictPtr);
RETURN_ERROR_IF(HUF_isError(hufHeaderSize), dictionary_corrupted); RETURN_ERROR_IF(HUF_isError(hufHeaderSize), dictionary_corrupted);
@ -2831,7 +2835,7 @@ size_t ZSTD_loadCEntropy(ZSTD_compressedBlockState_t* bs, void* workspace,
bs->rep[1] = MEM_readLE32(dictPtr+4); bs->rep[1] = MEM_readLE32(dictPtr+4);
bs->rep[2] = MEM_readLE32(dictPtr+8); bs->rep[2] = MEM_readLE32(dictPtr+8);
dictPtr += 12; dictPtr += 12;
DEBUGLOG(1, "size %u)", (unsigned)(dictPtr - (const BYTE*)dict));
return dictPtr - (const BYTE*)dict; return dictPtr - (const BYTE*)dict;
} }
@ -2859,17 +2863,10 @@ static size_t ZSTD_loadZstdDictionary(ZSTD_compressedBlockState_t* bs,
size_t dictID; size_t dictID;
size_t eSize; size_t eSize;
ZSTD_STATIC_ASSERT(HUF_WORKSPACE_SIZE >= (1<<MAX(MLFSELog,LLFSELog))); dictID = params->fParams.noDictIDFlag ? 0 : MEM_readLE32(dictPtr + 4 /* skip magic number */ );
assert(dictSize > 8);
assert(MEM_readLE32(dictPtr) == ZSTD_MAGIC_DICTIONARY);
eSize = ZSTD_loadCEntropy(bs, workspace, offcodeNCount, &offcodeMaxValue, dict, dictSize); eSize = ZSTD_loadCEntropy(bs, workspace, offcodeNCount, &offcodeMaxValue, dict, dictSize);
FORWARD_IF_ERROR(eSize);
dictPtr += 4; /* skip magic number */ dictPtr += eSize;
dictID = params->fParams.noDictIDFlag ? 0 : MEM_readLE32(dictPtr);
dictPtr += 4;
dictPtr += eSize - 8;
{ size_t const dictContentSize = (size_t)(dictEnd - dictPtr); { size_t const dictContentSize = (size_t)(dictEnd - dictPtr);
U32 offcodeMax = MaxOff; U32 offcodeMax = MaxOff;

View File

@ -930,6 +930,7 @@ MEM_STATIC void ZSTD_debugTable(const U32* table, U32 max)
#if defined (__cplusplus) #if defined (__cplusplus)
} }
#endif #endif
/* =============================================================== /* ===============================================================
* Shared internal declarations * Shared internal declarations
* These prototypes may be called from sources not in lib/compress * These prototypes may be called from sources not in lib/compress
@ -937,7 +938,9 @@ MEM_STATIC void ZSTD_debugTable(const U32* table, U32 max)
/* ZSTD_loadCEntropy() : /* ZSTD_loadCEntropy() :
* dict : must point at beginning of a valid zstd dictionary. * dict : must point at beginning of a valid zstd dictionary.
* return : size of dictionary header (size of magic number + dict ID + entropy tables) */ * return : size of dictionary header (size of magic number + dict ID + entropy tables)
* assumptions : magic number supposed already checked
* and dictSize >= 8 */
size_t ZSTD_loadCEntropy(ZSTD_compressedBlockState_t* bs, void* workspace, size_t ZSTD_loadCEntropy(ZSTD_compressedBlockState_t* bs, void* workspace,
short* offcodeNCount, unsigned* offcodeMaxValue, short* offcodeNCount, unsigned* offcodeMaxValue,
const void* const dict, size_t dictSize); const void* const dict, size_t dictSize);

View File

@ -106,15 +106,22 @@ size_t ZDICT_getDictHeaderSize(const void* dictBuffer, size_t dictSize)
{ size_t headerSize; { size_t headerSize;
unsigned offcodeMaxValue = MaxOff; unsigned offcodeMaxValue = MaxOff;
ZSTD_compressedBlockState_t* dummyBs = (ZSTD_compressedBlockState_t*)malloc(sizeof(ZSTD_compressedBlockState_t)); ZSTD_compressedBlockState_t* bs = (ZSTD_compressedBlockState_t*)malloc(sizeof(ZSTD_compressedBlockState_t));
U32* wksp = (U32*)malloc(HUF_WORKSPACE_SIZE); U32* wksp = (U32*)malloc(HUF_WORKSPACE_SIZE);
short* offcodeNCount = (short*)malloc((MaxOff+1)*sizeof(short)); short* offcodeNCount = (short*)malloc((MaxOff+1)*sizeof(short));
if (!dummyBs || !wksp || !offcodeNCount) { if (!bs || !wksp || !offcodeNCount) {
return ERROR(memory_allocation); return ERROR(memory_allocation);
} }
headerSize = ZSTD_loadCEntropy(dummyBs, wksp, offcodeNCount, &offcodeMaxValue, dictBuffer, dictSize); int i;
free(dummyBs); for (i = 0; i < ZSTD_REP_NUM; ++i)
bs->rep[i] = repStartValue[i];
bs->entropy.huf.repeatMode = HUF_repeat_none;
bs->entropy.fse.offcode_repeatMode = FSE_repeat_none;
bs->entropy.fse.matchlength_repeatMode = FSE_repeat_none;
bs->entropy.fse.litlength_repeatMode = FSE_repeat_none;
headerSize = ZSTD_loadCEntropy(bs, wksp, offcodeNCount, &offcodeMaxValue, dictBuffer, dictSize);
free(bs);
free(wksp); free(wksp);
free(offcodeNCount); free(offcodeNCount);
return headerSize; return headerSize;