Merge pull request #2524 from terrelln/huf-stack-reduction

[huf] Reduce stack usage of HUF_readDTableX2 by ~972 bytes
This commit is contained in:
Nick Terrell 2021-03-22 12:37:54 -07:00 committed by GitHub
commit ebc2dfa821
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 67 additions and 56 deletions

View File

@ -352,7 +352,7 @@ size_t FSE_buildDTable_raw (FSE_DTable* dt, unsigned nbBits);
size_t FSE_buildDTable_rle (FSE_DTable* dt, unsigned char symbolValue); size_t FSE_buildDTable_rle (FSE_DTable* dt, unsigned char symbolValue);
/**< build a fake FSE_DTable, designed to always generate the same symbolValue */ /**< build a fake FSE_DTable, designed to always generate the same symbolValue */
#define FSE_DECOMPRESS_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) (FSE_DTABLE_SIZE_U32(maxTableLog) + FSE_BUILD_DTABLE_WKSP_SIZE_U32(maxTableLog, maxSymbolValue)) #define FSE_DECOMPRESS_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) (FSE_DTABLE_SIZE_U32(maxTableLog) + FSE_BUILD_DTABLE_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) + (FSE_MAX_SYMBOL_VALUE + 1) / 2 + 1)
#define FSE_DECOMPRESS_WKSP_SIZE(maxTableLog, maxSymbolValue) (FSE_DECOMPRESS_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) * sizeof(unsigned)) #define FSE_DECOMPRESS_WKSP_SIZE(maxTableLog, maxSymbolValue) (FSE_DECOMPRESS_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) * sizeof(unsigned))
size_t FSE_decompress_wksp(void* dst, size_t dstCapacity, const void* cSrc, size_t cSrcSize, unsigned maxLog, void* workSpace, size_t wkspSize); size_t FSE_decompress_wksp(void* dst, size_t dstCapacity, const void* cSrc, size_t cSrcSize, unsigned maxLog, void* workSpace, size_t wkspSize);
/**< same as FSE_decompress(), using an externally allocated `workSpace` produced with `FSE_DECOMPRESS_WKSP_SIZE_U32(maxLog, maxSymbolValue)` */ /**< same as FSE_decompress(), using an externally allocated `workSpace` produced with `FSE_DECOMPRESS_WKSP_SIZE_U32(maxLog, maxSymbolValue)` */

View File

