From 6004c1117f1f79a2b44363f05571eb19ccc00817 Mon Sep 17 00:00:00 2001 From: Nick Terrell Date: Fri, 14 Aug 2020 15:28:59 -0700 Subject: [PATCH] speed up small blocks --- lib/common/entropy_common.c | 84 ++++++++++++++++------- lib/compress/fse_compress.c | 21 ++---- lib/decompress/zstd_decompress.c | 9 ++- lib/decompress/zstd_decompress_block.c | 84 +++++++++++++++++++---- lib/decompress/zstd_decompress_block.h | 2 +- lib/decompress/zstd_decompress_internal.h | 2 + 6 files changed, 143 insertions(+), 59 deletions(-) diff --git a/lib/common/entropy_common.c b/lib/common/entropy_common.c index 6b825afe..0a13d9d9 100644 --- a/lib/common/entropy_common.c +++ b/lib/common/entropy_common.c @@ -50,6 +50,7 @@ size_t FSE_readNCount (short* normalizedCounter, unsigned* maxSVPtr, unsigned* t U32 bitStream; int bitCount; unsigned charnum = 0; + unsigned const maxSV1 = *maxSVPtr + 1; int previous0 = 0; if (hbSize < 4) { @@ -76,27 +77,39 @@ size_t FSE_readNCount (short* normalizedCounter, unsigned* maxSVPtr, unsigned* t threshold = 1<1) & (charnum<=*maxSVPtr)) { + for (;;) { if (previous0) { - unsigned n0 = charnum; - while ((bitStream & 0xFFFF) == 0xFFFF) { - n0 += 24; - if (ip < iend-5) { - ip += 2; + // TODO: Generalize to FSE_countTrailingZeros() or something + int repeats = __builtin_ctz(~bitStream) >> 1; + while (repeats >= 12) { + charnum += 3 * 12; + if (ip < iend-6) { + ip += 3; bitStream = MEM_readLE32(ip) >> bitCount; } else { - bitStream >>= 16; - bitCount += 16; - } } - while ((bitStream & 3) == 3) { - n0 += 3; - bitStream >>= 2; - bitCount += 2; + bitStream >>= 24; + bitCount += 24; + } + repeats = __builtin_ctz(~bitStream) >> 1; } - n0 += bitStream & 3; + charnum += 3 * repeats; + bitStream >>= 2 * repeats; + bitCount += 2 * repeats; + + assert(bitCount < 30 && (bitStream & 3) != 3); + charnum += bitStream & 3; bitCount += 2; - if (n0 > *maxSVPtr) return ERROR(maxSymbolValue_tooSmall); - while (charnum < n0) normalizedCounter[charnum++] = 0; + + /* This is an error, but break and return an error + * at the end, because returning out of a loop makes + * it harder for the compiler to optimize. + */ + if (charnum >= maxSV1) break; + + /* We don't need to set the normalized count to 0 + * because we already memset the whole buffer to 0. + */ + if ((ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { assert((bitCount >> 3) <= 3); /* For first condition to work */ ip += bitCount>>3; @@ -104,8 +117,10 @@ size_t FSE_readNCount (short* normalizedCounter, unsigned* maxSVPtr, unsigned* t bitStream = MEM_readLE32(ip) >> bitCount; } else { bitStream >>= 2; - } } - { int const max = (2*threshold-1) - remaining; + } + } + { + int const max = (2*threshold-1) - remaining; int count; if ((bitStream & (threshold-1)) < (U32)max) { @@ -118,15 +133,31 @@ size_t FSE_readNCount (short* normalizedCounter, unsigned* maxSVPtr, unsigned* t } count--; /* extra accuracy */ - remaining -= count < 0 ? -count : count; /* -1 means +1 */ + /* When it matters (small blocks), this is a + * predictable branch, because we don't use -1. + */ + if (count >= 0) { + remaining -= count; + } else { + assert(count == -1); + remaining += count; + } normalizedCounter[charnum++] = (short)count; previous0 = !count; - while (remaining < threshold) { - nbBits--; - threshold >>= 1; - } - if ((ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { + assert(threshold > 1); + if (remaining < threshold) { + /* This branch can be folded into the + * threshold update condition because we + * know that threshold > 1. + */ + if (remaining <= 1) break; + nbBits = BIT_highbit32(remaining) + 1; + threshold = 1 << (nbBits - 1); + } + if (charnum >= maxSV1) break; + + if (LIKELY((ip <= iend-7) || (ip + (bitCount>>3) <= iend-4))) { ip += bitCount>>3; bitCount &= 7; } else { @@ -134,8 +165,10 @@ size_t FSE_readNCount (short* normalizedCounter, unsigned* maxSVPtr, unsigned* t ip = iend - 4; } bitStream = MEM_readLE32(ip) >> (bitCount & 31); - } } /* while ((remaining>1) & (charnum<=*maxSVPtr)) */ + } } if (remaining != 1) return ERROR(corruption_detected); + /* Only possible when there are too many zeros. */ + if (charnum > maxSV1) return ERROR(maxSymbolValue_tooSmall); if (bitCount > 32) return ERROR(corruption_detected); *maxSVPtr = charnum-1; @@ -143,7 +176,6 @@ size_t FSE_readNCount (short* normalizedCounter, unsigned* maxSVPtr, unsigned* t return ip-istart; } - /*! HUF_readStats() : Read compact Huffman tree, saved by HUF_writeCTable(). `huffWeight` is destination buffer. diff --git a/lib/compress/fse_compress.c b/lib/compress/fse_compress.c index 5290a918..1187e3e6 100644 --- a/lib/compress/fse_compress.c +++ b/lib/compress/fse_compress.c @@ -341,6 +341,8 @@ unsigned FSE_optimalTableLog(unsigned maxTableLog, size_t srcSize, unsigned maxS return FSE_optimalTableLog_internal(maxTableLog, srcSize, maxSymbolValue, 2); } +// TODO: Emit -1 based on # of symbols +#define LOW_PROB 0 /* Secondary normalization method. To be used when primary method fails. */ @@ -361,7 +363,7 @@ static size_t FSE_normalizeM2(short* norm, U32 tableLog, const unsigned* count, norm[s]=0; continue; } - if (count[s] <= lowThreshold) { + if (LOW_PROB && count[s] <= lowThreshold) { norm[s] = -1; distributed++; total -= count[s]; @@ -431,7 +433,6 @@ static size_t FSE_normalizeM2(short* norm, U32 tableLog, const unsigned* count, return 0; } - size_t FSE_normalizeCount (short* normalizedCounter, unsigned tableLog, const unsigned* count, size_t total, unsigned maxSymbolValue) @@ -455,7 +456,7 @@ size_t FSE_normalizeCount (short* normalizedCounter, unsigned tableLog, for (s=0; s<=maxSymbolValue; s++) { if (count[s] == total) return 0; /* rle special case */ if (count[s] == 0) { normalizedCounter[s]=0; continue; } - if (count[s] <= lowThreshold) { + if (LOW_PROB && count[s] <= lowThreshold) { normalizedCounter[s] = -1; stillToDistribute--; } else { @@ -476,20 +477,6 @@ size_t FSE_normalizeCount (short* normalizedCounter, unsigned tableLog, else normalizedCounter[largest] += (short)stillToDistribute; } -#if 0 - { /* Print Table (debug) */ - U32 s; - U32 nTotal = 0; - for (s=0; s<=maxSymbolValue; s++) - RAWLOG(2, "%3i: %4i \n", s, normalizedCounter[s]); - for (s=0; s<=maxSymbolValue; s++) - nTotal += abs(normalizedCounter[s]); - if (nTotal != (1U<OFTable, offcodeNCount, offcodeMaxValue, OF_base, OF_bits, - offcodeLog); + offcodeLog, + entropy->workspace, sizeof(entropy->workspace)); dictPtr += offcodeHeaderSize; } @@ -1104,7 +1105,8 @@ ZSTD_loadDEntropy(ZSTD_entropyDTables_t* entropy, ZSTD_buildFSETable( entropy->MLTable, matchlengthNCount, matchlengthMaxValue, ML_base, ML_bits, - matchlengthLog); + matchlengthLog, + entropy->workspace, sizeof(entropy->workspace)); dictPtr += matchlengthHeaderSize; } @@ -1117,7 +1119,8 @@ ZSTD_loadDEntropy(ZSTD_entropyDTables_t* entropy, ZSTD_buildFSETable( entropy->LLTable, litlengthNCount, litlengthMaxValue, LL_base, LL_bits, - litlengthLog); + litlengthLog, + entropy->workspace, sizeof(entropy->workspace)); dictPtr += litlengthHeaderSize; } diff --git a/lib/decompress/zstd_decompress_block.c b/lib/decompress/zstd_decompress_block.c index e93d6feb..95afcaa3 100644 --- a/lib/decompress/zstd_decompress_block.c +++ b/lib/decompress/zstd_decompress_block.c @@ -368,19 +368,18 @@ void ZSTD_buildFSETable(ZSTD_seqSymbol* dt, const short* normalizedCounter, unsigned maxSymbolValue, const U32* baseValue, const U32* nbAdditionalBits, - unsigned tableLog) + unsigned tableLog, U32* wksp, size_t wkspSize) { ZSTD_seqSymbol* const tableDecode = dt+1; U16 symbolNext[MaxSeq+1]; U32 const maxSV1 = maxSymbolValue + 1; U32 const tableSize = 1 << tableLog; - U32 highThreshold = tableSize-1; /* Sanity Checks */ assert(maxSymbolValue <= MaxSeq); assert(tableLog <= MaxFSELog); - + U32 highThreshold = tableSize - 1; /* Init, lay down lowprob symbols */ { ZSTD_seqSymbol_header DTableH; DTableH.tableLog = tableLog; @@ -400,12 +399,68 @@ ZSTD_buildFSETable(ZSTD_seqSymbol* dt, } /* Spread symbols */ - { U32 const tableMask = tableSize-1; + assert(tableSize <= 512); + /* Specialized symbol spreading for the case when there are + * no low probability (-1 count) symbols. When compressing + * small blocks we avoid low probability symbols to hit this + * case, since header decoding speed matters more. + */ + if (highThreshold == tableSize - 1) { + size_t const tableMask = tableSize-1; + size_t const step = FSE_TABLESTEP(tableSize); + /* First lay down the symbols in order. + * We use a uint64_t to lay down 8 bytes at a time. This reduces branch + * misses since small blocks generally have small table logs, so nearly + * all symbols have counts <= 8. We ensure we have 8 bytes at the end of + * our buffer to handle the over-write. + */ + BYTE* spread = (BYTE*)wksp; + assert(wkspSize >= (1u << MaxFSELog) + sizeof(U64)); + (void)wkspSize; + { + U64 const add = 0x0101010101010101ull; + size_t pos = 0; + U64 sv = 0; + U32 s; + for (s=0; s highThreshold) position = (position + step) & tableMask; /* lowprob area */ @@ -414,7 +469,8 @@ ZSTD_buildFSETable(ZSTD_seqSymbol* dt, } /* Build Decoding table */ - { U32 u; + { + U32 u; for (u=0; u maxLog, corruption_detected, ""); - ZSTD_buildFSETable(DTableSpace, norm, max, baseValue, nbAdditionalBits, tableLog); + ZSTD_buildFSETable(DTableSpace, norm, max, baseValue, nbAdditionalBits, tableLog, wksp, wkspSize); *DTablePtr = DTableSpace; return headerSize; } @@ -520,7 +577,8 @@ size_t ZSTD_decodeSeqHeaders(ZSTD_DCtx* dctx, int* nbSeqPtr, ip, iend-ip, LL_base, LL_bits, LL_defaultDTable, dctx->fseEntropy, - dctx->ddictIsCold, nbSeq); + dctx->ddictIsCold, nbSeq, + dctx->workspace, sizeof(dctx->workspace)); RETURN_ERROR_IF(ZSTD_isError(llhSize), corruption_detected, "ZSTD_buildSeqTable failed"); ip += llhSize; } @@ -530,7 +588,8 @@ size_t ZSTD_decodeSeqHeaders(ZSTD_DCtx* dctx, int* nbSeqPtr, ip, iend-ip, OF_base, OF_bits, OF_defaultDTable, dctx->fseEntropy, - dctx->ddictIsCold, nbSeq); + dctx->ddictIsCold, nbSeq, + dctx->workspace, sizeof(dctx->workspace)); RETURN_ERROR_IF(ZSTD_isError(ofhSize), corruption_detected, "ZSTD_buildSeqTable failed"); ip += ofhSize; } @@ -540,7 +599,8 @@ size_t ZSTD_decodeSeqHeaders(ZSTD_DCtx* dctx, int* nbSeqPtr, ip, iend-ip, ML_base, ML_bits, ML_defaultDTable, dctx->fseEntropy, - dctx->ddictIsCold, nbSeq); + dctx->ddictIsCold, nbSeq, + dctx->workspace, sizeof(dctx->workspace)); RETURN_ERROR_IF(ZSTD_isError(mlhSize), corruption_detected, "ZSTD_buildSeqTable failed"); ip += mlhSize; } diff --git a/lib/decompress/zstd_decompress_block.h b/lib/decompress/zstd_decompress_block.h index bf39b735..201d6a9f 100644 --- a/lib/decompress/zstd_decompress_block.h +++ b/lib/decompress/zstd_decompress_block.h @@ -53,7 +53,7 @@ size_t ZSTD_decompressBlock_internal(ZSTD_DCtx* dctx, void ZSTD_buildFSETable(ZSTD_seqSymbol* dt, const short* normalizedCounter, unsigned maxSymbolValue, const U32* baseValue, const U32* nbAdditionalBits, - unsigned tableLog); + unsigned tableLog, U32* wksp, size_t wkspSize); #endif /* ZSTD_DEC_BLOCK_H */ diff --git a/lib/decompress/zstd_decompress_internal.h b/lib/decompress/zstd_decompress_internal.h index 9ad96c55..1a5c7ee6 100644 --- a/lib/decompress/zstd_decompress_internal.h +++ b/lib/decompress/zstd_decompress_internal.h @@ -72,6 +72,7 @@ static const U32 ML_base[MaxML+1] = { } ZSTD_seqSymbol; #define SEQSYMBOL_TABLE_SIZE(log) (1 + (1 << (log))) + #define ZSTD_FSE_WKSP_SIZE_U32 130 typedef struct { ZSTD_seqSymbol LLTable[SEQSYMBOL_TABLE_SIZE(LLFSELog)]; /* Note : Space reserved for FSE Tables */ @@ -79,6 +80,7 @@ typedef struct { ZSTD_seqSymbol MLTable[SEQSYMBOL_TABLE_SIZE(MLFSELog)]; /* and therefore must be at least HUF_DECOMPRESS_WORKSPACE_SIZE large */ HUF_DTable hufTable[HUF_DTABLE_SIZE(HufLog)]; /* can accommodate HUF_decompress4X */ U32 rep[ZSTD_REP_NUM]; + U32 workspace[ZSTD_FSE_WKSP_SIZE_U32]; } ZSTD_entropyDTables_t; typedef enum { ZSTDds_getFrameHeaderSize, ZSTDds_decodeFrameHeader,