Merge pull request #56079 from ZhiguoZh/20231027-combine-filter-avx512

Optimize DB::combineFilters with AVX512_VBMI2 intrinsic
This commit is contained in:
Dmitry Novik 2023-11-22 15:42:33 +01:00 committed by GitHub
commit 7539928814
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 204 additions and 4 deletions

View File

@ -18,6 +18,10 @@
#include <emmintrin.h>
#endif
#if USE_MULTITARGET_CODE
#include <immintrin.h>
#endif
#if defined(__aarch64__) && defined(__ARM_NEON)
# include <arm_neon.h>
# pragma clang diagnostic ignored "-Wreserved-identifier"
@ -1253,6 +1257,32 @@ static void checkCombinedFiltersSize(size_t bytes_in_first_filter, size_t second
"does not match second filter size ({})", bytes_in_first_filter, second_filter_size);
}
DECLARE_AVX512VBMI2_SPECIFIC_CODE(
inline void combineFiltersImpl(UInt8 * first_begin, const UInt8 * first_end, const UInt8 * second_begin)
{
constexpr size_t AVX512_VEC_SIZE_IN_BYTES = 64;
while (first_begin + AVX512_VEC_SIZE_IN_BYTES <= first_end)
{
UInt64 mask = bytes64MaskToBits64Mask(first_begin);
__m512i src = _mm512_loadu_si512(reinterpret_cast<void *>(first_begin));
__m512i dst = _mm512_mask_expandloadu_epi8(src, static_cast<__mmask64>(mask), reinterpret_cast<const void *>(second_begin));
_mm512_storeu_si512(reinterpret_cast<void *>(first_begin), dst);
first_begin += AVX512_VEC_SIZE_IN_BYTES;
second_begin += std::popcount(mask);
}
for (/* empty */; first_begin < first_end; ++first_begin)
{
if (*first_begin)
{
*first_begin = *second_begin++;
}
}
}
)
/// Second filter size must be equal to number of 1s in the first filter.
/// The result has size equal to first filter size and contains 1s only where both filters contain 1s.
static ColumnPtr combineFilters(ColumnPtr first, ColumnPtr second)
@ -1295,12 +1325,21 @@ static ColumnPtr combineFilters(ColumnPtr first, ColumnPtr second)
auto & first_data = typeid_cast<ColumnUInt8 *>(mut_first.get())->getData();
const auto * second_data = second_descr.data->data();
for (auto & val : first_data)
#if USE_MULTITARGET_CODE
if (isArchSupported(TargetArch::AVX512VBMI2))
{
if (val)
TargetSpecific::AVX512VBMI2::combineFiltersImpl(first_data.begin(), first_data.end(), second_data);
}
else
#endif
{
for (auto & val : first_data)
{
val = *second_data;
++second_data;
if (val)
{
val = *second_data;
++second_data;
}
}
}

View File

