diff --git a/lib/zstd_internal.h b/lib/zstd_internal.h index 31a99cbd..9c888e77 100644 --- a/lib/zstd_internal.h +++ b/lib/zstd_internal.h @@ -242,6 +242,9 @@ typedef struct { U32 log2litSum; U32 log2offCodeSum; U32 factor; + U32 cachedPrice; + U32 cachedLitLength; + const BYTE* cachedLiterals; ZSTD_stats_t stats; } seqStore_t; diff --git a/lib/zstd_opt.h b/lib/zstd_opt.h index 4f6e4c8e..efda97e5 100644 --- a/lib/zstd_opt.h +++ b/lib/zstd_opt.h @@ -54,6 +54,8 @@ MEM_STATIC void ZSTD_rescaleFreqs(seqStore_t* ssPtr) unsigned u; if (ssPtr->litLengthSum == 0) { + ssPtr->cachedLiterals = NULL; + ssPtr->cachedPrice = ssPtr->cachedLitLength = 0; ssPtr->litSum = (2<litLengthSum = MaxLL+1; ssPtr->matchLengthSum = MaxML+1; @@ -98,17 +100,41 @@ MEM_STATIC void ZSTD_rescaleFreqs(seqStore_t* ssPtr) } -FORCE_INLINE U32 ZSTD_getLiteralPrice(seqStore_t* seqStorePtr, U32 litLength, const BYTE* literals) +FORCE_INLINE U32 ZSTD_getLiteralPrice(seqStore_t* ssPtr, U32 litLength, const BYTE* literals) { U32 price, u; if (litLength == 0) - return seqStorePtr->log2litLengthSum - ZSTD_highbit(seqStorePtr->litLengthFreq[0]+1); + return ssPtr->log2litLengthSum - ZSTD_highbit(ssPtr->litLengthFreq[0]+1); /* literals */ - price = litLength * seqStorePtr->log2litSum; +#define ZSTD_CACHE_LITPRICES +#ifdef ZSTD_CACHE_LITPRICES + if (ssPtr->cachedLiterals == literals) { + // if (ssPtr->cachedLitLength > litLength) printf("ERROR: ssPtr->cachedLitLength > litLength\n"); + U32 additional = litLength - ssPtr->cachedLitLength; + const BYTE* literals2 = ssPtr->cachedLiterals + ssPtr->cachedLitLength; + price = ssPtr->cachedPrice + additional * ssPtr->log2litSum; + for (u=0; u < additional; u++) + price -= ZSTD_highbit(ssPtr->litFreq[literals2[u]]+1); + ssPtr->cachedPrice = price; + ssPtr->cachedLitLength = litLength; + } else { + price = litLength * ssPtr->log2litSum; + for (u=0; u < litLength; u++) + price -= ZSTD_highbit(ssPtr->litFreq[literals[u]]+1); + + if (litLength >= 12) { + ssPtr->cachedLiterals = literals; + ssPtr->cachedPrice = price; + ssPtr->cachedLitLength = litLength; + } + } +#else + price = litLength * ssPtr->log2litSum; for (u=0; u < litLength; u++) - price -= ZSTD_highbit(seqStorePtr->litFreq[literals[u]]+1); + price -= ZSTD_highbit(ssPtr->litFreq[literals[u]]+1); +#endif /* literal Length */ { static const BYTE LL_Code[64] = { 0, 1, 2, 3, 4, 5, 6, 7, @@ -121,7 +147,7 @@ FORCE_INLINE U32 ZSTD_getLiteralPrice(seqStore_t* seqStorePtr, U32 litLength, co 24, 24, 24, 24, 24, 24, 24, 24 }; const BYTE LL_deltaCode = 19; const BYTE llCode = (litLength>63) ? (BYTE)ZSTD_highbit(litLength) + LL_deltaCode : LL_Code[litLength]; - price += LL_bits[llCode] + seqStorePtr->log2litLengthSum - ZSTD_highbit(seqStorePtr->litLengthFreq[llCode]+1); + price += LL_bits[llCode] + ssPtr->log2litLengthSum - ZSTD_highbit(ssPtr->litLengthFreq[llCode]+1); } return price; @@ -465,7 +491,11 @@ void ZSTD_compressBlock_opt_generic(ZSTD_CCtx* ctx, memset(opt, 0, sizeof(ZSTD_optimal_t)); last_pos = 0; inr = ip; +#ifdef ZSTD_CACHE_LITPRICES + litstart = anchor; +#else litstart = ((U32)(ip - anchor) > 128) ? ip - 128 : anchor; +#endif opt[0].litlen = (U32)(ip - litstart); /* check repCode */