Merge pull request #3139 from danlark1/dev
[lazy] Optimize ZSTD_row_getMatchMask for levels 8-10 for ARM
This commit is contained in:
commit
1c8a6974c7
@ -954,6 +954,30 @@ void ZSTD_row_update(ZSTD_matchState_t* const ms, const BYTE* ip) {
|
||||
ZSTD_row_update_internal(ms, ip, mls, rowLog, rowMask, 0 /* don't use cache */);
|
||||
}
|
||||
|
||||
/* Returns the mask width of bits group of which will be set to 1. Given not all
|
||||
* architectures have easy movemask instruction, this helps to iterate over
|
||||
* groups of bits easier and faster.
|
||||
*/
|
||||
FORCE_INLINE_TEMPLATE U32
|
||||
ZSTD_row_matchMaskGroupWidth(const U32 rowEntries)
|
||||
{
|
||||
assert((rowEntries == 16) || (rowEntries == 32) || rowEntries == 64);
|
||||
assert(rowEntries <= ZSTD_ROW_HASH_MAX_ENTRIES);
|
||||
(void)rowEntries;
|
||||
#if defined(ZSTD_ARCH_ARM_NEON)
|
||||
if (rowEntries == 16) {
|
||||
return 4;
|
||||
}
|
||||
if (rowEntries == 32) {
|
||||
return 2;
|
||||
}
|
||||
if (rowEntries == 64) {
|
||||
return 1;
|
||||
}
|
||||
#endif
|
||||
return 1;
|
||||
}
|
||||
|
||||
#if defined(ZSTD_ARCH_X86_SSE2)
|
||||
FORCE_INLINE_TEMPLATE ZSTD_VecMask
|
||||
ZSTD_row_getSSEMask(int nbChunks, const BYTE* const src, const BYTE tag, const U32 head)
|
||||
@ -974,67 +998,78 @@ ZSTD_row_getSSEMask(int nbChunks, const BYTE* const src, const BYTE tag, const U
|
||||
}
|
||||
#endif
|
||||
|
||||
/* Returns a ZSTD_VecMask (U32) that has the nth bit set to 1 if the newly-computed "tag" matches
|
||||
* the hash at the nth position in a row of the tagTable.
|
||||
* Each row is a circular buffer beginning at the value of "head". So we must rotate the "matches" bitfield
|
||||
* to match up with the actual layout of the entries within the hashTable */
|
||||
#if defined(ZSTD_ARCH_ARM_NEON)
|
||||
FORCE_INLINE_TEMPLATE ZSTD_VecMask
|
||||
ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head, const U32 rowEntries)
|
||||
ZSTD_row_getNEONMask(const U32 rowEntries, const BYTE* const src, const BYTE tag, const U32 headGrouped)
|
||||
{
|
||||
assert((rowEntries == 16) || (rowEntries == 32) || rowEntries == 64);
|
||||
if (rowEntries == 16) {
|
||||
/* vshrn_n_u16 shifts by 4 every u16 and narrows to 8 lower bits.
|
||||
* After that groups of 4 bits represent the equalMask. We lower
|
||||
* all bits except the highest in these groups by doing AND with
|
||||
* 0x88 = 0b10001000.
|
||||
*/
|
||||
const uint8x16_t chunk = vld1q_u8(src);
|
||||
const uint16x8_t equalMask = vreinterpretq_u16_u8(vceqq_u8(chunk, vdupq_n_u8(tag)));
|
||||
const uint8x8_t res = vshrn_n_u16(equalMask, 4);
|
||||
const U64 matches = vget_lane_u64(vreinterpret_u64_u8(res), 0);
|
||||
return ZSTD_rotateRight_U64(matches, headGrouped) & 0x8888888888888888ull;
|
||||
} else if (rowEntries == 32) {
|
||||
/* Same idea as with rowEntries == 16 but doing AND with
|
||||
* 0x55 = 0b01010101.
|
||||
*/
|
||||
const uint16x8x2_t chunk = vld2q_u16((const uint16_t*)(const void*)src);
|
||||
const uint8x16_t chunk0 = vreinterpretq_u8_u16(chunk.val[0]);
|
||||
const uint8x16_t chunk1 = vreinterpretq_u8_u16(chunk.val[1]);
|
||||
const uint8x16_t dup = vdupq_n_u8(tag);
|
||||
const uint8x8_t t0 = vshrn_n_u16(vreinterpretq_u16_u8(vceqq_u8(chunk0, dup)), 6);
|
||||
const uint8x8_t t1 = vshrn_n_u16(vreinterpretq_u16_u8(vceqq_u8(chunk1, dup)), 6);
|
||||
const uint8x8_t res = vsli_n_u8(t0, t1, 4);
|
||||
const U64 matches = vget_lane_u64(vreinterpret_u64_u8(res), 0) ;
|
||||
return ZSTD_rotateRight_U64(matches, headGrouped) & 0x5555555555555555ull;
|
||||
} else { /* rowEntries == 64 */
|
||||
const uint8x16x4_t chunk = vld4q_u8(src);
|
||||
const uint8x16_t dup = vdupq_n_u8(tag);
|
||||
const uint8x16_t cmp0 = vceqq_u8(chunk.val[0], dup);
|
||||
const uint8x16_t cmp1 = vceqq_u8(chunk.val[1], dup);
|
||||
const uint8x16_t cmp2 = vceqq_u8(chunk.val[2], dup);
|
||||
const uint8x16_t cmp3 = vceqq_u8(chunk.val[3], dup);
|
||||
|
||||
const uint8x16_t t0 = vsriq_n_u8(cmp1, cmp0, 1);
|
||||
const uint8x16_t t1 = vsriq_n_u8(cmp3, cmp2, 1);
|
||||
const uint8x16_t t2 = vsriq_n_u8(t1, t0, 2);
|
||||
const uint8x16_t t3 = vsriq_n_u8(t2, t2, 4);
|
||||
const uint8x8_t t4 = vshrn_n_u16(vreinterpretq_u16_u8(t3), 4);
|
||||
const U64 matches = vget_lane_u64(vreinterpret_u64_u8(t4), 0);
|
||||
return ZSTD_rotateRight_U64(matches, headGrouped);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
/* Returns a ZSTD_VecMask (U64) that has the nth group (determined by
|
||||
* ZSTD_row_matchMaskGroupWidth) of bits set to 1 if the newly-computed "tag"
|
||||
* matches the hash at the nth position in a row of the tagTable.
|
||||
* Each row is a circular buffer beginning at the value of "headGrouped". So we
|
||||
* must rotate the "matches" bitfield to match up with the actual layout of the
|
||||
* entries within the hashTable */
|
||||
FORCE_INLINE_TEMPLATE ZSTD_VecMask
|
||||
ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 headGrouped, const U32 rowEntries)
|
||||
{
|
||||
const BYTE* const src = tagRow + ZSTD_ROW_HASH_TAG_OFFSET;
|
||||
assert((rowEntries == 16) || (rowEntries == 32) || rowEntries == 64);
|
||||
assert(rowEntries <= ZSTD_ROW_HASH_MAX_ENTRIES);
|
||||
assert(ZSTD_row_matchMaskGroupWidth(rowEntries) * rowEntries <= sizeof(ZSTD_VecMask) * 8);
|
||||
|
||||
#if defined(ZSTD_ARCH_X86_SSE2)
|
||||
|
||||
return ZSTD_row_getSSEMask(rowEntries / 16, src, tag, head);
|
||||
return ZSTD_row_getSSEMask(rowEntries / 16, src, tag, headGrouped);
|
||||
|
||||
#else /* SW or NEON-LE */
|
||||
|
||||
# if defined(ZSTD_ARCH_ARM_NEON)
|
||||
/* This NEON path only works for little endian - otherwise use SWAR below */
|
||||
if (MEM_isLittleEndian()) {
|
||||
if (rowEntries == 16) {
|
||||
const uint8x16_t chunk = vld1q_u8(src);
|
||||
const uint16x8_t equalMask = vreinterpretq_u16_u8(vceqq_u8(chunk, vdupq_n_u8(tag)));
|
||||
const uint16x8_t t0 = vshlq_n_u16(equalMask, 7);
|
||||
const uint32x4_t t1 = vreinterpretq_u32_u16(vsriq_n_u16(t0, t0, 14));
|
||||
const uint64x2_t t2 = vreinterpretq_u64_u32(vshrq_n_u32(t1, 14));
|
||||
const uint8x16_t t3 = vreinterpretq_u8_u64(vsraq_n_u64(t2, t2, 28));
|
||||
const U16 hi = (U16)vgetq_lane_u8(t3, 8);
|
||||
const U16 lo = (U16)vgetq_lane_u8(t3, 0);
|
||||
return ZSTD_rotateRight_U16((hi << 8) | lo, head);
|
||||
} else if (rowEntries == 32) {
|
||||
const uint16x8x2_t chunk = vld2q_u16((const U16*)(const void*)src);
|
||||
const uint8x16_t chunk0 = vreinterpretq_u8_u16(chunk.val[0]);
|
||||
const uint8x16_t chunk1 = vreinterpretq_u8_u16(chunk.val[1]);
|
||||
const uint8x16_t equalMask0 = vceqq_u8(chunk0, vdupq_n_u8(tag));
|
||||
const uint8x16_t equalMask1 = vceqq_u8(chunk1, vdupq_n_u8(tag));
|
||||
const int8x8_t pack0 = vqmovn_s16(vreinterpretq_s16_u8(equalMask0));
|
||||
const int8x8_t pack1 = vqmovn_s16(vreinterpretq_s16_u8(equalMask1));
|
||||
const uint8x8_t t0 = vreinterpret_u8_s8(pack0);
|
||||
const uint8x8_t t1 = vreinterpret_u8_s8(pack1);
|
||||
const uint8x8_t t2 = vsri_n_u8(t1, t0, 2);
|
||||
const uint8x8x2_t t3 = vuzp_u8(t2, t0);
|
||||
const uint8x8_t t4 = vsri_n_u8(t3.val[1], t3.val[0], 4);
|
||||
const U32 matches = vget_lane_u32(vreinterpret_u32_u8(t4), 0);
|
||||
return ZSTD_rotateRight_U32(matches, head);
|
||||
} else { /* rowEntries == 64 */
|
||||
const uint8x16x4_t chunk = vld4q_u8(src);
|
||||
const uint8x16_t dup = vdupq_n_u8(tag);
|
||||
const uint8x16_t cmp0 = vceqq_u8(chunk.val[0], dup);
|
||||
const uint8x16_t cmp1 = vceqq_u8(chunk.val[1], dup);
|
||||
const uint8x16_t cmp2 = vceqq_u8(chunk.val[2], dup);
|
||||
const uint8x16_t cmp3 = vceqq_u8(chunk.val[3], dup);
|
||||
|
||||
const uint8x16_t t0 = vsriq_n_u8(cmp1, cmp0, 1);
|
||||
const uint8x16_t t1 = vsriq_n_u8(cmp3, cmp2, 1);
|
||||
const uint8x16_t t2 = vsriq_n_u8(t1, t0, 2);
|
||||
const uint8x16_t t3 = vsriq_n_u8(t2, t2, 4);
|
||||
const uint8x8_t t4 = vshrn_n_u16(vreinterpretq_u16_u8(t3), 4);
|
||||
const U64 matches = vget_lane_u64(vreinterpret_u64_u8(t4), 0);
|
||||
return ZSTD_rotateRight_U64(matches, head);
|
||||
}
|
||||
return ZSTD_row_getNEONMask(rowEntries, src, tag, headGrouped);
|
||||
}
|
||||
# endif /* ZSTD_ARCH_ARM_NEON */
|
||||
/* SWAR */
|
||||
@ -1071,11 +1106,11 @@ ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head,
|
||||
}
|
||||
matches = ~matches;
|
||||
if (rowEntries == 16) {
|
||||
return ZSTD_rotateRight_U16((U16)matches, head);
|
||||
return ZSTD_rotateRight_U16((U16)matches, headGrouped);
|
||||
} else if (rowEntries == 32) {
|
||||
return ZSTD_rotateRight_U32((U32)matches, head);
|
||||
return ZSTD_rotateRight_U32((U32)matches, headGrouped);
|
||||
} else {
|
||||
return ZSTD_rotateRight_U64((U64)matches, head);
|
||||
return ZSTD_rotateRight_U64((U64)matches, headGrouped);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@ -1123,6 +1158,7 @@ size_t ZSTD_RowFindBestMatch(
|
||||
const U32 rowEntries = (1U << rowLog);
|
||||
const U32 rowMask = rowEntries - 1;
|
||||
const U32 cappedSearchLog = MIN(cParams->searchLog, rowLog); /* nb of searches is capped at nb entries per row */
|
||||
const U32 groupWidth = ZSTD_row_matchMaskGroupWidth(rowEntries);
|
||||
U32 nbAttempts = 1U << cappedSearchLog;
|
||||
size_t ml=4-1;
|
||||
|
||||
@ -1165,15 +1201,15 @@ size_t ZSTD_RowFindBestMatch(
|
||||
U32 const tag = hash & ZSTD_ROW_HASH_TAG_MASK;
|
||||
U32* const row = hashTable + relRow;
|
||||
BYTE* tagRow = (BYTE*)(tagTable + relRow);
|
||||
U32 const head = *tagRow & rowMask;
|
||||
U32 const headGrouped = (*tagRow & rowMask) * groupWidth;
|
||||
U32 matchBuffer[ZSTD_ROW_HASH_MAX_ENTRIES];
|
||||
size_t numMatches = 0;
|
||||
size_t currMatch = 0;
|
||||
ZSTD_VecMask matches = ZSTD_row_getMatchMask(tagRow, (BYTE)tag, head, rowEntries);
|
||||
ZSTD_VecMask matches = ZSTD_row_getMatchMask(tagRow, (BYTE)tag, headGrouped, rowEntries);
|
||||
|
||||
/* Cycle through the matches and prefetch */
|
||||
for (; (matches > 0) && (nbAttempts > 0); --nbAttempts, matches &= (matches - 1)) {
|
||||
U32 const matchPos = (head + ZSTD_VecMask_next(matches)) & rowMask;
|
||||
U32 const matchPos = ((headGrouped + ZSTD_VecMask_next(matches)) / groupWidth) & rowMask;
|
||||
U32 const matchIndex = row[matchPos];
|
||||
assert(numMatches < rowEntries);
|
||||
if (matchIndex < lowLimit)
|
||||
@ -1234,14 +1270,14 @@ size_t ZSTD_RowFindBestMatch(
|
||||
const U32 dmsSize = (U32)(dmsEnd - dmsBase);
|
||||
const U32 dmsIndexDelta = dictLimit - dmsSize;
|
||||
|
||||
{ U32 const head = *dmsTagRow & rowMask;
|
||||
{ U32 const headGrouped = (*dmsTagRow & rowMask) * groupWidth;
|
||||
U32 matchBuffer[ZSTD_ROW_HASH_MAX_ENTRIES];
|
||||
size_t numMatches = 0;
|
||||
size_t currMatch = 0;
|
||||
ZSTD_VecMask matches = ZSTD_row_getMatchMask(dmsTagRow, (BYTE)dmsTag, head, rowEntries);
|
||||
ZSTD_VecMask matches = ZSTD_row_getMatchMask(dmsTagRow, (BYTE)dmsTag, headGrouped, rowEntries);
|
||||
|
||||
for (; (matches > 0) && (nbAttempts > 0); --nbAttempts, matches &= (matches - 1)) {
|
||||
U32 const matchPos = (head + ZSTD_VecMask_next(matches)) & rowMask;
|
||||
U32 const matchPos = ((headGrouped + ZSTD_VecMask_next(matches)) / groupWidth) & rowMask;
|
||||
U32 const matchIndex = dmsRow[matchPos];
|
||||
if (matchIndex < dmsLowestIndex)
|
||||
break;
|
||||
|
Loading…
x
Reference in New Issue
Block a user