ensures that sampleSizes table is large enough

as recommended by @terrelln
This commit is contained in:
Yann Collet 2017-09-15 15:31:31 -07:00
parent 25a60488dd
commit c68d17f2da

View File

@ -97,11 +97,16 @@ const char* DiB_getErrorName(size_t errorCode) { return ERR_getErrorName(errorCo
* File related operations * File related operations
**********************************************************/ **********************************************************/
/** DiB_loadFiles() : /** DiB_loadFiles() :
* load files listed in fileNamesTable into buffer, even if buffer is too small. * load samples from files listed in fileNamesTable into buffer.
* @return : nb of files effectively loaded into `buffer` * works even if buffer is too small to load all samples.
* *bufferSizePtr is modified, it provides the amount data loaded within buffer */ * Also provides the size of each sample into sampleSizes table
* which must be sized correctly, using DiB_fileStats().
* @return : nb of samples effectively loaded into `buffer`
* *bufferSizePtr is modified, it provides the amount data loaded within buffer.
* sampleSizes is filled with the size of each sample.
*/
static unsigned DiB_loadFiles(void* buffer, size_t* bufferSizePtr, static unsigned DiB_loadFiles(void* buffer, size_t* bufferSizePtr,
size_t* chunkSizes, size_t* sampleSizes, unsigned sstSize,
const char** fileNamesTable, unsigned nbFiles, size_t targetChunkSize, const char** fileNamesTable, unsigned nbFiles, size_t targetChunkSize,
unsigned displayLevel) unsigned displayLevel)
{ {
@ -126,8 +131,12 @@ static unsigned DiB_loadFiles(void* buffer, size_t* bufferSizePtr,
{ size_t const readSize = fread(buff+pos, 1, toLoad, f); { size_t const readSize = fread(buff+pos, 1, toLoad, f);
if (readSize != toLoad) EXM_THROW(11, "Pb reading %s", fileName); if (readSize != toLoad) EXM_THROW(11, "Pb reading %s", fileName);
pos += readSize; pos += readSize;
chunkSizes[nbLoadedChunks++] = toLoad; sampleSizes[nbLoadedChunks++] = toLoad;
remainingToLoad -= targetChunkSize; remainingToLoad -= targetChunkSize;
if (nbLoadedChunks == sstSize) { /* no more space left in sampleSizes table */
fileIndex = nbFiles; /* stop there */
break;
}
if (toLoad < targetChunkSize) { if (toLoad < targetChunkSize) {
fseek(f, (long)(targetChunkSize - toLoad), SEEK_CUR); fseek(f, (long)(targetChunkSize - toLoad), SEEK_CUR);
} } } } } }
@ -221,9 +230,14 @@ static void DiB_saveDict(const char* dictFileName,
typedef struct { typedef struct {
U64 totalSizeToLoad; U64 totalSizeToLoad;
unsigned oneSampleTooLarge; unsigned oneSampleTooLarge;
unsigned nbChunks; unsigned nbSamples;
} fileStats; } fileStats;
/*! DiB_fileStats() :
* Given a list of files, and a chunkSize (0 == no chunk, whole files)
* provides the amount of data to be loaded and the resulting nb of samples.
* This is useful primarily for allocation purpose => sample buffer, and sample sizes table.
*/
static fileStats DiB_fileStats(const char** fileNamesTable, unsigned nbFiles, size_t chunkSize, unsigned displayLevel) static fileStats DiB_fileStats(const char** fileNamesTable, unsigned nbFiles, size_t chunkSize, unsigned displayLevel)
{ {
fileStats fs; fileStats fs;
@ -231,12 +245,12 @@ static fileStats DiB_fileStats(const char** fileNamesTable, unsigned nbFiles, si
memset(&fs, 0, sizeof(fs)); memset(&fs, 0, sizeof(fs));
for (n=0; n<nbFiles; n++) { for (n=0; n<nbFiles; n++) {
U64 const fileSize = UTIL_getFileSize(fileNamesTable[n]); U64 const fileSize = UTIL_getFileSize(fileNamesTable[n]);
U32 const nbChunks = (U32)(chunkSize ? (fileSize + (chunkSize-1)) / chunkSize : 1); U32 const nbSamples = (U32)(chunkSize ? (fileSize + (chunkSize-1)) / chunkSize : 1);
U64 const chunkToLoad = chunkSize ? MIN(chunkSize, fileSize) : fileSize; U64 const chunkToLoad = chunkSize ? MIN(chunkSize, fileSize) : fileSize;
size_t const cappedChunkSize = (size_t)MIN(chunkToLoad, SAMPLESIZE_MAX); size_t const cappedChunkSize = (size_t)MIN(chunkToLoad, SAMPLESIZE_MAX);
fs.totalSizeToLoad += cappedChunkSize * nbChunks; fs.totalSizeToLoad += cappedChunkSize * nbSamples;
fs.oneSampleTooLarge |= (chunkSize > 2*SAMPLESIZE_MAX); fs.oneSampleTooLarge |= (chunkSize > 2*SAMPLESIZE_MAX);
fs.nbChunks += nbChunks; fs.nbSamples += nbSamples;
} }
DISPLAYLEVEL(4, "Preparing to load : %u KB \n", (U32)(fs.totalSizeToLoad >> 10)); DISPLAYLEVEL(4, "Preparing to load : %u KB \n", (U32)(fs.totalSizeToLoad >> 10));
return fs; return fs;
@ -260,12 +274,12 @@ int DiB_trainFromFiles(const char* dictFileName, unsigned maxDictSize,
ZDICT_legacy_params_t *params, ZDICT_cover_params_t *coverParams, ZDICT_legacy_params_t *params, ZDICT_cover_params_t *coverParams,
int optimizeCover) int optimizeCover)
{ {
unsigned displayLevel = params ? params->zParams.notificationLevel : unsigned const displayLevel = params ? params->zParams.notificationLevel :
coverParams ? coverParams->zParams.notificationLevel : coverParams ? coverParams->zParams.notificationLevel :
0; /* should never happen */ 0; /* should never happen */
void* const dictBuffer = malloc(maxDictSize); void* const dictBuffer = malloc(maxDictSize);
fileStats const fs = DiB_fileStats(fileNamesTable, nbFiles, chunkSize, displayLevel); fileStats const fs = DiB_fileStats(fileNamesTable, nbFiles, chunkSize, displayLevel);
size_t* const chunkSizes = (size_t*)malloc(fs.nbChunks * sizeof(size_t)); size_t* const sampleSizes = (size_t*)malloc(fs.nbSamples * sizeof(size_t));
size_t const memMult = params ? MEMMULT : COVER_MEMMULT; size_t const memMult = params ? MEMMULT : COVER_MEMMULT;
size_t const maxMem = DiB_findMaxMem(fs.totalSizeToLoad * memMult) / memMult; size_t const maxMem = DiB_findMaxMem(fs.totalSizeToLoad * memMult) / memMult;
size_t loadedSize = (size_t) MIN ((unsigned long long)maxMem, fs.totalSizeToLoad); size_t loadedSize = (size_t) MIN ((unsigned long long)maxMem, fs.totalSizeToLoad);
@ -273,14 +287,14 @@ int DiB_trainFromFiles(const char* dictFileName, unsigned maxDictSize,
int result = 0; int result = 0;
/* Checks */ /* Checks */
if ((!chunkSizes) || (!srcBuffer) || (!dictBuffer)) if ((!sampleSizes) || (!srcBuffer) || (!dictBuffer))
EXM_THROW(12, "not enough memory for DiB_trainFiles"); /* should not happen */ EXM_THROW(12, "not enough memory for DiB_trainFiles"); /* should not happen */
if (fs.oneSampleTooLarge) { if (fs.oneSampleTooLarge) {
DISPLAYLEVEL(2, "! Warning : some sample(s) are very large \n"); DISPLAYLEVEL(2, "! Warning : some sample(s) are very large \n");
DISPLAYLEVEL(2, "! Note that dictionary is only useful for small samples. \n"); DISPLAYLEVEL(2, "! Note that dictionary is only useful for small samples. \n");
DISPLAYLEVEL(2, "! As a consequence, only the first %u bytes of each sample are loaded \n", SAMPLESIZE_MAX); DISPLAYLEVEL(2, "! As a consequence, only the first %u bytes of each sample are loaded \n", SAMPLESIZE_MAX);
} }
if (fs.nbChunks < 5) { if (fs.nbSamples < 5) {
DISPLAYLEVEL(2, "! Warning : nb of samples too low for proper processing ! \n"); DISPLAYLEVEL(2, "! Warning : nb of samples too low for proper processing ! \n");
DISPLAYLEVEL(2, "! Please provide _one file per sample_. \n"); DISPLAYLEVEL(2, "! Please provide _one file per sample_. \n");
EXM_THROW(14, "nb of samples too low"); /* we now clearly forbid this case */ EXM_THROW(14, "nb of samples too low"); /* we now clearly forbid this case */
@ -297,24 +311,24 @@ int DiB_trainFromFiles(const char* dictFileName, unsigned maxDictSize,
/* Load input buffer */ /* Load input buffer */
DISPLAYLEVEL(3, "Shuffling input files\n"); DISPLAYLEVEL(3, "Shuffling input files\n");
DiB_shuffle(fileNamesTable, nbFiles); DiB_shuffle(fileNamesTable, nbFiles);
nbFiles = DiB_loadFiles(srcBuffer, &loadedSize, chunkSizes, fileNamesTable, nbFiles, chunkSize, displayLevel); nbFiles = DiB_loadFiles(srcBuffer, &loadedSize, sampleSizes, fs.nbSamples, fileNamesTable, nbFiles, chunkSize, displayLevel);
{ size_t dictSize; { size_t dictSize;
if (params) { if (params) {
DiB_fillNoise((char*)srcBuffer + loadedSize, NOISELENGTH); /* guard band, for end of buffer condition */ DiB_fillNoise((char*)srcBuffer + loadedSize, NOISELENGTH); /* guard band, for end of buffer condition */
dictSize = ZDICT_trainFromBuffer_unsafe_legacy(dictBuffer, maxDictSize, dictSize = ZDICT_trainFromBuffer_unsafe_legacy(dictBuffer, maxDictSize,
srcBuffer, chunkSizes, fs.nbChunks, srcBuffer, sampleSizes, fs.nbSamples,
*params); *params);
} else if (optimizeCover) { } else if (optimizeCover) {
dictSize = ZDICT_optimizeTrainFromBuffer_cover(dictBuffer, maxDictSize, dictSize = ZDICT_optimizeTrainFromBuffer_cover(dictBuffer, maxDictSize,
srcBuffer, chunkSizes, fs.nbChunks, srcBuffer, sampleSizes, fs.nbSamples,
coverParams); coverParams);
if (!ZDICT_isError(dictSize)) { if (!ZDICT_isError(dictSize)) {
DISPLAYLEVEL(2, "k=%u\nd=%u\nsteps=%u\n", coverParams->k, coverParams->d, coverParams->steps); DISPLAYLEVEL(2, "k=%u\nd=%u\nsteps=%u\n", coverParams->k, coverParams->d, coverParams->steps);
} }
} else { } else {
dictSize = ZDICT_trainFromBuffer_cover(dictBuffer, maxDictSize, srcBuffer, dictSize = ZDICT_trainFromBuffer_cover(dictBuffer, maxDictSize, srcBuffer,
chunkSizes, fs.nbChunks, *coverParams); sampleSizes, fs.nbSamples, *coverParams);
} }
if (ZDICT_isError(dictSize)) { if (ZDICT_isError(dictSize)) {
DISPLAYLEVEL(1, "dictionary training failed : %s \n", ZDICT_getErrorName(dictSize)); /* should not happen */ DISPLAYLEVEL(1, "dictionary training failed : %s \n", ZDICT_getErrorName(dictSize)); /* should not happen */
@ -329,7 +343,7 @@ int DiB_trainFromFiles(const char* dictFileName, unsigned maxDictSize,
/* clean up */ /* clean up */
_cleanup: _cleanup:
free(srcBuffer); free(srcBuffer);
free(chunkSizes); free(sampleSizes);
free(dictBuffer); free(dictBuffer);
return result; return result;
} }