@ -0,0 +1,161 @@
#include <gtest/gtest.h>
#include <Columns/ColumnVector.h>
// I know that inclusion of .cpp is not good at all
#include <Storages/MergeTree/MergeTreeRangeReader.cpp> // NOLINT
using namespace DB;
/* The combineFilters function from MergeTreeRangeReader.cpp could be optimized with Intel's AVX512VBMI2 intrinsic,
* _mm512_mask_expandloadu_epi8. And these tests are added to ensure that the vectorized code outputs the exact results
* as the original scalar code when the required hardware feature is supported on the device.
*
* To avoid the contingency of the all-one/all-zero sequences, this test fills in the filters with alternating 1s and
* 0s so that only the 4i-th (i is a non-negative integer) elements in the combined filter equals 1s and others are 0s.
* For example, given the size of the first filter to be 11, the generated and the output filters are:
*
* first_filter: [1 0 1 0 1 0 1 0 1 0 1]
* second_filter: [1 0 1 0 1 0]
* output_filter: [1 0 0 0 1 0 0 0 1 0 0]
*/
bool testCombineFilters(size_t size)
{
auto generateFilterWithAlternatingOneAndZero = [](size_t len)->ColumnPtr
{
auto filter = ColumnUInt8::create(len, 0);
auto & filter_data = filter->getData();
for (size_t i = 0; i < len; i += 2)
filter_data[i] = 1;
return filter;
};
auto first_filter = generateFilterWithAlternatingOneAndZero(size);
/// The count of 1s in the first_filter is floor((size + 1) / 2), which should be the size of the second_filter.
auto second_filter = generateFilterWithAlternatingOneAndZero((size + 1) / 2);
auto result = combineFilters(first_filter, second_filter);
if (result->size() != size)
{
return false;
}
for (size_t i = 0; i < size; i++)
{
if (i % 4 == 0)
{
if (result->get64(i) != 1)
{
return false;
}
}
else
{
if (result->get64(i) != 0)
{
return false;
}
}
}
return true;
}
/* This test is to further test DB::combineFilters by combining two UInt8 columns. Given the implementation of
* DB::combineFilters, the non-zero values in the first column are contiguously replaced with the elements in the
* second column. And to validate the first column with arbitrary intervals, this test constructs its values in
* the following manner: the count of 0s between two consecutive 1s increases in step of 1. An example column
* with the size of 16 looks like:
* [1 1 0 1 0 0 1 0 0 0 1 0 0 0 0 1]
*
* The second column contains the consecutively incremented UInt8 integers between 0x00 and 0xFF, and when the overflow
* occurs, the value would reset to 0x00 and increment again.
*/
bool testCombineColumns(size_t size)
{
auto generateFirstColumn = [] (size_t len, size_t & non_zero_count)->ColumnPtr
{
auto column = ColumnUInt8::create(len, 0);
auto & column_data = column->getData();
non_zero_count = 0;
for (size_t i = 0; i < len; non_zero_count++, i += non_zero_count)
{
column_data[i] = 1;
}
return column;
};
auto generateSecondColumn = [] (size_t len)->ColumnPtr
{
auto column = ColumnUInt8::create(len, 0);
auto & column_data = column->getData();
for (size_t i = 0; i < len; i++)
{
column_data[i] = static_cast<UInt8>(i);
}
return column;
};
size_t non_zero_count = 0;
auto first_column = generateFirstColumn(size, non_zero_count);
const auto & first_column_data = typeid_cast<const ColumnUInt8 *>(first_column.get())->getData();
/// The count of non-zero values in the first column should be the size of the second column.
auto second_column = generateSecondColumn(non_zero_count);
auto result = combineFilters(first_column, second_column);
const auto & result_data = typeid_cast<const ColumnUInt8 *>(result.get())->getData();
if (result->size() != size) return false;
UInt8 expected = 0;
for (size_t i = 0; i < size; ++i)
{
if (first_column_data[i])
{
if (result_data[i] != expected)
{
return false;
}
/// Integer overflow is speculated during the integer increments. It is the expected behavior.
expected++;
}
else
{
if (result_data[i] != 0)
{
return false;
}
}
}
return true;
}
TEST(MergeTree, CombineFilters)
{
/// Tests with only 0/1 and fixed intervals.
EXPECT_TRUE(testCombineFilters(1));
EXPECT_TRUE(testCombineFilters(2));
EXPECT_TRUE(testCombineFilters(63));
EXPECT_TRUE(testCombineFilters(64));
EXPECT_TRUE(testCombineFilters(65));
EXPECT_TRUE(testCombineFilters(200));
EXPECT_TRUE(testCombineFilters(201));
EXPECT_TRUE(testCombineFilters(300));
/// Extended tests: combination of two UInt8 columns.
EXPECT_TRUE(testCombineColumns(1));
EXPECT_TRUE(testCombineColumns(2));
EXPECT_TRUE(testCombineColumns(63));
EXPECT_TRUE(testCombineColumns(64));
EXPECT_TRUE(testCombineColumns(200));
EXPECT_TRUE(testCombineColumns(201));
EXPECT_TRUE(testCombineColumns(2000));
EXPECT_TRUE(testCombineColumns(200000));
}