diff --git a/lib/common/zstd_internal.h b/lib/common/zstd_internal.h index f7359646..7ea2ecc4 100644 --- a/lib/common/zstd_internal.h +++ b/lib/common/zstd_internal.h @@ -232,6 +232,8 @@ typedef struct ZSTD_CCtx_params_s { ZSTD_dictMode_e dictMode; U32 dictContentByRef; + /* Multithreading */ + U32 nbThreads; } ZSTD_CCtx_params; diff --git a/lib/compress/zstd_compress.c b/lib/compress/zstd_compress.c index 7802df50..e47ae9a1 100644 --- a/lib/compress/zstd_compress.c +++ b/lib/compress/zstd_compress.c @@ -124,7 +124,7 @@ struct ZSTD_CCtx_s { size_t prefixSize; /* Multi-threading */ - U32 nbThreads; +// U32 nbThreads; ZSTDMT_CCtx* mtctx; }; @@ -434,25 +434,25 @@ size_t ZSTD_CCtx_setParameter(ZSTD_CCtx* cctx, ZSTD_cParameter param, unsigned v #ifndef ZSTD_MULTITHREAD if (value > 1) return ERROR(parameter_unsupported); #endif - if ((value>1) && (cctx->nbThreads != value)) { + if ((value>1) && (cctx->requestedParams.nbThreads != value)) { if (cctx->staticSize) /* MT not compatible with static alloc */ return ERROR(parameter_unsupported); ZSTDMT_freeCCtx(cctx->mtctx); - cctx->nbThreads = 1; + cctx->requestedParams.nbThreads = 1; cctx->mtctx = ZSTDMT_createCCtx_advanced(value, cctx->customMem); if (cctx->mtctx == NULL) return ERROR(memory_allocation); } - cctx->nbThreads = value; + cctx->requestedParams.nbThreads = value; return 0; case ZSTD_p_jobSize: - if (cctx->nbThreads <= 1) return ERROR(parameter_unsupported); + if (cctx->requestedParams.nbThreads <= 1) return ERROR(parameter_unsupported); assert(cctx->mtctx != NULL); return ZSTDMT_setMTCtxParameter(cctx->mtctx, ZSTDMT_p_sectionSize, value); case ZSTD_p_overlapSizeLog: - DEBUGLOG(5, " setting overlap with nbThreads == %u", cctx->nbThreads); - if (cctx->nbThreads <= 1) return ERROR(parameter_unsupported); + DEBUGLOG(5, " setting overlap with nbThreads == %u", cctx->requestedParams.nbThreads); + if (cctx->requestedParams.nbThreads <= 1) return ERROR(parameter_unsupported); assert(cctx->mtctx != NULL); return ZSTDMT_setMTCtxParameter(cctx->mtctx, ZSTDMT_p_overlapSectionLog, value); @@ -549,7 +549,12 @@ size_t ZSTD_CCtxParam_setParameter( return 0; case ZSTD_p_nbThreads : - // TODO + if (value == 0) { return 0; } +#ifndef ZSTD_MULTITHREAD + if (value > 1) return ERROR(parameter_unsupported); +#endif + // Do checks when applying parameters to cctx. + params->nbThreads = value; return 0; case ZSTD_p_jobSize : @@ -4224,8 +4229,8 @@ size_t ZSTD_compress_generic (ZSTD_CCtx* cctx, assert(prefix==NULL || cctx->cdict==NULL); /* only one can be set */ #ifdef ZSTD_MULTITHREAD - if (cctx->nbThreads > 1) { - DEBUGLOG(4, "call ZSTDMT_initCStream_internal as nbThreads=%u", cctx->nbThreads); + if (cctx->requestedParams.nbThreads > 1) { + DEBUGLOG(4, "call ZSTDMT_initCStream_internal as nbThreads=%u", cctx->requestedParams.nbThreads); CHECK_F( ZSTDMT_initCStream_internal(cctx->mtctx, prefix, prefixSize, cctx->cdict, params, cctx->pledgedSrcSizePlusOne-1) ); cctx->streamStage = zcss_load; } else @@ -4236,7 +4241,7 @@ size_t ZSTD_compress_generic (ZSTD_CCtx* cctx, /* compression stage */ #ifdef ZSTD_MULTITHREAD - if (cctx->nbThreads > 1) { + if (cctx->requestedParams.nbThreads > 1) { size_t const flushMin = ZSTDMT_compressStream_generic(cctx->mtctx, output, input, endOp); DEBUGLOG(5, "ZSTDMT_compressStream_generic : %u", (U32)flushMin); if ( ZSTD_isError(flushMin)