diff --git a/lib/dictBuilder/cover.c b/lib/dictBuilder/cover.c index 5fd2c9c7..176c386c 100644 --- a/lib/dictBuilder/cover.c +++ b/lib/dictBuilder/cover.c @@ -223,10 +223,10 @@ static COVER_ctx_t *g_ctx = NULL; /** * Returns the sum of the sample sizes. */ -static size_t COVER_sum(const size_t *samplesSizes, unsigned firstSample, unsigned lastSample) { +static size_t COVER_sum(const size_t *samplesSizes, unsigned nbSamples) { size_t sum = 0; unsigned i; - for (i = firstSample; i < lastSample; ++i) { + for (i = 0; i < nbSamples; ++i) { sum += samplesSizes[i]; } return sum; @@ -540,13 +540,12 @@ static int COVER_ctx_init(COVER_ctx_t *ctx, const void *samplesBuffer, const size_t *samplesSizes, unsigned nbSamples, unsigned d, double splitPoint) { const BYTE *const samples = (const BYTE *)samplesBuffer; - const unsigned kFirst = 0; - const size_t totalSamplesSize = COVER_sum(samplesSizes, kFirst, nbSamples); + const size_t totalSamplesSize = COVER_sum(samplesSizes, nbSamples); /* Split samples into testing and training sets */ const unsigned nbTrainSamples = splitPoint < 1.0 ? (unsigned)((double)nbSamples * splitPoint) : nbSamples; const unsigned nbTestSamples = splitPoint < 1.0 ? nbSamples - nbTrainSamples : nbSamples; - const size_t trainingSamplesSize = splitPoint < 1.0 ? COVER_sum(samplesSizes, kFirst, nbTrainSamples) : totalSamplesSize; - const size_t testSamplesSize = splitPoint < 1.0 ? COVER_sum(samplesSizes, nbTrainSamples, nbSamples) : totalSamplesSize; + const size_t trainingSamplesSize = splitPoint < 1.0 ? COVER_sum(samplesSizes, nbTrainSamples) : totalSamplesSize; + const size_t testSamplesSize = splitPoint < 1.0 ? COVER_sum(samplesSizes + nbTrainSamples, nbTestSamples) : totalSamplesSize; /* Checks */ if (totalSamplesSize < MAX(d, sizeof(U64)) || totalSamplesSize >= (size_t)COVER_MAX_SAMPLES_SIZE) { @@ -564,11 +563,6 @@ static int COVER_ctx_init(COVER_ctx_t *ctx, const void *samplesBuffer, DISPLAYLEVEL(1, "Total number of testing samples is %u and is invalid.", nbTestSamples); return 0; } - /* Check if nbTrainSamples plus nbTestSamples add up to nbSamples when splitPoint is less than 1*/ - if (nbTrainSamples + nbTestSamples != nbSamples && splitPoint < 1.0) { - DISPLAYLEVEL(1, "nbTrainSamples plus nbTestSamples don't add up to nbSamples"); - return 0; - } /* Zero the context */ memset(ctx, 0, sizeof(*ctx)); DISPLAYLEVEL(2, "Training on %u samples of total size %u\n", nbTrainSamples, @@ -976,7 +970,7 @@ ZDICTLIB_API size_t ZDICT_optimizeTrainFromBuffer_cover( /* constants */ const unsigned nbThreads = parameters->nbThreads; const double splitPoint = - (parameters->splitPoint <= 0.0 || parameters->splitPoint > 1.0) ? DEFAULT_SPLITPOINT : parameters->splitPoint; + parameters->splitPoint <= 0.0 ? DEFAULT_SPLITPOINT : parameters->splitPoint; const unsigned kMinD = parameters->d == 0 ? 6 : parameters->d; const unsigned kMaxD = parameters->d == 0 ? 8 : parameters->d; const unsigned kMinK = parameters->k == 0 ? 50 : parameters->k; diff --git a/lib/dictBuilder/zdict.h b/lib/dictBuilder/zdict.h index 45d78b05..8244c3ba 100644 --- a/lib/dictBuilder/zdict.h +++ b/lib/dictBuilder/zdict.h @@ -86,7 +86,7 @@ typedef struct { unsigned d; /* dmer size : constraint: 0 < d <= k : Reasonable range [6, 16] */ unsigned steps; /* Number of steps : Only used for optimization : 0 means default (32) : Higher means more parameters checked */ unsigned nbThreads; /* Number of threads : constraint: 0 < nbThreads : 1 means single-threaded : Only used for optimization : Ignored if ZSTD_MULTITHREAD is not defined */ - double splitPoint; /* Percentage of samples used for training: the first nbSamples * splitPoint samples will be used to training */ + double splitPoint; /* Percentage of samples used for training: the first nbSamples * splitPoint samples will be used to training, 0 means default (0.8) */ ZDICT_params_t zParams; } ZDICT_cover_params_t;