[libzstd] Fix memcpy() on potential NULL source

* `ZSTD_decompressStream_generic()` `ip` may be `NULL` for one of the calls
  to `memcpy()`
* Assert the source is not `NULL` for calls to `memcpy()` where I believe
  the source should not be `NULL`.
dev
Nick Terrell 2017-07-03 12:31:55 -07:00
parent 1d39550471
commit c80fc50a8d
1 changed files with 12 additions and 3 deletions

View File

@ -1690,6 +1690,7 @@ ZSTD_nextInputType_e ZSTD_nextInputType(ZSTD_DCtx* dctx) {
switch(dctx->stage) switch(dctx->stage)
{ {
default: /* should not happen */ default: /* should not happen */
assert(0);
case ZSTDds_getFrameHeaderSize: case ZSTDds_getFrameHeaderSize:
case ZSTDds_decodeFrameHeader: case ZSTDds_decodeFrameHeader:
return ZSTDnit_frameHeader; return ZSTDnit_frameHeader;
@ -1724,6 +1725,7 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c
{ {
case ZSTDds_getFrameHeaderSize : case ZSTDds_getFrameHeaderSize :
if (srcSize != ZSTD_frameHeaderSize_prefix) return ERROR(srcSize_wrong); /* unauthorized */ if (srcSize != ZSTD_frameHeaderSize_prefix) return ERROR(srcSize_wrong); /* unauthorized */
assert(src != NULL);
if ((MEM_readLE32(src) & 0xFFFFFFF0U) == ZSTD_MAGIC_SKIPPABLE_START) { /* skippable frame */ if ((MEM_readLE32(src) & 0xFFFFFFF0U) == ZSTD_MAGIC_SKIPPABLE_START) { /* skippable frame */
memcpy(dctx->headerBuffer, src, ZSTD_frameHeaderSize_prefix); memcpy(dctx->headerBuffer, src, ZSTD_frameHeaderSize_prefix);
dctx->expected = ZSTD_skippableHeaderSize - ZSTD_frameHeaderSize_prefix; /* magic number + skippable frame length */ dctx->expected = ZSTD_skippableHeaderSize - ZSTD_frameHeaderSize_prefix; /* magic number + skippable frame length */
@ -1741,6 +1743,7 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c
dctx->expected = 0; /* not necessary to copy more */ dctx->expected = 0; /* not necessary to copy more */
case ZSTDds_decodeFrameHeader: case ZSTDds_decodeFrameHeader:
assert(src != NULL);
memcpy(dctx->headerBuffer + ZSTD_frameHeaderSize_prefix, src, dctx->expected); memcpy(dctx->headerBuffer + ZSTD_frameHeaderSize_prefix, src, dctx->expected);
CHECK_F(ZSTD_decodeFrameHeader(dctx, dctx->headerBuffer, dctx->headerSize)); CHECK_F(ZSTD_decodeFrameHeader(dctx, dctx->headerBuffer, dctx->headerSize));
dctx->expected = ZSTD_blockHeaderSize; dctx->expected = ZSTD_blockHeaderSize;
@ -1820,7 +1823,8 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c
return 0; return 0;
} }
case ZSTDds_decodeSkippableHeader: case ZSTDds_decodeSkippableHeader:
{ memcpy(dctx->headerBuffer + ZSTD_frameHeaderSize_prefix, src, dctx->expected); { assert(src != NULL);
memcpy(dctx->headerBuffer + ZSTD_frameHeaderSize_prefix, src, dctx->expected);
dctx->expected = MEM_readLE32(dctx->headerBuffer + 4); dctx->expected = MEM_readLE32(dctx->headerBuffer + 4);
dctx->stage = ZSTDds_skipFrame; dctx->stage = ZSTDds_skipFrame;
return 0; return 0;
@ -2063,6 +2067,8 @@ ZSTD_DDict* ZSTD_initStaticDDict(void* workspace, size_t workspaceSize,
{ {
size_t const neededSpace = sizeof(ZSTD_DDict) + (byReference ? 0 : dictSize); size_t const neededSpace = sizeof(ZSTD_DDict) + (byReference ? 0 : dictSize);
ZSTD_DDict* const ddict = (ZSTD_DDict*)workspace; ZSTD_DDict* const ddict = (ZSTD_DDict*)workspace;
assert(workspace != NULL);
assert(dict != NULL);
if ((size_t)workspace & 7) return NULL; /* 8-aligned */ if ((size_t)workspace & 7) return NULL; /* 8-aligned */
if (workspaceSize < neededSpace) return NULL; if (workspaceSize < neededSpace) return NULL;
if (!byReference) { if (!byReference) {
@ -2321,11 +2327,14 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
if (hSize != 0) { /* need more input */ if (hSize != 0) { /* need more input */
size_t const toLoad = hSize - zds->lhSize; /* if hSize!=0, hSize > zds->lhSize */ size_t const toLoad = hSize - zds->lhSize; /* if hSize!=0, hSize > zds->lhSize */
if (toLoad > (size_t)(iend-ip)) { /* not enough input to load full header */ if (toLoad > (size_t)(iend-ip)) { /* not enough input to load full header */
if (iend-ip > 0) {
memcpy(zds->headerBuffer + zds->lhSize, ip, iend-ip); memcpy(zds->headerBuffer + zds->lhSize, ip, iend-ip);
zds->lhSize += iend-ip; zds->lhSize += iend-ip;
}
input->pos = input->size; input->pos = input->size;
return (MAX(ZSTD_frameHeaderSize_min, hSize) - zds->lhSize) + ZSTD_blockHeaderSize; /* remaining header bytes + next block header */ return (MAX(ZSTD_frameHeaderSize_min, hSize) - zds->lhSize) + ZSTD_blockHeaderSize; /* remaining header bytes + next block header */
} }
assert(ip != NULL);
memcpy(zds->headerBuffer + zds->lhSize, ip, toLoad); zds->lhSize = hSize; ip += toLoad; memcpy(zds->headerBuffer + zds->lhSize, ip, toLoad); zds->lhSize = hSize; ip += toLoad;
break; break;
} } } }