[rsyncable] Ensure ZSTD_compressBound() is respected

In degenerate cases `--rsyncable` could create very small blocks (1
byte). This causes the compressed output to be larger than
`ZSTD_compressBound()`. Fix the issue by ensuring that rsyncable mode
never outputs blocks smaller than 128 KB.

The minimum job size is 512 KB, so we shouldn't lose many
synchronization points from skipping any that cause blocks smaller than
128 KB. And even if we do, that is fine, because we'll find the next
one.

This fixes the `raw_dictionary_round_trip` oss-fuzz assert.

Credit to OSS-Fuzz
dev
Nick Terrell 2021-09-13 16:59:20 -07:00
parent c10067c44e
commit a418b4e478
1 changed files with 30 additions and 2 deletions

View File

@ -807,6 +807,15 @@ typedef struct {
static const roundBuff_t kNullRoundBuff = {NULL, 0, 0};
#define RSYNC_LENGTH 32
/* Don't create chunks smaller than the zstd block size.
* This stops us from regressing compression ratio too much,
* and ensures our output fits in ZSTD_compressBound().
*
* If this is shrunk < ZSTD_BLOCKSIZELOG_MIN then
* ZSTD_COMPRESSBOUND() will need to be updated.
*/
#define RSYNC_MIN_BLOCK_LOG ZSTD_BLOCKSIZELOG_MAX
#define RSYNC_MIN_BLOCK_SIZE (1<<RSYNC_MIN_BLOCK_LOG)
typedef struct {
U64 hash;
@ -1252,6 +1261,9 @@ size_t ZSTDMT_initCStream_internal(
/* Aim for the targetsectionSize as the average job size. */
U32 const jobSizeKB = (U32)(mtctx->targetSectionSize >> 10);
U32 const rsyncBits = (assert(jobSizeKB >= 1), ZSTD_highbit32(jobSizeKB) + 10);
/* We refuse to create jobs < RSYNC_MIN_BLOCK_SIZE bytes, so make sure our
* expected job size is at least 4x larger. */
assert(rsyncBits >= RSYNC_MIN_BLOCK_LOG + 2);
DEBUGLOG(4, "rsyncLog = %u", rsyncBits);
mtctx->rsync.hash = 0;
mtctx->rsync.hitMask = (1ULL << rsyncBits) - 1;
@ -1678,6 +1690,11 @@ findSynchronizationPoint(ZSTDMT_CCtx const* mtctx, ZSTD_inBuffer const input)
if (!mtctx->params.rsyncable)
/* Rsync is disabled. */
return syncPoint;
if (mtctx->inBuff.filled + input.size - input.pos < RSYNC_MIN_BLOCK_SIZE)
/* We don't emit synchronization points if it would produce too small blocks.
* We don't have enough input to find a synchronization point, so don't look.
*/
return syncPoint;
if (mtctx->inBuff.filled + syncPoint.toLoad < RSYNC_LENGTH)
/* Not enough to compute the hash.
* We will miss any synchronization points in this RSYNC_LENGTH byte
@ -1688,14 +1705,24 @@ findSynchronizationPoint(ZSTDMT_CCtx const* mtctx, ZSTD_inBuffer const input)
*/
return syncPoint;
/* Initialize the loop variables. */
if (mtctx->inBuff.filled >= RSYNC_LENGTH) {
if (mtctx->inBuff.filled < RSYNC_MIN_BLOCK_SIZE - RSYNC_LENGTH) {
/* We don't need to scan the first RSYNC_MIN_BLOCK_SIZE positions
* because they can't possibly be a sync point. So we can start
* part way through the input buffer.
*/
pos = RSYNC_MIN_BLOCK_SIZE - mtctx->inBuff.filled;
assert(pos <= input.size - input.pos /* validated earlier */);
assert(pos >= RSYNC_LENGTH);
prev = istart + pos - RSYNC_LENGTH;
hash = ZSTD_rollingHash_compute(prev, RSYNC_LENGTH);
} else if (mtctx->inBuff.filled >= RSYNC_LENGTH) {
/* We have enough bytes buffered to initialize the hash.
* Start scanning at the beginning of the input.
*/
pos = 0;
prev = (BYTE const*)mtctx->inBuff.buffer.start + mtctx->inBuff.filled - RSYNC_LENGTH;
hash = ZSTD_rollingHash_compute(prev, RSYNC_LENGTH);
if ((hash & hitMask) == hitMask) {
if ((hash & hitMask) == hitMask && mtctx->inBuff.filled >= RSYNC_MIN_BLOCK_SIZE) {
/* We're already at a sync point so don't load any more until
* we're able to flush this sync point.
* This likely happened because the job table was full so we
@ -1728,6 +1755,7 @@ findSynchronizationPoint(ZSTDMT_CCtx const* mtctx, ZSTD_inBuffer const input)
BYTE const toRemove = pos < RSYNC_LENGTH ? prev[pos] : istart[pos - RSYNC_LENGTH];
/* if (pos >= RSYNC_LENGTH) assert(ZSTD_rollingHash_compute(istart + pos - RSYNC_LENGTH, RSYNC_LENGTH) == hash); */
hash = ZSTD_rollingHash_rotate(hash, toRemove, istart[pos], primePower);
assert(mtctx->inBuff.filled + pos >= RSYNC_MIN_BLOCK_SIZE);
if ((hash & hitMask) == hitMask) {
syncPoint.toLoad = pos + 1;
syncPoint.flush = 1;