Streaming decompression can detect incorrect header ID sooner
Streaming decompression used to wait for a minimum of 5 bytes before attempting decoding. This meant that, in the case that only a few bytes (<5) were provided, and assuming these bytes are incorrect, there would be no error reported. The streaming API would simply request more data, waiting for at least 5 bytes. This PR makes it possible to detect incorrect Frame IDs as soon as the first byte is provided. Fix #3169
This commit is contained in:
parent
f6ef14329f
commit
91aeade735
@ -79,11 +79,11 @@
|
||||
*************************************/
|
||||
|
||||
#define DDICT_HASHSET_MAX_LOAD_FACTOR_COUNT_MULT 4
|
||||
#define DDICT_HASHSET_MAX_LOAD_FACTOR_SIZE_MULT 3 /* These two constants represent SIZE_MULT/COUNT_MULT load factor without using a float.
|
||||
* Currently, that means a 0.75 load factor.
|
||||
* So, if count * COUNT_MULT / size * SIZE_MULT != 0, then we've exceeded
|
||||
* the load factor of the ddict hash set.
|
||||
*/
|
||||
#define DDICT_HASHSET_MAX_LOAD_FACTOR_SIZE_MULT 3 /* These two constants represent SIZE_MULT/COUNT_MULT load factor without using a float.
|
||||
* Currently, that means a 0.75 load factor.
|
||||
* So, if count * COUNT_MULT / size * SIZE_MULT != 0, then we've exceeded
|
||||
* the load factor of the ddict hash set.
|
||||
*/
|
||||
|
||||
#define DDICT_HASHSET_TABLE_BASE_SIZE 64
|
||||
#define DDICT_HASHSET_RESIZE_FACTOR 2
|
||||
@ -439,16 +439,40 @@ size_t ZSTD_frameHeaderSize(const void* src, size_t srcSize)
|
||||
* note : only works for formats ZSTD_f_zstd1 and ZSTD_f_zstd1_magicless
|
||||
* @return : 0, `zfhPtr` is correctly filled,
|
||||
* >0, `srcSize` is too small, value is wanted `srcSize` amount,
|
||||
* or an error code, which can be tested using ZSTD_isError() */
|
||||
** or an error code, which can be tested using ZSTD_isError() */
|
||||
size_t ZSTD_getFrameHeader_advanced(ZSTD_frameHeader* zfhPtr, const void* src, size_t srcSize, ZSTD_format_e format)
|
||||
{
|
||||
const BYTE* ip = (const BYTE*)src;
|
||||
size_t const minInputSize = ZSTD_startingInputLength(format);
|
||||
|
||||
ZSTD_memset(zfhPtr, 0, sizeof(*zfhPtr)); /* not strictly necessary, but static analyzer do not understand that zfhPtr is only going to be read only if return value is zero, since they are 2 different signals */
|
||||
if (srcSize < minInputSize) return minInputSize;
|
||||
RETURN_ERROR_IF(src==NULL, GENERIC, "invalid parameter");
|
||||
DEBUGLOG(5, "ZSTD_getFrameHeader_advanced: minInputSize = %zu, srcSize = %zu", minInputSize, srcSize);
|
||||
|
||||
if (srcSize > 0) {
|
||||
/* note : technically could be considered an assert(), since it's an invalid entry */
|
||||
RETURN_ERROR_IF(src==NULL, GENERIC, "invalid parameter : src==NULL, but srcSize>0");
|
||||
}
|
||||
if (srcSize < minInputSize) {
|
||||
if (srcSize > 0 && format != ZSTD_f_zstd1_magicless) {
|
||||
/* when receiving less than @minInputSize bytes,
|
||||
* control these bytes at least correspond to a supported magic number
|
||||
* in order to error out early if they don't.
|
||||
**/
|
||||
size_t const toCopy = MIN(4, srcSize);
|
||||
unsigned char hbuf[4]; MEM_writeLE32(hbuf, ZSTD_MAGICNUMBER);
|
||||
assert(src != NULL);
|
||||
ZSTD_memcpy(hbuf, src, toCopy);
|
||||
if ( MEM_readLE32(hbuf) != ZSTD_MAGICNUMBER ) {
|
||||
/* not a zstd frame : let's check if it's a skippable frame */
|
||||
MEM_writeLE32(hbuf, ZSTD_MAGIC_SKIPPABLE_START);
|
||||
ZSTD_memcpy(hbuf, src, toCopy);
|
||||
if ((MEM_readLE32(hbuf) & ZSTD_MAGIC_SKIPPABLE_MASK) != ZSTD_MAGIC_SKIPPABLE_START) {
|
||||
RETURN_ERROR(prefix_unknown,
|
||||
"first bytes don't correspond to any supported magic number");
|
||||
} } }
|
||||
return minInputSize;
|
||||
}
|
||||
|
||||
ZSTD_memset(zfhPtr, 0, sizeof(*zfhPtr)); /* not strictly necessary, but static analyzers may not understand that zfhPtr will be read only if return value is zero, since they are 2 different signals */
|
||||
if ( (format != ZSTD_f_zstd1_magicless)
|
||||
&& (MEM_readLE32(src) != ZSTD_MAGICNUMBER) ) {
|
||||
if ((MEM_readLE32(src) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) {
|
||||
@ -1981,7 +2005,6 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
|
||||
if (zds->refMultipleDDicts && zds->ddictSet) {
|
||||
ZSTD_DCtx_selectFrameDDict(zds);
|
||||
}
|
||||
DEBUGLOG(5, "header size : %u", (U32)hSize);
|
||||
if (ZSTD_isError(hSize)) {
|
||||
#if defined(ZSTD_LEGACY_SUPPORT) && (ZSTD_LEGACY_SUPPORT>=1)
|
||||
U32 const legacyVersion = ZSTD_isLegacy(istart, iend-istart);
|
||||
@ -2013,6 +2036,11 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
|
||||
zds->lhSize += remainingInput;
|
||||
}
|
||||
input->pos = input->size;
|
||||
/* check first few bytes */
|
||||
FORWARD_IF_ERROR(
|
||||
ZSTD_getFrameHeader_advanced(&zds->fParams, zds->headerBuffer, zds->lhSize, zds->format),
|
||||
"First few bytes detected incorrect" );
|
||||
/* return hint input size */
|
||||
return (MAX((size_t)ZSTD_FRAMEHEADERSIZE_MIN(zds->format), hSize) - zds->lhSize) + ZSTD_blockHeaderSize; /* remaining header bytes + next block header */
|
||||
}
|
||||
assert(ip != NULL);
|
||||
|
@ -424,6 +424,15 @@ static int basicUnitTests(U32 seed, double compressibility)
|
||||
} }
|
||||
DISPLAYLEVEL(3, "OK \n");
|
||||
|
||||
/* check decompression fails early if first bytes are wrong */
|
||||
DISPLAYLEVEL(3, "test%3i : early decompression error if first bytes are incorrect : ", testNb++);
|
||||
{ const char buf[3] = { 0 }; /* too short, not enough to start decoding header */
|
||||
ZSTD_inBuffer inb = { buf, sizeof(buf), 0 };
|
||||
size_t const remaining = ZSTD_decompressStream(zd, &outBuff, &inb);
|
||||
if (!ZSTD_isError(remaining)) goto _output_error; /* should have errored out immediately (note: this does not test the exact error code) */
|
||||
}
|
||||
DISPLAYLEVEL(3, "OK \n");
|
||||
|
||||
/* context size functions */
|
||||
DISPLAYLEVEL(3, "test%3i : estimate DStream size : ", testNb++);
|
||||
{ ZSTD_frameHeader fhi;
|
||||
|
Loading…
x
Reference in New Issue
Block a user