Refactor prefetching for the decoding loop

Following #2545,
I noticed that one field in `seq_t` is optional,
and only used in combination with prefetching.
(This may have contributed to static analyzer failure to detect correct initialization).

I then wondered if it would be possible to rewrite the code
so that this optional part is handled directly by the prefetching code
rather than delegated as an option into `ZSTD_decodeSequence()`.

This resulted into this refactoring exercise
where the prefetching responsibility is better isolated into its own function
and `ZSTD_decodeSequence()` is streamlined to contain strictly Sequence decoding operations.
Incidently, due to better code locality,
it reduces the need to send information around,
leading to simplified interface, and smaller state structures.
dev
Yann Collet 2021-03-18 11:28:22 -07:00
parent eace4abc25
commit f5434663ea
1 changed files with 28 additions and 22 deletions

View File

@ -658,7 +658,6 @@ typedef struct {
size_t litLength;
size_t matchLength;
size_t offset;
const BYTE* match;
} seq_t;
typedef struct {
@ -672,9 +671,6 @@ typedef struct {
ZSTD_fseState stateOffb;
ZSTD_fseState stateML;
size_t prevOffset[ZSTD_REP_NUM];
const BYTE* prefixStart;
const BYTE* dictEnd;
size_t pos;
} seqState_t;
/*! ZSTD_overlapCopy8() :
@ -936,10 +932,9 @@ ZSTD_updateFseStateWithDInfo(ZSTD_fseState* DStatePtr, BIT_DStream_t* bitD, ZSTD
: 0)
typedef enum { ZSTD_lo_isRegularOffset, ZSTD_lo_isLongOffset=1 } ZSTD_longOffset_e;
typedef enum { ZSTD_p_noPrefetch=0, ZSTD_p_prefetch=1 } ZSTD_prefetch_e;
FORCE_INLINE_TEMPLATE seq_t
ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets, const ZSTD_prefetch_e prefetch)
ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets)
{
seq_t seq;
ZSTD_seqSymbol const llDInfo = seqState->stateLL.table[seqState->stateLL.state];
@ -1014,14 +1009,6 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets, c
DEBUGLOG(6, "seq: litL=%u, matchL=%u, offset=%u",
(U32)seq.litLength, (U32)seq.matchLength, (U32)seq.offset);
if (prefetch == ZSTD_p_prefetch) {
size_t const pos = seqState->pos + seq.litLength;
const BYTE* const matchBase = (seq.offset > pos) ? seqState->dictEnd : seqState->prefixStart;
seq.match = matchBase + pos - seq.offset; /* note : this operation can overflow when seq.offset is really too large, which can only happen when input is corrupted.
* No consequence though : no memory access will occur, offset is only used for prefetching */
seqState->pos = pos + seq.matchLength;
}
/* ANS state update
* gcc-9.0.0 does 2.5% worse with ZSTD_updateFseStateWithDInfo().
* clang-9.2.0 does 7% worse with ZSTD_updateFseState().
@ -1180,7 +1167,7 @@ ZSTD_decompressSequences_body( ZSTD_DCtx* dctx,
__asm__(".p2align 4");
#endif
for ( ; ; ) {
seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset, ZSTD_p_noPrefetch);
seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset);
size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litEnd, prefixStart, vBase, dictEnd);
#if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE)
assert(!ZSTD_isError(oneSeqSize));
@ -1232,6 +1219,24 @@ ZSTD_decompressSequences_default(ZSTD_DCtx* dctx,
#endif /* ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG */
#ifndef ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT
FORCE_INLINE_TEMPLATE size_t
ZSTD_prefetchMatch(size_t prefixPos, seq_t const sequence,
const BYTE* const prefixStart, const BYTE* const dictEnd)
{
prefixPos += sequence.litLength;
{ const BYTE* const matchBase = (sequence.offset > prefixPos) ? dictEnd : prefixStart;
const BYTE* const match = matchBase + prefixPos - sequence.offset; /* note : this operation can overflow when seq.offset is really too large, which can only happen when input is corrupted.
* No consequence though : no memory access will occur, offset is only used for prefetching */
PREFETCH_L1(match); PREFETCH_L1(match + sequence.matchLength - 1); /* note : it's safe to invoke PREFETCH() on any memory address, including invalid ones */
}
return prefixPos + sequence.matchLength;
}
/* This decoding function employs prefetching
* to reduce latency impact of cache misses.
* It's generally employed when block contains a significant portion of long-distance matches
* or when coupled with a "cold" dictionary */
FORCE_INLINE_TEMPLATE size_t
ZSTD_decompressSequencesLong_body(
ZSTD_DCtx* dctx,
@ -1261,11 +1266,10 @@ ZSTD_decompressSequencesLong_body(
int const seqAdvance = MIN(nbSeq, ADVANCED_SEQS);
seqState_t seqState;
int seqNb;
size_t prefixPos = (size_t)(op-prefixStart); /* track position relative to prefixStart */
dctx->fseEntropy = 1;
{ int i; for (i=0; i<ZSTD_REP_NUM; i++) seqState.prevOffset[i] = dctx->entropy.rep[i]; }
seqState.prefixStart = prefixStart;
seqState.pos = (size_t)(op-prefixStart);
seqState.dictEnd = dictEnd;
assert(dst != NULL);
assert(iend >= ip);
RETURN_ERROR_IF(
@ -1277,21 +1281,23 @@ ZSTD_decompressSequencesLong_body(
/* prepare in advance */
for (seqNb=0; (BIT_reloadDStream(&seqState.DStream) <= BIT_DStream_completed) && (seqNb<seqAdvance); seqNb++) {
sequences[seqNb] = ZSTD_decodeSequence(&seqState, isLongOffset, ZSTD_p_prefetch);
PREFETCH_L1(sequences[seqNb].match); PREFETCH_L1(sequences[seqNb].match + sequences[seqNb].matchLength - 1); /* note : it's safe to invoke PREFETCH() on any memory address, including invalid ones */
seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset);
prefixPos = ZSTD_prefetchMatch(prefixPos, sequence, prefixStart, dictEnd);
sequences[seqNb] = sequence;
}
RETURN_ERROR_IF(seqNb<seqAdvance, corruption_detected, "");
/* decode and decompress */
for ( ; (BIT_reloadDStream(&(seqState.DStream)) <= BIT_DStream_completed) && (seqNb<nbSeq) ; seqNb++) {
seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset, ZSTD_p_prefetch);
seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset);
size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequences[(seqNb-ADVANCED_SEQS) & STORED_SEQS_MASK], &litPtr, litEnd, prefixStart, dictStart, dictEnd);
#if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE)
assert(!ZSTD_isError(oneSeqSize));
if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequences[(seqNb-ADVANCED_SEQS) & STORED_SEQS_MASK], prefixStart, dictStart);
#endif
if (ZSTD_isError(oneSeqSize)) return oneSeqSize;
PREFETCH_L1(sequence.match); PREFETCH_L1(sequence.match + sequence.matchLength - 1); /* note : it's safe to invoke PREFETCH() on any memory address, including invalid ones */
prefixPos = ZSTD_prefetchMatch(prefixPos, sequence, prefixStart, dictEnd);
sequences[seqNb & STORED_SEQS_MASK] = sequence;
op += oneSeqSize;
}