diff --git a/lib/zstd.c b/lib/zstd.c index 37c67144..b70e9684 100644 --- a/lib/zstd.c +++ b/lib/zstd.c @@ -1031,7 +1031,7 @@ struct ZSTD_DCtx_s const BYTE* litPtr; size_t litBufSize; size_t litSize; - BYTE litBuffer[BLOCKSIZE]; + BYTE litBuffer[BLOCKSIZE + 8 /* margin for wildcopy */]; }; /* typedef'd to ZSTD_Dctx within "zstd_static.h" */ @@ -1098,17 +1098,25 @@ size_t ZSTD_decodeLiteralsBlock(void* ctx, default: case 0: { - size_t nbLiterals = BLOCKSIZE; - const size_t readSize = ZSTD_decompressLiterals(dctx->litBuffer, &nbLiterals, src, srcSize); + size_t litSize = BLOCKSIZE; + const size_t readSize = ZSTD_decompressLiterals(dctx->litBuffer, &litSize, src, srcSize); dctx->litPtr = dctx->litBuffer; dctx->litBufSize = BLOCKSIZE; - dctx->litSize = nbLiterals; + dctx->litSize = litSize; return readSize; /* works if it's an error too */ } case IS_RAW: { const size_t litSize = (MEM_readLE32(istart) & 0xFFFFFF) >> 2; /* no buffer issue : srcSize >= MIN_CBLOCK_SIZE */ - if (litSize > srcSize-3) return ERROR(corruption_detected); + if (litSize > srcSize-11) /* risk of reading too far with wildcopy */ + { + if (litSize > srcSize-3) return ERROR(corruption_detected); + memcpy(dctx->litBuffer, istart, litSize); + dctx->litBufSize = BLOCKSIZE; + dctx->litSize = litSize; + return litSize+3; + } + /* direct reference into compressed stream */ dctx->litPtr = istart+3; dctx->litBufSize = srcSize-3; dctx->litSize = litSize; @@ -1328,10 +1336,10 @@ static size_t ZSTD_execSequence(BYTE* op, BYTE* const oend_8 = oend-8; const BYTE* const litEnd = *litPtr + sequence.litLength; - /* check */ + /* checks */ if (oLitEnd > oend_8) return ERROR(dstSize_tooSmall); /* last match must start at a minimum distance of 8 from oend */ if (oMatchEnd > oend) return ERROR(dstSize_tooSmall); /* overwrite beyond dst buffer */ - if (litEnd > litLimit) return ERROR(corruption_detected); /* overRead beyond lit buffer */ + if (litEnd > litLimit-8) return ERROR(corruption_detected); /* overRead beyond lit buffer */ /* copy Literals */ ZSTD_wildcopy(op, *litPtr, sequence.litLength); /* note : oLitEnd <= oend-8 : no risk of overwrite beyond oend */