Merge pull request #53588 from kitaisreal/aarch64-neon-memequal-wide

AARCH64 Neon memequal wide
This commit is contained in:
Alexey Milovidov 2023-08-21 21:11:01 +03:00 committed by GitHub
commit 316664456f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 67 additions and 34 deletions

View File

@ -11,6 +11,7 @@
#include <base/defines.h> #include <base/defines.h>
#include <base/types.h> #include <base/types.h>
#include <base/unaligned.h> #include <base/unaligned.h>
#include <base/simd.h>
#include <city.h> #include <city.h>
@ -29,6 +30,11 @@
#define CRC_INT __crc32cd #define CRC_INT __crc32cd
#endif #endif
#if defined(__aarch64__) && defined(__ARM_NEON)
#include <arm_neon.h>
#pragma clang diagnostic ignored "-Wreserved-identifier"
#endif
/** /**
* The std::string_view-like container to avoid creating strings to find substrings in the hash table. * The std::string_view-like container to avoid creating strings to find substrings in the hash table.
@ -74,14 +80,14 @@ using StringRefs = std::vector<StringRef>;
* For more information, see hash_map_string_2.cpp * For more information, see hash_map_string_2.cpp
*/ */
inline bool compareSSE2(const char * p1, const char * p2) inline bool compare8(const char * p1, const char * p2)
{ {
return 0xFFFF == _mm_movemask_epi8(_mm_cmpeq_epi8( return 0xFFFF == _mm_movemask_epi8(_mm_cmpeq_epi8(
_mm_loadu_si128(reinterpret_cast<const __m128i *>(p1)), _mm_loadu_si128(reinterpret_cast<const __m128i *>(p1)),
_mm_loadu_si128(reinterpret_cast<const __m128i *>(p2)))); _mm_loadu_si128(reinterpret_cast<const __m128i *>(p2))));
} }
inline bool compareSSE2x4(const char * p1, const char * p2) inline bool compare64(const char * p1, const char * p2)
{ {
return 0xFFFF == _mm_movemask_epi8( return 0xFFFF == _mm_movemask_epi8(
_mm_and_si128( _mm_and_si128(
@ -101,7 +107,30 @@ inline bool compareSSE2x4(const char * p1, const char * p2)
_mm_loadu_si128(reinterpret_cast<const __m128i *>(p2) + 3))))); _mm_loadu_si128(reinterpret_cast<const __m128i *>(p2) + 3)))));
} }
inline bool memequalSSE2Wide(const char * p1, const char * p2, size_t size) #elif defined(__aarch64__) && defined(__ARM_NEON)
inline bool compare8(const char * p1, const char * p2)
{
uint64_t mask = getNibbleMask(vceqq_u8(
vld1q_u8(reinterpret_cast<const unsigned char *>(p1)), vld1q_u8(reinterpret_cast<const unsigned char *>(p2))));
return 0xFFFFFFFFFFFFFFFF == mask;
}
inline bool compare64(const char * p1, const char * p2)
{
uint64_t mask = getNibbleMask(vandq_u8(
vandq_u8(vceqq_u8(vld1q_u8(reinterpret_cast<const unsigned char *>(p1)), vld1q_u8(reinterpret_cast<const unsigned char *>(p2))),
vceqq_u8(vld1q_u8(reinterpret_cast<const unsigned char *>(p1 + 16)), vld1q_u8(reinterpret_cast<const unsigned char *>(p2 + 16)))),
vandq_u8(vceqq_u8(vld1q_u8(reinterpret_cast<const unsigned char *>(p1 + 32)), vld1q_u8(reinterpret_cast<const unsigned char *>(p2 + 32))),
vceqq_u8(vld1q_u8(reinterpret_cast<const unsigned char *>(p1 + 48)), vld1q_u8(reinterpret_cast<const unsigned char *>(p2 + 48))))));
return 0xFFFFFFFFFFFFFFFF == mask;
}
#endif
#if defined(__SSE2__) || (defined(__aarch64__) && defined(__ARM_NEON))
inline bool memequalWide(const char * p1, const char * p2, size_t size)
{ {
/** The order of branches and the trick with overlapping comparisons /** The order of branches and the trick with overlapping comparisons
* are the same as in memcpy implementation. * are the same as in memcpy implementation.
@ -138,7 +167,7 @@ inline bool memequalSSE2Wide(const char * p1, const char * p2, size_t size)
while (size >= 64) while (size >= 64)
{ {
if (compareSSE2x4(p1, p2)) if (compare64(p1, p2))
{ {
p1 += 64; p1 += 64;
p2 += 64; p2 += 64;
@ -150,17 +179,16 @@ inline bool memequalSSE2Wide(const char * p1, const char * p2, size_t size)
switch (size / 16) switch (size / 16)
{ {
case 3: if (!compareSSE2(p1 + 32, p2 + 32)) return false; [[fallthrough]]; case 3: if (!compare8(p1 + 32, p2 + 32)) return false; [[fallthrough]];
case 2: if (!compareSSE2(p1 + 16, p2 + 16)) return false; [[fallthrough]]; case 2: if (!compare8(p1 + 16, p2 + 16)) return false; [[fallthrough]];
case 1: if (!compareSSE2(p1, p2)) return false; case 1: if (!compare8(p1, p2)) return false;
} }
return compareSSE2(p1 + size - 16, p2 + size - 16); return compare8(p1 + size - 16, p2 + size - 16);
} }
#endif #endif
inline bool operator== (StringRef lhs, StringRef rhs) inline bool operator== (StringRef lhs, StringRef rhs)
{ {
if (lhs.size != rhs.size) if (lhs.size != rhs.size)
@ -169,8 +197,8 @@ inline bool operator== (StringRef lhs, StringRef rhs)
if (lhs.size == 0) if (lhs.size == 0)
return true; return true;
#if defined(__SSE2__) #if defined(__SSE2__) || (defined(__aarch64__) && defined(__ARM_NEON))
return memequalSSE2Wide(lhs.data, rhs.data, lhs.size); return memequalWide(lhs.data, rhs.data, lhs.size);
#else #else
return 0 == memcmp(lhs.data, rhs.data, lhs.size); return 0 == memcmp(lhs.data, rhs.data, lhs.size);
#endif #endif

14
base/base/simd.h Normal file
View File

@ -0,0 +1,14 @@
#pragma once
#if defined(__aarch64__) && defined(__ARM_NEON)
# include <arm_neon.h>
# pragma clang diagnostic ignored "-Wreserved-identifier"
/// Returns a 64 bit mask of nibbles (4 bits for each byte).
inline uint64_t getNibbleMask(uint8x16_t res)
{
return vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(res), 4)), 0);
}
#endif

View File

@ -2,6 +2,7 @@
#include <optional> #include <optional>
#include <base/types.h> #include <base/types.h>
#include <base/simd.h>
#include <Common/BitHelpers.h> #include <Common/BitHelpers.h>
#include <Poco/UTF8Encoding.h> #include <Poco/UTF8Encoding.h>
@ -72,16 +73,13 @@ inline size_t countCodePoints(const UInt8 * data, size_t size)
res += __builtin_popcount(_mm_movemask_epi8( res += __builtin_popcount(_mm_movemask_epi8(
_mm_cmpgt_epi8(_mm_loadu_si128(reinterpret_cast<const __m128i *>(data)), threshold))); _mm_cmpgt_epi8(_mm_loadu_si128(reinterpret_cast<const __m128i *>(data)), threshold)));
#elif defined(__aarch64__) && defined(__ARM_NEON) #elif defined(__aarch64__) && defined(__ARM_NEON)
/// Returns a 64 bit mask of nibbles (4 bits for each byte).
auto get_nibble_mask
= [](uint8x16_t input) -> uint64_t { return vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(input), 4)), 0); };
constexpr auto bytes_sse = 16; constexpr auto bytes_sse = 16;
const auto * src_end_sse = data + size / bytes_sse * bytes_sse; const auto * src_end_sse = data + size / bytes_sse * bytes_sse;
const auto threshold = vdupq_n_s8(0xBF); const auto threshold = vdupq_n_s8(0xBF);
for (; data < src_end_sse; data += bytes_sse) for (; data < src_end_sse; data += bytes_sse)
res += std::popcount(get_nibble_mask(vcgtq_s8(vld1q_s8(reinterpret_cast<const int8_t *>(data)), threshold))); res += std::popcount(getNibbleMask(vcgtq_s8(vld1q_s8(reinterpret_cast<const int8_t *>(data)), threshold)));
res >>= 2; res >>= 2;
#endif #endif

View File

@ -4,6 +4,8 @@
#include <bit> #include <bit>
#include <cstdint> #include <cstdint>
#include <base/simd.h>
#include <Core/Defines.h> #include <Core/Defines.h>
@ -504,11 +506,6 @@ inline bool memoryIsZeroSmallAllowOverflow15(const void * data, size_t size)
# include <arm_neon.h> # include <arm_neon.h>
# pragma clang diagnostic ignored "-Wreserved-identifier" # pragma clang diagnostic ignored "-Wreserved-identifier"
inline uint64_t getNibbleMask(uint8x16_t res)
{
return vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(res), 4)), 0);
}
template <typename Char> template <typename Char>
inline int memcmpSmallAllowOverflow15(const Char * a, size_t a_size, const Char * b, size_t b_size) inline int memcmpSmallAllowOverflow15(const Char * a, size_t a_size, const Char * b, size_t b_size)
{ {

View File

@ -7,6 +7,8 @@
#include <string_view> #include <string_view>
#include <base/simd.h>
#ifdef __SSE2__ #ifdef __SSE2__
# include <emmintrin.h> # include <emmintrin.h>
#endif #endif
@ -73,16 +75,13 @@ struct ToValidUTF8Impl
/// Fast skip of ASCII for aarch64. /// Fast skip of ASCII for aarch64.
static constexpr size_t SIMD_BYTES = 16; static constexpr size_t SIMD_BYTES = 16;
const char * simd_end = p + (end - p) / SIMD_BYTES * SIMD_BYTES; const char * simd_end = p + (end - p) / SIMD_BYTES * SIMD_BYTES;
/// Returns a 64 bit mask of nibbles (4 bits for each byte).
auto get_nibble_mask = [](uint8x16_t input) -> uint64_t
{ return vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(input), 4)), 0); };
/// Other options include /// Other options include
/// vmaxvq_u8(input) < 0b10000000; /// vmaxvq_u8(input) < 0b10000000;
/// Used by SIMDJSON, has latency 3 for M1, 6 for everything else /// Used by SIMDJSON, has latency 3 for M1, 6 for everything else
/// SIMDJSON uses it for 64 byte masks, so it's a little different. /// SIMDJSON uses it for 64 byte masks, so it's a little different.
/// vmaxvq_u32(vandq_u32(input, vdupq_n_u32(0x80808080))) // u32 version has latency 3 /// vmaxvq_u32(vandq_u32(input, vdupq_n_u32(0x80808080))) // u32 version has latency 3
/// shrn version has universally <=3 cycles, on servers 2 cycles. /// shrn version has universally <=3 cycles, on servers 2 cycles.
while (p < simd_end && get_nibble_mask(vcgeq_u8(vld1q_u8(reinterpret_cast<const uint8_t *>(p)), vdupq_n_u8(0x80))) == 0) while (p < simd_end && getNibbleMask(vcgeq_u8(vld1q_u8(reinterpret_cast<const uint8_t *>(p)), vdupq_n_u8(0x80))) == 0)
p += SIMD_BYTES; p += SIMD_BYTES;
if (!(p < end)) if (!(p < end))

View File

@ -12,6 +12,8 @@
#include <cstdlib> #include <cstdlib>
#include <bit> #include <bit>
#include <base/simd.h>
#ifdef __SSE2__ #ifdef __SSE2__
#include <emmintrin.h> #include <emmintrin.h>
#endif #endif
@ -819,14 +821,11 @@ void readCSVStringInto(Vector & s, ReadBuffer & buf, const FormatSettings::CSV &
auto rc = vdupq_n_u8('\r'); auto rc = vdupq_n_u8('\r');
auto nc = vdupq_n_u8('\n'); auto nc = vdupq_n_u8('\n');
auto dc = vdupq_n_u8(delimiter); auto dc = vdupq_n_u8(delimiter);
/// Returns a 64 bit mask of nibbles (4 bits for each byte).
auto get_nibble_mask = [](uint8x16_t input) -> uint64_t
{ return vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(input), 4)), 0); };
for (; next_pos + 15 < buf.buffer().end(); next_pos += 16) for (; next_pos + 15 < buf.buffer().end(); next_pos += 16)
{ {
uint8x16_t bytes = vld1q_u8(reinterpret_cast<const uint8_t *>(next_pos)); uint8x16_t bytes = vld1q_u8(reinterpret_cast<const uint8_t *>(next_pos));
auto eq = vorrq_u8(vorrq_u8(vceqq_u8(bytes, rc), vceqq_u8(bytes, nc)), vceqq_u8(bytes, dc)); auto eq = vorrq_u8(vorrq_u8(vceqq_u8(bytes, rc), vceqq_u8(bytes, nc)), vceqq_u8(bytes, dc));
uint64_t bit_mask = get_nibble_mask(eq); uint64_t bit_mask = getNibbleMask(eq);
if (bit_mask) if (bit_mask)
{ {
next_pos += std::countr_zero(bit_mask) >> 2; next_pos += std::countr_zero(bit_mask) >> 2;

View File

@ -1,6 +1,7 @@
#include <Poco/UTF8Encoding.h> #include <Poco/UTF8Encoding.h>
#include <IO/WriteBufferValidUTF8.h> #include <IO/WriteBufferValidUTF8.h>
#include <base/types.h> #include <base/types.h>
#include <base/simd.h>
#ifdef __SSE2__ #ifdef __SSE2__
#include <emmintrin.h> #include <emmintrin.h>
@ -84,16 +85,13 @@ void WriteBufferValidUTF8::nextImpl()
/// Fast skip of ASCII for aarch64. /// Fast skip of ASCII for aarch64.
static constexpr size_t SIMD_BYTES = 16; static constexpr size_t SIMD_BYTES = 16;
const char * simd_end = p + (pos - p) / SIMD_BYTES * SIMD_BYTES; const char * simd_end = p + (pos - p) / SIMD_BYTES * SIMD_BYTES;
/// Returns a 64 bit mask of nibbles (4 bits for each byte).
auto get_nibble_mask = [](uint8x16_t input) -> uint64_t
{ return vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(input), 4)), 0); };
/// Other options include /// Other options include
/// vmaxvq_u8(input) < 0b10000000; /// vmaxvq_u8(input) < 0b10000000;
/// Used by SIMDJSON, has latency 3 for M1, 6 for everything else /// Used by SIMDJSON, has latency 3 for M1, 6 for everything else
/// SIMDJSON uses it for 64 byte masks, so it's a little different. /// SIMDJSON uses it for 64 byte masks, so it's a little different.
/// vmaxvq_u32(vandq_u32(input, vdupq_n_u32(0x80808080))) // u32 version has latency 3 /// vmaxvq_u32(vandq_u32(input, vdupq_n_u32(0x80808080))) // u32 version has latency 3
/// shrn version has universally <=3 cycles, on servers 2 cycles. /// shrn version has universally <=3 cycles, on servers 2 cycles.
while (p < simd_end && get_nibble_mask(vcgeq_u8(vld1q_u8(reinterpret_cast<const uint8_t *>(p)), vdupq_n_u8(0x80))) == 0) while (p < simd_end && getNibbleMask(vcgeq_u8(vld1q_u8(reinterpret_cast<const uint8_t *>(p)), vdupq_n_u8(0x80))) == 0)
p += SIMD_BYTES; p += SIMD_BYTES;
if (!(p < pos)) if (!(p < pos))

View File

@ -64,8 +64,8 @@ inline bool operator==(SmallStringRef lhs, SmallStringRef rhs)
if (lhs.size == 0) if (lhs.size == 0)
return true; return true;
#ifdef __SSE2__ #if defined(__SSE2__) || (defined(__aarch64__) && defined(__ARM_NEON))
return memequalSSE2Wide(lhs.data(), rhs.data(), lhs.size); return memequalWide(lhs.data(), rhs.data(), lhs.size);
#else #else
return 0 == memcmp(lhs.data(), rhs.data(), lhs.size); return 0 == memcmp(lhs.data(), rhs.data(), lhs.size);
#endif #endif