This commit is contained in:
Raúl Marín 2024-11-18 20:03:35 +01:00
parent 47bed13b42
commit 445a5e9c9e

View File

@ -994,120 +994,126 @@ ColumnPtr ColumnVector<T>::createWithOffsets(const IColumn::Offsets & offsets, c
return res; return res;
} }
DECLARE_DEFAULT_CODE(template <typename Container, typename Type> void vectorIndexImpl( DECLARE_DEFAULT_CODE(
const Container & data, const PaddedPODArray<Type> & indexes, size_t limit, Container & res_data) { template <typename Container, typename Type> void vectorIndexImpl(
for (size_t i = 0; i < limit; ++i) const Container & data, const PaddedPODArray<Type> & indexes, size_t limit, Container & res_data)
res_data[i] = data[indexes[i]];
});
DECLARE_AVX512VBMI_SPECIFIC_CODE(template <typename Container, typename Type> void vectorIndexImpl(
const Container & data, const PaddedPODArray<Type> & indexes, size_t limit, Container & res_data) {
static constexpr UInt64 MASK64 = 0xffffffffffffffff;
const size_t limit64 = limit & ~63;
size_t pos = 0;
size_t data_size = data.size();
auto data_pos = reinterpret_cast<const UInt8 *>(data.data());
auto indexes_pos = reinterpret_cast<const UInt8 *>(indexes.data());
auto res_pos = reinterpret_cast<UInt8 *>(res_data.data());
if (limit == 0)
return; /// nothing to do, just return
if (data_size <= 64)
{ {
/// one single mask load for table size <= 64 for (size_t i = 0; i < limit; ++i)
__mmask64 last_mask = MASK64 >> (64 - data_size); res_data[i] = data[indexes[i]];
__m512i table1 = _mm512_maskz_loadu_epi8(last_mask, data_pos);
/// 64 bytes table lookup using one single permutexvar_epi8
while (pos < limit64)
{
__m512i vidx = _mm512_loadu_epi8(indexes_pos + pos);
__m512i out = _mm512_permutexvar_epi8(vidx, table1);
_mm512_storeu_epi8(res_pos + pos, out);
pos += 64;
}
/// tail handling
if (limit > limit64)
{
__mmask64 tail_mask = MASK64 >> (limit64 + 64 - limit);
__m512i vidx = _mm512_maskz_loadu_epi8(tail_mask, indexes_pos + pos);
__m512i out = _mm512_permutexvar_epi8(vidx, table1);
_mm512_mask_storeu_epi8(res_pos + pos, tail_mask, out);
}
} }
else if (data_size <= 128) );
{
/// table size (64, 128] requires 2 zmm load
__mmask64 last_mask = MASK64 >> (128 - data_size);
__m512i table1 = _mm512_loadu_epi8(data_pos);
__m512i table2 = _mm512_maskz_loadu_epi8(last_mask, data_pos + 64);
/// 128 bytes table lookup using one single permute2xvar_epi8 DECLARE_AVX512VBMI_SPECIFIC_CODE(
while (pos < limit64) template <typename Container, typename Type>
{ void vectorIndexImpl(const Container & data, const PaddedPODArray<Type> & indexes, size_t limit, Container & res_data)
__m512i vidx = _mm512_loadu_epi8(indexes_pos + pos);
__m512i out = _mm512_permutex2var_epi8(table1, vidx, table2);
_mm512_storeu_epi8(res_pos + pos, out);
pos += 64;
}
if (limit > limit64)
{
__mmask64 tail_mask = MASK64 >> (limit64 + 64 - limit);
__m512i vidx = _mm512_maskz_loadu_epi8(tail_mask, indexes_pos + pos);
__m512i out = _mm512_permutex2var_epi8(table1, vidx, table2);
_mm512_mask_storeu_epi8(res_pos + pos, tail_mask, out);
}
}
else
{ {
if (data_size > 256) static constexpr UInt64 MASK64 = 0xffffffffffffffff;
{ const size_t limit64 = limit & ~63;
/// byte index will not exceed 256 boundary. size_t pos = 0;
data_size = 256; size_t data_size = data.size();
}
__m512i table1 = _mm512_loadu_epi8(data_pos); auto data_pos = reinterpret_cast<const UInt8 *>(data.data());
__m512i table2 = _mm512_loadu_epi8(data_pos + 64); auto indexes_pos = reinterpret_cast<const UInt8 *>(indexes.data());
__m512i table3, table4; auto res_pos = reinterpret_cast<UInt8 *>(res_data.data());
if (data_size <= 192)
if (limit == 0)
return; /// nothing to do, just return
if (data_size <= 64)
{ {
/// only 3 tables need to load if size <= 192 /// one single mask load for table size <= 64
__mmask64 last_mask = MASK64 >> (192 - data_size); __mmask64 last_mask = MASK64 >> (64 - data_size);
table3 = _mm512_maskz_loadu_epi8(last_mask, data_pos + 128); __m512i table1 = _mm512_maskz_loadu_epi8(last_mask, data_pos);
table4 = _mm512_setzero_si512();
/// 64 bytes table lookup using one single permutexvar_epi8
while (pos < limit64)
{
__m512i vidx = _mm512_loadu_epi8(indexes_pos + pos);
__m512i out = _mm512_permutexvar_epi8(vidx, table1);
_mm512_storeu_epi8(res_pos + pos, out);
pos += 64;
}
/// tail handling
if (limit > limit64)
{
__mmask64 tail_mask = MASK64 >> (limit64 + 64 - limit);
__m512i vidx = _mm512_maskz_loadu_epi8(tail_mask, indexes_pos + pos);
__m512i out = _mm512_permutexvar_epi8(vidx, table1);
_mm512_mask_storeu_epi8(res_pos + pos, tail_mask, out);
}
}
else if (data_size <= 128)
{
/// table size (64, 128] requires 2 zmm load
__mmask64 last_mask = MASK64 >> (128 - data_size);
__m512i table1 = _mm512_loadu_epi8(data_pos);
__m512i table2 = _mm512_maskz_loadu_epi8(last_mask, data_pos + 64);
/// 128 bytes table lookup using one single permute2xvar_epi8
while (pos < limit64)
{
__m512i vidx = _mm512_loadu_epi8(indexes_pos + pos);
__m512i out = _mm512_permutex2var_epi8(table1, vidx, table2);
_mm512_storeu_epi8(res_pos + pos, out);
pos += 64;
}
if (limit > limit64)
{
__mmask64 tail_mask = MASK64 >> (limit64 + 64 - limit);
__m512i vidx = _mm512_maskz_loadu_epi8(tail_mask, indexes_pos + pos);
__m512i out = _mm512_permutex2var_epi8(table1, vidx, table2);
_mm512_mask_storeu_epi8(res_pos + pos, tail_mask, out);
}
} }
else else
{ {
__mmask64 last_mask = MASK64 >> (256 - data_size); if (data_size > 256)
table3 = _mm512_loadu_epi8(data_pos + 128); {
table4 = _mm512_maskz_loadu_epi8(last_mask, data_pos + 192); /// byte index will not exceed 256 boundary.
} data_size = 256;
}
/// 256 bytes table lookup can use: 2 permute2xvar_epi8 plus 1 blender with MSB __m512i table1 = _mm512_loadu_epi8(data_pos);
while (pos < limit64) __m512i table2 = _mm512_loadu_epi8(data_pos + 64);
{ __m512i table3, table4;
__m512i vidx = _mm512_loadu_epi8(indexes_pos + pos); if (data_size <= 192)
__m512i tmp1 = _mm512_permutex2var_epi8(table1, vidx, table2); {
__m512i tmp2 = _mm512_permutex2var_epi8(table3, vidx, table4); /// only 3 tables need to load if size <= 192
__mmask64 msb = _mm512_movepi8_mask(vidx); __mmask64 last_mask = MASK64 >> (192 - data_size);
__m512i out = _mm512_mask_blend_epi8(msb, tmp1, tmp2); table3 = _mm512_maskz_loadu_epi8(last_mask, data_pos + 128);
_mm512_storeu_epi8(res_pos + pos, out); table4 = _mm512_setzero_si512();
pos += 64; }
} else
if (limit > limit64) {
{ __mmask64 last_mask = MASK64 >> (256 - data_size);
__mmask64 tail_mask = MASK64 >> (limit64 + 64 - limit); table3 = _mm512_loadu_epi8(data_pos + 128);
__m512i vidx = _mm512_maskz_loadu_epi8(tail_mask, indexes_pos + pos); table4 = _mm512_maskz_loadu_epi8(last_mask, data_pos + 192);
__m512i tmp1 = _mm512_permutex2var_epi8(table1, vidx, table2); }
__m512i tmp2 = _mm512_permutex2var_epi8(table3, vidx, table4);
__mmask64 msb = _mm512_movepi8_mask(vidx); /// 256 bytes table lookup can use: 2 permute2xvar_epi8 plus 1 blender with MSB
__m512i out = _mm512_mask_blend_epi8(msb, tmp1, tmp2); while (pos < limit64)
_mm512_mask_storeu_epi8(res_pos + pos, tail_mask, out); {
__m512i vidx = _mm512_loadu_epi8(indexes_pos + pos);
__m512i tmp1 = _mm512_permutex2var_epi8(table1, vidx, table2);
__m512i tmp2 = _mm512_permutex2var_epi8(table3, vidx, table4);
__mmask64 msb = _mm512_movepi8_mask(vidx);
__m512i out = _mm512_mask_blend_epi8(msb, tmp1, tmp2);
_mm512_storeu_epi8(res_pos + pos, out);
pos += 64;
}
if (limit > limit64)
{
__mmask64 tail_mask = MASK64 >> (limit64 + 64 - limit);
__m512i vidx = _mm512_maskz_loadu_epi8(tail_mask, indexes_pos + pos);
__m512i tmp1 = _mm512_permutex2var_epi8(table1, vidx, table2);
__m512i tmp2 = _mm512_permutex2var_epi8(table3, vidx, table4);
__mmask64 msb = _mm512_movepi8_mask(vidx);
__m512i out = _mm512_mask_blend_epi8(msb, tmp1, tmp2);
_mm512_mask_storeu_epi8(res_pos + pos, tail_mask, out);
}
} }
} }
}); );
template <typename T> template <typename T>
template <typename Type> template <typename Type>