diff --git a/tests/decodecorpus.c b/tests/decodecorpus.c index 08c2cc0a..d2bf565c 100644 --- a/tests/decodecorpus.c +++ b/tests/decodecorpus.c @@ -1297,6 +1297,54 @@ cleanup: return ret; } +static size_t testDecodeWithDict(U32 seed, size_t dictSize) +{ + U32 const dictID = RAND(&seed); + int errorDetected = 0; + BYTE* const fullDict = malloc(dictSize); + if (fullDict == NULL) { + return ERROR(GENERIC); + } + + { + int ret = genRandomDict(dictID, seed, dictSize, fullDict); + if (ret != 0) { + errorDetected = ERROR(GENERIC); + goto dictTestCleanup; + } + } + + frame_t* fr; + { + size_t dictContentSize = dictSize-dictSize/4; + BYTE* const dictContent = fullDict+dictSize/4; + dictInfo const info = initDictInfo(1, dictContentSize, dictContent, dictID); + seed = generateFrame(seed, &fr, info); + } + + { + ZSTD_DCtx* const dctx = ZSTD_createDCtx(); + { + size_t const returnValue = ZSTD_decompress_usingDict(dctx, DECOMPRESSED_BUFFER, MAX_DECOMPRESSED_SIZE, + fr.dataStart, (BYTE*)fr.data - (BYTE*)fr.dataStart, + fullDict, dictSize); + if (ZSTD_isError(returnValue)) { + errorDetected = ZSTD_getErrorName(returnValue); + goto dictTestCleanup + } + } + + if (memcmp(DECOMPRESSED_BUFFER, fr->srcStart, (BYTE*)fr->src - (BYTE*)fr->srcStart) != 0) { + errorDetected = ERROR(corruption_detected); + goto dictTestCleanup; + } + } + +dictTestCleanup: + free(fullDict); + return errorDetected; +} + static int runTestMode(U32 seed, unsigned numFiles, unsigned const testDurationS) { unsigned fnum;