diff --git a/lib/compress/zstd_compress.c b/lib/compress/zstd_compress.c index c0ccf83c..a4a3ee08 100644 --- a/lib/compress/zstd_compress.c +++ b/lib/compress/zstd_compress.c @@ -3034,17 +3034,18 @@ size_t ZSTD_resetCStream(ZSTD_CStream* zcs, unsigned long long pledgedSrcSize) { zcs->params.fParams.contentSizeFlag = (pledgedSrcSize > 0); - return ZSTD_resetCStream_internal(zcs, pledgedSrcSize); } /* ZSTD_initCStream_internal() : - * params are supposed validated at this stage */ -static size_t ZSTD_initCStream_internal(ZSTD_CStream* zcs, - const void* dict, size_t dictSize, - ZSTD_parameters params, unsigned long long pledgedSrcSize) + * params are supposed validated at this stage + * and zcs->cdict is supposed to be correct */ +static size_t ZSTD_initCStream_stage2(ZSTD_CStream* zcs, + const ZSTD_parameters params, + unsigned long long pledgedSrcSize) { assert(!ZSTD_isError(ZSTD_checkCParams(params.cParams))); + /* allocate buffers */ { size_t const neededInBuffSize = (size_t)1 << params.cParams.windowLog; if (zcs->inBuffSize < neededInBuffSize) { @@ -3065,19 +3066,39 @@ static size_t ZSTD_initCStream_internal(ZSTD_CStream* zcs, zcs->outBuffSize = outBuffSize; } - if (dict && dictSize >= 8) { - ZSTD_freeCDict(zcs->cdictLocal); - zcs->cdictLocal = ZSTD_createCDict_advanced(dict, dictSize, 0, params, zcs->customMem); - if (zcs->cdictLocal == NULL) return ERROR(memory_allocation); - zcs->cdict = zcs->cdictLocal; - } else zcs->cdict = NULL; - zcs->checksum = params.fParams.checksumFlag > 0; zcs->params = params; return ZSTD_resetCStream_internal(zcs, pledgedSrcSize); } +/* note : cdict must outlive compression session */ +size_t ZSTD_initCStream_usingCDict(ZSTD_CStream* zcs, const ZSTD_CDict* cdict) +{ + if (!cdict) return ERROR(GENERIC); /* cannot handle NULL cdict (does not know what to do) */ + { ZSTD_parameters const params = ZSTD_getParamsFromCDict(cdict); + zcs->cdict = cdict; + return ZSTD_initCStream_stage2(zcs, params, 0); + } +} + +static size_t ZSTD_initCStream_internal(ZSTD_CStream* zcs, + const void* dict, size_t dictSize, + ZSTD_parameters params, unsigned long long pledgedSrcSize) +{ + assert(!ZSTD_isError(ZSTD_checkCParams(params.cParams))); + zcs->cdict = NULL; + + if (dict && dictSize >= 8) { + ZSTD_freeCDict(zcs->cdictLocal); + zcs->cdictLocal = ZSTD_createCDict_advanced(dict, dictSize, 0 /* copy */, params, zcs->customMem); + if (zcs->cdictLocal == NULL) return ERROR(memory_allocation); + zcs->cdict = zcs->cdictLocal; + } + + return ZSTD_initCStream_stage2(zcs, params, pledgedSrcSize); +} + size_t ZSTD_initCStream_advanced(ZSTD_CStream* zcs, const void* dict, size_t dictSize, ZSTD_parameters params, unsigned long long pledgedSrcSize) @@ -3086,16 +3107,6 @@ size_t ZSTD_initCStream_advanced(ZSTD_CStream* zcs, return ZSTD_initCStream_internal(zcs, dict, dictSize, params, pledgedSrcSize); } -/* note : cdict must outlive compression session */ -size_t ZSTD_initCStream_usingCDict(ZSTD_CStream* zcs, const ZSTD_CDict* cdict) -{ - ZSTD_parameters const params = ZSTD_getParamsFromCDict(cdict); - size_t const initError = ZSTD_initCStream_internal(zcs, NULL, 0, params, 0); - zcs->cdict = cdict; - if (ZSTD_isError(initError)) return initError; - return ZSTD_resetCStream_internal(zcs, 0); -} - size_t ZSTD_initCStream_usingDict(ZSTD_CStream* zcs, const void* dict, size_t dictSize, int compressionLevel) { ZSTD_parameters const params = ZSTD_getParams(compressionLevel, 0, dictSize); @@ -3105,13 +3116,14 @@ size_t ZSTD_initCStream_usingDict(ZSTD_CStream* zcs, const void* dict, size_t di size_t ZSTD_initCStream_srcSize(ZSTD_CStream* zcs, int compressionLevel, unsigned long long pledgedSrcSize) { ZSTD_parameters params = ZSTD_getParams(compressionLevel, pledgedSrcSize, 0); - if (pledgedSrcSize) params.fParams.contentSizeFlag = 1; + params.fParams.contentSizeFlag = (pledgedSrcSize>0); return ZSTD_initCStream_internal(zcs, NULL, 0, params, pledgedSrcSize); } size_t ZSTD_initCStream(ZSTD_CStream* zcs, int compressionLevel) { - return ZSTD_initCStream_usingDict(zcs, NULL, 0, compressionLevel); + ZSTD_parameters const params = ZSTD_getParams(compressionLevel, 0, 0); + return ZSTD_initCStream_internal(zcs, NULL, 0, params, 0); } size_t ZSTD_sizeof_CStream(const ZSTD_CStream* zcs)