Improved vectorized execution of main loop for array norm/distance

This commit is contained in:
Alexander Gololobov 2022-07-02 22:34:06 +02:00
parent 92cbc2a3b5
commit c6691cc5f2
3 changed files with 107 additions and 11 deletions

View File

@ -38,6 +38,12 @@ struct L1Distance
state.sum += fabs(x - y);
}
template <typename ResultType>
static void combine(State<ResultType> & state, const State<ResultType> & other_state, const ConstParams &)
{
state.sum += other_state.sum;
}
template <typename ResultType>
static ResultType finalize(const State<ResultType> & state, const ConstParams &)
{
@ -63,6 +69,12 @@ struct L2Distance
state.sum += (x - y) * (x - y);
}
template <typename ResultType>
static void combine(State<ResultType> & state, const State<ResultType> & other_state, const ConstParams &)
{
state.sum += other_state.sum;
}
template <typename ResultType>
static ResultType finalize(const State<ResultType> & state, const ConstParams &)
{
@ -103,6 +115,12 @@ struct LpDistance
state.sum += std::pow(fabs(x - y), params.power);
}
template <typename ResultType>
static void combine(State<ResultType> & state, const State<ResultType> & other_state, const ConstParams &)
{
state.sum += other_state.sum;
}
template <typename ResultType>
static ResultType finalize(const State<ResultType> & state, const ConstParams & params)
{
@ -128,6 +146,12 @@ struct LinfDistance
state.dist = fmax(state.dist, fabs(x - y));
}
template <typename ResultType>
static void combine(State<ResultType> & state, const State<ResultType> & other_state, const ConstParams &)
{
state.dist = fmax(state.dist, other_state.dist);
}
template <typename ResultType>
static ResultType finalize(const State<ResultType> & state, const ConstParams &)
{
@ -157,6 +181,14 @@ struct CosineDistance
state.y_squared += y * y;
}
template <typename ResultType>
static void combine(State<ResultType> & state, const State<ResultType> & other_state, const ConstParams &)
{
state.dot_prod += other_state.dot_prod;
state.x_squared += other_state.x_squared;
state.y_squared += other_state.y_squared;
}
template <typename ResultType>
static ResultType finalize(const State<ResultType> & state, const ConstParams &)
{
@ -339,10 +371,23 @@ private:
size_t row = 0;
for (auto off : offsets_x)
{
typename Kernel::template State<Float64> state;
/// Process chunks in vectorized manner
static constexpr size_t VEC_SIZE = 4;
typename Kernel::template State<ResultType> states[VEC_SIZE];
for (; prev + VEC_SIZE < off; prev += VEC_SIZE)
{
for (size_t s = 0; s < VEC_SIZE; ++s)
Kernel::template accumulate<ResultType>(states[s], data_x[prev+s], data_y[prev+s], kernel_params);
}
typename Kernel::template State<ResultType> state;
for (const auto & other_state : states)
Kernel::template combine<ResultType>(state, other_state, kernel_params);
/// Process the tail
for (; prev < off; ++prev)
{
Kernel::template accumulate<Float64>(state, data_x[prev], data_y[prev], kernel_params);
Kernel::template accumulate<ResultType>(state, data_x[prev], data_y[prev], kernel_params);
}
result_data[row] = Kernel::finalize(state, kernel_params);
row++;
@ -392,10 +437,24 @@ private:
size_t row = 0;
for (auto off : offsets_y)
{
typename Kernel::template State<Float64> state;
for (size_t i = 0; prev < off; ++i, ++prev)
/// Process chunks in vectorized manner
static constexpr size_t VEC_SIZE = 4;
typename Kernel::template State<ResultType> states[VEC_SIZE];
size_t i = 0;
for (; prev + VEC_SIZE < off; i += VEC_SIZE, prev += VEC_SIZE)
{
Kernel::template accumulate<Float64>(state, data_x[i], data_y[prev], kernel_params);
for (size_t s = 0; s < VEC_SIZE; ++s)
Kernel::template accumulate<ResultType>(states[s], data_x[i+s], data_y[prev+s], kernel_params);
}
typename Kernel::template State<ResultType> state;
for (const auto & other_state : states)
Kernel::template combine<ResultType>(state, other_state, kernel_params);
/// Process the tail
for (; prev < off; ++i, ++prev)
{
Kernel::template accumulate<ResultType>(state, data_x[i], data_y[prev], kernel_params);
}
result_data[row] = Kernel::finalize(state, kernel_params);
row++;

View File

@ -31,6 +31,12 @@ struct L1Norm
return result + fabs(value);
}
template <typename ResultType>
inline static ResultType combine(ResultType result, ResultType other_result, const ConstParams &)
{
return result + other_result;
}
template <typename ResultType>
inline static ResultType finalize(ResultType result, const ConstParams &)
{
@ -50,6 +56,12 @@ struct L2Norm
return result + value * value;
}
template <typename ResultType>
inline static ResultType combine(ResultType result, ResultType other_result, const ConstParams &)
{
return result + other_result;
}
template <typename ResultType>
inline static ResultType finalize(ResultType result, const ConstParams &)
{
@ -85,6 +97,12 @@ struct LpNorm
return result + std::pow(fabs(value), params.power);
}
template <typename ResultType>
inline static ResultType combine(ResultType result, ResultType other_result, const ConstParams &)
{
return result + other_result;
}
template <typename ResultType>
inline static ResultType finalize(ResultType result, const ConstParams & params)
{
@ -104,6 +122,12 @@ struct LinfNorm
return fmax(result, fabs(value));
}
template <typename ResultType>
inline static ResultType combine(ResultType result, ResultType other_result, const ConstParams &)
{
return fmax(result, other_result);
}
template <typename ResultType>
inline static ResultType finalize(ResultType result, const ConstParams &)
{
@ -221,10 +245,23 @@ private:
size_t row = 0;
for (auto off : offsets)
{
Float64 result = 0;
/// Process chunks in vectorized manner
static constexpr size_t VEC_SIZE = 4;
ResultType results[VEC_SIZE] = {0};
for (; prev + VEC_SIZE < off; prev += VEC_SIZE)
{
for (size_t s = 0; s < VEC_SIZE; ++s)
results[s] = Kernel::template accumulate<ResultType>(results[s], data[prev+s], kernel_params);
}
ResultType result = 0;
for (const auto & other_state : results)
result = Kernel::template combine<ResultType>(result, other_state, kernel_params);
/// Process the tail
for (; prev < off; ++prev)
{
result = Kernel::template accumulate<Float64>(result, data[prev], kernel_params);
result = Kernel::template accumulate<ResultType>(result, data[prev], kernel_params);
}
result_data[row] = Kernel::finalize(result, kernel_params);
row++;

View File

@ -37,12 +37,12 @@ nan
2 1 2031 788 981.3289733414064 1182.129011571918 1397429 0.1939823640079572
2 2 0 0 0 0 0 0
3 3 0 0 0 0 0 0
3 4 68 2 6.238144819822315 11.661903789690601 136 0.0010041996325123037
4 3 68 2 6.238144819822315 11.661903789690601 136 0.0010041996325123037
3 4 68 2 6.238144819822316 11.661903789690601 136 0.0010041996325123037
4 3 68 2 6.238144819822316 11.661903789690601 136 0.0010041996325123037
4 4 0 0 0 0 0 0
5 5 0 0 0 0 0 0
5 6 268 2 9.70940985211152 23.15167380558045 536 0.00007815428961455151
6 5 268 2 9.70940985211152 23.15167380558045 536 0.00007815428961455151
5 6 268 2 9.70940985211151 23.15167380558045 536 0.00007815428961455151
6 5 268 2 9.70940985211151 23.15167380558045 536 0.00007815428961455151
6 6 0 0 0 0 0 0
1 1 0 0 0 0 0 0
1 2 2031 788 992.2102104083964 1182.129011571918 1397429 0.1939823640079572