@ -310,6 +310,12 @@ size_t FSE_decompress_wksp(void* dst, size_t dstCapacity, const void* cSrc, size
return FSE_decompress_wksp_bmi2(dst, dstCapacity, cSrc, cSrcSize, maxLog, workSpace, wkspSize, /* bmi2 */ 0); return FSE_decompress_wksp_bmi2(dst, dstCapacity, cSrc, cSrcSize, maxLog, workSpace, wkspSize, /* bmi2 */ 0);
} }
typedef struct {
short ncount[FSE_MAX_SYMBOL_VALUE + 1];
FSE_DTable dtable[1]; /* Dynamically sized */
} FSE_DecompressWksp;
FORCE_INLINE_TEMPLATE size_t FSE_decompress_wksp_body( FORCE_INLINE_TEMPLATE size_t FSE_decompress_wksp_body(
void* dst, size_t dstCapacity, void* dst, size_t dstCapacity,
const void* cSrc, size_t cSrcSize, const void* cSrc, size_t cSrcSize,
@ -318,33 +324,37 @@ FORCE_INLINE_TEMPLATE size_t FSE_decompress_wksp_body(
{ {
const BYTE* const istart = (const BYTE*)cSrc; const BYTE* const istart = (const BYTE*)cSrc;
const BYTE* ip = istart; const BYTE* ip = istart;
short counting[FSE_MAX_SYMBOL_VALUE+1];
unsigned tableLog; unsigned tableLog;
unsigned maxSymbolValue = FSE_MAX_SYMBOL_VALUE; unsigned maxSymbolValue = FSE_MAX_SYMBOL_VALUE;
FSE_DTable* const dtable = (FSE_DTable*)workSpace; FSE_DecompressWksp* const wksp = (FSE_DecompressWksp*)workSpace;
DEBUG_STATIC_ASSERT((FSE_MAX_SYMBOL_VALUE + 1) % 2 == 0);
if (wkspSize < sizeof(*wksp)) return ERROR(GENERIC);
/* normal FSE decoding mode */ /* normal FSE decoding mode */
size_t const NCountLength = FSE_readNCount_bmi2(counting, &maxSymbolValue, &tableLog, istart, cSrcSize, bmi2); {
if (FSE_isError(NCountLength)) return NCountLength; size_t const NCountLength = FSE_readNCount_bmi2(wksp->ncount, &maxSymbolValue, &tableLog, istart, cSrcSize, bmi2);
if (tableLog > maxLog) return ERROR(tableLog_tooLarge); if (FSE_isError(NCountLength)) return NCountLength;
assert(NCountLength <= cSrcSize); if (tableLog > maxLog) return ERROR(tableLog_tooLarge);
ip += NCountLength; assert(NCountLength <= cSrcSize);
cSrcSize -= NCountLength; ip += NCountLength;
cSrcSize -= NCountLength;
}
if (FSE_DECOMPRESS_WKSP_SIZE(tableLog, maxSymbolValue) > wkspSize) return ERROR(tableLog_tooLarge); if (FSE_DECOMPRESS_WKSP_SIZE(tableLog, maxSymbolValue) > wkspSize) return ERROR(tableLog_tooLarge);
workSpace = dtable + FSE_DTABLE_SIZE_U32(tableLog); workSpace = wksp->dtable + FSE_DTABLE_SIZE_U32(tableLog);
wkspSize -= FSE_DTABLE_SIZE(tableLog); wkspSize -= sizeof(*wksp) + FSE_DTABLE_SIZE(tableLog);
CHECK_F( FSE_buildDTable_internal(dtable, counting, maxSymbolValue, tableLog, workSpace, wkspSize) ); CHECK_F( FSE_buildDTable_internal(wksp->dtable, wksp->ncount, maxSymbolValue, tableLog, workSpace, wkspSize) );
{ {
const void* ptr = dtable; const void* ptr = wksp->dtable;
const FSE_DTableHeader* DTableH = (const FSE_DTableHeader*)ptr; const FSE_DTableHeader* DTableH = (const FSE_DTableHeader*)ptr;
const U32 fastMode = DTableH->fastMode; const U32 fastMode = DTableH->fastMode;
/* select fast mode (static) */ /* select fast mode (static) */
if (fastMode) return FSE_decompress_usingDTable_generic(dst, dstCapacity, ip, cSrcSize, dtable, 1); if (fastMode) return FSE_decompress_usingDTable_generic(dst, dstCapacity, ip, cSrcSize, wksp->dtable, 1);
return FSE_decompress_usingDTable_generic(dst, dstCapacity, ip, cSrcSize, dtable, 0); return FSE_decompress_usingDTable_generic(dst, dstCapacity, ip, cSrcSize, wksp->dtable, 0);
} }
} }

View File

@ -279,7 +279,7 @@ U32 HUF_selectDecoder (size_t dstSize, size_t cSrcSize);
* a required workspace size greater than that specified in the following * a required workspace size greater than that specified in the following
* macro. * macro.
*/ */
#define HUF_DECOMPRESS_WORKSPACE_SIZE (2 << 10) #define HUF_DECOMPRESS_WORKSPACE_SIZE ((2 << 10) + (1 << 9))
#define HUF_DECOMPRESS_WORKSPACE_SIZE_U32 (HUF_DECOMPRESS_WORKSPACE_SIZE / sizeof(U32)) #define HUF_DECOMPRESS_WORKSPACE_SIZE_U32 (HUF_DECOMPRESS_WORKSPACE_SIZE / sizeof(U32))
#ifndef HUF_FORCE_DECOMPRESS_X2 #ifndef HUF_FORCE_DECOMPRESS_X2

View File

@ -528,13 +528,15 @@ typedef rankValCol_t rankVal_t[HUF_TABLELOG_MAX];
static void HUF_fillDTableX2Level2(HUF_DEltX2* DTable, U32 sizeLog, const U32 consumed, static void HUF_fillDTableX2Level2(HUF_DEltX2* DTable, U32 sizeLog, const U32 consumed,
const U32* rankValOrigin, const int minWeight, const U32* rankValOrigin, const int minWeight,
const sortedSymbol_t* sortedSymbols, const U32 sortedListSize, const sortedSymbol_t* sortedSymbols, const U32 sortedListSize,
U32 nbBitsBaseline, U16 baseSeq) U32 nbBitsBaseline, U16 baseSeq, U32* wksp, size_t wkspSize)
{ {
HUF_DEltX2 DElt; HUF_DEltX2 DElt;
U32 rankVal[HUF_TABLELOG_MAX + 1]; U32* rankVal = wksp;
assert(wkspSize >= HUF_TABLELOG_MAX + 1);
(void)wkspSize;
/* get pre-calculated rankVal */ /* get pre-calculated rankVal */
ZSTD_memcpy(rankVal, rankValOrigin, sizeof(rankVal)); ZSTD_memcpy(rankVal, rankValOrigin, sizeof(U32) * (HUF_TABLELOG_MAX + 1));
/* fill skipped values */ /* fill skipped values */
if (minWeight>1) { if (minWeight>1) {
@ -569,14 +571,18 @@ static void HUF_fillDTableX2Level2(HUF_DEltX2* DTable, U32 sizeLog, const U32 co
static void HUF_fillDTableX2(HUF_DEltX2* DTable, const U32 targetLog, static void HUF_fillDTableX2(HUF_DEltX2* DTable, const U32 targetLog,
const sortedSymbol_t* sortedList, const U32 sortedListSize, const sortedSymbol_t* sortedList, const U32 sortedListSize,
const U32* rankStart, rankVal_t rankValOrigin, const U32 maxWeight, const U32* rankStart, rankVal_t rankValOrigin, const U32 maxWeight,
const U32 nbBitsBaseline) const U32 nbBitsBaseline, U32* wksp, size_t wkspSize)
{ {
U32 rankVal[HUF_TABLELOG_MAX + 1]; U32* rankVal = wksp;
const int scaleLog = nbBitsBaseline - targetLog; /* note : targetLog >= srcLog, hence scaleLog <= 1 */ const int scaleLog = nbBitsBaseline - targetLog; /* note : targetLog >= srcLog, hence scaleLog <= 1 */
const U32 minBits = nbBitsBaseline - maxWeight; const U32 minBits = nbBitsBaseline - maxWeight;
U32 s; U32 s;
ZSTD_memcpy(rankVal, rankValOrigin, sizeof(rankVal)); assert(wkspSize >= HUF_TABLELOG_MAX + 1);
wksp += HUF_TABLELOG_MAX + 1;
wkspSize -= HUF_TABLELOG_MAX + 1;
ZSTD_memcpy(rankVal, rankValOrigin, sizeof(U32) * (HUF_TABLELOG_MAX + 1));
/* fill DTable */ /* fill DTable */
for (s=0; s<sortedListSize; s++) { for (s=0; s<sortedListSize; s++) {
@ -594,7 +600,7 @@ static void HUF_fillDTableX2(HUF_DEltX2* DTable, const U32 targetLog,
HUF_fillDTableX2Level2(DTable+start, targetLog-nbBits, nbBits, HUF_fillDTableX2Level2(DTable+start, targetLog-nbBits, nbBits,
rankValOrigin[nbBits], minWeight, rankValOrigin[nbBits], minWeight,
sortedList+sortedRank, sortedListSize-sortedRank, sortedList+sortedRank, sortedListSize-sortedRank,
nbBitsBaseline, symbol); nbBitsBaseline, symbol, wksp, wkspSize);
} else { } else {
HUF_DEltX2 DElt; HUF_DEltX2 DElt;
MEM_writeLE16(&(DElt.sequence), symbol); MEM_writeLE16(&(DElt.sequence), symbol);
@ -608,6 +614,15 @@ static void HUF_fillDTableX2(HUF_DEltX2* DTable, const U32 targetLog,
} }
} }
typedef struct {
rankValCol_t rankVal[HUF_TABLELOG_MAX];
U32 rankStats[HUF_TABLELOG_MAX + 1];
U32 rankStart0[HUF_TABLELOG_MAX + 2];
sortedSymbol_t sortedSymbol[HUF_SYMBOLVALUE_MAX + 1];
BYTE weightList[HUF_SYMBOLVALUE_MAX + 1];
U32 calleeWksp[HUF_READ_STATS_WORKSPACE_SIZE_U32];
} HUF_ReadDTableX2_Workspace;
size_t HUF_readDTableX2_wksp(HUF_DTable* DTable, size_t HUF_readDTableX2_wksp(HUF_DTable* DTable,
const void* src, size_t srcSize, const void* src, size_t srcSize,
void* workSpace, size_t wkspSize) void* workSpace, size_t wkspSize)
@ -620,47 +635,32 @@ size_t HUF_readDTableX2_wksp(HUF_DTable* DTable,
HUF_DEltX2* const dt = (HUF_DEltX2*)dtPtr; HUF_DEltX2* const dt = (HUF_DEltX2*)dtPtr;
U32 *rankStart; U32 *rankStart;
rankValCol_t* rankVal; HUF_ReadDTableX2_Workspace* const wksp = (HUF_ReadDTableX2_Workspace*)workSpace;
U32* rankStats;
U32* rankStart0;
sortedSymbol_t* sortedSymbol;
BYTE* weightList;
size_t spaceUsed32 = 0;
rankVal = (rankValCol_t *)((U32 *)workSpace + spaceUsed32); if (sizeof(*wksp) > wkspSize) return ERROR(GENERIC);
spaceUsed32 += (sizeof(rankValCol_t) * HUF_TABLELOG_MAX) >> 2;
rankStats = (U32 *)workSpace + spaceUsed32;
spaceUsed32 += HUF_TABLELOG_MAX + 1;
rankStart0 = (U32 *)workSpace + spaceUsed32;
spaceUsed32 += HUF_TABLELOG_MAX + 2;
sortedSymbol = (sortedSymbol_t *)workSpace + (spaceUsed32 * sizeof(U32)) / sizeof(sortedSymbol_t);
spaceUsed32 += HUF_ALIGN(sizeof(sortedSymbol_t) * (HUF_SYMBOLVALUE_MAX + 1), sizeof(U32)) >> 2;
weightList = (BYTE *)((U32 *)workSpace + spaceUsed32);
spaceUsed32 += HUF_ALIGN(HUF_SYMBOLVALUE_MAX + 1, sizeof(U32)) >> 2;
if ((spaceUsed32 << 2) > wkspSize) return ERROR(tableLog_tooLarge); rankStart = wksp->rankStart0 + 1;
ZSTD_memset(wksp->rankStats, 0, sizeof(wksp->rankStats));
rankStart = rankStart0 + 1; ZSTD_memset(wksp->rankStart0, 0, sizeof(wksp->rankStart0));
ZSTD_memset(rankStats, 0, sizeof(U32) * (2 * HUF_TABLELOG_MAX + 2 + 1));
DEBUG_STATIC_ASSERT(sizeof(HUF_DEltX2) == sizeof(HUF_DTable)); /* if compiler fails here, assertion is wrong */ DEBUG_STATIC_ASSERT(sizeof(HUF_DEltX2) == sizeof(HUF_DTable)); /* if compiler fails here, assertion is wrong */
if (maxTableLog > HUF_TABLELOG_MAX) return ERROR(tableLog_tooLarge); if (maxTableLog > HUF_TABLELOG_MAX) return ERROR(tableLog_tooLarge);
/* ZSTD_memset(weightList, 0, sizeof(weightList)); */ /* is not necessary, even though some analyzer complain ... */ /* ZSTD_memset(weightList, 0, sizeof(weightList)); */ /* is not necessary, even though some analyzer complain ... */
iSize = HUF_readStats(weightList, HUF_SYMBOLVALUE_MAX + 1, rankStats, &nbSymbols, &tableLog, src, srcSize); iSize = HUF_readStats_wksp(wksp->weightList, HUF_SYMBOLVALUE_MAX + 1, wksp->rankStats, &nbSymbols, &tableLog, src, srcSize, wksp->calleeWksp, sizeof(wksp->calleeWksp), /* bmi2 */ 0);
if (HUF_isError(iSize)) return iSize; if (HUF_isError(iSize)) return iSize;
/* check result */ /* check result */
if (tableLog > maxTableLog) return ERROR(tableLog_tooLarge); /* DTable can't fit code depth */ if (tableLog > maxTableLog) return ERROR(tableLog_tooLarge); /* DTable can't fit code depth */
/* find maxWeight */ /* find maxWeight */
for (maxW = tableLog; rankStats[maxW]==0; maxW--) {} /* necessarily finds a solution before 0 */ for (maxW = tableLog; wksp->rankStats[maxW]==0; maxW--) {} /* necessarily finds a solution before 0 */
/* Get start index of each weight */ /* Get start index of each weight */
{ U32 w, nextRankStart = 0; { U32 w, nextRankStart = 0;
for (w=1; w<maxW+1; w++) { for (w=1; w<maxW+1; w++) {
U32 curr = nextRankStart; U32 curr = nextRankStart;
nextRankStart += rankStats[w]; nextRankStart += wksp->rankStats[w];
rankStart[w] = curr; rankStart[w] = curr;
} }
rankStart[0] = nextRankStart; /* put all 0w symbols at the end of sorted list*/ rankStart[0] = nextRankStart; /* put all 0w symbols at the end of sorted list*/
@ -670,37 +670,38 @@ size_t HUF_readDTableX2_wksp(HUF_DTable* DTable,
/* sort symbols by weight */ /* sort symbols by weight */
{ U32 s; { U32 s;
for (s=0; s<nbSymbols; s++) { for (s=0; s<nbSymbols; s++) {
U32 const w = weightList[s]; U32 const w = wksp->weightList[s];
U32 const r = rankStart[w]++; U32 const r = rankStart[w]++;
sortedSymbol[r].symbol = (BYTE)s; wksp->sortedSymbol[r].symbol = (BYTE)s;
sortedSymbol[r].weight = (BYTE)w; wksp->sortedSymbol[r].weight = (BYTE)w;
} }
rankStart[0] = 0; /* forget 0w symbols; this is beginning of weight(1) */ rankStart[0] = 0; /* forget 0w symbols; this is beginning of weight(1) */
} }
/* Build rankVal */ /* Build rankVal */
{ U32* const rankVal0 = rankVal[0]; { U32* const rankVal0 = wksp->rankVal[0];
{ int const rescale = (maxTableLog-tableLog) - 1; /* tableLog <= maxTableLog */ { int const rescale = (maxTableLog-tableLog) - 1; /* tableLog <= maxTableLog */
U32 nextRankVal = 0; U32 nextRankVal = 0;
U32 w; U32 w;
for (w=1; w<maxW+1; w++) { for (w=1; w<maxW+1; w++) {
U32 curr = nextRankVal; U32 curr = nextRankVal;
nextRankVal += rankStats[w] << (w+rescale); nextRankVal += wksp->rankStats[w] << (w+rescale);
rankVal0[w] = curr; rankVal0[w] = curr;
} } } }
{ U32 const minBits = tableLog+1 - maxW; { U32 const minBits = tableLog+1 - maxW;
U32 consumed; U32 consumed;
for (consumed = minBits; consumed < maxTableLog - minBits + 1; consumed++) { for (consumed = minBits; consumed < maxTableLog - minBits + 1; consumed++) {
U32* const rankValPtr = rankVal[consumed]; U32* const rankValPtr = wksp->rankVal[consumed];
U32 w; U32 w;
for (w = 1; w < maxW+1; w++) { for (w = 1; w < maxW+1; w++) {
rankValPtr[w] = rankVal0[w] >> consumed; rankValPtr[w] = rankVal0[w] >> consumed;
} } } } } } } }
HUF_fillDTableX2(dt, maxTableLog, HUF_fillDTableX2(dt, maxTableLog,
sortedSymbol, sizeOfSort, wksp->sortedSymbol, sizeOfSort,
rankStart0, rankVal, maxW, wksp->rankStart0, wksp->rankVal, maxW,
tableLog+1); tableLog+1,
wksp->calleeWksp, sizeof(wksp->calleeWksp) / sizeof(U32));
dtd.tableLog = (BYTE)maxTableLog; dtd.tableLog = (BYTE)maxTableLog;
dtd.tableType = 1; dtd.tableType = 1;
@ -1225,7 +1226,7 @@ size_t HUF_decompress1X1 (void* dst, size_t dstSize, const void* cSrc, size_t cS
HUF_CREATE_STATIC_DTABLEX1(DTable, HUF_TABLELOG_MAX); HUF_CREATE_STATIC_DTABLEX1(DTable, HUF_TABLELOG_MAX);
return HUF_decompress1X1_DCtx (DTable, dst, dstSize, cSrc, cSrcSize); return HUF_decompress1X1_DCtx (DTable, dst, dstSize, cSrc, cSrcSize);
} }
#endif #endif
#ifndef HUF_FORCE_DECOMPRESS_X1 #ifndef HUF_FORCE_DECOMPRESS_X1
size_t HUF_readDTableX2(HUF_DTable* DTable, const void* src, size_t srcSize) size_t HUF_readDTableX2(HUF_DTable* DTable, const void* src, size_t srcSize)