diff --git a/lib/common/entropy_common.c b/lib/common/entropy_common.c index 0d27265a..38e18b16 100644 --- a/lib/common/entropy_common.c +++ b/lib/common/entropy_common.c @@ -110,13 +110,14 @@ size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigne int repeats = FSE_ctz(~bitStream | 0x80000000) >> 1; while (repeats >= 12) { charnum += 3 * 12; - if (ip <= iend-7) { + if (LIKELY(ip <= iend-7)) { ip += 3; - bitStream = MEM_readLE32(ip) >> bitCount; } else { - bitStream >>= 24; - bitCount += 24; + bitCount -= (int)(8 * (iend - 7 - ip)); + bitCount &= 31; + ip = iend - 4; } + bitStream = MEM_readLE32(ip) >> bitCount; repeats = FSE_ctz(~bitStream | 0x80000000) >> 1; } charnum += 3 * repeats; @@ -124,6 +125,7 @@ size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigne bitCount += 2 * repeats; /* Add the final repeat which isn't 0b11. */ + assert((bitStream & 3) < 3); charnum += bitStream & 3; bitCount += 2; @@ -137,14 +139,16 @@ size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigne * because we already memset the whole buffer to 0. */ - if ((ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { + if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { assert((bitCount >> 3) <= 3); /* For first condition to work */ ip += bitCount>>3; bitCount &= 7; - bitStream = MEM_readLE32(ip) >> bitCount; } else { - bitStream >>= 2; + bitCount -= (int)(8 * (iend - 4 - ip)); + bitCount &= 31; + ip = iend - 4; } + bitStream = MEM_readLE32(ip) >> bitCount; } { int const max = (2*threshold-1) - remaining; @@ -184,14 +188,15 @@ size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigne } if (charnum >= maxSV1) break; - if (LIKELY((ip <= iend-7) || (ip + (bitCount>>3) <= iend-4))) { + if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { ip += bitCount>>3; bitCount &= 7; } else { bitCount -= (int)(8 * (iend - 4 - ip)); + bitCount &= 31; ip = iend - 4; } - bitStream = MEM_readLE32(ip) >> (bitCount & 31); + bitStream = MEM_readLE32(ip) >> bitCount; } } if (remaining != 1) return ERROR(corruption_detected); /* Only possible when there are too many zeros. */ diff --git a/tests/fuzz/Makefile b/tests/fuzz/Makefile index 42988c34..d88fae9c 100644 --- a/tests/fuzz/Makefile +++ b/tests/fuzz/Makefile @@ -96,7 +96,8 @@ FUZZ_TARGETS := \ dictionary_loader \ raw_dictionary_round_trip \ dictionary_stream_round_trip \ - decompress_dstSize_tooSmall + decompress_dstSize_tooSmall \ + fse_read_ncount all: $(FUZZ_TARGETS) @@ -184,6 +185,9 @@ dictionary_loader: $(FUZZ_HEADERS) $(FUZZ_ROUND_TRIP_OBJ) rt_fuzz_dictionary_loa decompress_dstSize_tooSmall: $(FUZZ_HEADERS) $(FUZZ_DECOMPRESS_OBJ) d_fuzz_decompress_dstSize_tooSmall.o $(CXX) $(FUZZ_TARGET_FLAGS) $(FUZZ_DECOMPRESS_OBJ) d_fuzz_decompress_dstSize_tooSmall.o $(LIB_FUZZING_ENGINE) -o $@ +fse_read_ncount: $(FUZZ_HEADERS) $(FUZZ_ROUND_TRIP_OBJ) rt_fuzz_fse_read_ncount.o + $(CXX) $(FUZZ_TARGET_FLAGS) $(FUZZ_ROUND_TRIP_OBJ) rt_fuzz_fse_read_ncount.o $(LIB_FUZZING_ENGINE) -o $@ + libregression.a: $(FUZZ_HEADERS) $(PRGDIR)/util.h $(PRGDIR)/util.c d_fuzz_regression_driver.o $(AR) $(FUZZ_ARFLAGS) $@ d_fuzz_regression_driver.o diff --git a/tests/fuzz/fse_read_ncount.c b/tests/fuzz/fse_read_ncount.c new file mode 100644 index 00000000..e20a9382 --- /dev/null +++ b/tests/fuzz/fse_read_ncount.c @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2016-2020, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the + * LICENSE file in the root directory of this source tree) and the GPLv2 (found + * in the COPYING file in the root directory of this source tree). + * You may select, at your option, one of the above-listed licenses. + */ + +/** + * This fuzz target round trips the FSE normalized count with FSE_writeNCount() + * and FSE_readNcount() to ensure that it can always round trip correctly. + */ + +#define FSE_STATIC_LINKING_ONLY +#define ZSTD_STATIC_LINKING_ONLY + +#include +#include +#include +#include +#include "fuzz_helpers.h" +#include "zstd_helpers.h" +#include "fuzz_data_producer.h" +#include "fse.h" + +int LLVMFuzzerTestOneInput(const uint8_t *src, size_t size) +{ + FUZZ_dataProducer_t *producer = FUZZ_dataProducer_create(src, size); + + /* Pick a random tableLog and maxSymbolValue */ + unsigned const tableLog = FUZZ_dataProducer_uint32Range(producer, FSE_MIN_TABLELOG, FSE_MAX_TABLELOG); + unsigned const maxSymbolValue = FUZZ_dataProducer_uint32Range(producer, 0, 255); + + unsigned remainingWeight = (1u << tableLog) - 1; + size_t dataSize; + BYTE data[512]; + short ncount[256]; + + /* Randomly fill the normalized count */ + memset(ncount, 0, sizeof(ncount)); + { + unsigned s; + for (s = 0; s < maxSymbolValue && remainingWeight > 0; ++s) { + short n = (short)FUZZ_dataProducer_int32Range(producer, -1, remainingWeight); + ncount[s] = n; + if (n < 0) { + remainingWeight -= 1; + } else { + assert((unsigned)n <= remainingWeight); + remainingWeight -= n; + } + } + /* Ensure ncount[maxSymbolValue] != 0 and the sum is (1<= FSE_NCountWriteBound(maxSymbolValue, tableLog)); + dataSize = FSE_writeNCount(data, sizeof(data), ncount, maxSymbolValue, tableLog); + FUZZ_ZASSERT(dataSize); + } + /* Read & validate the normalized count */ + { + short rtNcount[256]; + unsigned rtMaxSymbolValue = 255; + unsigned rtTableLog; + /* Copy into a buffer with a random amount of random data at the end */ + size_t const buffSize = (size_t)FUZZ_dataProducer_uint32Range(producer, dataSize, sizeof(data)); + BYTE* const buff = FUZZ_malloc(buffSize); + size_t rtDataSize; + memcpy(buff, data, dataSize); + { + size_t b; + for (b = dataSize; b < buffSize; ++b) { + buff[b] = (BYTE)FUZZ_dataProducer_uint32Range(producer, 0, 255); + } + } + + rtDataSize = FSE_readNCount(rtNcount, &rtMaxSymbolValue, &rtTableLog, buff, buffSize); + FUZZ_ZASSERT(rtDataSize); + FUZZ_ASSERT(rtDataSize == dataSize); + FUZZ_ASSERT(rtMaxSymbolValue == maxSymbolValue); + FUZZ_ASSERT(rtTableLog == tableLog); + { + unsigned s; + for (s = 0; s <= maxSymbolValue; ++s) { + FUZZ_ASSERT(ncount[s] == rtNcount[s]); + } + } + free(buff); + } + + FUZZ_dataProducer_free(producer); + return 0; +} diff --git a/tests/fuzz/fuzz.py b/tests/fuzz/fuzz.py index 6332eeb9..24430a22 100755 --- a/tests/fuzz/fuzz.py +++ b/tests/fuzz/fuzz.py @@ -60,6 +60,7 @@ TARGET_INFO = { 'raw_dictionary_round_trip': TargetInfo(InputType.RAW_DATA), 'dictionary_stream_round_trip': TargetInfo(InputType.RAW_DATA), 'decompress_dstSize_tooSmall': TargetInfo(InputType.RAW_DATA), + 'fse_read_ncount': TargetInfo(InputType.RAW_DATA), } TARGETS = list(TARGET_INFO.keys()) ALL_TARGETS = TARGETS + ['all']