diff --git a/lib/compress/zstd_compress.c b/lib/compress/zstd_compress.c index ebc9c185..0a2d589d 100644 --- a/lib/compress/zstd_compress.c +++ b/lib/compress/zstd_compress.c @@ -2851,7 +2851,9 @@ size_t ZSTD_compress_generic (ZSTD_CCtx* cctx, if ((cctx->pledgedSrcSizePlusOne-1) <= ZSTDMT_JOBSIZE_MIN) params.nbThreads = 1; /* do not invoke multi-threading when src size is too small */ if (params.nbThreads > 1) { - if (cctx->mtctx == NULL || cctx->appliedParams.nbThreads != params.nbThreads) { + if (cctx->mtctx == NULL || (params.nbThreads != ZSTDMT_getNbThreads(cctx->mtctx))) { + DEBUGLOG(4, "ZSTD_compress_generic: creating new mtctx for nbThreads=%u (previous: %u)", + params.nbThreads, ZSTDMT_getNbThreads(cctx->mtctx)); ZSTDMT_freeCCtx(cctx->mtctx); cctx->mtctx = ZSTDMT_createCCtx_advanced(params.nbThreads, cctx->customMem); if (cctx->mtctx == NULL) return ERROR(memory_allocation); diff --git a/lib/compress/zstdmt_compress.c b/lib/compress/zstdmt_compress.c index a5e996d3..659992f3 100644 --- a/lib/compress/zstdmt_compress.c +++ b/lib/compress/zstdmt_compress.c @@ -459,6 +459,15 @@ size_t ZSTDMT_CCtxParam_setNbThreads(ZSTD_CCtx_params* params, unsigned nbThread return nbThreads; } +/* ZSTDMT_getNbThreads(): + * @return nb threads currently active in mtctx. + * mtctx must be valid */ +size_t ZSTDMT_getNbThreads(const ZSTDMT_CCtx* mtctx) +{ + assert(mtctx != NULL); + return mtctx->params.nbThreads; +} + ZSTDMT_CCtx* ZSTDMT_createCCtx_advanced(unsigned nbThreads, ZSTD_customMem cMem) { ZSTDMT_CCtx* mtctx; diff --git a/lib/compress/zstdmt_compress.h b/lib/compress/zstdmt_compress.h index 4209cf3c..d12f0adb 100644 --- a/lib/compress/zstdmt_compress.h +++ b/lib/compress/zstdmt_compress.h @@ -114,9 +114,14 @@ size_t ZSTDMT_CCtxParam_setMTCtxParameter(ZSTD_CCtx_params* params, ZSTDMT_param /* ZSTDMT_CCtxParam_setNbThreads() * Set nbThreads, and clamp it correctly, - * but also reset jobSize and overlapLog */ + * also reset jobSize and overlapLog */ size_t ZSTDMT_CCtxParam_setNbThreads(ZSTD_CCtx_params* params, unsigned nbThreads); +/* ZSTDMT_getNbThreads(): + * @return nb threads currently active in mtctx. + * mtctx must be valid */ +size_t ZSTDMT_getNbThreads(const ZSTDMT_CCtx* mtctx); + /*! ZSTDMT_initCStream_internal() : * Private use only. Init streaming operation. * expects params to be valid.