speed up small blocks

This commit is contained in:
Nick Terrell 2020-08-14 15:28:59 -07:00
parent a8006264cf
commit 6004c1117f
6 changed files with 143 additions and 59 deletions

View File

@ -50,6 +50,7 @@ size_t FSE_readNCount (short* normalizedCounter, unsigned* maxSVPtr, unsigned* t
U32 bitStream; U32 bitStream;
int bitCount; int bitCount;
unsigned charnum = 0; unsigned charnum = 0;
unsigned const maxSV1 = *maxSVPtr + 1;
int previous0 = 0; int previous0 = 0;
if (hbSize < 4) { if (hbSize < 4) {
@ -76,27 +77,39 @@ size_t FSE_readNCount (short* normalizedCounter, unsigned* maxSVPtr, unsigned* t
threshold = 1<<nbBits; threshold = 1<<nbBits;
nbBits++; nbBits++;
while ((remaining>1) & (charnum<=*maxSVPtr)) { for (;;) {
if (previous0) { if (previous0) {
unsigned n0 = charnum; // TODO: Generalize to FSE_countTrailingZeros() or something
while ((bitStream & 0xFFFF) == 0xFFFF) { int repeats = __builtin_ctz(~bitStream) >> 1;
n0 += 24; while (repeats >= 12) {
if (ip < iend-5) { charnum += 3 * 12;
ip += 2; if (ip < iend-6) {
ip += 3;
bitStream = MEM_readLE32(ip) >> bitCount; bitStream = MEM_readLE32(ip) >> bitCount;
} else { } else {
bitStream >>= 16; bitStream >>= 24;
bitCount += 16; bitCount += 24;
} }
while ((bitStream & 3) == 3) {
n0 += 3;
bitStream >>= 2;
bitCount += 2;
} }
n0 += bitStream & 3; repeats = __builtin_ctz(~bitStream) >> 1;
}
charnum += 3 * repeats;
bitStream >>= 2 * repeats;
bitCount += 2 * repeats;
assert(bitCount < 30 && (bitStream & 3) != 3);
charnum += bitStream & 3;
bitCount += 2; bitCount += 2;
if (n0 > *maxSVPtr) return ERROR(maxSymbolValue_tooSmall);
while (charnum < n0) normalizedCounter[charnum++] = 0; /* This is an error, but break and return an error
* at the end, because returning out of a loop makes
* it harder for the compiler to optimize.
*/
if (charnum >= maxSV1) break;
/* We don't need to set the normalized count to 0
* because we already memset the whole buffer to 0.
*/
if ((ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { if ((ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) {
assert((bitCount >> 3) <= 3); /* For first condition to work */ assert((bitCount >> 3) <= 3); /* For first condition to work */
ip += bitCount>>3; ip += bitCount>>3;
@ -104,8 +117,10 @@ size_t FSE_readNCount (short* normalizedCounter, unsigned* maxSVPtr, unsigned* t
bitStream = MEM_readLE32(ip) >> bitCount; bitStream = MEM_readLE32(ip) >> bitCount;
} else { } else {
bitStream >>= 2; bitStream >>= 2;
} } }
{ int const max = (2*threshold-1) - remaining; }
{
int const max = (2*threshold-1) - remaining;
int count; int count;
if ((bitStream & (threshold-1)) < (U32)max) { if ((bitStream & (threshold-1)) < (U32)max) {
@ -118,15 +133,31 @@ size_t FSE_readNCount (short* normalizedCounter, unsigned* maxSVPtr, unsigned* t
} }
count--; /* extra accuracy */ count--; /* extra accuracy */
remaining -= count < 0 ? -count : count; /* -1 means +1 */ /* When it matters (small blocks), this is a
* predictable branch, because we don't use -1.
*/
if (count >= 0) {
remaining -= count;
} else {
assert(count == -1);
remaining += count;
}
normalizedCounter[charnum++] = (short)count; normalizedCounter[charnum++] = (short)count;
previous0 = !count; previous0 = !count;
while (remaining < threshold) {
nbBits--;
threshold >>= 1;
}
if ((ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { assert(threshold > 1);
if (remaining < threshold) {
/* This branch can be folded into the
* threshold update condition because we
* know that threshold > 1.
*/
if (remaining <= 1) break;
nbBits = BIT_highbit32(remaining) + 1;
threshold = 1 << (nbBits - 1);
}
if (charnum >= maxSV1) break;
if (LIKELY((ip <= iend-7) || (ip + (bitCount>>3) <= iend-4))) {
ip += bitCount>>3; ip += bitCount>>3;
bitCount &= 7; bitCount &= 7;
} else { } else {
@ -134,8 +165,10 @@ size_t FSE_readNCount (short* normalizedCounter, unsigned* maxSVPtr, unsigned* t
ip = iend - 4; ip = iend - 4;
} }
bitStream = MEM_readLE32(ip) >> (bitCount & 31); bitStream = MEM_readLE32(ip) >> (bitCount & 31);
} } /* while ((remaining>1) & (charnum<=*maxSVPtr)) */ } }
if (remaining != 1) return ERROR(corruption_detected); if (remaining != 1) return ERROR(corruption_detected);
/* Only possible when there are too many zeros. */
if (charnum > maxSV1) return ERROR(maxSymbolValue_tooSmall);
if (bitCount > 32) return ERROR(corruption_detected); if (bitCount > 32) return ERROR(corruption_detected);
*maxSVPtr = charnum-1; *maxSVPtr = charnum-1;
@ -143,7 +176,6 @@ size_t FSE_readNCount (short* normalizedCounter, unsigned* maxSVPtr, unsigned* t
return ip-istart; return ip-istart;
} }
/*! HUF_readStats() : /*! HUF_readStats() :
Read compact Huffman tree, saved by HUF_writeCTable(). Read compact Huffman tree, saved by HUF_writeCTable().
`huffWeight` is destination buffer. `huffWeight` is destination buffer.

View File

@ -341,6 +341,8 @@ unsigned FSE_optimalTableLog(unsigned maxTableLog, size_t srcSize, unsigned maxS
return FSE_optimalTableLog_internal(maxTableLog, srcSize, maxSymbolValue, 2); return FSE_optimalTableLog_internal(maxTableLog, srcSize, maxSymbolValue, 2);
} }
// TODO: Emit -1 based on # of symbols
#define LOW_PROB 0
/* Secondary normalization method. /* Secondary normalization method.
To be used when primary method fails. */ To be used when primary method fails. */
@ -361,7 +363,7 @@ static size_t FSE_normalizeM2(short* norm, U32 tableLog, const unsigned* count,
norm[s]=0; norm[s]=0;
continue; continue;
} }
if (count[s] <= lowThreshold) { if (LOW_PROB && count[s] <= lowThreshold) {
norm[s] = -1; norm[s] = -1;
distributed++; distributed++;
total -= count[s]; total -= count[s];
@ -431,7 +433,6 @@ static size_t FSE_normalizeM2(short* norm, U32 tableLog, const unsigned* count,
return 0; return 0;
} }
size_t FSE_normalizeCount (short* normalizedCounter, unsigned tableLog, size_t FSE_normalizeCount (short* normalizedCounter, unsigned tableLog,
const unsigned* count, size_t total, const unsigned* count, size_t total,
unsigned maxSymbolValue) unsigned maxSymbolValue)
@ -455,7 +456,7 @@ size_t FSE_normalizeCount (short* normalizedCounter, unsigned tableLog,
for (s=0; s<=maxSymbolValue; s++) { for (s=0; s<=maxSymbolValue; s++) {
if (count[s] == total) return 0; /* rle special case */ if (count[s] == total) return 0; /* rle special case */
if (count[s] == 0) { normalizedCounter[s]=0; continue; } if (count[s] == 0) { normalizedCounter[s]=0; continue; }
if (count[s] <= lowThreshold) { if (LOW_PROB && count[s] <= lowThreshold) {
normalizedCounter[s] = -1; normalizedCounter[s] = -1;
stillToDistribute--; stillToDistribute--;
} else { } else {
@ -476,20 +477,6 @@ size_t FSE_normalizeCount (short* normalizedCounter, unsigned tableLog,
else normalizedCounter[largest] += (short)stillToDistribute; else normalizedCounter[largest] += (short)stillToDistribute;
} }
#if 0
{ /* Print Table (debug) */
U32 s;
U32 nTotal = 0;
for (s=0; s<=maxSymbolValue; s++)
RAWLOG(2, "%3i: %4i \n", s, normalizedCounter[s]);
for (s=0; s<=maxSymbolValue; s++)
nTotal += abs(normalizedCounter[s]);
if (nTotal != (1U<<tableLog))
RAWLOG(2, "Warning !!! Total == %u != %u !!!", nTotal, 1U<<tableLog);
getchar();
}
#endif
return tableLog; return tableLog;
} }

View File

@ -1091,7 +1091,8 @@ ZSTD_loadDEntropy(ZSTD_entropyDTables_t* entropy,
ZSTD_buildFSETable( entropy->OFTable, ZSTD_buildFSETable( entropy->OFTable,
offcodeNCount, offcodeMaxValue, offcodeNCount, offcodeMaxValue,
OF_base, OF_bits, OF_base, OF_bits,
offcodeLog); offcodeLog,
entropy->workspace, sizeof(entropy->workspace));
dictPtr += offcodeHeaderSize; dictPtr += offcodeHeaderSize;
} }
@ -1104,7 +1105,8 @@ ZSTD_loadDEntropy(ZSTD_entropyDTables_t* entropy,
ZSTD_buildFSETable( entropy->MLTable, ZSTD_buildFSETable( entropy->MLTable,
matchlengthNCount, matchlengthMaxValue, matchlengthNCount, matchlengthMaxValue,
ML_base, ML_bits, ML_base, ML_bits,
matchlengthLog); matchlengthLog,
entropy->workspace, sizeof(entropy->workspace));
dictPtr += matchlengthHeaderSize; dictPtr += matchlengthHeaderSize;
} }
@ -1117,7 +1119,8 @@ ZSTD_loadDEntropy(ZSTD_entropyDTables_t* entropy,
ZSTD_buildFSETable( entropy->LLTable, ZSTD_buildFSETable( entropy->LLTable,
litlengthNCount, litlengthMaxValue, litlengthNCount, litlengthMaxValue,
LL_base, LL_bits, LL_base, LL_bits,
litlengthLog); litlengthLog,
entropy->workspace, sizeof(entropy->workspace));
dictPtr += litlengthHeaderSize; dictPtr += litlengthHeaderSize;
} }

View File

@ -368,19 +368,18 @@ void
ZSTD_buildFSETable(ZSTD_seqSymbol* dt, ZSTD_buildFSETable(ZSTD_seqSymbol* dt,
const short* normalizedCounter, unsigned maxSymbolValue, const short* normalizedCounter, unsigned maxSymbolValue,
const U32* baseValue, const U32* nbAdditionalBits, const U32* baseValue, const U32* nbAdditionalBits,
unsigned tableLog) unsigned tableLog, U32* wksp, size_t wkspSize)
{ {
ZSTD_seqSymbol* const tableDecode = dt+1; ZSTD_seqSymbol* const tableDecode = dt+1;
U16 symbolNext[MaxSeq+1]; U16 symbolNext[MaxSeq+1];
U32 const maxSV1 = maxSymbolValue + 1; U32 const maxSV1 = maxSymbolValue + 1;
U32 const tableSize = 1 << tableLog; U32 const tableSize = 1 << tableLog;
U32 highThreshold = tableSize-1;
/* Sanity Checks */ /* Sanity Checks */
assert(maxSymbolValue <= MaxSeq); assert(maxSymbolValue <= MaxSeq);
assert(tableLog <= MaxFSELog); assert(tableLog <= MaxFSELog);
U32 highThreshold = tableSize - 1;
/* Init, lay down lowprob symbols */ /* Init, lay down lowprob symbols */
{ ZSTD_seqSymbol_header DTableH; { ZSTD_seqSymbol_header DTableH;
DTableH.tableLog = tableLog; DTableH.tableLog = tableLog;
@ -400,12 +399,68 @@ ZSTD_buildFSETable(ZSTD_seqSymbol* dt,
} }
/* Spread symbols */ /* Spread symbols */
{ U32 const tableMask = tableSize-1; assert(tableSize <= 512);
/* Specialized symbol spreading for the case when there are
* no low probability (-1 count) symbols. When compressing
* small blocks we avoid low probability symbols to hit this
* case, since header decoding speed matters more.
*/
if (highThreshold == tableSize - 1) {
size_t const tableMask = tableSize-1;
size_t const step = FSE_TABLESTEP(tableSize);
/* First lay down the symbols in order.
* We use a uint64_t to lay down 8 bytes at a time. This reduces branch
* misses since small blocks generally have small table logs, so nearly
* all symbols have counts <= 8. We ensure we have 8 bytes at the end of
* our buffer to handle the over-write.
*/
BYTE* spread = (BYTE*)wksp;
assert(wkspSize >= (1u << MaxFSELog) + sizeof(U64));
(void)wkspSize;
{
U64 const add = 0x0101010101010101ull;
size_t pos = 0;
U64 sv = 0;
U32 s;
for (s=0; s<maxSV1; ++s, sv += add) {
int i;
int const n = normalizedCounter[s];
MEM_write64(spread + pos, sv);
for (i = 8; i < n; i += 8) {
MEM_write64(spread + pos + i, sv);
}
pos += n;
}
}
/* Now we spread those positions across the table.
* The benefit of doing it in two stages is that we avoid the the
* variable size inner loop, which caused lots of branch misses.
* Now we can run through all the positions without any branch misses.
* We unroll the loop twice, since that is what emperically worked best.
*/
{
size_t position = 0;
size_t s;
size_t const unroll = 2;
assert(tableSize % unroll == 0); /* FSE_MIN_TABLELOG is 5 */
for (s = 0; s < (size_t)tableSize; s += unroll) {
size_t u;
for (u = 0; u < unroll; ++u) {
size_t const uPosition = (position + (u * step)) & tableMask;
tableDecode[uPosition].baseValue = spread[s + u];
}
position = (position + (unroll * step)) & tableMask;
}
assert(position == 0);
}
} else {
U32 const tableMask = tableSize-1;
U32 const step = FSE_TABLESTEP(tableSize); U32 const step = FSE_TABLESTEP(tableSize);
U32 s, position = 0; U32 s, position = 0;
for (s=0; s<maxSV1; s++) { for (s=0; s<maxSV1; s++) {
int i; int i;
for (i=0; i<normalizedCounter[s]; i++) { int const n = normalizedCounter[s];
for (i=0; i<n; i++) {
tableDecode[position].baseValue = s; tableDecode[position].baseValue = s;
position = (position + step) & tableMask; position = (position + step) & tableMask;
while (position > highThreshold) position = (position + step) & tableMask; /* lowprob area */ while (position > highThreshold) position = (position + step) & tableMask; /* lowprob area */
@ -414,7 +469,8 @@ ZSTD_buildFSETable(ZSTD_seqSymbol* dt,
} }
/* Build Decoding table */ /* Build Decoding table */
{ U32 u; {
U32 u;
for (u=0; u<tableSize; u++) { for (u=0; u<tableSize; u++) {
U32 const symbol = tableDecode[u].baseValue; U32 const symbol = tableDecode[u].baseValue;
U32 const nextState = symbolNext[symbol]++; U32 const nextState = symbolNext[symbol]++;
@ -423,7 +479,8 @@ ZSTD_buildFSETable(ZSTD_seqSymbol* dt,
assert(nbAdditionalBits[symbol] < 255); assert(nbAdditionalBits[symbol] < 255);
tableDecode[u].nbAdditionalBits = (BYTE)nbAdditionalBits[symbol]; tableDecode[u].nbAdditionalBits = (BYTE)nbAdditionalBits[symbol];
tableDecode[u].baseValue = baseValue[symbol]; tableDecode[u].baseValue = baseValue[symbol];
} } }
}
} }
@ -435,7 +492,7 @@ static size_t ZSTD_buildSeqTable(ZSTD_seqSymbol* DTableSpace, const ZSTD_seqSymb
const void* src, size_t srcSize, const void* src, size_t srcSize,
const U32* baseValue, const U32* nbAdditionalBits, const U32* baseValue, const U32* nbAdditionalBits,
const ZSTD_seqSymbol* defaultTable, U32 flagRepeatTable, const ZSTD_seqSymbol* defaultTable, U32 flagRepeatTable,
int ddictIsCold, int nbSeq) int ddictIsCold, int nbSeq, U32* wksp, size_t wkspSize)
{ {
switch(type) switch(type)
{ {
@ -467,7 +524,7 @@ static size_t ZSTD_buildSeqTable(ZSTD_seqSymbol* DTableSpace, const ZSTD_seqSymb
size_t const headerSize = FSE_readNCount(norm, &max, &tableLog, src, srcSize); size_t const headerSize = FSE_readNCount(norm, &max, &tableLog, src, srcSize);
RETURN_ERROR_IF(FSE_isError(headerSize), corruption_detected, ""); RETURN_ERROR_IF(FSE_isError(headerSize), corruption_detected, "");
RETURN_ERROR_IF(tableLog > maxLog, corruption_detected, ""); RETURN_ERROR_IF(tableLog > maxLog, corruption_detected, "");
ZSTD_buildFSETable(DTableSpace, norm, max, baseValue, nbAdditionalBits, tableLog); ZSTD_buildFSETable(DTableSpace, norm, max, baseValue, nbAdditionalBits, tableLog, wksp, wkspSize);
*DTablePtr = DTableSpace; *DTablePtr = DTableSpace;
return headerSize; return headerSize;
} }
@ -520,7 +577,8 @@ size_t ZSTD_decodeSeqHeaders(ZSTD_DCtx* dctx, int* nbSeqPtr,
ip, iend-ip, ip, iend-ip,
LL_base, LL_bits, LL_base, LL_bits,
LL_defaultDTable, dctx->fseEntropy, LL_defaultDTable, dctx->fseEntropy,
dctx->ddictIsCold, nbSeq); dctx->ddictIsCold, nbSeq,
dctx->workspace, sizeof(dctx->workspace));
RETURN_ERROR_IF(ZSTD_isError(llhSize), corruption_detected, "ZSTD_buildSeqTable failed"); RETURN_ERROR_IF(ZSTD_isError(llhSize), corruption_detected, "ZSTD_buildSeqTable failed");
ip += llhSize; ip += llhSize;
} }
@ -530,7 +588,8 @@ size_t ZSTD_decodeSeqHeaders(ZSTD_DCtx* dctx, int* nbSeqPtr,
ip, iend-ip, ip, iend-ip,
OF_base, OF_bits, OF_base, OF_bits,
OF_defaultDTable, dctx->fseEntropy, OF_defaultDTable, dctx->fseEntropy,
dctx->ddictIsCold, nbSeq); dctx->ddictIsCold, nbSeq,
dctx->workspace, sizeof(dctx->workspace));
RETURN_ERROR_IF(ZSTD_isError(ofhSize), corruption_detected, "ZSTD_buildSeqTable failed"); RETURN_ERROR_IF(ZSTD_isError(ofhSize), corruption_detected, "ZSTD_buildSeqTable failed");
ip += ofhSize; ip += ofhSize;
} }
@ -540,7 +599,8 @@ size_t ZSTD_decodeSeqHeaders(ZSTD_DCtx* dctx, int* nbSeqPtr,
ip, iend-ip, ip, iend-ip,
ML_base, ML_bits, ML_base, ML_bits,
ML_defaultDTable, dctx->fseEntropy, ML_defaultDTable, dctx->fseEntropy,
dctx->ddictIsCold, nbSeq); dctx->ddictIsCold, nbSeq,
dctx->workspace, sizeof(dctx->workspace));
RETURN_ERROR_IF(ZSTD_isError(mlhSize), corruption_detected, "ZSTD_buildSeqTable failed"); RETURN_ERROR_IF(ZSTD_isError(mlhSize), corruption_detected, "ZSTD_buildSeqTable failed");
ip += mlhSize; ip += mlhSize;
} }

View File

@ -53,7 +53,7 @@ size_t ZSTD_decompressBlock_internal(ZSTD_DCtx* dctx,
void ZSTD_buildFSETable(ZSTD_seqSymbol* dt, void ZSTD_buildFSETable(ZSTD_seqSymbol* dt,
const short* normalizedCounter, unsigned maxSymbolValue, const short* normalizedCounter, unsigned maxSymbolValue,
const U32* baseValue, const U32* nbAdditionalBits, const U32* baseValue, const U32* nbAdditionalBits,
unsigned tableLog); unsigned tableLog, U32* wksp, size_t wkspSize);
#endif /* ZSTD_DEC_BLOCK_H */ #endif /* ZSTD_DEC_BLOCK_H */

View File

@ -72,6 +72,7 @@ static const U32 ML_base[MaxML+1] = {
} ZSTD_seqSymbol; } ZSTD_seqSymbol;
#define SEQSYMBOL_TABLE_SIZE(log) (1 + (1 << (log))) #define SEQSYMBOL_TABLE_SIZE(log) (1 + (1 << (log)))
#define ZSTD_FSE_WKSP_SIZE_U32 130
typedef struct { typedef struct {
ZSTD_seqSymbol LLTable[SEQSYMBOL_TABLE_SIZE(LLFSELog)]; /* Note : Space reserved for FSE Tables */ ZSTD_seqSymbol LLTable[SEQSYMBOL_TABLE_SIZE(LLFSELog)]; /* Note : Space reserved for FSE Tables */
@ -79,6 +80,7 @@ typedef struct {
ZSTD_seqSymbol MLTable[SEQSYMBOL_TABLE_SIZE(MLFSELog)]; /* and therefore must be at least HUF_DECOMPRESS_WORKSPACE_SIZE large */ ZSTD_seqSymbol MLTable[SEQSYMBOL_TABLE_SIZE(MLFSELog)]; /* and therefore must be at least HUF_DECOMPRESS_WORKSPACE_SIZE large */
HUF_DTable hufTable[HUF_DTABLE_SIZE(HufLog)]; /* can accommodate HUF_decompress4X */ HUF_DTable hufTable[HUF_DTABLE_SIZE(HufLog)]; /* can accommodate HUF_decompress4X */
U32 rep[ZSTD_REP_NUM]; U32 rep[ZSTD_REP_NUM];
U32 workspace[ZSTD_FSE_WKSP_SIZE_U32];
} ZSTD_entropyDTables_t; } ZSTD_entropyDTables_t;
typedef enum { ZSTDds_getFrameHeaderSize, ZSTDds_decodeFrameHeader, typedef enum { ZSTDds_getFrameHeaderSize, ZSTDds_decodeFrameHeader,