mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-22 15:42:02 +00:00
Merge pull request #58866 from rschu1ze/distance-with-cpu-dispatch
AVX vectorization of distance functions
This commit is contained in:
commit
a6fcd63159
@ -1,6 +1,7 @@
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <Columns/IColumn.h>
|
||||
#include <Common/TargetSpecific.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/IDataType.h>
|
||||
@ -9,6 +10,10 @@
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
#include <base/range.h>
|
||||
|
||||
#if USE_MULTITARGET_CODE
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
@ -75,6 +80,49 @@ struct L2Distance
|
||||
state.sum += other_state.sum;
|
||||
}
|
||||
|
||||
#if USE_MULTITARGET_CODE
|
||||
template <typename ResultType>
|
||||
AVX512_FUNCTION_SPECIFIC_ATTRIBUTE static void accumulateCombine(
|
||||
const ResultType * __restrict data_x,
|
||||
const ResultType * __restrict data_y,
|
||||
size_t i_max,
|
||||
size_t & i_x,
|
||||
size_t & i_y,
|
||||
State<ResultType> & state)
|
||||
{
|
||||
__m512 sums;
|
||||
if constexpr (std::is_same_v<ResultType, Float32>)
|
||||
sums = _mm512_setzero_ps();
|
||||
else
|
||||
sums = _mm512_setzero_pd();
|
||||
|
||||
const size_t n = (std::is_same_v<ResultType, Float32>) ? 16 : 8;
|
||||
|
||||
for (; i_x + n < i_max; i_x += n, i_y += n)
|
||||
{
|
||||
if constexpr (std::is_same_v<ResultType, Float32>)
|
||||
{
|
||||
__m512 x = _mm512_loadu_ps(data_x + i_x);
|
||||
__m512 y = _mm512_loadu_ps(data_y + i_y);
|
||||
__m512 differences = _mm512_sub_ps(x, y);
|
||||
sums = _mm512_fmadd_ps(differences, differences, sums);
|
||||
}
|
||||
else
|
||||
{
|
||||
__m512 x = _mm512_loadu_pd(data_x + i_x);
|
||||
__m512 y = _mm512_loadu_pd(data_y + i_y);
|
||||
__m512 differences = _mm512_sub_pd(x, y);
|
||||
sums = _mm512_fmadd_pd(differences, differences, sums);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<ResultType, Float32>)
|
||||
state.sum = _mm512_reduce_add_ps(sums);
|
||||
else
|
||||
state.sum = _mm512_reduce_add_pd(sums);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename ResultType>
|
||||
static ResultType finalize(const State<ResultType> & state, const ConstParams &)
|
||||
{
|
||||
@ -189,6 +237,70 @@ struct CosineDistance
|
||||
state.y_squared += other_state.y_squared;
|
||||
}
|
||||
|
||||
#if USE_MULTITARGET_CODE
|
||||
template <typename ResultType>
|
||||
AVX512_FUNCTION_SPECIFIC_ATTRIBUTE static void accumulateCombine(
|
||||
const ResultType * __restrict data_x,
|
||||
const ResultType * __restrict data_y,
|
||||
size_t i_max,
|
||||
size_t & i_x,
|
||||
size_t & i_y,
|
||||
State<ResultType> & state)
|
||||
{
|
||||
__m512 dot_products;
|
||||
__m512 x_squareds;
|
||||
__m512 y_squareds;
|
||||
|
||||
if constexpr (std::is_same_v<ResultType, Float32>)
|
||||
{
|
||||
dot_products = _mm512_setzero_ps();
|
||||
x_squareds = _mm512_setzero_ps();
|
||||
y_squareds = _mm512_setzero_ps();
|
||||
}
|
||||
else
|
||||
{
|
||||
dot_products = _mm512_setzero_pd();
|
||||
x_squareds = _mm512_setzero_pd();
|
||||
y_squareds = _mm512_setzero_pd();
|
||||
}
|
||||
|
||||
const size_t n = (std::is_same_v<ResultType, Float32>) ? 16 : 8;
|
||||
|
||||
for (; i_x + n < i_max; i_x += n, i_y += n)
|
||||
{
|
||||
if constexpr (std::is_same_v<ResultType, Float32>)
|
||||
{
|
||||
__m512 x = _mm512_loadu_ps(data_x + i_x);
|
||||
__m512 y = _mm512_loadu_ps(data_y + i_y);
|
||||
dot_products = _mm512_fmadd_ps(x, y, dot_products);
|
||||
x_squareds = _mm512_fmadd_ps(x, x, x_squareds);
|
||||
y_squareds = _mm512_fmadd_ps(y, y, y_squareds);
|
||||
}
|
||||
else
|
||||
{
|
||||
__m512 x = _mm512_loadu_pd(data_x + i_x);
|
||||
__m512 y = _mm512_loadu_pd(data_y + i_y);
|
||||
dot_products = _mm512_fmadd_pd(x, y, dot_products);
|
||||
x_squareds = _mm512_fmadd_pd(x, x, x_squareds);
|
||||
y_squareds = _mm512_fmadd_pd(y, y, y_squareds);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<ResultType, Float32>)
|
||||
{
|
||||
state.dot_prod = _mm512_reduce_add_ps(dot_products);
|
||||
state.x_squared = _mm512_reduce_add_ps(x_squareds);
|
||||
state.y_squared = _mm512_reduce_add_ps(y_squareds);
|
||||
}
|
||||
else
|
||||
{
|
||||
state.dot_prod = _mm512_reduce_add_pd(dot_products);
|
||||
state.x_squared = _mm512_reduce_add_pd(x_squareds);
|
||||
state.y_squared = _mm512_reduce_add_pd(y_squareds);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename ResultType>
|
||||
static ResultType finalize(const State<ResultType> & state, const ConstParams &)
|
||||
{
|
||||
@ -352,7 +464,7 @@ private:
|
||||
/// Check that arrays in both columns are the sames size
|
||||
for (size_t row = 0; row < offsets_x.size(); ++row)
|
||||
{
|
||||
if (unlikely(offsets_x[row] != offsets_y[row]))
|
||||
if (offsets_x[row] != offsets_y[row]) [[unlikely]]
|
||||
{
|
||||
ColumnArray::Offset prev_offset = row > 0 ? offsets_x[row] : 0;
|
||||
throw Exception(
|
||||
@ -420,7 +532,7 @@ private:
|
||||
ColumnArray::Offset prev_offset = 0;
|
||||
for (size_t row : collections::range(0, offsets_y.size()))
|
||||
{
|
||||
if (unlikely(offsets_x[0] != offsets_y[row] - prev_offset))
|
||||
if (offsets_x[0] != offsets_y[row] - prev_offset) [[unlikely]]
|
||||
{
|
||||
throw Exception(
|
||||
ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH,
|
||||
@ -438,14 +550,35 @@ private:
|
||||
auto & result_data = result->getData();
|
||||
|
||||
/// Do the actual computation
|
||||
ColumnArray::Offset prev = 0;
|
||||
size_t prev = 0;
|
||||
size_t row = 0;
|
||||
|
||||
for (auto off : offsets_y)
|
||||
{
|
||||
size_t i = 0;
|
||||
typename Kernel::template State<ResultType> state;
|
||||
|
||||
/// SIMD optimization: process multiple elements in both input arrays at once.
|
||||
/// To avoid combinatorial explosion of SIMD kernels, focus on
|
||||
/// - the two most common input/output types (Float32 x Float32) --> Float32 and (Float64 x Float64) --> Float64 instead of 10 x
|
||||
/// 10 input types x 2 output types,
|
||||
/// - const/non-const inputs instead of non-const/non-const inputs
|
||||
/// - the two most common metrics L2 and cosine distance,
|
||||
/// - the most powerful SIMD instruction set (AVX-512F).
|
||||
#if USE_MULTITARGET_CODE
|
||||
if constexpr (std::is_same_v<ResultType, FirstArgType> && std::is_same_v<ResultType, SecondArgType>) /// ResultType is Float32 or Float64
|
||||
{
|
||||
if constexpr (std::is_same_v<Kernel, L2Distance>
|
||||
|| std::is_same_v<Kernel, CosineDistance>)
|
||||
{
|
||||
if (isArchSupported(TargetArch::AVX512F))
|
||||
Kernel::template accumulateCombine<ResultType>(data_x.data(), data_y.data(), i + offsets_x[0], i, prev, state);
|
||||
}
|
||||
}
|
||||
#else
|
||||
/// Process chunks in vectorized manner
|
||||
static constexpr size_t VEC_SIZE = 4;
|
||||
typename Kernel::template State<ResultType> states[VEC_SIZE];
|
||||
size_t i = 0;
|
||||
for (; prev + VEC_SIZE < off; i += VEC_SIZE, prev += VEC_SIZE)
|
||||
{
|
||||
for (size_t s = 0; s < VEC_SIZE; ++s)
|
||||
@ -453,10 +586,9 @@ private:
|
||||
states[s], static_cast<ResultType>(data_x[i + s]), static_cast<ResultType>(data_y[prev + s]), kernel_params);
|
||||
}
|
||||
|
||||
typename Kernel::template State<ResultType> state;
|
||||
for (const auto & other_state : states)
|
||||
Kernel::template combine<ResultType>(state, other_state, kernel_params);
|
||||
|
||||
#endif
|
||||
/// Process the tail
|
||||
for (; prev < off; ++i, ++prev)
|
||||
{
|
||||
@ -466,6 +598,7 @@ private:
|
||||
result_data[row] = Kernel::finalize(state, kernel_params);
|
||||
row++;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -80,3 +80,7 @@ nan
|
||||
5 6 268 2 10.234459893824097 23.15167380558045 536 0.00007815428961455151
|
||||
6 5 268 2 10.234459893824097 23.15167380558045 536 0.00007815428961455151
|
||||
6 6 0 0 0 0 0 0
|
||||
5.8309517
|
||||
0.0003244877
|
||||
5.830951894845301
|
||||
0.0003245172890904424
|
||||
|
@ -12,10 +12,10 @@ SELECT cosineDistance([1, 2, 3], [0, 0, 0]);
|
||||
-- Overflows
|
||||
WITH CAST([-547274980, 1790553898, 1981517754, 1908431500, 1352428565, -573412550, -552499284, 2096941042], 'Array(Int32)') AS a
|
||||
SELECT
|
||||
L1Distance(a,a),
|
||||
L2Distance(a,a),
|
||||
L2SquaredDistance(a,a),
|
||||
LinfDistance(a,a),
|
||||
L1Distance(a, a),
|
||||
L2Distance(a, a),
|
||||
L2SquaredDistance(a, a),
|
||||
LinfDistance(a, a),
|
||||
cosineDistance(a, a);
|
||||
|
||||
DROP TABLE IF EXISTS vec1;
|
||||
@ -88,15 +88,33 @@ SELECT
|
||||
FROM vec2f v1, vec2d v2
|
||||
WHERE length(v1.v) == length(v2.v);
|
||||
|
||||
SELECT L1Distance([0, 0], [1]); -- { serverError 190 }
|
||||
SELECT L2Distance([1, 2], (3,4)); -- { serverError 43 }
|
||||
SELECT L2SquaredDistance([1, 2], (3,4)); -- { serverError 43 }
|
||||
SELECT LpDistance([1, 2], [3,4]); -- { serverError 42 }
|
||||
SELECT LpDistance([1, 2], [3,4], -1.); -- { serverError 69 }
|
||||
SELECT LpDistance([1, 2], [3,4], 'aaa'); -- { serverError 43 }
|
||||
SELECT LpDistance([1, 2], [3,4], materialize(2.7)); -- { serverError 44 }
|
||||
SELECT L1Distance([0, 0], [1]); -- { serverError SIZES_OF_ARRAYS_DONT_MATCH }
|
||||
SELECT L2Distance([1, 2], (3,4)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT L2SquaredDistance([1, 2], (3,4)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT LpDistance([1, 2], [3,4]); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
|
||||
SELECT LpDistance([1, 2], [3,4], -1.); -- { serverError ARGUMENT_OUT_OF_BOUND }
|
||||
SELECT LpDistance([1, 2], [3,4], 'aaa'); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT LpDistance([1, 2], [3,4], materialize(2.7)); -- { serverError ILLEGAL_COLUMN }
|
||||
|
||||
DROP TABLE vec1;
|
||||
DROP TABLE vec2;
|
||||
DROP TABLE vec2f;
|
||||
DROP TABLE vec2d;
|
||||
|
||||
-- Queries which trigger manually vectorized implementation
|
||||
|
||||
SELECT L2Distance(
|
||||
[toFloat32(0.0), toFloat32(1.0), toFloat32(2.0), toFloat32(3.0), toFloat32(4.0), toFloat32(5.0), toFloat32(6.0), toFloat32(7.0), toFloat32(8.0), toFloat32(9.0), toFloat32(10.0), toFloat32(11.0), toFloat32(12.0), toFloat32(13.0), toFloat32(14.0), toFloat32(15.0), toFloat32(16.0), toFloat32(17.0), toFloat32(18.0), toFloat32(19.0), toFloat32(20.0), toFloat32(21.0), toFloat32(22.0), toFloat32(23.0), toFloat32(24.0), toFloat32(25.0), toFloat32(26.0), toFloat32(27.0), toFloat32(28.0), toFloat32(29.0), toFloat32(30.0), toFloat32(31.0), toFloat32(32.0), toFloat32(33.0)],
|
||||
materialize([toFloat32(1.0), toFloat32(2.0), toFloat32(3.0), toFloat32(4.0), toFloat32(5.0), toFloat32(6.0), toFloat32(7.0), toFloat32(8.0), toFloat32(9.0), toFloat32(10.0), toFloat32(11.0), toFloat32(12.0), toFloat32(13.0), toFloat32(14.0), toFloat32(15.0), toFloat32(16.0), toFloat32(17.0), toFloat32(18.0), toFloat32(19.0), toFloat32(20.0), toFloat32(21.0), toFloat32(22.0), toFloat32(23.0), toFloat32(24.0), toFloat32(25.0), toFloat32(26.0), toFloat32(27.0), toFloat32(28.0), toFloat32(29.0), toFloat32(30.0), toFloat32(31.0), toFloat32(32.0), toFloat32(33.0), toFloat32(34.0)]));
|
||||
|
||||
SELECT cosineDistance(
|
||||
[toFloat32(0.0), toFloat32(1.0), toFloat32(2.0), toFloat32(3.0), toFloat32(4.0), toFloat32(5.0), toFloat32(6.0), toFloat32(7.0), toFloat32(8.0), toFloat32(9.0), toFloat32(10.0), toFloat32(11.0), toFloat32(12.0), toFloat32(13.0), toFloat32(14.0), toFloat32(15.0), toFloat32(16.0), toFloat32(17.0), toFloat32(18.0), toFloat32(19.0), toFloat32(20.0), toFloat32(21.0), toFloat32(22.0), toFloat32(23.0), toFloat32(24.0), toFloat32(25.0), toFloat32(26.0), toFloat32(27.0), toFloat32(28.0), toFloat32(29.0), toFloat32(30.0), toFloat32(31.0), toFloat32(32.0), toFloat32(33.0)],
|
||||
materialize([toFloat32(1.0), toFloat32(2.0), toFloat32(3.0), toFloat32(4.0), toFloat32(5.0), toFloat32(6.0), toFloat32(7.0), toFloat32(8.0), toFloat32(9.0), toFloat32(10.0), toFloat32(11.0), toFloat32(12.0), toFloat32(13.0), toFloat32(14.0), toFloat32(15.0), toFloat32(16.0), toFloat32(17.0), toFloat32(18.0), toFloat32(19.0), toFloat32(20.0), toFloat32(21.0), toFloat32(22.0), toFloat32(23.0), toFloat32(24.0), toFloat32(25.0), toFloat32(26.0), toFloat32(27.0), toFloat32(28.0), toFloat32(29.0), toFloat32(30.0), toFloat32(31.0), toFloat32(32.0), toFloat32(33.0), toFloat32(34.0)]));
|
||||
|
||||
SELECT L2Distance(
|
||||
[toFloat64(0.0), toFloat64(1.0), toFloat64(2.0), toFloat64(3.0), toFloat64(4.0), toFloat64(5.0), toFloat64(6.0), toFloat64(7.0), toFloat64(8.0), toFloat64(9.0), toFloat64(10.0), toFloat64(11.0), toFloat64(12.0), toFloat64(13.0), toFloat64(14.0), toFloat64(15.0), toFloat64(16.0), toFloat64(17.0), toFloat64(18.0), toFloat64(19.0), toFloat64(20.0), toFloat64(21.0), toFloat64(22.0), toFloat64(23.0), toFloat64(24.0), toFloat64(25.0), toFloat64(26.0), toFloat64(27.0), toFloat64(28.0), toFloat64(29.0), toFloat64(30.0), toFloat64(31.0), toFloat64(32.0), toFloat64(33.0)],
|
||||
materialize([toFloat64(1.0), toFloat64(2.0), toFloat64(3.0), toFloat64(4.0), toFloat64(5.0), toFloat64(6.0), toFloat64(7.0), toFloat64(8.0), toFloat64(9.0), toFloat64(10.0), toFloat64(11.0), toFloat64(12.0), toFloat64(13.0), toFloat64(14.0), toFloat64(15.0), toFloat64(16.0), toFloat64(17.0), toFloat64(18.0), toFloat64(19.0), toFloat64(20.0), toFloat64(21.0), toFloat64(22.0), toFloat64(23.0), toFloat64(24.0), toFloat64(25.0), toFloat64(26.0), toFloat64(27.0), toFloat64(28.0), toFloat64(29.0), toFloat64(30.0), toFloat64(31.0), toFloat64(32.0), toFloat64(33.0), toFloat64(34.0)]));
|
||||
|
||||
SELECT cosineDistance(
|
||||
[toFloat64(0.0), toFloat64(1.0), toFloat64(2.0), toFloat64(3.0), toFloat64(4.0), toFloat64(5.0), toFloat64(6.0), toFloat64(7.0), toFloat64(8.0), toFloat64(9.0), toFloat64(10.0), toFloat64(11.0), toFloat64(12.0), toFloat64(13.0), toFloat64(14.0), toFloat64(15.0), toFloat64(16.0), toFloat64(17.0), toFloat64(18.0), toFloat64(19.0), toFloat64(20.0), toFloat64(21.0), toFloat64(22.0), toFloat64(23.0), toFloat64(24.0), toFloat64(25.0), toFloat64(26.0), toFloat64(27.0), toFloat64(28.0), toFloat64(29.0), toFloat64(30.0), toFloat64(31.0), toFloat64(32.0), toFloat64(33.0)],
|
||||
materialize([toFloat64(1.0), toFloat64(2.0), toFloat64(3.0), toFloat64(4.0), toFloat64(5.0), toFloat64(6.0), toFloat64(7.0), toFloat64(8.0), toFloat64(9.0), toFloat64(10.0), toFloat64(11.0), toFloat64(12.0), toFloat64(13.0), toFloat64(14.0), toFloat64(15.0), toFloat64(16.0), toFloat64(17.0), toFloat64(18.0), toFloat64(19.0), toFloat64(20.0), toFloat64(21.0), toFloat64(22.0), toFloat64(23.0), toFloat64(24.0), toFloat64(25.0), toFloat64(26.0), toFloat64(27.0), toFloat64(28.0), toFloat64(29.0), toFloat64(30.0), toFloat64(31.0), toFloat64(32.0), toFloat64(33.0), toFloat64(34.0)]));
|
||||
|
Loading…
Reference in New Issue
Block a user