diff --git a/lib/compress/zstd_compress.c b/lib/compress/zstd_compress.c index 0c5acba6..bf0a1488 100644 --- a/lib/compress/zstd_compress.c +++ b/lib/compress/zstd_compress.c @@ -3803,9 +3803,9 @@ static void ZSTD_overflowCorrectIfNeeded(ZSTD_matchState_t* ms, void const* ip, void const* iend) { - if (ZSTD_window_needOverflowCorrection(ms->window, iend)) { - U32 const maxDist = (U32)1 << params->cParams.windowLog; - U32 const cycleLog = ZSTD_cycleLog(params->cParams.chainLog, params->cParams.strategy); + U32 const cycleLog = ZSTD_cycleLog(params->cParams.chainLog, params->cParams.strategy); + U32 const maxDist = (U32)1 << params->cParams.windowLog; + if (ZSTD_window_needOverflowCorrection(ms->window, cycleLog, maxDist, ms->loadedDictEnd, ip, iend)) { U32 const correction = ZSTD_window_correctOverflow(&ms->window, cycleLog, maxDist, ip); ZSTD_STATIC_ASSERT(ZSTD_CHAINLOG_MAX <= 30); ZSTD_STATIC_ASSERT(ZSTD_WINDOWLOG_MAX_32 <= 30); diff --git a/lib/compress/zstd_compress_internal.h b/lib/compress/zstd_compress_internal.h index 3c488c61..3e6c576d 100644 --- a/lib/compress/zstd_compress_internal.h +++ b/lib/compress/zstd_compress_internal.h @@ -895,15 +895,55 @@ MEM_STATIC ZSTD_dictMode_e ZSTD_matchState_dictMode(const ZSTD_matchState_t *ms) ZSTD_noDict; } +/* Defining this macro to non-zero tells zstd to run the overflow correction + * code much more frequently. This is very inefficient, and should only be + * used for tests and fuzzers. + */ +#ifndef ZSTD_WINDOW_OVERFLOW_CORRECT_FREQUENTLY +# ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION +# define ZSTD_WINDOW_OVERFLOW_CORRECT_FREQUENTLY 1 +# else +# define ZSTD_WINDOW_OVERFLOW_CORRECT_FREQUENTLY 0 +# endif +#endif + +/** + * ZSTD_window_canOverflowCorrect(): + * Returns non-zero if the indices are large enough for overflow correction + * to work correctly without impacting compression ratio. + */ +MEM_STATIC U32 ZSTD_window_canOverflowCorrect(ZSTD_window_t const window, + U32 cycleLog, + U32 maxDist, + U32 loadedDictEnd, + void const* src) +{ + U32 const cycleSize = 1u << cycleLog; + U32 const curr = (U32)((BYTE const*)src - window.base); + U32 const minIndexToOverflowCorrect = cycleSize + MAX(maxDist, cycleSize); + U32 const indexLargeEnough = curr > minIndexToOverflowCorrect; + U32 const dictionaryInvalidated = curr > maxDist + loadedDictEnd; + return indexLargeEnough && dictionaryInvalidated; +} + /** * ZSTD_window_needOverflowCorrection(): * Returns non-zero if the indices are getting too large and need overflow * protection. */ MEM_STATIC U32 ZSTD_window_needOverflowCorrection(ZSTD_window_t const window, + U32 cycleLog, + U32 maxDist, + U32 loadedDictEnd, + void const* src, void const* srcEnd) { U32 const curr = (U32)((BYTE const*)srcEnd - window.base); + if (ZSTD_WINDOW_OVERFLOW_CORRECT_FREQUENTLY) { + if (ZSTD_window_canOverflowCorrect(window, cycleLog, maxDist, loadedDictEnd, src)) { + return 1; + } + } return curr > ZSTD_CURRENT_MAX; } @@ -915,7 +955,6 @@ MEM_STATIC U32 ZSTD_window_needOverflowCorrection(ZSTD_window_t const window, * * The least significant cycleLog bits of the indices must remain the same, * which may be 0. Every index up to maxDist in the past must be valid. - * NOTE: (maxDist & cycleMask) must be zero. */ MEM_STATIC U32 ZSTD_window_correctOverflow(ZSTD_window_t* window, U32 cycleLog, U32 maxDist, void const* src) @@ -939,17 +978,25 @@ MEM_STATIC U32 ZSTD_window_correctOverflow(ZSTD_window_t* window, U32 cycleLog, * 3. (cctx->lowLimit + 1< 3<<29 + 1<base); U32 const currentCycle0 = curr & cycleMask; /* Exclude zero so that newCurrent - maxDist >= 1. */ - U32 const currentCycle1 = currentCycle0 == 0 ? (1U << cycleLog) : currentCycle0; - U32 const newCurrent = currentCycle1 + maxDist; + U32 const currentCycle1 = currentCycle0 == 0 ? cycleSize : currentCycle0; + U32 const newCurrent = currentCycle1 + MAX(maxDist, cycleSize); U32 const correction = curr - newCurrent; - assert((maxDist & cycleMask) == 0); + /* maxDist must be a power of two so that: + * (newCurrent & cycleMask) == (curr & cycleMask) + * This is required to not corrupt the chains / binary tree. + */ + assert((maxDist & (maxDist - 1)) == 0); + assert((curr & cycleMask) == (newCurrent & cycleMask)); assert(curr > newCurrent); - /* Loose bound, should be around 1<<29 (see above) */ - assert(correction > 1<<28); + if (!ZSTD_WINDOW_OVERFLOW_CORRECT_FREQUENTLY) { + /* Loose bound, should be around 1<<29 (see above) */ + assert(correction > 1<<28); + } window->base += correction; window->dictBase += correction; diff --git a/lib/compress/zstd_ldm.c b/lib/compress/zstd_ldm.c index 7dea97aa..60245ac2 100644 --- a/lib/compress/zstd_ldm.c +++ b/lib/compress/zstd_ldm.c @@ -500,7 +500,7 @@ size_t ZSTD_ldm_generateSequences( assert(chunkStart < iend); /* 1. Perform overflow correction if necessary. */ - if (ZSTD_window_needOverflowCorrection(ldmState->window, chunkEnd)) { + if (ZSTD_window_needOverflowCorrection(ldmState->window, 0, maxDist, ldmState->loadedDictEnd, chunkStart, chunkEnd)) { U32 const ldmHSize = 1U << params->hashLog; U32 const correction = ZSTD_window_correctOverflow( &ldmState->window, /* cycleLog */ 0, maxDist, chunkStart); diff --git a/tests/Makefile b/tests/Makefile index d80b4c4e..c4d1472d 100644 --- a/tests/Makefile +++ b/tests/Makefile @@ -28,7 +28,8 @@ DEBUGLEVEL ?= 1 export DEBUGLEVEL # transmit value to sub-makefiles DEBUGFLAGS = -g -DDEBUGLEVEL=$(DEBUGLEVEL) CPPFLAGS += -I$(ZSTDDIR) -I$(ZSTDDIR)/common -I$(ZSTDDIR)/compress \ - -I$(ZSTDDIR)/dictBuilder -I$(ZSTDDIR)/deprecated -I$(PRGDIR) + -I$(ZSTDDIR)/dictBuilder -I$(ZSTDDIR)/deprecated -I$(PRGDIR) \ + -DZSTD_WINDOW_OVERFLOW_CORRECT_FREQUENTLY=1 ifeq ($(OS),Windows_NT) # MinGW assumed CPPFLAGS += -D__USE_MINGW_ANSI_STDIO # compatibility with %zu formatting endif diff --git a/tests/zstreamtest.c b/tests/zstreamtest.c index 9cbc82a8..f2b16db2 100644 --- a/tests/zstreamtest.c +++ b/tests/zstreamtest.c @@ -672,7 +672,7 @@ static int basicUnitTests(U32 seed, double compressibility) cSize = ZSTD_compress(compressedBuffer, compressedBufferSize, CNBuffer, CNBufferSize, 1); CHECK_Z(cSize); { ZSTD_DCtx* dctx = ZSTD_createDCtx(); - size_t const dctxSize0 = ZSTD_sizeof_DCtx(dctx); + size_t const dctxSize0 = ZSTD_sizeof_DCtx(dctx); size_t dctxSize1; CHECK_Z(ZSTD_DCtx_setParameter(dctx, ZSTD_d_stableOutBuffer, 1)); @@ -735,7 +735,7 @@ static int basicUnitTests(U32 seed, double compressibility) CHECK(ZSTD_getErrorCode(r) != ZSTD_error_dstBuffer_wrong, "Must error but got %s", ZSTD_getErrorName(r)); } DISPLAYLEVEL(3, "OK \n"); - + DISPLAYLEVEL(3, "test%3i : ZSTD_decompressStream() buffered output : ", testNb++); ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only); CHECK_Z(ZSTD_DCtx_setParameter(dctx, ZSTD_d_stableOutBuffer, 0));