diff --git a/lib/compress/huf_compress.c b/lib/compress/huf_compress.c index 67f0c928..30fa0dcb 100644 --- a/lib/compress/huf_compress.c +++ b/lib/compress/huf_compress.c @@ -211,31 +211,62 @@ typedef struct nodeElt_s { BYTE nbBits; } nodeElt; +/** + * HUF_setMaxHeight(): + * Enforces maxNbBits on the Huffman tree described in huffNode. + * + * It sets all nodes with nbBits > maxNbBits to be maxNbBits. Then it adjusts + * the tree to so that it is a valid canonical Huffman tree. + * + * @pre The sum of the ranks of each symbol == 2^largestBits, + * where largestBits == huffNode[lastNonNull].nbBits. + * @post The sum of the ranks of each symbol == 2^largestBits, + * where largestBits is the return value <= maxNbBits. + * + * @param huffNode The Huffman tree modified in place to enforce maxNbBits. + * @param lastNonNull The symbol with the lowest count in the Huffman tree. + * @param maxNbBits The maximum allowed number of bits, which the Huffman tree + * may not respect. After this function the Huffman tree will + * respect maxNbBits. + * @return The maximum number of bits of the Huffman tree after adjustment, + * necessarily no more than maxNbBits. + */ static U32 HUF_setMaxHeight(nodeElt* huffNode, U32 lastNonNull, U32 maxNbBits) { const U32 largestBits = huffNode[lastNonNull].nbBits; - if (largestBits <= maxNbBits) return largestBits; /* early exit : no elt > maxNbBits */ + /* early exit : no elt > maxNbBits, so the tree is already valid. */ + if (largestBits <= maxNbBits) return largestBits; /* there are several too large elements (at least >= 2) */ { int totalCost = 0; const U32 baseCost = 1 << (largestBits - maxNbBits); int n = (int)lastNonNull; + /* Adjust any ranks > maxNbBits to maxNbBits. + * Compute totalCost, which is how far the sum of the ranks is + * we are over 2^largestBits after adjust the offending ranks. + */ while (huffNode[n].nbBits > maxNbBits) { totalCost += baseCost - (1 << (largestBits - huffNode[n].nbBits)); huffNode[n].nbBits = (BYTE)maxNbBits; - n --; - } /* n stops at huffNode[n].nbBits <= maxNbBits */ - while (huffNode[n].nbBits == maxNbBits) n--; /* n end at index of smallest symbol using < maxNbBits */ + n--; + } + /* n stops at huffNode[n].nbBits <= maxNbBits */ + assert(huffNode[n].nbBits <= maxNbBits); + /* n end at index of smallest symbol using < maxNbBits */ + while (huffNode[n].nbBits == maxNbBits) --n; - /* renorm totalCost */ - totalCost >>= (largestBits - maxNbBits); /* note : totalCost is necessarily a multiple of baseCost */ + /* renorm totalCost from 2^largestBits to 2^maxNbBits + * note : totalCost is necessarily a multiple of baseCost */ + assert((totalCost & (baseCost - 1)) == 0); + totalCost >>= (largestBits - maxNbBits); + assert(totalCost > 0); /* repay normalized cost */ { U32 const noSymbol = 0xF0F0F0F0; U32 rankLast[HUF_TABLELOG_MAX+2]; - /* Get pos of last (smallest) symbol per rank */ + /* Get pos of last (smallest = lowest cum. count) symbol per rank */ ZSTD_memset(rankLast, 0xF0, sizeof(rankLast)); { U32 currentNbBits = maxNbBits; int pos; @@ -246,34 +277,65 @@ static U32 HUF_setMaxHeight(nodeElt* huffNode, U32 lastNonNull, U32 maxNbBits) } } while (totalCost > 0) { + /* Try to reduce the next power of 2 above totalCost because we + * gain back half the rank. + */ U32 nBitsToDecrease = BIT_highbit32((U32)totalCost) + 1; for ( ; nBitsToDecrease > 1; nBitsToDecrease--) { U32 const highPos = rankLast[nBitsToDecrease]; U32 const lowPos = rankLast[nBitsToDecrease-1]; if (highPos == noSymbol) continue; + /* Decrease highPos if no symbols of lowPos or if it is + * not cheaper to remove 2 lowPos than highPos. + */ if (lowPos == noSymbol) break; { U32 const highTotal = huffNode[highPos].count; U32 const lowTotal = 2 * huffNode[lowPos].count; if (highTotal <= lowTotal) break; } } /* only triggered when no more rank 1 symbol left => find closest one (note : there is necessarily at least one !) */ + assert(rankLast[nBitsToDecrease] != noSymbol || nBitsToDecrease == 1); /* HUF_MAX_TABLELOG test just to please gcc 5+; but it should not be necessary */ while ((nBitsToDecrease<=HUF_TABLELOG_MAX) && (rankLast[nBitsToDecrease] == noSymbol)) - nBitsToDecrease ++; + nBitsToDecrease++; + assert(rankLast[nBitsToDecrease] != noSymbol); + /* Increase the number of bits to gain back half the rank cost. */ totalCost -= 1 << (nBitsToDecrease-1); + huffNode[rankLast[nBitsToDecrease]].nbBits++; + + /* Fix up the new rank. + * If the new rank was empty, this symbol is now its smallest. + * Otherwise, this symbol will be the largest in the new rank so no adjustment. + */ if (rankLast[nBitsToDecrease-1] == noSymbol) - rankLast[nBitsToDecrease-1] = rankLast[nBitsToDecrease]; /* this rank is no longer empty */ - huffNode[rankLast[nBitsToDecrease]].nbBits ++; + rankLast[nBitsToDecrease-1] = rankLast[nBitsToDecrease]; + /* Fix up the old rank. + * If the symbol was at position 0, meaning it was the highest weight symbol in the tree, + * it must be the only symbol in its rank, so the old rank now has no symbols. + * Otherwise, since the Huffman nodes are sorted by count, the previous position is now + * the smallest node in the rank. If the previous position belongs to a different rank, + * then the rank is now empty. + */ if (rankLast[nBitsToDecrease] == 0) /* special case, reached largest symbol */ rankLast[nBitsToDecrease] = noSymbol; else { rankLast[nBitsToDecrease]--; if (huffNode[rankLast[nBitsToDecrease]].nbBits != maxNbBits-nBitsToDecrease) rankLast[nBitsToDecrease] = noSymbol; /* this rank is now empty */ - } } /* while (totalCost > 0) */ + } + } /* while (totalCost > 0) */ + /* If we've removed too much weight, then we have to add it back. + * To avoid overshooting again, we only adjust the smallest rank. + * We take the largest nodes from the lowest rank 0 and move them + * to rank 1. There's guaranteed to be enough rank 0 symbols because + * TODO. + */ while (totalCost < 0) { /* Sometimes, cost correction overshoot */ - if (rankLast[1] == noSymbol) { /* special case : no rank 1 symbol (using maxNbBits-1); let's create one from largest rank 0 (using maxNbBits) */ + /* special case : no rank 1 symbol (using maxNbBits-1); + * let's create one from largest rank 0 (using maxNbBits). + */ + if (rankLast[1] == noSymbol) { while (huffNode[n].nbBits == maxNbBits) n--; huffNode[n+1].nbBits--; assert(n >= 0); @@ -284,7 +346,9 @@ static U32 HUF_setMaxHeight(nodeElt* huffNode, U32 lastNonNull, U32 maxNbBits) huffNode[ rankLast[1] + 1 ].nbBits--; rankLast[1]++; totalCost ++; - } } } /* there are several too large elements (at least >= 2) */ + } + } /* repay normalized cost */ + } /* there are several too large elements (at least >= 2) */ return maxNbBits; } @@ -303,21 +367,45 @@ typedef struct { rankPos rankPosition[RANK_POSITION_TABLE_SIZE]; } HUF_buildCTable_wksp_tables; +/** + * HUF_sort(): + * Sorts the symbols [0, maxSymbolValue] by count[symbol] in decreasing order. + * + * @param[out] huffNode Sorted symbols by decreasing count. Only members `.count` and `.byte` are filled. + * Must have (maxSymbolValue + 1) entries. + * @param[in] count Histogram of the symbols. + * @param[in] maxSymbolValue Maximum symbol value. + * @param rankPosition This is a scratch workspace. Must have RANK_POSITION_TABLE_SIZE entries. + */ static void HUF_sort(nodeElt* huffNode, const unsigned* count, U32 maxSymbolValue, rankPos* rankPosition) { - U32 n; + int n; + int const maxSymbolValue1 = (int)maxSymbolValue + 1; + /* Compute base and set curr to base. + * For symbol s let lowerRank = BIT_highbit32(count[n]+1) and rank = lowerRank + 1. + * Then 2^lowerRank <= count[n]+1 <= 2^rank. + * We attribute each symbol to lowerRank's base value, because we want to know where + * each rank begins in the output, so for rank R we want to count ranks R+1 and above. + */ ZSTD_memset(rankPosition, 0, sizeof(*rankPosition) * RANK_POSITION_TABLE_SIZE); - for (n=0; n<=maxSymbolValue; n++) { - U32 r = BIT_highbit32(count[n] + 1); - rankPosition[r].base ++; + for (n = 0; n < maxSymbolValue1; ++n) { + U32 lowerRank = BIT_highbit32(count[n] + 1); + rankPosition[lowerRank].base++; } - for (n=30; n>0; n--) rankPosition[n-1].base += rankPosition[n].base; - for (n=0; n<32; n++) rankPosition[n].curr = rankPosition[n].base; - for (n=0; n<=maxSymbolValue; n++) { + assert(rankPosition[RANK_POSITION_TABLE_SIZE - 1].base == 0); + for (n = RANK_POSITION_TABLE_SIZE - 1; n > 0; --n) { + rankPosition[n-1].base += rankPosition[n].base; + rankPosition[n-1].curr = rankPosition[n-1].base; + } + /* Sort */ + for (n = 0; n < maxSymbolValue1; ++n) { U32 const c = count[n]; U32 const r = BIT_highbit32(c+1) + 1; U32 pos = rankPosition[r].curr++; + /* Insert into the correct position in the rank. + * We have at most 256 symbols, so this insertion should be fine. + */ while ((pos > rankPosition[r].base) && (c > huffNode[pos-1].count)) { huffNode[pos] = huffNode[pos-1]; pos--; @@ -334,28 +422,20 @@ static void HUF_sort(nodeElt* huffNode, const unsigned* count, U32 maxSymbolValu */ #define STARTNODE (HUF_SYMBOLVALUE_MAX+1) -size_t HUF_buildCTable_wksp (HUF_CElt* tree, const unsigned* count, U32 maxSymbolValue, U32 maxNbBits, void* workSpace, size_t wkspSize) +/* HUF_buildTree(): + * Takes the huffNode array sorted by HUF_sort() and builds an unlimited-depth Huffman tree. + * + * @param huffNode The array sorted by HUF_sort(). Builds the Huffman tree in this array. + * @param maxSymbolValue The maximum symbol value. + * @return The smallest node in the Huffman tree (by count). + */ +static int HUF_buildTree(nodeElt* huffNode, U32 maxSymbolValue) { - HUF_buildCTable_wksp_tables* const wksp_tables = (HUF_buildCTable_wksp_tables*)workSpace; - nodeElt* const huffNode0 = wksp_tables->huffNodeTbl; - nodeElt* const huffNode = huffNode0+1; + nodeElt* const huffNode0 = huffNode - 1; int nonNullRank; int lowS, lowN; int nodeNb = STARTNODE; int n, nodeRoot; - - /* safety checks */ - if (((size_t)workSpace & 3) != 0) return ERROR(GENERIC); /* must be aligned on 4-bytes boundaries */ - if (wkspSize < sizeof(HUF_buildCTable_wksp_tables)) - return ERROR(workSpace_tooSmall); - if (maxNbBits == 0) maxNbBits = HUF_TABLELOG_DEFAULT; - if (maxSymbolValue > HUF_SYMBOLVALUE_MAX) - return ERROR(maxSymbolValue_tooLarge); - ZSTD_memset(huffNode0, 0, sizeof(huffNodeTable)); - - /* sort, decreasing order */ - HUF_sort(huffNode, count, maxSymbolValue, wksp_tables->rankPosition); - /* init for parents */ nonNullRank = (int)maxSymbolValue; while(huffNode[nonNullRank].count == 0) nonNullRank--; @@ -382,28 +462,68 @@ size_t HUF_buildCTable_wksp (HUF_CElt* tree, const unsigned* count, U32 maxSymbo for (n=0; n<=nonNullRank; n++) huffNode[n].nbBits = huffNode[ huffNode[n].parent ].nbBits + 1; + return nonNullRank; +} + +/** + * HUF_buildCTableFromTree(): + * Build the CTable given the Huffman tree in huffNode. + * + * @param[out] CTable The output Huffman CTable. + * @param huffNode The Huffman tree. + * @param nonNullRank The last and smallest node in the Huffman tree. + * @param maxSymbolValue The maximum symbol value. + * @param maxNbBits The exact maximum number of bits used in the Huffman tree. + */ +static void HUF_buildCTableFromTree(HUF_CElt* CTable, nodeElt const* huffNode, int nonNullRank, U32 maxSymbolValue, U32 maxNbBits) +{ + /* fill result into ctable (val, nbBits) */ + int n; + U16 nbPerRank[HUF_TABLELOG_MAX+1] = {0}; + U16 valPerRank[HUF_TABLELOG_MAX+1] = {0}; + int const alphabetSize = (int)(maxSymbolValue + 1); + for (n=0; n<=nonNullRank; n++) + nbPerRank[huffNode[n].nbBits]++; + /* determine starting value per rank */ + { U16 min = 0; + for (n=(int)maxNbBits; n>0; n--) { + valPerRank[n] = min; /* get starting value within each rank */ + min += nbPerRank[n]; + min >>= 1; + } } + for (n=0; nhuffNodeTbl; + nodeElt* const huffNode = huffNode0+1; + int nonNullRank; + + /* safety checks */ + if (((size_t)workSpace & 3) != 0) return ERROR(GENERIC); /* must be aligned on 4-bytes boundaries */ + if (wkspSize < sizeof(HUF_buildCTable_wksp_tables)) + return ERROR(workSpace_tooSmall); + if (maxNbBits == 0) maxNbBits = HUF_TABLELOG_DEFAULT; + if (maxSymbolValue > HUF_SYMBOLVALUE_MAX) + return ERROR(maxSymbolValue_tooLarge); + ZSTD_memset(huffNode0, 0, sizeof(huffNodeTable)); + + /* sort, decreasing order */ + HUF_sort(huffNode, count, maxSymbolValue, wksp_tables->rankPosition); + + /* build tree */ + nonNullRank = HUF_buildTree(huffNode, maxSymbolValue); + /* enforce maxTableLog */ maxNbBits = HUF_setMaxHeight(huffNode, (U32)nonNullRank, maxNbBits); + if (maxNbBits > HUF_TABLELOG_MAX) return ERROR(GENERIC); /* check fit into table */ - /* fill result into tree (val, nbBits) */ - { U16 nbPerRank[HUF_TABLELOG_MAX+1] = {0}; - U16 valPerRank[HUF_TABLELOG_MAX+1] = {0}; - int const alphabetSize = (int)(maxSymbolValue + 1); - if (maxNbBits > HUF_TABLELOG_MAX) return ERROR(GENERIC); /* check fit into table */ - for (n=0; n<=nonNullRank; n++) - nbPerRank[huffNode[n].nbBits]++; - /* determine stating value per rank */ - { U16 min = 0; - for (n=(int)maxNbBits; n>0; n--) { - valPerRank[n] = min; /* get starting value within each rank */ - min += nbPerRank[n]; - min >>= 1; - } } - for (n=0; n