mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-09-19 16:20:50 +00:00
Merge 4cfa85c77c
into bb22736bc3
This commit is contained in:
commit
09af676d4f
@ -102,25 +102,28 @@ String toString(TargetArch arch);
|
||||
/// NOLINTNEXTLINE
|
||||
#define USE_MULTITARGET_CODE 1
|
||||
|
||||
#define AVX512VBMI2_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512vl,avx512vbmi,avx512vbmi2")))
|
||||
#define AVX512VBMI_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512vl,avx512vbmi")))
|
||||
#define AVX512BW_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw")))
|
||||
#define AMXBF16_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512dq,avx512vl,avx512bf16,avx512vbmi,avx512vbmi2,amx-bf16")))
|
||||
#define AVX512VBMI2_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512dq,avx512vl,avx512bf16,avx512vbmi,avx512vbmi2")))
|
||||
#define AVX512VBMI_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512dq,avx512bf16,avx512vl,avx512vbmi")))
|
||||
#define AVX512BW_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512dq,avx512bf16")))
|
||||
#define AVX512_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f")))
|
||||
#define AVX2_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,bmi2")))
|
||||
#define AVX_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx"))
|
||||
#define AVX2_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,fma,bmi2")))
|
||||
#define AVX_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx")))
|
||||
#define SSE42_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt")))
|
||||
#define DEFAULT_FUNCTION_SPECIFIC_ATTRIBUTE
|
||||
|
||||
# define BEGIN_AMXBF16_SPECIFIC_CODE \
|
||||
_Pragma("clang attribute push(__attribute__((target(\"sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512dq,avx512bf16,avx512vl,avx512vbmi,avx512vbmi2,amx-bf16\"))),apply_to=function)")
|
||||
# define BEGIN_AVX512VBMI2_SPECIFIC_CODE \
|
||||
_Pragma("clang attribute push(__attribute__((target(\"sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512vl,avx512vbmi,avx512vbmi2\"))),apply_to=function)")
|
||||
_Pragma("clang attribute push(__attribute__((target(\"sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512dq,avx512bf16,avx512vl,avx512vbmi,avx512vbmi2\"))),apply_to=function)")
|
||||
# define BEGIN_AVX512VBMI_SPECIFIC_CODE \
|
||||
_Pragma("clang attribute push(__attribute__((target(\"sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512vl,avx512vbmi\"))),apply_to=function)")
|
||||
_Pragma("clang attribute push(__attribute__((target(\"sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512dq,avx512bf16,avx512vl,avx512vbmi\"))),apply_to=function)")
|
||||
# define BEGIN_AVX512BW_SPECIFIC_CODE \
|
||||
_Pragma("clang attribute push(__attribute__((target(\"sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw\"))),apply_to=function)")
|
||||
_Pragma("clang attribute push(__attribute__((target(\"sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512dq,avx512bf16\"))),apply_to=function)")
|
||||
# define BEGIN_AVX512F_SPECIFIC_CODE \
|
||||
_Pragma("clang attribute push(__attribute__((target(\"sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f\"))),apply_to=function)")
|
||||
# define BEGIN_AVX2_SPECIFIC_CODE \
|
||||
_Pragma("clang attribute push(__attribute__((target(\"sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,bmi2\"))),apply_to=function)")
|
||||
_Pragma("clang attribute push(__attribute__((target(\"sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,fma,bmi2\"))),apply_to=function)")
|
||||
# define BEGIN_AVX_SPECIFIC_CODE \
|
||||
_Pragma("clang attribute push(__attribute__((target(\"sse,sse2,sse3,ssse3,sse4,popcnt,avx\"))),apply_to=function)")
|
||||
# define BEGIN_SSE42_SPECIFIC_CODE \
|
||||
@ -197,6 +200,14 @@ namespace TargetSpecific::AVX512VBMI2 { \
|
||||
} \
|
||||
END_TARGET_SPECIFIC_CODE
|
||||
|
||||
#define DECLARE_AMXBF16_SPECIFIC_CODE(...) \
|
||||
BEGIN_AMXBF16_SPECIFIC_CODE \
|
||||
namespace TargetSpecific::AMXBF16 { \
|
||||
DUMMY_FUNCTION_DEFINITION \
|
||||
using namespace DB::TargetSpecific::AMXBF16; \
|
||||
__VA_ARGS__ \
|
||||
} \
|
||||
END_TARGET_SPECIFIC_CODE
|
||||
|
||||
#else
|
||||
|
||||
@ -211,6 +222,7 @@ END_TARGET_SPECIFIC_CODE
|
||||
#define DECLARE_AVX512BW_SPECIFIC_CODE(...)
|
||||
#define DECLARE_AVX512VBMI_SPECIFIC_CODE(...)
|
||||
#define DECLARE_AVX512VBMI2_SPECIFIC_CODE(...)
|
||||
#define DECLARE_AMXBF16_SPECIFIC_CODE(...)
|
||||
|
||||
#endif
|
||||
|
||||
@ -229,7 +241,8 @@ DECLARE_AVX2_SPECIFIC_CODE (__VA_ARGS__) \
|
||||
DECLARE_AVX512F_SPECIFIC_CODE(__VA_ARGS__) \
|
||||
DECLARE_AVX512BW_SPECIFIC_CODE (__VA_ARGS__) \
|
||||
DECLARE_AVX512VBMI_SPECIFIC_CODE (__VA_ARGS__) \
|
||||
DECLARE_AVX512VBMI2_SPECIFIC_CODE (__VA_ARGS__)
|
||||
DECLARE_AVX512VBMI2_SPECIFIC_CODE (__VA_ARGS__) \
|
||||
DECLARE_AMXBF16_SPECIFIC_CODE (__VA_ARGS__)
|
||||
|
||||
DECLARE_DEFAULT_CODE(
|
||||
constexpr auto BuildArch = TargetArch::Default; /// NOLINT
|
||||
@ -263,6 +276,10 @@ DECLARE_AVX512VBMI2_SPECIFIC_CODE(
|
||||
constexpr auto BuildArch = TargetArch::AVX512VBMI2; /// NOLINT
|
||||
) // DECLARE_AVX512VBMI2_SPECIFIC_CODE
|
||||
|
||||
DECLARE_AMXBF16_SPECIFIC_CODE(
|
||||
constexpr auto BuildArch = TargetArch::AMXBF16; /// NOLINT
|
||||
) // DECLARE_AMXBF16_SPECIFIC_CODE
|
||||
|
||||
/** Runtime Dispatch helpers for class members.
|
||||
*
|
||||
* Example of usage:
|
||||
|
504
src/Functions/PartitionByHyperplanes.cpp
Normal file
504
src/Functions/PartitionByHyperplanes.cpp
Normal file
@ -0,0 +1,504 @@
|
||||
#include <cstddef>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <sys/syscall.h>
|
||||
#include <base/types.h>
|
||||
#include <Columns/ColumnNullable.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
#include <Functions/IFunction.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Common/FunctionDocumentation.h>
|
||||
#include <Common/logger_useful.h>
|
||||
#include <Columns/ColumnFixedString.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnConst.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <Columns/IColumn.h>
|
||||
#include <Core/TypeId.h>
|
||||
#include <DataTypes/DataTypeFixedString.h>
|
||||
#include <DataTypes/IDataType.h>
|
||||
#include <config.h>
|
||||
|
||||
|
||||
#if defined(__AMX_BF16__)
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_COLUMN;
|
||||
extern const int SIZES_OF_ARRAYS_DONT_MATCH;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
size_t constexpr tile_size = 16;
|
||||
|
||||
DECLARE_DEFAULT_CODE(
|
||||
void doMultiplyTile(
|
||||
const ColumnFloat32 & nested_vectors_data,
|
||||
const ColumnFloat32 & nested_normals_data,
|
||||
ColumnFloat32 & col_res,
|
||||
size_t vector_count,
|
||||
size_t normal_count,
|
||||
size_t dimension,
|
||||
size_t vectors_data_index,
|
||||
size_t normals_data_index,
|
||||
size_t coordinate_index)
|
||||
{
|
||||
for (size_t nt = vectors_data_index; nt < vectors_data_index + tile_size && nt < vector_count; ++nt)
|
||||
for (size_t mt = normals_data_index; mt < normals_data_index + tile_size && mt < normal_count; ++mt)
|
||||
for (size_t kt = coordinate_index; kt < coordinate_index + tile_size && kt < dimension; ++kt)
|
||||
{
|
||||
auto vector_element = nested_vectors_data.getFloat32(nt * dimension + kt);
|
||||
auto normal_element = nested_normals_data.getFloat32(mt * dimension + kt);
|
||||
|
||||
col_res.getElement(nt * normal_count + mt) += vector_element * normal_element;
|
||||
}
|
||||
}
|
||||
) // DECLARE_DEFAULT_CODE
|
||||
|
||||
|
||||
#if defined(OS_LINUX)
|
||||
|
||||
DECLARE_AMXBF16_SPECIFIC_CODE(
|
||||
static std::vector<uint16_t> bufA(2 * tile_size * tile_size), bufB(2 * tile_size * tile_size);
|
||||
|
||||
// Config for AMX unit
|
||||
struct TileConfig
|
||||
{
|
||||
uint8_t palette_id; // must be 1
|
||||
uint8_t start_row; // must be 0
|
||||
uint8_t reserved_0[14];
|
||||
uint16_t colsb[16]; // actual row length in bytes
|
||||
uint8_t rows[16]; // actual row count
|
||||
};
|
||||
|
||||
void doMultiplyConvertedTile(size_t K, const uint16_t* A0, const uint16_t* A1,
|
||||
const uint16_t* B0, const uint16_t* B1,
|
||||
float* C, size_t ldc)
|
||||
{
|
||||
_tile_stream_loadd(0, C, ldc * 4);
|
||||
_tile_stream_loadd(1, C + 16, ldc * 4);
|
||||
_tile_stream_loadd(2, C + 16 * ldc, ldc * 4);
|
||||
_tile_stream_loadd(3, C + 16 * ldc + 16, ldc * 4);
|
||||
|
||||
for (size_t k = 0; k < K; k += tile_size)
|
||||
{
|
||||
_tile_stream_loadd(4, A0 + k * 16, 64);
|
||||
_tile_stream_loadd(5, A1 + k * 16, 64);
|
||||
_tile_loadd(6, B0 + k * 16, 64);
|
||||
_tile_loadd(7, B1 + k * 16, 64);
|
||||
_tile_dpbf16ps(0, 4, 6);
|
||||
_tile_dpbf16ps(1, 4, 7);
|
||||
_tile_dpbf16ps(2, 5, 6);
|
||||
_tile_dpbf16ps(3, 5, 7);
|
||||
}
|
||||
|
||||
_tile_stored(0, C, ldc * 4);
|
||||
_tile_stored(1, C + 16, ldc * 4);
|
||||
_tile_stored(2, C + 16 * ldc, ldc * 4);
|
||||
_tile_stored(3, C + 16 * ldc + 16, ldc * 4);
|
||||
}
|
||||
|
||||
void convert(const float* src, uint16_t* dst)
|
||||
{
|
||||
__m512 s0 = _mm512_loadu_ps(src + 0 * 16);
|
||||
__m512 s1 = _mm512_loadu_ps(src + 1 * 16);
|
||||
_mm512_storeu_si512(dst, _mm512_cvtne2ps_pbh(s1, s0));
|
||||
}
|
||||
|
||||
void convert(size_t K, const float* A, size_t lda, uint16_t* buf)
|
||||
{
|
||||
for (size_t k = 0; k < K; k += tile_size, A += tile_size)
|
||||
for (size_t i = 0; i < 16; ++i, buf += tile_size)
|
||||
convert(A + i * lda, buf);
|
||||
}
|
||||
void doMultiplyTile(
|
||||
const ColumnFloat32 & nested_vectors_data,
|
||||
const ColumnFloat32 & nested_normals_data,
|
||||
ColumnFloat32 & col_res,
|
||||
size_t vector_count,
|
||||
size_t normal_count,
|
||||
size_t dimension,
|
||||
size_t vectors_data_index,
|
||||
size_t normals_data_index,
|
||||
size_t coordinate_index)
|
||||
{
|
||||
// Should be executed once
|
||||
static auto load_config = [&] {
|
||||
TileConfig tileinfo;
|
||||
tileinfo.palette_id = 1;
|
||||
tileinfo.start_row = 0;
|
||||
tileinfo.rows[0] = tile_size * 4;
|
||||
tileinfo.colsb[0] = tile_size;
|
||||
|
||||
tileinfo.rows[1] = tile_size * 4;
|
||||
tileinfo.colsb[1] = tile_size;
|
||||
|
||||
tileinfo.rows[2] = tile_size * 4;
|
||||
tileinfo.colsb[2] = tile_size;
|
||||
|
||||
_tile_loadconfig(&tileinfo);
|
||||
|
||||
const int ARCH_REQ_XCOMP_PERM = 0x1023;
|
||||
const int XFEATURE_XTILEDATA = 18;
|
||||
syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA);
|
||||
|
||||
return 0;
|
||||
}();
|
||||
(void)load_config;
|
||||
|
||||
auto* vector_data = nested_vectors_data.getData().data() + vectors_data_index * dimension + coordinate_index;
|
||||
auto* normals_data = nested_normals_data.getData().data() + normals_data_index * dimension + coordinate_index;
|
||||
|
||||
convert(dimension, normals_data + 0, vector_count, bufB.data());
|
||||
convert(dimension, normals_data + tile_size, vector_count, bufB.data() + tile_size * tile_size);
|
||||
convert(dimension, vector_data, normal_count, bufA.data());
|
||||
convert(dimension, vector_data + tile_size, normal_count, bufA.data() + tile_size * tile_size);
|
||||
|
||||
doMultiplyConvertedTile(
|
||||
dimension,
|
||||
bufA.data(), bufA.data() + tile_size * tile_size,
|
||||
bufB.data(), bufB.data() + tile_size * tile_size,
|
||||
col_res.getData().data() + vectors_data_index * normal_count + normals_data_index,
|
||||
normal_count);
|
||||
_tile_release();
|
||||
}
|
||||
) // DECLARE_AMX_SPECIFIC_CODE
|
||||
|
||||
#endif
|
||||
|
||||
DECLARE_AVX2_SPECIFIC_CODE(
|
||||
void doMultiplyTile(
|
||||
const ColumnFloat32 & nested_vectors_data,
|
||||
const ColumnFloat32 & nested_normals_data,
|
||||
ColumnFloat32 & col_res,
|
||||
size_t vector_count,
|
||||
size_t normal_count,
|
||||
size_t dimension,
|
||||
size_t vectors_data_index,
|
||||
size_t normals_data_index,
|
||||
size_t coordinate_index)
|
||||
{
|
||||
for (size_t nt = vectors_data_index; nt < vectors_data_index + tile_size && nt < vector_count; ++nt)
|
||||
for (size_t mt = normals_data_index; mt < normals_data_index + tile_size && mt < normal_count; mt += 8)
|
||||
{
|
||||
__m256 sum = _mm256_setzero_ps();
|
||||
for (size_t kt = coordinate_index; kt < coordinate_index + tile_size && kt < dimension; ++kt)
|
||||
{
|
||||
__m256 va = _mm256_broadcast_ss(&nested_vectors_data.getData()[nt * dimension + kt]);
|
||||
__m256 vb = _mm256_loadu_ps(&nested_normals_data.getData()[mt * dimension + kt]);
|
||||
sum = _mm256_fmadd_ps(va, vb, sum);
|
||||
}
|
||||
_mm256_storeu_ps(&col_res.getData()[nt * normal_count + mt], sum);
|
||||
}
|
||||
}
|
||||
|
||||
) // DECLARE_AVX2_SPECIFIC_CODE
|
||||
|
||||
DECLARE_AVX512BW_SPECIFIC_CODE(
|
||||
void doMultiplyTile(
|
||||
const ColumnFloat32 & nested_vectors_data,
|
||||
const ColumnFloat32 & nested_normals_data,
|
||||
ColumnFloat32 & col_res,
|
||||
size_t vector_count,
|
||||
size_t normal_count,
|
||||
size_t dimension,
|
||||
size_t vectors_data_index,
|
||||
size_t normals_data_index,
|
||||
size_t coordinate_index)
|
||||
{
|
||||
for (size_t nt = vectors_data_index; nt < vectors_data_index + tile_size && nt < vector_count; ++nt)
|
||||
for (size_t mt = normals_data_index; mt < normals_data_index + tile_size && mt < normal_count; mt += 16)
|
||||
{
|
||||
__m512 sum = _mm512_setzero_ps();
|
||||
for (size_t kt = coordinate_index; kt < coordinate_index + tile_size && kt < dimension; ++kt)
|
||||
{
|
||||
__m256 va256 = _mm256_broadcast_ss(&nested_vectors_data.getData()[nt * dimension + kt]);
|
||||
__m512 va = _mm512_broadcast_f32x8(va256);
|
||||
__m512 vb = _mm512_loadu_ps(&nested_normals_data.getData()[mt * dimension + kt]);
|
||||
sum = _mm512_fmadd_ps(va, vb, sum);
|
||||
}
|
||||
_mm512_storeu_ps(&col_res.getData()[nt * normal_count + mt], sum);
|
||||
}
|
||||
}
|
||||
|
||||
) // DECLARE_AVX512_SPECIFIC_CODE
|
||||
|
||||
class FunctionPartitionByHyperplanes : public IFunction
|
||||
{
|
||||
private:
|
||||
static void multiplyTile(
|
||||
const ColumnFloat32 & nested_vectors_data,
|
||||
const ColumnFloat32 & nested_normals_data,
|
||||
ColumnFloat32 & col_res,
|
||||
size_t vector_count,
|
||||
size_t normal_count,
|
||||
size_t dimension,
|
||||
size_t vectors_data_index,
|
||||
size_t normals_data_index,
|
||||
size_t coordinate_index)
|
||||
{
|
||||
#if USE_MULTITARGET_CODE
|
||||
|
||||
#if defined(OS_LINUX)
|
||||
if (isArchSupported(TargetArch::AMXBF16))
|
||||
{
|
||||
TargetSpecific::AMXBF16::doMultiplyTile(
|
||||
nested_vectors_data,
|
||||
nested_normals_data,
|
||||
col_res,
|
||||
vector_count,
|
||||
normal_count,
|
||||
dimension,
|
||||
vectors_data_index,
|
||||
normals_data_index,
|
||||
coordinate_index);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (isArchSupported(TargetArch::AVX512BW))
|
||||
{
|
||||
TargetSpecific::AVX512BW::doMultiplyTile(
|
||||
nested_vectors_data,
|
||||
nested_normals_data,
|
||||
col_res,
|
||||
vector_count,
|
||||
normal_count,
|
||||
dimension,
|
||||
vectors_data_index,
|
||||
normals_data_index,
|
||||
coordinate_index);
|
||||
return;
|
||||
}
|
||||
|
||||
if (isArchSupported(TargetArch::AVX2))
|
||||
{
|
||||
TargetSpecific::AVX2::doMultiplyTile(
|
||||
nested_vectors_data,
|
||||
nested_normals_data,
|
||||
col_res,
|
||||
vector_count,
|
||||
normal_count,
|
||||
dimension,
|
||||
vectors_data_index,
|
||||
normals_data_index,
|
||||
coordinate_index);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
TargetSpecific::Default::doMultiplyTile(
|
||||
nested_vectors_data,
|
||||
nested_normals_data,
|
||||
col_res,
|
||||
vector_count,
|
||||
normal_count,
|
||||
dimension,
|
||||
vectors_data_index,
|
||||
normals_data_index,
|
||||
coordinate_index);
|
||||
}
|
||||
|
||||
/*
|
||||
The implementation uses tiled matrix multiplication.
|
||||
A tile is a part of a matrix of size tile_size x tile_size.
|
||||
The algorithm iterates over the tiles of the first and second matrix.
|
||||
Classical matrix multiplication is used inside each tile, but since the size of the matrices inside the tile is small,
|
||||
both tiles can fit entirely in the cache, which significantly reduces the number of cache misses.
|
||||
*/
|
||||
static void executeInternal(
|
||||
const ColumnFloat32 & nested_vectors_data,
|
||||
const ColumnFloat32 & nested_normals_data,
|
||||
size_t vector_count,
|
||||
size_t normal_count,
|
||||
size_t dimension,
|
||||
ColumnFloat32 & col_res)
|
||||
{
|
||||
for (size_t i = 0; i < vector_count; i += tile_size)
|
||||
{
|
||||
for (size_t j = 0; j < normal_count; j += tile_size)
|
||||
{
|
||||
for (size_t k = 0; k < dimension; k += tile_size)
|
||||
multiplyTile(
|
||||
nested_vectors_data,
|
||||
nested_normals_data,
|
||||
col_res,
|
||||
vector_count,
|
||||
normal_count,
|
||||
dimension,
|
||||
i,
|
||||
j,
|
||||
k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void checkDimension(const ColumnArray::Offsets & offsets, size_t dimension)
|
||||
{
|
||||
size_t prev_offset = 0;
|
||||
for (const auto offset : offsets)
|
||||
{
|
||||
if (offset - prev_offset != dimension)
|
||||
throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "All vectors must have equal size");
|
||||
prev_offset = offset;
|
||||
}
|
||||
}
|
||||
|
||||
ColumnPtr createFixedStringResult(
|
||||
const IColumn * nested_offsets_data,
|
||||
const ColumnFloat32::MutablePtr & col_res,
|
||||
size_t vector_count,
|
||||
size_t normal_count) const
|
||||
{
|
||||
auto res = ColumnFixedString::create((normal_count / 8) + (normal_count % 8 != 0));
|
||||
auto& res_chars = res->getChars();
|
||||
res_chars.reserve(vector_count * normal_count / 8);
|
||||
char temp = 0;
|
||||
for (size_t i = 0; i < vector_count; ++i)
|
||||
{
|
||||
for (size_t j = 0; j < normal_count; ++j)
|
||||
{
|
||||
temp = temp << 1;
|
||||
auto res_val = col_res->getData()[i * normal_count + j];
|
||||
auto offset = nested_offsets_data->getFloat32(j);
|
||||
if (res_val > offset)
|
||||
temp |= 1;
|
||||
if ((j + 1) % 8 == 0)
|
||||
{
|
||||
res_chars.push_back(temp);
|
||||
temp = 0;
|
||||
}
|
||||
}
|
||||
if (normal_count % 8 != 0)
|
||||
{
|
||||
res_chars.push_back(temp);
|
||||
temp = 0;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr auto name = "partitionByHyperplanes";
|
||||
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionPartitionByHyperplanes>(); }
|
||||
|
||||
String getName() const override { return name; }
|
||||
|
||||
bool isVariadic() const override { return true; }
|
||||
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
|
||||
|
||||
size_t getNumberOfArguments() const override { return 0; }
|
||||
bool useDefaultImplementationForConstants() const override { return true; }
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
|
||||
{
|
||||
FunctionArgumentDescriptors mandatory_args{
|
||||
{"vectors", static_cast<FunctionArgumentDescriptor::TypeValidator>(&isArray), nullptr, "Array"},
|
||||
{"normals", static_cast<FunctionArgumentDescriptor::TypeValidator>(&isArray), isColumnConst, "const Array"},
|
||||
};
|
||||
|
||||
FunctionArgumentDescriptors optional_args{
|
||||
{"offsets", static_cast<FunctionArgumentDescriptor::TypeValidator>(&isArray), isColumnConst, "const Array"},
|
||||
};
|
||||
|
||||
validateFunctionArgumentTypes(*this, arguments, mandatory_args, optional_args);
|
||||
|
||||
return std::make_shared<DataTypeFixedString>(1);
|
||||
}
|
||||
|
||||
const ColumnFloat32 & getVectorData(const ColumnArray * vectors, const ColumnArray::Offsets ** offsets) const
|
||||
{
|
||||
if (vectors->getData().getDataType() == TypeIndex::Array)
|
||||
{
|
||||
const auto & vectors_data = typeid_cast<const ColumnArray &>(vectors->getData());
|
||||
*offsets = &vectors_data.getOffsets();
|
||||
if (vectors_data.getData().getDataType() != TypeIndex::Float32)
|
||||
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Arguments of function {} must be Array of Float32", getName());
|
||||
return typeid_cast<const ColumnFloat32 &>(vectors_data.getData());
|
||||
}
|
||||
return typeid_cast<const ColumnFloat32 &>(vectors->getData());
|
||||
}
|
||||
|
||||
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & /*result_type*/, size_t /*input_rows_count*/) const override
|
||||
{
|
||||
const ColumnArray * vectors = typeid_cast<const ColumnArray *>(arguments[0].column.get());
|
||||
if (!vectors)
|
||||
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument of function {} must be Array", getName());
|
||||
|
||||
const ColumnArray * normals = typeid_cast<const ColumnArray *>(arguments[1].column.get());
|
||||
if (!normals)
|
||||
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Second argument of function {} must be Array", getName());
|
||||
|
||||
if (vectors->getData().getDataType() == TypeIndex::Nothing || normals->getData().getDataType() == TypeIndex::Nothing)
|
||||
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Arguments of function {} must be Array", getName());
|
||||
|
||||
const auto * vector_offsets = &vectors->getOffsets();
|
||||
const auto & nested_vectors_data = getVectorData(vectors, &vector_offsets);
|
||||
|
||||
const auto * normal_offsets = &normals->getOffsets();
|
||||
const auto & nested_normals_data = getVectorData(normals, &normal_offsets);
|
||||
|
||||
const size_t dimension = vectors->getOffsets().front();
|
||||
const size_t vector_count = vectors->getOffsets().size();
|
||||
const size_t normal_count = normal_offsets->size();
|
||||
|
||||
if (vector_count == 0 || normal_count == 0)
|
||||
return ColumnFixedString::create(0);
|
||||
|
||||
auto offsets = ColumnConst::create(ColumnFloat32::create(1, 0), normal_count);
|
||||
const IColumn * nested_offsets_data = offsets.get();
|
||||
if (arguments.size() >= 3)
|
||||
{
|
||||
const ColumnConst * offsets_const = typeid_cast<const ColumnConst *>(arguments[2].column.get());
|
||||
if (!offsets_const)
|
||||
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Third argument of function {} must be Array", getName());
|
||||
nested_offsets_data = offsets_const;
|
||||
}
|
||||
|
||||
checkDimension(*vector_offsets, dimension);
|
||||
checkDimension(*normal_offsets, dimension);
|
||||
|
||||
auto col_res = ColumnFloat32::create(vector_count * normal_count);
|
||||
executeInternal(
|
||||
nested_vectors_data,
|
||||
nested_normals_data,
|
||||
vector_count,
|
||||
normal_count,
|
||||
dimension,
|
||||
*col_res);
|
||||
|
||||
return createFixedStringResult(nested_offsets_data, col_res, vector_count, normal_count);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
REGISTER_FUNCTION(partitionByHyperplanes)
|
||||
{
|
||||
factory.registerFunction<FunctionPartitionByHyperplanes>(FunctionDocumentation{
|
||||
.description=R"(
|
||||
This function partitions a given point in vector space based on its position relative to one or more defined hyperplanes.
|
||||
A hyperplane in an N-dimensional space is defined by a normal vector and an optional offset from the origin.
|
||||
The function takes a vector representing a point and arrays of normal vectors (and optional offsets) representing the hyperplanes.
|
||||
It returns a String, where each bit indicates the side of the corresponding hyperplane on which the point lies, with '0' indicating one side and '1' the other.
|
||||
This method allows for efficient classification of points into multiple regions defined by these hyperplanes, suitable for complex geometric, spatial analysis, and machine learning applications.
|
||||
Ensure all vectors and normals are in the same dimensional space for accurate results.
|
||||
)",
|
||||
.examples={{"partitionByHyperplanes", "SELECT partitionByHyperplanes([2.0, 3.0], [[1.0, -1.0], [-1.0, 2.0]], [[0.0, 0.0], [1.0, 0.0]])", ""}},
|
||||
.categories{"Geometric", "Classification"}
|
||||
});
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,6 @@
|
||||
01
|
||||
00
|
||||
03
|
||||
01
|
||||
00
|
||||
1F
|
10
tests/queries/0_stateless/03152_partition_by_hyperplanes.sql
Normal file
10
tests/queries/0_stateless/03152_partition_by_hyperplanes.sql
Normal file
@ -0,0 +1,10 @@
|
||||
select hex((SELECT partitionByHyperplanes([toFloat32(2.0), toFloat32(3.0)], [toFloat32(1.0), toFloat32(1.0)])));
|
||||
select hex((SELECT partitionByHyperplanes([toFloat32(2.0), toFloat32(3.0)], [toFloat32(1.0), toFloat32(-1.0)])));
|
||||
SELECT hex((select partitionByHyperplanes([toFloat32(2.0), toFloat32(3.0)], [[toFloat32(1.0), toFloat32(1.0)], [toFloat32(1.0), toFloat32(2.0)]])));
|
||||
select hex((SELECT partitionByHyperplanes((select range(1024) :: Array(Float32)), (select range(1024) :: Array(Float32)))));
|
||||
select hex((SELECT partitionByHyperplanes((select range(1024) :: Array(Float32)), (select range(-1024, 0) :: Array(Float32)))));
|
||||
select hex((SELECT partitionByHyperplanes((select range(1024) :: Array(Float32)), (select arrayWithConstant(5, (select range(1024) :: Array(Float32)))))));
|
||||
|
||||
SELECT hex((select partitionByHyperplanes([toFloat32(2.0), toFloat32(3.0)], [[toFloat32(1.0), toFloat32(1.0)], [toFloat32(1.0)]]))); -- { serverError 190 }
|
||||
SELECT hex((select partitionByHyperplanes([toFloat32(2.0), toFloat32(3.0)], []))); -- { serverError 44 }
|
||||
SELECT hex((select partitionByHyperplanes([toFloat32(2.0), toFloat32(3.0)], [[[toFloat32(2.0), toFloat32(3.0)]]]))); -- { serverError 44 }
|
Loading…
Reference in New Issue
Block a user