diff --git a/tests/decodecorpus.c b/tests/decodecorpus.c index 632acabb..357d831e 100644 --- a/tests/decodecorpus.c +++ b/tests/decodecorpus.c @@ -18,6 +18,7 @@ #include "zstd.h" #include "zstd_internal.h" #include "mem.h" +#include "zdict.h" // Direct access to internal compression functions is required #include "zstd_compress.c" @@ -316,7 +317,8 @@ static void writeFrameHeader(U32* seed, frame_t* frame, int genDict, size_t dict op[pos++] = windowByte; } if(genDict) { - MEM_writeLE32(op + pos, (U32) dictSize); + MEM_writeLE32(op + pos, (U32) dictID); + pos += 4; } if (contentSizeFlag) { switch (fcsCode) { @@ -608,7 +610,7 @@ static inline void initSeqStore(seqStore_t *seqStore) { /* Randomly generate sequence commands */ static U32 generateSequences(U32* seed, frame_t* frame, seqStore_t* seqStore, - size_t contentSize, size_t literalsSize, int genDict, size_t dictSize) + size_t contentSize, size_t literalsSize, int genDict, size_t dictSize, BYTE* dictContent) { /* The total length of all the matches */ size_t const remainingMatch = contentSize - literalsSize; @@ -686,11 +688,17 @@ static U32 generateSequences(U32* seed, frame_t* frame, seqStore_t* seqStore, repIndex = MIN(2, offsetCode + 1); } } - } while (offset > (size_t)((BYTE*)srcPtr - (BYTE*)frame->srcStart) || offset == 0); + } while (((!genDict) && (offset > (size_t)((BYTE*)srcPtr - (BYTE*)frame->srcStart))) || offset == 0); { size_t j; for (j = 0; j < matchLen; j++) { - *srcPtr = *(srcPtr-offset); + if(srcPtr-offset < frame->srcStart){ + /* copy from dictionary instead of literals */ + *srcPtr = *(dictContent + dictSize - (offset-(srcPtr-frame->srcStart))); + } + else{ + *srcPtr = *(srcPtr-offset); + } srcPtr++; } } @@ -940,7 +948,7 @@ static size_t writeSequences(U32* seed, frame_t* frame, seqStore_t* seqStorePtr, } static size_t writeSequencesBlock(U32* seed, frame_t* frame, size_t contentSize, - size_t literalsSize, int genDict, size_t dictSize) + size_t literalsSize, int genDict, size_t dictSize, BYTE* dictContent) { seqStore_t seqStore; size_t numSequences; @@ -949,14 +957,14 @@ static size_t writeSequencesBlock(U32* seed, frame_t* frame, size_t contentSize, initSeqStore(&seqStore); /* randomly generate sequences */ - numSequences = generateSequences(seed, frame, &seqStore, contentSize, literalsSize, genDict, dictSize); + numSequences = generateSequences(seed, frame, &seqStore, contentSize, literalsSize, genDict, dictSize, dictContent); /* write them out to the frame data */ CHECKERR(writeSequences(seed, frame, &seqStore, numSequences)); return numSequences; } -static size_t writeCompressedBlock(U32* seed, frame_t* frame, size_t contentSize, int genDict, size_t dictSize) +static size_t writeCompressedBlock(U32* seed, frame_t* frame, size_t contentSize, int genDict, size_t dictSize, BYTE* dictContent) { BYTE* const blockStart = (BYTE*)frame->data; size_t literalsSize; @@ -968,7 +976,7 @@ static size_t writeCompressedBlock(U32* seed, frame_t* frame, size_t contentSize DISPLAYLEVEL(4, " literals size: %u\n", (U32)literalsSize); - nbSeq = writeSequencesBlock(seed, frame, contentSize, literalsSize, genDict, dictSize); + nbSeq = writeSequencesBlock(seed, frame, contentSize, literalsSize, genDict, dictSize, dictContent); DISPLAYLEVEL(4, " number of sequences: %u\n", (U32)nbSeq); @@ -976,7 +984,7 @@ static size_t writeCompressedBlock(U32* seed, frame_t* frame, size_t contentSize } static void writeBlock(U32* seed, frame_t* frame, size_t contentSize, - int lastBlock, int genDict, size_t dictSize) + int lastBlock, int genDict, size_t dictSize, BYTE* dictContent) { int const blockTypeDesc = RAND(seed) % 8; size_t blockSize; @@ -1016,7 +1024,7 @@ static void writeBlock(U32* seed, frame_t* frame, size_t contentSize, frame->oldStats = frame->stats; frame->data = op; - compressedSize = writeCompressedBlock(seed, frame, contentSize, genDict, dictSize); + compressedSize = writeCompressedBlock(seed, frame, contentSize, genDict, dictSize, dictContent); if (compressedSize > contentSize) { blockType = 0; memcpy(op, frame->src, contentSize); @@ -1042,7 +1050,7 @@ static void writeBlock(U32* seed, frame_t* frame, size_t contentSize, frame->data = op; } -static void writeBlocks(U32* seed, frame_t* frame, int genDict, size_t dictSize) +static void writeBlocks(U32* seed, frame_t* frame, int genDict, size_t dictSize, BYTE* dictContent) { size_t contentLeft = frame->header.contentSize; size_t const maxBlockSize = MIN(MAX_BLOCK_SIZE, frame->header.windowSize); @@ -1065,7 +1073,7 @@ static void writeBlocks(U32* seed, frame_t* frame, int genDict, size_t dictSize) } } - writeBlock(seed, frame, blockContentSize, lastBlock, genDict, dictSize); + writeBlock(seed, frame, blockContentSize, lastBlock, genDict, dictSize, dictContent); contentLeft -= blockContentSize; if (lastBlock) break; @@ -1130,14 +1138,14 @@ static void initFrame(frame_t* fr) } /* Return the final seed */ -static U32 generateFrame(U32 seed, frame_t* fr, int genDict, size_t dictSize) +static U32 generateFrame(U32 seed, frame_t* fr, int genDict, size_t dictSize, BYTE* dictContent) { /* generate a complete frame */ DISPLAYLEVEL(1, "frame seed: %u\n", seed); initFrame(fr); writeFrameHeader(&seed, fr, genDict, dictSize); - writeBlocks(&seed, fr, genDict, dictSize); + writeBlocks(&seed, fr, genDict, dictSize, dictContent); writeChecksum(fr); return seed; @@ -1224,7 +1232,7 @@ static int runTestMode(U32 seed, unsigned numFiles, unsigned const testDurationS else DISPLAYUPDATE("\r%u ", fnum); - seed = generateFrame(seed, &fr, 0, 0); + seed = generateFrame(seed, &fr, 0, 0, NULL); { size_t const r = testDecodeSimple(&fr); if (ZSTD_isError(r)) { @@ -1259,7 +1267,7 @@ static int generateFile(U32 seed, const char* const path, DISPLAY("seed: %u\n", seed); - generateFrame(seed, &fr, 0, 0); + generateFrame(seed, &fr, 0, 0, NULL); outputBuffer(fr.dataStart, (BYTE*)fr.data - (BYTE*)fr.dataStart, path); if (origPath) { @@ -1281,7 +1289,7 @@ static int generateCorpus(U32 seed, unsigned numFiles, const char* const path, DISPLAYUPDATE("\r%u/%u ", fnum, numFiles); - seed = generateFrame(seed, &fr, 0, 0); + seed = generateFrame(seed, &fr, 0, 0, NULL); if (snprintf(outPath, MAX_PATH, "%s/z%06u.zst", path, fnum) + 1 > MAX_PATH) { DISPLAY("Error: path too long\n"); @@ -1308,9 +1316,11 @@ static int generateCorpusWithDict(U32 seed, unsigned numFiles, const char* const { const size_t minDictSize = 8; char outPath[MAX_PATH]; + BYTE* dictContent; + BYTE* fullDict; U32 dictID; - BYTE* dictStart; unsigned fnum; + BYTE* decompressedPtr; ZSTD_DCtx* dctx = ZSTD_createDCtx(); if(snprintf(outPath, MAX_PATH, "%s/dictionary", path) + 1 > MAX_PATH) { DISPLAY("Error: path too long\n"); @@ -1318,37 +1328,50 @@ static int generateCorpusWithDict(U32 seed, unsigned numFiles, const char* const } /* Generate the dictionary randomly first */ - if(dictSize < minDictSize){ - DISPLAY("Error: dictionary size (%zu) is too small\n", dictSize); - } - else{ - /* variable declaration */ - dictStart = malloc(dictSize); - size_t pos = 0; - dictID = RAND(&seed) + 1; + dictContent = malloc(dictSize-400); + dictID = RAND(&seed); + fullDict = malloc(dictSize); + RAND_buffer(&seed, dictContent, dictSize-40); + { + /* create random samples */ + unsigned numSamples = RAND(&seed); + unsigned i = 0; + size_t* sampleSizes = malloc(numSamples*sizeof(size_t)); + size_t* curr = sampleSizes; + size_t totalSize = 0; + while(i < numSamples){ + *curr = RAND(&seed) % (4 << 20); + totalSize += *curr; + curr++; + } + ZDICT_params_t zdictParams; + BYTE* samples = malloc(totalSize); + RAND_buffer(&seed, samples, totalSize); - /* write dictionary magic number */ - MEM_writeLE32(dictStart + pos, ZSTD_DICT_MAGIC); - pos += 4; + /* set dictionary params */ + memset(&zdictParams, 0, sizeof(zdictParams)); + zdictParams.notificationLevel = 1; + zdictParams.dictID = dictID; + zdictParams.compressionLevel = 5; - /* write random dictionary ID */ - MEM_writeLE32(dictStart + pos, dictID); - pos += 4; - - /* randomly generate the rest of the dictionary */ - RAND_buffer(&seed, dictStart + pos, dictSize-8); - outputBuffer(dictStart, dictSize, outPath); + /* finalize dictionary with random samples */ + ZDICT_finalizeDictionary(fullDict, dictSize, + dictContent, dictSize-400, + samples, sampleSizes, numSamples, + zdictParams); } + + decompressedPtr = malloc(MAX_DECOMPRESSED_SIZE); /* generate random compressed/decompressed files */ for (fnum = 0; fnum < numFiles; fnum++) { frame_t fr; size_t returnValue; - BYTE* decompressedPtr = malloc(MAX_DECOMPRESSED_SIZE); + DISPLAYUPDATE("\r%u/%u ", fnum, numFiles); - seed = generateFrame(seed, &fr, 1, dictSize); + seed = generateFrame(seed, &fr, 1, dictSize, dictContent); if (snprintf(outPath, MAX_PATH, "%s/z%06u.zst", path, fnum) + 1 > MAX_PATH) { DISPLAY("Error: path too long\n"); @@ -1368,13 +1391,10 @@ static int generateCorpusWithDict(U32 seed, unsigned numFiles, const char* const returnValue = ZSTD_decompress_usingDict(dctx, decompressedPtr, MAX_DECOMPRESSED_SIZE, fr.srcStart, (BYTE*)fr.src - (BYTE*)fr.srcStart, - dictStart,dictSize); + fullDict, dictSize); } - - /* write uncompressed versions of files */ - DISPLAY("This is origPath: %s\nAnd this is numFiles: %d\n", origPath, numFiles); return 0; }