From 99b89999aa93eed375a95fa35204d61917b856dc Mon Sep 17 00:00:00 2001 From: Artem Mustafin Date: Sun, 18 Feb 2024 23:07:39 +0000 Subject: [PATCH 01/14] initial file of hilbertEncode + separate common functions code --- .../FunctionSpaceFillingCurveEncode.h | 68 +++++++++++++ src/Functions/hilbertEncode.cpp | 96 +++++++++++++++++++ src/Functions/mortonEncode.cpp | 55 +---------- 3 files changed, 166 insertions(+), 53 deletions(-) create mode 100644 src/Functions/FunctionSpaceFillingCurveEncode.h create mode 100644 src/Functions/hilbertEncode.cpp diff --git a/src/Functions/FunctionSpaceFillingCurveEncode.h b/src/Functions/FunctionSpaceFillingCurveEncode.h new file mode 100644 index 00000000000..257b49176bc --- /dev/null +++ b/src/Functions/FunctionSpaceFillingCurveEncode.h @@ -0,0 +1,68 @@ +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int ARGUMENT_OUT_OF_BOUND; + extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION; +} + +class FunctionSpaceFillingCurveEncode: public IFunction { +public: + bool isVariadic() const override + { + return true; + } + + size_t getNumberOfArguments() const override + { + return 0; + } + + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; } + + bool useDefaultImplementationForConstants() const override { return true; } + + DataTypePtr getReturnTypeImpl(const DB::DataTypes & arguments) const override + { + size_t vector_start_index = 0; + if (arguments.empty()) + throw Exception(ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION, + "At least one UInt argument is required for function {}", + getName()); + if (WhichDataType(arguments[0]).isTuple()) + { + vector_start_index = 1; + const auto * type_tuple = typeid_cast(arguments[0].get()); + auto tuple_size = type_tuple->getElements().size(); + if (tuple_size != (arguments.size() - 1)) + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, + "Illegal argument {} for function {}, tuple size should be equal to number of UInt arguments", + arguments[0]->getName(), getName()); + for (size_t i = 0; i < tuple_size; i++) + { + if (!WhichDataType(type_tuple->getElement(i)).isNativeUInt()) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal type {} of argument in tuple for function {}, should be a native UInt", + type_tuple->getElement(i)->getName(), getName()); + } + } + + for (size_t i = vector_start_index; i < arguments.size(); i++) + { + const auto & arg = arguments[i]; + if (!WhichDataType(arg).isNativeUInt()) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal type {} of argument of function {}, should be a native UInt", + arg->getName(), getName()); + } + return std::make_shared(); + } +}; + +} diff --git a/src/Functions/hilbertEncode.cpp b/src/Functions/hilbertEncode.cpp new file mode 100644 index 00000000000..a9b137df86d --- /dev/null +++ b/src/Functions/hilbertEncode.cpp @@ -0,0 +1,96 @@ +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int ARGUMENT_OUT_OF_BOUND; + extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION; +} + + +class FunctionHilbertEncode : public FunctionSpaceFillingCurveEncode +{ +public: + static constexpr auto name = "hilbertEncode"; + static FunctionPtr create(ContextPtr) + { + return std::make_shared(); + } + + String getName() const override { return name; } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override + { + size_t num_dimensions = arguments.size(); + if (num_dimensions < 1 || num_dimensions > 2) { + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal number of UInt arguments of function {}: should be at least 1 and not more than 2", + getName()); + } + + size_t vector_start_index = 0; + const auto * const_col = typeid_cast(arguments[0].column.get()); + const ColumnTuple * mask; + if (const_col) + mask = typeid_cast(const_col->getDataColumnPtr().get()); + else + mask = typeid_cast(arguments[0].column.get()); + if (mask) + { + num_dimensions = mask->tupleSize(); + vector_start_index = 1; + for (size_t i = 0; i < num_dimensions; i++) + { + auto ratio = mask->getColumn(i).getUInt(0); + if (ratio > 8 || ratio < 1) + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, + "Illegal argument {} of function {}, should be a number in range 1-8", + arguments[0].column->getName(), getName()); + } + } + + auto non_const_arguments = arguments; + for (auto & argument : non_const_arguments) + argument.column = argument.column->convertToFullColumnIfConst(); + + auto col_res = ColumnUInt64::create(); + ColumnUInt64::Container & vec_res = col_res->getData(); + vec_res.resize(input_rows_count); + + const ColumnPtr & col0 = non_const_arguments[0 + vector_start_index].column; + if (num_dimensions == 1) { + for (size_t i = 0; i < input_rows_count; i++) + { + vec_res[i] = col0->getUInt(i); + } + return col_res; + } + + return nullptr; + } +}; + + +REGISTER_FUNCTION(HilbertEncode) +{ + factory.registerFunction(FunctionDocumentation{ + .description=R"( + +)", + .examples{ + }, + .categories {} + }); +} + +} diff --git a/src/Functions/mortonEncode.cpp b/src/Functions/mortonEncode.cpp index fee14c7784b..5365e3d1cca 100644 --- a/src/Functions/mortonEncode.cpp +++ b/src/Functions/mortonEncode.cpp @@ -1,10 +1,9 @@ #include #include -#include -#include #include #include #include +#include #include #include @@ -144,7 +143,7 @@ constexpr auto MortonND_5D_Enc = mortonnd::MortonNDLutEncoder<5, 12, 8>(); constexpr auto MortonND_6D_Enc = mortonnd::MortonNDLutEncoder<6, 10, 8>(); constexpr auto MortonND_7D_Enc = mortonnd::MortonNDLutEncoder<7, 9, 8>(); constexpr auto MortonND_8D_Enc = mortonnd::MortonNDLutEncoder<8, 8, 8>(); -class FunctionMortonEncode : public IFunction +class FunctionMortonEncode : public FunctionSpaceFillingCurveEncode { public: static constexpr auto name = "mortonEncode"; @@ -158,56 +157,6 @@ public: return name; } - bool isVariadic() const override - { - return true; - } - - size_t getNumberOfArguments() const override - { - return 0; - } - - bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; } - - bool useDefaultImplementationForConstants() const override { return true; } - - DataTypePtr getReturnTypeImpl(const DB::DataTypes & arguments) const override - { - size_t vectorStartIndex = 0; - if (arguments.empty()) - throw Exception(ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION, - "At least one UInt argument is required for function {}", - getName()); - if (WhichDataType(arguments[0]).isTuple()) - { - vectorStartIndex = 1; - const auto * type_tuple = typeid_cast(arguments[0].get()); - auto tuple_size = type_tuple->getElements().size(); - if (tuple_size != (arguments.size() - 1)) - throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, - "Illegal argument {} for function {}, tuple size should be equal to number of UInt arguments", - arguments[0]->getName(), getName()); - for (size_t i = 0; i < tuple_size; i++) - { - if (!WhichDataType(type_tuple->getElement(i)).isNativeUInt()) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "Illegal type {} of argument in tuple for function {}, should be a native UInt", - type_tuple->getElement(i)->getName(), getName()); - } - } - - for (size_t i = vectorStartIndex; i < arguments.size(); i++) - { - const auto & arg = arguments[i]; - if (!WhichDataType(arg).isNativeUInt()) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "Illegal type {} of argument of function {}, should be a native UInt", - arg->getName(), getName()); - } - return std::make_shared(); - } - static UInt64 expand(UInt64 ratio, UInt64 value) { switch (ratio) // NOLINT(bugprone-switch-missing-default-case) From c13dd9dc8c5c940d03e6c9dd8d98ea363f332c86 Mon Sep 17 00:00:00 2001 From: Artem Mustafin Date: Mon, 19 Feb 2024 20:21:52 +0000 Subject: [PATCH 02/14] hilbert encode function added --- src/Functions/hilbertEncode.cpp | 86 ++++++++++++++++++++++++++++----- 1 file changed, 75 insertions(+), 11 deletions(-) diff --git a/src/Functions/hilbertEncode.cpp b/src/Functions/hilbertEncode.cpp index a9b137df86d..2bcb46c79a3 100644 --- a/src/Functions/hilbertEncode.cpp +++ b/src/Functions/hilbertEncode.cpp @@ -1,21 +1,80 @@ -#include -#include -#include #include #include +#include +#include +#include #include +#include #include +#include namespace DB { -namespace ErrorCodes -{ - extern const int ILLEGAL_TYPE_OF_ARGUMENT; - extern const int ARGUMENT_OUT_OF_BOUND; - extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION; -} +class FunctionHilbertEncode2DWIthLookupTableImpl { +public: + static UInt64 encode(UInt64 x, UInt64 y) { + const auto leading_zeros_count = getLeadingZeroBits(x | y); + const auto used_bits = std::numeric_limits::digits - leading_zeros_count; + + UInt8 remaind_shift = BIT_STEP - used_bits % BIT_STEP; + if (remaind_shift == BIT_STEP) + remaind_shift = 0; + x <<= remaind_shift; + y <<= remaind_shift; + + UInt8 current_state = 0; + UInt64 hilbert_code = 0; + Int8 current_shift = used_bits + remaind_shift - BIT_STEP; + + while (current_shift > 0) + { + const UInt8 x_bits = (x >> current_shift) & STEP_MASK; + const UInt8 y_bits = (y >> current_shift) & STEP_MASK; + const auto hilbert_bits = getCodeAndUpdateState(x_bits, y_bits, current_state); + const UInt8 hilbert_code_shift = static_cast(current_shift) << 1; + hilbert_code |= (hilbert_bits << hilbert_code_shift); + + current_shift -= BIT_STEP; + } + + hilbert_code >>= (remaind_shift << 1); + return hilbert_code; + } + +private: + + // LOOKUP_TABLE[SSXXXYYY] = SSHHHHHH] + // where SS - 2 bits for state, XXX - 3 bits of x, YYY - 3 bits of y + static UInt8 getCodeAndUpdateState(UInt8 x_bits, UInt8 y_bits, UInt8& state) { + const UInt8 table_index = state | (x_bits << BIT_STEP) | y_bits; + const auto table_code = LOOKUP_TABLE[table_index]; + state = table_code & STATE_MASK; + return table_code & HILBERT_MASK; + } + + constexpr static UInt8 BIT_STEP = 3; + constexpr static UInt8 STEP_MASK = (1 << BIT_STEP) - 1; + constexpr static UInt8 HILBERT_MASK = (1 << (BIT_STEP << 1)) - 1; + constexpr static UInt8 STATE_MASK = static_cast(-1) - HILBERT_MASK; + + constexpr static UInt8 LOOKUP_TABLE[256] = { + 64, 1, 206, 79, 16, 211, 84, 21, 131, 2, 205, 140, 81, 82, 151, 22, 4, 199, 8, 203, 158, + 157, 88, 25, 69, 70, 73, 74, 31, 220, 155, 26, 186, 185, 182, 181, 32, 227, 100, 37, 59, + 248, 55, 244, 97, 98, 167, 38, 124, 61, 242, 115, 174, 173, 104, 41, 191, 62, 241, 176, 47, + 236, 171, 42, 0, 195, 68, 5, 250, 123, 60, 255, 65, 66, 135, 6, 249, 184, 125, 126, 142, + 141, 72, 9, 246, 119, 178, 177, 15, 204, 139, 10, 245, 180, 51, 240, 80, 17, 222, 95, 96, + 33, 238, 111, 147, 18, 221, 156, 163, 34, 237, 172, 20, 215, 24, 219, 36, 231, 40, 235, 85, + 86, 89, 90, 101, 102, 105, 106, 170, 169, 166, 165, 154, 153, 150, 149, 43, 232, 39, 228, + 27, 216, 23, 212, 108, 45, 226, 99, 92, 29, 210, 83, 175, 46, 225, 160, 159, 30, 209, 144, + 48, 243, 116, 53, 202, 75, 12, 207, 113, 114, 183, 54, 201, 136, 77, 78, 190, 189, 120, 57, + 198, 71, 130, 129, 63, 252, 187, 58, 197, 132, 3, 192, 234, 107, 44, 239, 112, 49, 254, + 127, 233, 168, 109, 110, 179, 50, 253, 188, 230, 103, 162, 161, 52, 247, 56, 251, 229, 164, + 35, 224, 117, 118, 121, 122, 218, 91, 28, 223, 138, 137, 134, 133, 217, 152, 93, 94, 11, + 200, 7, 196, 214, 87, 146, 145, 76, 13, 194, 67, 213, 148, 19, 208, 143, 14, 193, 128, + }; +}; class FunctionHilbertEncode : public FunctionSpaceFillingCurveEncode @@ -69,14 +128,19 @@ public: const ColumnPtr & col0 = non_const_arguments[0 + vector_start_index].column; if (num_dimensions == 1) { - for (size_t i = 0; i < input_rows_count; i++) + for (size_t i = 0; i < input_rows_count; ++i) { vec_res[i] = col0->getUInt(i); } return col_res; } - return nullptr; + const ColumnPtr & col1 = non_const_arguments[1 + vector_start_index].column; + for (size_t i = 0; i < input_rows_count; ++i) + { + vec_res[i] = FunctionHilbertEncode2DWIthLookupTableImpl::encode(col0->getUInt(i), col1->getUInt(i)); + } + return col_res; } }; From 46e81dae49a86f8cdd024b083cbb76d7b0fabe8e Mon Sep 17 00:00:00 2001 From: Artem Mustafin Date: Mon, 19 Feb 2024 21:56:49 +0000 Subject: [PATCH 03/14] code style + renaming --- src/Functions/hilbertEncode.cpp | 48 ++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/src/Functions/hilbertEncode.cpp b/src/Functions/hilbertEncode.cpp index 2bcb46c79a3..f486b49eba8 100644 --- a/src/Functions/hilbertEncode.cpp +++ b/src/Functions/hilbertEncode.cpp @@ -12,48 +12,63 @@ namespace DB { -class FunctionHilbertEncode2DWIthLookupTableImpl { +class FunctionHilbertEncode2DWIthLookupTableImpl +{ public: - static UInt64 encode(UInt64 x, UInt64 y) { + static UInt64 encode(UInt64 x, UInt64 y) + { const auto leading_zeros_count = getLeadingZeroBits(x | y); const auto used_bits = std::numeric_limits::digits - leading_zeros_count; - UInt8 remaind_shift = BIT_STEP - used_bits % BIT_STEP; - if (remaind_shift == BIT_STEP) - remaind_shift = 0; - x <<= remaind_shift; - y <<= remaind_shift; + const auto shift_for_align = getShiftForStepsAlign(used_bits); + x <<= shift_for_align; + y <<= shift_for_align; UInt8 current_state = 0; UInt64 hilbert_code = 0; - Int8 current_shift = used_bits + remaind_shift - BIT_STEP; + Int8 current_shift = used_bits + shift_for_align - BIT_STEP; while (current_shift > 0) { const UInt8 x_bits = (x >> current_shift) & STEP_MASK; const UInt8 y_bits = (y >> current_shift) & STEP_MASK; const auto hilbert_bits = getCodeAndUpdateState(x_bits, y_bits, current_state); - const UInt8 hilbert_code_shift = static_cast(current_shift) << 1; - hilbert_code |= (hilbert_bits << hilbert_code_shift); + hilbert_code |= (hilbert_bits << getHilbertShift(current_shift)); current_shift -= BIT_STEP; } - hilbert_code >>= (remaind_shift << 1); + hilbert_code >>= getHilbertShift(shift_for_align); return hilbert_code; } private: - // LOOKUP_TABLE[SSXXXYYY] = SSHHHHHH] + // LOOKUP_TABLE[SSXXXYYY] = SSHHHHHH // where SS - 2 bits for state, XXX - 3 bits of x, YYY - 3 bits of y - static UInt8 getCodeAndUpdateState(UInt8 x_bits, UInt8 y_bits, UInt8& state) { + static UInt8 getCodeAndUpdateState(UInt8 x_bits, UInt8 y_bits, UInt8& state) + { const UInt8 table_index = state | (x_bits << BIT_STEP) | y_bits; const auto table_code = LOOKUP_TABLE[table_index]; state = table_code & STATE_MASK; return table_code & HILBERT_MASK; } + // hilbert code is double size of input values + static UInt8 getHilbertShift(UInt8 shift) + { + return shift << 1; + } + + static UInt8 getShiftForStepsAlign(UInt8 used_bits) + { + UInt8 shift_for_align = BIT_STEP - used_bits % BIT_STEP; + if (shift_for_align == BIT_STEP) + shift_for_align = 0; + + return shift_for_align; + } + constexpr static UInt8 BIT_STEP = 3; constexpr static UInt8 STEP_MASK = (1 << BIT_STEP) - 1; constexpr static UInt8 HILBERT_MASK = (1 << (BIT_STEP << 1)) - 1; @@ -113,8 +128,8 @@ public: auto ratio = mask->getColumn(i).getUInt(0); if (ratio > 8 || ratio < 1) throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, - "Illegal argument {} of function {}, should be a number in range 1-8", - arguments[0].column->getName(), getName()); + "Illegal argument {} of function {}, should be a number in range 1-8", + arguments[0].column->getName(), getName()); } } @@ -127,7 +142,8 @@ public: vec_res.resize(input_rows_count); const ColumnPtr & col0 = non_const_arguments[0 + vector_start_index].column; - if (num_dimensions == 1) { + if (num_dimensions == 1) + { for (size_t i = 0; i < input_rows_count; ++i) { vec_res[i] = col0->getUInt(i); From 9a65f9a80d0942201b2a928b9b4e67451cb57840 Mon Sep 17 00:00:00 2001 From: Yarik Briukhovetskyi <114298166+yariks5s@users.noreply.github.com> Date: Tue, 20 Feb 2024 14:12:26 +0100 Subject: [PATCH 04/14] restart CI --- src/Functions/hilbertEncode.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Functions/hilbertEncode.cpp b/src/Functions/hilbertEncode.cpp index f486b49eba8..861cf42fbdb 100644 --- a/src/Functions/hilbertEncode.cpp +++ b/src/Functions/hilbertEncode.cpp @@ -167,8 +167,6 @@ REGISTER_FUNCTION(HilbertEncode) .description=R"( )", - .examples{ - }, .categories {} }); } From 1f9b4a74d958ac3d3577f44c87dd3116347ce97a Mon Sep 17 00:00:00 2001 From: Artem Mustafin Date: Thu, 22 Feb 2024 11:14:53 +0000 Subject: [PATCH 05/14] fixed algorithm + template for steps sizes --- src/Functions/hilbertEncode.cpp | 144 +++++++++++++++++++------------- 1 file changed, 84 insertions(+), 60 deletions(-) diff --git a/src/Functions/hilbertEncode.cpp b/src/Functions/hilbertEncode.cpp index f486b49eba8..52090e259c5 100644 --- a/src/Functions/hilbertEncode.cpp +++ b/src/Functions/hilbertEncode.cpp @@ -12,68 +12,26 @@ namespace DB { -class FunctionHilbertEncode2DWIthLookupTableImpl -{ +template +class HilbertLookupTable { public: - static UInt64 encode(UInt64 x, UInt64 y) - { - const auto leading_zeros_count = getLeadingZeroBits(x | y); - const auto used_bits = std::numeric_limits::digits - leading_zeros_count; + constexpr static UInt8 LOOKUP_TABLE[0] = {}; +}; - const auto shift_for_align = getShiftForStepsAlign(used_bits); - x <<= shift_for_align; - y <<= shift_for_align; - - UInt8 current_state = 0; - UInt64 hilbert_code = 0; - Int8 current_shift = used_bits + shift_for_align - BIT_STEP; - - while (current_shift > 0) - { - const UInt8 x_bits = (x >> current_shift) & STEP_MASK; - const UInt8 y_bits = (y >> current_shift) & STEP_MASK; - const auto hilbert_bits = getCodeAndUpdateState(x_bits, y_bits, current_state); - hilbert_code |= (hilbert_bits << getHilbertShift(current_shift)); - - current_shift -= BIT_STEP; - } - - hilbert_code >>= getHilbertShift(shift_for_align); - return hilbert_code; - } - -private: - - // LOOKUP_TABLE[SSXXXYYY] = SSHHHHHH - // where SS - 2 bits for state, XXX - 3 bits of x, YYY - 3 bits of y - static UInt8 getCodeAndUpdateState(UInt8 x_bits, UInt8 y_bits, UInt8& state) - { - const UInt8 table_index = state | (x_bits << BIT_STEP) | y_bits; - const auto table_code = LOOKUP_TABLE[table_index]; - state = table_code & STATE_MASK; - return table_code & HILBERT_MASK; - } - - // hilbert code is double size of input values - static UInt8 getHilbertShift(UInt8 shift) - { - return shift << 1; - } - - static UInt8 getShiftForStepsAlign(UInt8 used_bits) - { - UInt8 shift_for_align = BIT_STEP - used_bits % BIT_STEP; - if (shift_for_align == BIT_STEP) - shift_for_align = 0; - - return shift_for_align; - } - - constexpr static UInt8 BIT_STEP = 3; - constexpr static UInt8 STEP_MASK = (1 << BIT_STEP) - 1; - constexpr static UInt8 HILBERT_MASK = (1 << (BIT_STEP << 1)) - 1; - constexpr static UInt8 STATE_MASK = static_cast(-1) - HILBERT_MASK; +template <> +class HilbertLookupTable<2> { +public: + constexpr static UInt8 LOOKUP_TABLE[16] = { + 4, 1, 11, 2, + 0, 15, 5, 6, + 10, 9, 3, 12, + 14, 7, 13, 8 + }; +}; +template <> +class HilbertLookupTable<3> { +public: constexpr static UInt8 LOOKUP_TABLE[256] = { 64, 1, 206, 79, 16, 211, 84, 21, 131, 2, 205, 140, 81, 82, 151, 22, 4, 199, 8, 203, 158, 157, 88, 25, 69, 70, 73, 74, 31, 220, 155, 26, 186, 185, 182, 181, 32, 227, 100, 37, 59, @@ -92,6 +50,72 @@ private: }; + +template +class FunctionHilbertEncode2DWIthLookupTableImpl +{ +public: + static UInt64 encode(UInt64 x, UInt64 y) + { + const auto leading_zeros_count = getLeadingZeroBits(x | y); + const auto used_bits = std::numeric_limits::digits - leading_zeros_count; + + auto [iterations, current_shift] = getIterationsAndInitialShift(used_bits); + UInt8 current_state = 0; + UInt64 hilbert_code = 0; + + for (; iterations > 0; --iterations, current_shift -= bit_step) + { + if (iterations % 2 == 0) { + std::swap(x, y); + } + const UInt8 x_bits = (x >> current_shift) & STEP_MASK; + const UInt8 y_bits = (y >> current_shift) & STEP_MASK; + const auto hilbert_bits = getCodeAndUpdateState(x_bits, y_bits, current_state); + hilbert_code |= (hilbert_bits << getHilbertShift(current_shift)); + } + + return hilbert_code; + } + +private: + + // LOOKUP_TABLE[SSXXXYYY] = SSHHHHHH + // where SS - 2 bits for state, XXX - 3 bits of x, YYY - 3 bits of y + // State is rotation of curve on every step, left/up/right/down - therefore 2 bits + static UInt8 getCodeAndUpdateState(UInt8 x_bits, UInt8 y_bits, UInt8& state) + { + const UInt8 table_index = state | (x_bits << bit_step) | y_bits; + const auto table_code = HilbertLookupTable::LOOKUP_TABLE[table_index]; + state = table_code & STATE_MASK; + return table_code & HILBERT_MASK; + } + + // hilbert code is double size of input values + static constexpr UInt8 getHilbertShift(UInt8 shift) + { + return shift << 1; + } + + static std::pair getIterationsAndInitialShift(UInt8 used_bits) + { + UInt8 iterations = used_bits / bit_step; + UInt8 initial_shift = iterations * bit_step; + if (initial_shift < used_bits) + { + ++iterations; + } else { + initial_shift -= bit_step; + } + return {iterations, initial_shift}; + } + + constexpr static UInt8 STEP_MASK = (1 << bit_step) - 1; + constexpr static UInt8 HILBERT_MASK = (1 << getHilbertShift(bit_step)) - 1; + constexpr static UInt8 STATE_MASK = 0b11 << getHilbertShift(bit_step); +}; + + class FunctionHilbertEncode : public FunctionSpaceFillingCurveEncode { public: @@ -154,7 +178,7 @@ public: const ColumnPtr & col1 = non_const_arguments[1 + vector_start_index].column; for (size_t i = 0; i < input_rows_count; ++i) { - vec_res[i] = FunctionHilbertEncode2DWIthLookupTableImpl::encode(col0->getUInt(i), col1->getUInt(i)); + vec_res[i] = FunctionHilbertEncode2DWIthLookupTableImpl<3>::encode(col0->getUInt(i), col1->getUInt(i)); } return col_res; } From bf6bfcfb6d2d47a540e3b1d0f8d9cd187efe6819 Mon Sep 17 00:00:00 2001 From: Artem Mustafin Date: Thu, 22 Feb 2024 12:01:11 +0000 Subject: [PATCH 06/14] add unit test --- src/Functions/hilbertEncode.cpp | 185 +--------------- src/Functions/hilbertEncode.h | 202 ++++++++++++++++++ .../tests/gtest_hilbert_lookup_table.cpp | 23 ++ 3 files changed, 227 insertions(+), 183 deletions(-) create mode 100644 src/Functions/hilbertEncode.h create mode 100644 src/Functions/tests/gtest_hilbert_lookup_table.cpp diff --git a/src/Functions/hilbertEncode.cpp b/src/Functions/hilbertEncode.cpp index 52090e259c5..d24f734695e 100644 --- a/src/Functions/hilbertEncode.cpp +++ b/src/Functions/hilbertEncode.cpp @@ -1,189 +1,8 @@ -#include -#include -#include -#include +#include #include -#include -#include -#include -#include -namespace DB -{ - -template -class HilbertLookupTable { -public: - constexpr static UInt8 LOOKUP_TABLE[0] = {}; -}; - -template <> -class HilbertLookupTable<2> { -public: - constexpr static UInt8 LOOKUP_TABLE[16] = { - 4, 1, 11, 2, - 0, 15, 5, 6, - 10, 9, 3, 12, - 14, 7, 13, 8 - }; -}; - -template <> -class HilbertLookupTable<3> { -public: - constexpr static UInt8 LOOKUP_TABLE[256] = { - 64, 1, 206, 79, 16, 211, 84, 21, 131, 2, 205, 140, 81, 82, 151, 22, 4, 199, 8, 203, 158, - 157, 88, 25, 69, 70, 73, 74, 31, 220, 155, 26, 186, 185, 182, 181, 32, 227, 100, 37, 59, - 248, 55, 244, 97, 98, 167, 38, 124, 61, 242, 115, 174, 173, 104, 41, 191, 62, 241, 176, 47, - 236, 171, 42, 0, 195, 68, 5, 250, 123, 60, 255, 65, 66, 135, 6, 249, 184, 125, 126, 142, - 141, 72, 9, 246, 119, 178, 177, 15, 204, 139, 10, 245, 180, 51, 240, 80, 17, 222, 95, 96, - 33, 238, 111, 147, 18, 221, 156, 163, 34, 237, 172, 20, 215, 24, 219, 36, 231, 40, 235, 85, - 86, 89, 90, 101, 102, 105, 106, 170, 169, 166, 165, 154, 153, 150, 149, 43, 232, 39, 228, - 27, 216, 23, 212, 108, 45, 226, 99, 92, 29, 210, 83, 175, 46, 225, 160, 159, 30, 209, 144, - 48, 243, 116, 53, 202, 75, 12, 207, 113, 114, 183, 54, 201, 136, 77, 78, 190, 189, 120, 57, - 198, 71, 130, 129, 63, 252, 187, 58, 197, 132, 3, 192, 234, 107, 44, 239, 112, 49, 254, - 127, 233, 168, 109, 110, 179, 50, 253, 188, 230, 103, 162, 161, 52, 247, 56, 251, 229, 164, - 35, 224, 117, 118, 121, 122, 218, 91, 28, 223, 138, 137, 134, 133, 217, 152, 93, 94, 11, - 200, 7, 196, 214, 87, 146, 145, 76, 13, 194, 67, 213, 148, 19, 208, 143, 14, 193, 128, - }; -}; - - - -template -class FunctionHilbertEncode2DWIthLookupTableImpl -{ -public: - static UInt64 encode(UInt64 x, UInt64 y) - { - const auto leading_zeros_count = getLeadingZeroBits(x | y); - const auto used_bits = std::numeric_limits::digits - leading_zeros_count; - - auto [iterations, current_shift] = getIterationsAndInitialShift(used_bits); - UInt8 current_state = 0; - UInt64 hilbert_code = 0; - - for (; iterations > 0; --iterations, current_shift -= bit_step) - { - if (iterations % 2 == 0) { - std::swap(x, y); - } - const UInt8 x_bits = (x >> current_shift) & STEP_MASK; - const UInt8 y_bits = (y >> current_shift) & STEP_MASK; - const auto hilbert_bits = getCodeAndUpdateState(x_bits, y_bits, current_state); - hilbert_code |= (hilbert_bits << getHilbertShift(current_shift)); - } - - return hilbert_code; - } - -private: - - // LOOKUP_TABLE[SSXXXYYY] = SSHHHHHH - // where SS - 2 bits for state, XXX - 3 bits of x, YYY - 3 bits of y - // State is rotation of curve on every step, left/up/right/down - therefore 2 bits - static UInt8 getCodeAndUpdateState(UInt8 x_bits, UInt8 y_bits, UInt8& state) - { - const UInt8 table_index = state | (x_bits << bit_step) | y_bits; - const auto table_code = HilbertLookupTable::LOOKUP_TABLE[table_index]; - state = table_code & STATE_MASK; - return table_code & HILBERT_MASK; - } - - // hilbert code is double size of input values - static constexpr UInt8 getHilbertShift(UInt8 shift) - { - return shift << 1; - } - - static std::pair getIterationsAndInitialShift(UInt8 used_bits) - { - UInt8 iterations = used_bits / bit_step; - UInt8 initial_shift = iterations * bit_step; - if (initial_shift < used_bits) - { - ++iterations; - } else { - initial_shift -= bit_step; - } - return {iterations, initial_shift}; - } - - constexpr static UInt8 STEP_MASK = (1 << bit_step) - 1; - constexpr static UInt8 HILBERT_MASK = (1 << getHilbertShift(bit_step)) - 1; - constexpr static UInt8 STATE_MASK = 0b11 << getHilbertShift(bit_step); -}; - - -class FunctionHilbertEncode : public FunctionSpaceFillingCurveEncode -{ -public: - static constexpr auto name = "hilbertEncode"; - static FunctionPtr create(ContextPtr) - { - return std::make_shared(); - } - - String getName() const override { return name; } - - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override - { - size_t num_dimensions = arguments.size(); - if (num_dimensions < 1 || num_dimensions > 2) { - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "Illegal number of UInt arguments of function {}: should be at least 1 and not more than 2", - getName()); - } - - size_t vector_start_index = 0; - const auto * const_col = typeid_cast(arguments[0].column.get()); - const ColumnTuple * mask; - if (const_col) - mask = typeid_cast(const_col->getDataColumnPtr().get()); - else - mask = typeid_cast(arguments[0].column.get()); - if (mask) - { - num_dimensions = mask->tupleSize(); - vector_start_index = 1; - for (size_t i = 0; i < num_dimensions; i++) - { - auto ratio = mask->getColumn(i).getUInt(0); - if (ratio > 8 || ratio < 1) - throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, - "Illegal argument {} of function {}, should be a number in range 1-8", - arguments[0].column->getName(), getName()); - } - } - - auto non_const_arguments = arguments; - for (auto & argument : non_const_arguments) - argument.column = argument.column->convertToFullColumnIfConst(); - - auto col_res = ColumnUInt64::create(); - ColumnUInt64::Container & vec_res = col_res->getData(); - vec_res.resize(input_rows_count); - - const ColumnPtr & col0 = non_const_arguments[0 + vector_start_index].column; - if (num_dimensions == 1) - { - for (size_t i = 0; i < input_rows_count; ++i) - { - vec_res[i] = col0->getUInt(i); - } - return col_res; - } - - const ColumnPtr & col1 = non_const_arguments[1 + vector_start_index].column; - for (size_t i = 0; i < input_rows_count; ++i) - { - vec_res[i] = FunctionHilbertEncode2DWIthLookupTableImpl<3>::encode(col0->getUInt(i), col1->getUInt(i)); - } - return col_res; - } -}; - +namespace DB { REGISTER_FUNCTION(HilbertEncode) { diff --git a/src/Functions/hilbertEncode.h b/src/Functions/hilbertEncode.h new file mode 100644 index 00000000000..12c5fc4577b --- /dev/null +++ b/src/Functions/hilbertEncode.h @@ -0,0 +1,202 @@ +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +namespace HilbertDetails +{ + +template +class HilbertLookupTable { +public: + constexpr static UInt8 LOOKUP_TABLE[0] = {}; +}; + +template <> +class HilbertLookupTable<1> { +public: + constexpr static UInt8 LOOKUP_TABLE[16] = { + 4, 1, 11, 2, + 0, 15, 5, 6, + 10, 9, 3, 12, + 14, 7, 13, 8 + }; +}; + +template <> +class HilbertLookupTable<3> { +public: + constexpr static UInt8 LOOKUP_TABLE[256] = { + 64, 1, 206, 79, 16, 211, 84, 21, 131, 2, 205, 140, 81, 82, 151, 22, 4, 199, 8, 203, 158, + 157, 88, 25, 69, 70, 73, 74, 31, 220, 155, 26, 186, 185, 182, 181, 32, 227, 100, 37, 59, + 248, 55, 244, 97, 98, 167, 38, 124, 61, 242, 115, 174, 173, 104, 41, 191, 62, 241, 176, 47, + 236, 171, 42, 0, 195, 68, 5, 250, 123, 60, 255, 65, 66, 135, 6, 249, 184, 125, 126, 142, + 141, 72, 9, 246, 119, 178, 177, 15, 204, 139, 10, 245, 180, 51, 240, 80, 17, 222, 95, 96, + 33, 238, 111, 147, 18, 221, 156, 163, 34, 237, 172, 20, 215, 24, 219, 36, 231, 40, 235, 85, + 86, 89, 90, 101, 102, 105, 106, 170, 169, 166, 165, 154, 153, 150, 149, 43, 232, 39, 228, + 27, 216, 23, 212, 108, 45, 226, 99, 92, 29, 210, 83, 175, 46, 225, 160, 159, 30, 209, 144, + 48, 243, 116, 53, 202, 75, 12, 207, 113, 114, 183, 54, 201, 136, 77, 78, 190, 189, 120, 57, + 198, 71, 130, 129, 63, 252, 187, 58, 197, 132, 3, 192, 234, 107, 44, 239, 112, 49, 254, + 127, 233, 168, 109, 110, 179, 50, 253, 188, 230, 103, 162, 161, 52, 247, 56, 251, 229, 164, + 35, 224, 117, 118, 121, 122, 218, 91, 28, 223, 138, 137, 134, 133, 217, 152, 93, 94, 11, + 200, 7, 196, 214, 87, 146, 145, 76, 13, 194, 67, 213, 148, 19, 208, 143, 14, 193, 128, + }; +}; + +} + + +template +class FunctionHilbertEncode2DWIthLookupTableImpl +{ +public: + struct HilbertEncodeState { + UInt64 hilbert_code = 0; + UInt8 state = 0; + }; + + static UInt64 encode(UInt64 x, UInt64 y) + { + return encodeFromState(x, y, 0).hilbert_code; + } + + static HilbertEncodeState encodeFromState(UInt64 x, UInt64 y, UInt8 state) + { + HilbertEncodeState result; + result.state = state; + const auto leading_zeros_count = getLeadingZeroBits(x | y); + const auto used_bits = std::numeric_limits::digits - leading_zeros_count; + + auto [iterations, current_shift] = getIterationsAndInitialShift(used_bits); + + for (; iterations > 0; --iterations, current_shift -= bit_step) + { + if (iterations % 2 == 0) { + std::swap(x, y); + } + const UInt8 x_bits = (x >> current_shift) & STEP_MASK; + const UInt8 y_bits = (y >> current_shift) & STEP_MASK; + const auto current_step_state = getCodeAndUpdateState(x_bits, y_bits, result.state); + result.hilbert_code |= (current_step_state.hilbert_code << getHilbertShift(current_shift)); + result.state = current_step_state.state; + } + + return result; + } + +private: + // LOOKUP_TABLE[SSXXXYYY] = SSHHHHHH + // where SS - 2 bits for state, XXX - 3 bits of x, YYY - 3 bits of y + // State is rotation of curve on every step, left/up/right/down - therefore 2 bits + static HilbertEncodeState getCodeAndUpdateState(UInt8 x_bits, UInt8 y_bits, UInt8 state) + { + HilbertEncodeState result; + const UInt8 table_index = state | (x_bits << bit_step) | y_bits; + const auto table_code = HilbertDetails::HilbertLookupTable::LOOKUP_TABLE[table_index]; + result.state = table_code & STATE_MASK; + result.hilbert_code = table_code & HILBERT_MASK; + return result; + } + + // hilbert code is double size of input values + static constexpr UInt8 getHilbertShift(UInt8 shift) + { + return shift << 1; + } + + static std::pair getIterationsAndInitialShift(UInt8 used_bits) + { + UInt8 iterations = used_bits / bit_step; + UInt8 initial_shift = iterations * bit_step; + if (initial_shift < used_bits) + { + ++iterations; + } else { + initial_shift -= bit_step; + } + return {iterations, initial_shift}; + } + + constexpr static UInt8 STEP_MASK = (1 << bit_step) - 1; + constexpr static UInt8 HILBERT_MASK = (1 << getHilbertShift(bit_step)) - 1; + constexpr static UInt8 STATE_MASK = 0b11 << getHilbertShift(bit_step); +}; + + +class FunctionHilbertEncode : public FunctionSpaceFillingCurveEncode +{ +public: + static constexpr auto name = "hilbertEncode"; + static FunctionPtr create(ContextPtr) + { + return std::make_shared(); + } + + String getName() const override { return name; } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override + { + size_t num_dimensions = arguments.size(); + if (num_dimensions < 1 || num_dimensions > 2) { + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal number of UInt arguments of function {}: should be at least 1 and not more than 2", + getName()); + } + + size_t vector_start_index = 0; + const auto * const_col = typeid_cast(arguments[0].column.get()); + const ColumnTuple * mask; + if (const_col) + mask = typeid_cast(const_col->getDataColumnPtr().get()); + else + mask = typeid_cast(arguments[0].column.get()); + if (mask) + { + num_dimensions = mask->tupleSize(); + vector_start_index = 1; + for (size_t i = 0; i < num_dimensions; i++) + { + auto ratio = mask->getColumn(i).getUInt(0); + if (ratio > 8 || ratio < 1) + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, + "Illegal argument {} of function {}, should be a number in range 1-8", + arguments[0].column->getName(), getName()); + } + } + + auto non_const_arguments = arguments; + for (auto & argument : non_const_arguments) + argument.column = argument.column->convertToFullColumnIfConst(); + + auto col_res = ColumnUInt64::create(); + ColumnUInt64::Container & vec_res = col_res->getData(); + vec_res.resize(input_rows_count); + + const ColumnPtr & col0 = non_const_arguments[0 + vector_start_index].column; + if (num_dimensions == 1) + { + for (size_t i = 0; i < input_rows_count; ++i) + { + vec_res[i] = col0->getUInt(i); + } + return col_res; + } + + const ColumnPtr & col1 = non_const_arguments[1 + vector_start_index].column; + for (size_t i = 0; i < input_rows_count; ++i) + { + vec_res[i] = FunctionHilbertEncode2DWIthLookupTableImpl<3>::encode(col0->getUInt(i), col1->getUInt(i)); + } + return col_res; + } +}; + +} diff --git a/src/Functions/tests/gtest_hilbert_lookup_table.cpp b/src/Functions/tests/gtest_hilbert_lookup_table.cpp new file mode 100644 index 00000000000..f8143a6c47e --- /dev/null +++ b/src/Functions/tests/gtest_hilbert_lookup_table.cpp @@ -0,0 +1,23 @@ +#include +#include + + +void checkLookupTableConsistency(UInt8 x, UInt8 y, UInt8 state) +{ + auto step1 = DB::FunctionHilbertEncode2DWIthLookupTableImpl<1>::encodeFromState(x, y, state); + auto step2 = DB::FunctionHilbertEncode2DWIthLookupTableImpl<3>::encodeFromState(x, y, state); + ASSERT_EQ(step1.hilbert_code, step2.hilbert_code); + ASSERT_EQ(step1.state, step2.state); +} + + +TEST(HilbertLookupTable, bitStep1And3Consistnecy) +{ + for (int x = 0; x < 8; ++x) { + for (int y = 0; y < 8; ++y) { + for (int state = 0; state < 4; ++state) { + checkLookupTableConsistency(x, y, state); + } + } + } +} From 96f763b1ae453e4c491efa5947d634e058826d35 Mon Sep 17 00:00:00 2001 From: Artem Mustafin Date: Thu, 22 Feb 2024 16:08:13 +0000 Subject: [PATCH 07/14] refactoring + ut + description + ratio --- .../FunctionSpaceFillingCurveEncode.h | 4 +- src/Functions/hilbertEncode.cpp | 43 ++++++++- src/Functions/hilbertEncode.h | 92 ++++++++++--------- src/Functions/mortonEncode.cpp | 1 - src/Functions/tests/gtest_hilbert_encode.cpp | 18 ++++ .../tests/gtest_hilbert_lookup_table.cpp | 23 ----- 6 files changed, 111 insertions(+), 70 deletions(-) create mode 100644 src/Functions/tests/gtest_hilbert_encode.cpp delete mode 100644 src/Functions/tests/gtest_hilbert_lookup_table.cpp diff --git a/src/Functions/FunctionSpaceFillingCurveEncode.h b/src/Functions/FunctionSpaceFillingCurveEncode.h index 257b49176bc..399010bad54 100644 --- a/src/Functions/FunctionSpaceFillingCurveEncode.h +++ b/src/Functions/FunctionSpaceFillingCurveEncode.h @@ -1,3 +1,4 @@ +#pragma once #include #include #include @@ -12,7 +13,8 @@ namespace ErrorCodes extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION; } -class FunctionSpaceFillingCurveEncode: public IFunction { +class FunctionSpaceFillingCurveEncode: public IFunction +{ public: bool isVariadic() const override { diff --git a/src/Functions/hilbertEncode.cpp b/src/Functions/hilbertEncode.cpp index 8f0227227f0..8f09ba9531a 100644 --- a/src/Functions/hilbertEncode.cpp +++ b/src/Functions/hilbertEncode.cpp @@ -8,9 +8,50 @@ REGISTER_FUNCTION(HilbertEncode) { factory.registerFunction(FunctionDocumentation{ .description=R"( +Calculates code for Hilbert Curve for a list of unsigned integers +The function has two modes of operation: +- Simple +- Expanded + +Simple: accepts up to 2 unsigned integers as arguments and produces a UInt64 code. +[example:simple] + +Expanded: accepts a range mask (tuple) as a first argument and up to 2 unsigned integers as other arguments. +Each number in mask configures the amount of bits that corresponding argument will be shifted left +[example:range_expanded] +Note: tuple size must be equal to the number of the other arguments + +Range expansion can be beneficial when you need a similar distribution for arguments with wildly different ranges (or cardinality) +For example: 'IP Address' (0...FFFFFFFF) and 'Country code' (0...FF) + +Hilbert encoding for one argument is always the argument itself. +[example:identity] +Produces: `1` + +You can expand one argument too: +[example:identity_expanded] +Produces: `512` + +The function also accepts columns as arguments: +[example:from_table] + +But the range tuple must still be a constant: +[example:from_table_range] + +Please note that you can fit only so much bits of information into Morton code as UInt64 has. +Two arguments will have a range of maximum 2^32 (64/2) each +All overflow will be clamped to zero )", - .categories {} + .examples{ + {"simple", "SELECT hilbertEncode(1, 2, 3)", ""}, + {"range_expanded", "SELECT hilbertEncode((1,6), 1024, 16)", ""}, + {"identity", "SELECT hilbertEncode(1)", ""}, + {"identity_expanded", "SELECT hilbertEncode(tuple(2), 128)", ""}, + {"from_table", "SELECT hilbertEncode(n1, n2) FROM table", ""}, + {"from_table_range", "SELECT hilbertEncode((1,2), n1, n2) FROM table", ""}, + }, + .categories {"Hilbert coding", "Hilbert Curve"} }); } diff --git a/src/Functions/hilbertEncode.h b/src/Functions/hilbertEncode.h index 12c5fc4577b..876b3a07b5a 100644 --- a/src/Functions/hilbertEncode.h +++ b/src/Functions/hilbertEncode.h @@ -1,3 +1,4 @@ +#pragma once #include #include #include @@ -11,17 +12,25 @@ namespace DB { +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int ARGUMENT_OUT_OF_BOUND; +} + namespace HilbertDetails { template -class HilbertLookupTable { +class HilbertLookupTable +{ public: constexpr static UInt8 LOOKUP_TABLE[0] = {}; }; template <> -class HilbertLookupTable<1> { +class HilbertLookupTable<1> +{ public: constexpr static UInt8 LOOKUP_TABLE[16] = { 4, 1, 11, 2, @@ -32,7 +41,8 @@ public: }; template <> -class HilbertLookupTable<3> { +class HilbertLookupTable<3> +{ public: constexpr static UInt8 LOOKUP_TABLE[256] = { 64, 1, 206, 79, 16, 211, 84, 21, 131, 2, 205, 140, 81, 82, 151, 22, 4, 199, 8, 203, 158, @@ -58,52 +68,36 @@ template class FunctionHilbertEncode2DWIthLookupTableImpl { public: - struct HilbertEncodeState { - UInt64 hilbert_code = 0; - UInt8 state = 0; - }; - static UInt64 encode(UInt64 x, UInt64 y) { - return encodeFromState(x, y, 0).hilbert_code; - } - - static HilbertEncodeState encodeFromState(UInt64 x, UInt64 y, UInt8 state) - { - HilbertEncodeState result; - result.state = state; + UInt64 hilbert_code = 0; const auto leading_zeros_count = getLeadingZeroBits(x | y); const auto used_bits = std::numeric_limits::digits - leading_zeros_count; - auto [iterations, current_shift] = getIterationsAndInitialShift(used_bits); + auto [current_shift, state] = getInitialShiftAndState(used_bits); - for (; iterations > 0; --iterations, current_shift -= bit_step) + while (current_shift >= 0) { - if (iterations % 2 == 0) { - std::swap(x, y); - } const UInt8 x_bits = (x >> current_shift) & STEP_MASK; const UInt8 y_bits = (y >> current_shift) & STEP_MASK; - const auto current_step_state = getCodeAndUpdateState(x_bits, y_bits, result.state); - result.hilbert_code |= (current_step_state.hilbert_code << getHilbertShift(current_shift)); - result.state = current_step_state.state; + const auto hilbert_bits = getCodeAndUpdateState(x_bits, y_bits, state); + hilbert_code |= (hilbert_bits << getHilbertShift(current_shift)); + current_shift -= bit_step; } - return result; + return hilbert_code; } private: // LOOKUP_TABLE[SSXXXYYY] = SSHHHHHH // where SS - 2 bits for state, XXX - 3 bits of x, YYY - 3 bits of y // State is rotation of curve on every step, left/up/right/down - therefore 2 bits - static HilbertEncodeState getCodeAndUpdateState(UInt8 x_bits, UInt8 y_bits, UInt8 state) + static UInt64 getCodeAndUpdateState(UInt8 x_bits, UInt8 y_bits, UInt8& state) { - HilbertEncodeState result; const UInt8 table_index = state | (x_bits << bit_step) | y_bits; const auto table_code = HilbertDetails::HilbertLookupTable::LOOKUP_TABLE[table_index]; - result.state = table_code & STATE_MASK; - result.hilbert_code = table_code & HILBERT_MASK; - return result; + state = table_code & STATE_MASK; + return table_code & HILBERT_MASK; } // hilbert code is double size of input values @@ -112,17 +106,18 @@ private: return shift << 1; } - static std::pair getIterationsAndInitialShift(UInt8 used_bits) + static std::pair getInitialShiftAndState(UInt8 used_bits) { UInt8 iterations = used_bits / bit_step; - UInt8 initial_shift = iterations * bit_step; + Int8 initial_shift = iterations * bit_step; if (initial_shift < used_bits) { ++iterations; } else { initial_shift -= bit_step; } - return {iterations, initial_shift}; + UInt8 state = iterations % 2 == 0 ? 0b01 << getHilbertShift(bit_step) : 0; + return {initial_shift, state}; } constexpr static UInt8 STEP_MASK = (1 << bit_step) - 1; @@ -145,12 +140,6 @@ public: ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override { size_t num_dimensions = arguments.size(); - if (num_dimensions < 1 || num_dimensions > 2) { - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "Illegal number of UInt arguments of function {}: should be at least 1 and not more than 2", - getName()); - } - size_t vector_start_index = 0; const auto * const_col = typeid_cast(arguments[0].column.get()); const ColumnTuple * mask; @@ -165,9 +154,9 @@ public: for (size_t i = 0; i < num_dimensions; i++) { auto ratio = mask->getColumn(i).getUInt(0); - if (ratio > 8 || ratio < 1) + if (ratio > 32) throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, - "Illegal argument {} of function {}, should be a number in range 1-8", + "Illegal argument {} of function {}, should be a number in range 0-32", arguments[0].column->getName(), getName()); } } @@ -180,22 +169,37 @@ public: ColumnUInt64::Container & vec_res = col_res->getData(); vec_res.resize(input_rows_count); + const auto expand = [mask](const UInt64 value, const UInt8 column_id) { + if (mask) + return value << mask->getColumn(column_id).getUInt(0); + return value; + }; + const ColumnPtr & col0 = non_const_arguments[0 + vector_start_index].column; if (num_dimensions == 1) { for (size_t i = 0; i < input_rows_count; ++i) { - vec_res[i] = col0->getUInt(i); + vec_res[i] = expand(col0->getUInt(i), 0); } return col_res; } const ColumnPtr & col1 = non_const_arguments[1 + vector_start_index].column; - for (size_t i = 0; i < input_rows_count; ++i) + if (num_dimensions == 2) { - vec_res[i] = FunctionHilbertEncode2DWIthLookupTableImpl<3>::encode(col0->getUInt(i), col1->getUInt(i)); + for (size_t i = 0; i < input_rows_count; ++i) + { + vec_res[i] = FunctionHilbertEncode2DWIthLookupTableImpl<3>::encode( + expand(col0->getUInt(i), 0), + expand(col1->getUInt(i), 1)); + } + return col_res; } - return col_res; + + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal number of UInt arguments of function {}: should be not more than 2 dimensions", + getName()); } }; diff --git a/src/Functions/mortonEncode.cpp b/src/Functions/mortonEncode.cpp index 5365e3d1cca..63cabe5b77f 100644 --- a/src/Functions/mortonEncode.cpp +++ b/src/Functions/mortonEncode.cpp @@ -18,7 +18,6 @@ namespace ErrorCodes { extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int ARGUMENT_OUT_OF_BOUND; - extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION; } #define EXTRACT_VECTOR(INDEX) \ diff --git a/src/Functions/tests/gtest_hilbert_encode.cpp b/src/Functions/tests/gtest_hilbert_encode.cpp new file mode 100644 index 00000000000..43e72258355 --- /dev/null +++ b/src/Functions/tests/gtest_hilbert_encode.cpp @@ -0,0 +1,18 @@ +#include +#include +#include + + +TEST(HilbertLookupTable, bitStep1And3Consistnecy) +{ + const size_t bound = 1000; + for (size_t x = 0; x < bound; ++x) + { + for (size_t y = 0; y < bound; ++y) + { + auto hilbert1bit = DB::FunctionHilbertEncode2DWIthLookupTableImpl<1>::encode(x, y); + auto hilbert3bit = DB::FunctionHilbertEncode2DWIthLookupTableImpl<3>::encode(x, y); + ASSERT_EQ(hilbert1bit, hilbert3bit); + } + } +} diff --git a/src/Functions/tests/gtest_hilbert_lookup_table.cpp b/src/Functions/tests/gtest_hilbert_lookup_table.cpp deleted file mode 100644 index f8143a6c47e..00000000000 --- a/src/Functions/tests/gtest_hilbert_lookup_table.cpp +++ /dev/null @@ -1,23 +0,0 @@ -#include -#include - - -void checkLookupTableConsistency(UInt8 x, UInt8 y, UInt8 state) -{ - auto step1 = DB::FunctionHilbertEncode2DWIthLookupTableImpl<1>::encodeFromState(x, y, state); - auto step2 = DB::FunctionHilbertEncode2DWIthLookupTableImpl<3>::encodeFromState(x, y, state); - ASSERT_EQ(step1.hilbert_code, step2.hilbert_code); - ASSERT_EQ(step1.state, step2.state); -} - - -TEST(HilbertLookupTable, bitStep1And3Consistnecy) -{ - for (int x = 0; x < 8; ++x) { - for (int y = 0; y < 8; ++y) { - for (int state = 0; state < 4; ++state) { - checkLookupTableConsistency(x, y, state); - } - } - } -} From e63e7a4fa572602f2b72429b8752a59dd366aaf3 Mon Sep 17 00:00:00 2001 From: Yarik Briukhovetskyi <114298166+yariks5s@users.noreply.github.com> Date: Thu, 22 Feb 2024 23:06:52 +0100 Subject: [PATCH 08/14] style check --- src/Functions/hilbertEncode.cpp | 3 ++- src/Functions/hilbertEncode.h | 9 ++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/Functions/hilbertEncode.cpp b/src/Functions/hilbertEncode.cpp index 8f09ba9531a..0bad6f36b30 100644 --- a/src/Functions/hilbertEncode.cpp +++ b/src/Functions/hilbertEncode.cpp @@ -2,7 +2,8 @@ #include -namespace DB { +namespace DB +{ REGISTER_FUNCTION(HilbertEncode) { diff --git a/src/Functions/hilbertEncode.h b/src/Functions/hilbertEncode.h index 876b3a07b5a..28ad1e72666 100644 --- a/src/Functions/hilbertEncode.h +++ b/src/Functions/hilbertEncode.h @@ -113,7 +113,9 @@ private: if (initial_shift < used_bits) { ++iterations; - } else { + } + else + { initial_shift -= bit_step; } UInt8 state = iterations % 2 == 0 ? 0b01 << getHilbertShift(bit_step) : 0; @@ -169,8 +171,9 @@ public: ColumnUInt64::Container & vec_res = col_res->getData(); vec_res.resize(input_rows_count); - const auto expand = [mask](const UInt64 value, const UInt8 column_id) { - if (mask) + const auto expand = [mask](const UInt64 value, const UInt8 column_id) + { + if z(mask) return value << mask->getColumn(column_id).getUInt(0); return value; }; From c21d8495ba6a73c19d98d14daffa4c4650fb981a Mon Sep 17 00:00:00 2001 From: Artem Mustafin Date: Fri, 23 Feb 2024 21:18:05 +0000 Subject: [PATCH 09/14] add hilbert decode --- src/Functions/FunctionSpaceFillingCurve.h | 142 ++++++++++++ .../FunctionSpaceFillingCurveEncode.h | 70 ------ src/Functions/hilbertDecode.cpp | 55 +++++ src/Functions/hilbertDecode.h | 204 ++++++++++++++++++ src/Functions/hilbertEncode.cpp | 20 +- src/Functions/hilbertEncode.h | 16 +- src/Functions/mortonDecode.cpp | 77 +------ src/Functions/mortonEncode.cpp | 2 +- src/Functions/tests/gtest_hilbert_curve.cpp | 29 +++ src/Functions/tests/gtest_hilbert_encode.cpp | 18 -- 10 files changed, 458 insertions(+), 175 deletions(-) create mode 100644 src/Functions/FunctionSpaceFillingCurve.h delete mode 100644 src/Functions/FunctionSpaceFillingCurveEncode.h create mode 100644 src/Functions/hilbertDecode.cpp create mode 100644 src/Functions/hilbertDecode.h create mode 100644 src/Functions/tests/gtest_hilbert_curve.cpp delete mode 100644 src/Functions/tests/gtest_hilbert_encode.cpp diff --git a/src/Functions/FunctionSpaceFillingCurve.h b/src/Functions/FunctionSpaceFillingCurve.h new file mode 100644 index 00000000000..37c298e9e54 --- /dev/null +++ b/src/Functions/FunctionSpaceFillingCurve.h @@ -0,0 +1,142 @@ +#pragma once +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int ARGUMENT_OUT_OF_BOUND; + extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION; + extern const int ILLEGAL_COLUMN; +} + +class FunctionSpaceFillingCurveEncode: public IFunction +{ +public: + bool isVariadic() const override + { + return true; + } + + size_t getNumberOfArguments() const override + { + return 0; + } + + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; } + + bool useDefaultImplementationForConstants() const override { return true; } + + DataTypePtr getReturnTypeImpl(const DB::DataTypes & arguments) const override + { + size_t vector_start_index = 0; + if (arguments.empty()) + throw Exception(ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION, + "At least one UInt argument is required for function {}", + getName()); + if (WhichDataType(arguments[0]).isTuple()) + { + vector_start_index = 1; + const auto * type_tuple = typeid_cast(arguments[0].get()); + auto tuple_size = type_tuple->getElements().size(); + if (tuple_size != (arguments.size() - 1)) + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, + "Illegal argument {} for function {}, tuple size should be equal to number of UInt arguments", + arguments[0]->getName(), getName()); + for (size_t i = 0; i < tuple_size; i++) + { + if (!WhichDataType(type_tuple->getElement(i)).isNativeUInt()) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal type {} of argument in tuple for function {}, should be a native UInt", + type_tuple->getElement(i)->getName(), getName()); + } + } + + for (size_t i = vector_start_index; i < arguments.size(); i++) + { + const auto & arg = arguments[i]; + if (!WhichDataType(arg).isNativeUInt()) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal type {} of argument of function {}, should be a native UInt", + arg->getName(), getName()); + } + return std::make_shared(); + } +}; + +template +class FunctionSpaceFillingCurveDecode: public IFunction +{ +public: + size_t getNumberOfArguments() const override + { + return 2; + } + + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; } + + ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {0}; } + + DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override + { + UInt64 tuple_size = 0; + const auto * col_const = typeid_cast(arguments[0].column.get()); + if (!col_const) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, + "Illegal column type {} of function {}, should be a constant (UInt or Tuple)", + arguments[0].type->getName(), getName()); + if (!WhichDataType(arguments[1].type).isNativeUInt()) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, + "Illegal column type {} of function {}, should be a native UInt", + arguments[1].type->getName(), getName()); + const auto * mask = typeid_cast(col_const->getDataColumnPtr().get()); + if (mask) + { + tuple_size = mask->tupleSize(); + } + else if (WhichDataType(arguments[0].type).isNativeUInt()) + { + tuple_size = col_const->getUInt(0); + } + else + throw Exception(ErrorCodes::ILLEGAL_COLUMN, + "Illegal column type {} of function {}, should be UInt or Tuple", + arguments[0].type->getName(), getName()); + if (tuple_size > max_dimensions || tuple_size < 1) + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, + "Illegal first argument for function {}, should be a number in range 1-{} or a Tuple of such size", + getName(), String{max_dimensions}); + if (mask) + { + const auto * type_tuple = typeid_cast(arguments[0].type.get()); + for (size_t i = 0; i < tuple_size; i++) + { + if (!WhichDataType(type_tuple->getElement(i)).isNativeUInt()) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal type {} of argument in tuple for function {}, should be a native UInt", + type_tuple->getElement(i)->getName(), getName()); + auto ratio = mask->getColumn(i).getUInt(0); + if (ratio > max_ratio || ratio < min_ratio) + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, + "Illegal argument {} in tuple for function {}, should be a number in range {}-{}", + ratio, getName(), String{min_ratio}, String{max_ratio}); + } + } + DataTypes types(tuple_size); + for (size_t i = 0; i < tuple_size; i++) + { + types[i] = std::make_shared(); + } + return std::make_shared(types); + } +}; + +} diff --git a/src/Functions/FunctionSpaceFillingCurveEncode.h b/src/Functions/FunctionSpaceFillingCurveEncode.h deleted file mode 100644 index 399010bad54..00000000000 --- a/src/Functions/FunctionSpaceFillingCurveEncode.h +++ /dev/null @@ -1,70 +0,0 @@ -#pragma once -#include -#include -#include - -namespace DB -{ - -namespace ErrorCodes -{ - extern const int ILLEGAL_TYPE_OF_ARGUMENT; - extern const int ARGUMENT_OUT_OF_BOUND; - extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION; -} - -class FunctionSpaceFillingCurveEncode: public IFunction -{ -public: - bool isVariadic() const override - { - return true; - } - - size_t getNumberOfArguments() const override - { - return 0; - } - - bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; } - - bool useDefaultImplementationForConstants() const override { return true; } - - DataTypePtr getReturnTypeImpl(const DB::DataTypes & arguments) const override - { - size_t vector_start_index = 0; - if (arguments.empty()) - throw Exception(ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION, - "At least one UInt argument is required for function {}", - getName()); - if (WhichDataType(arguments[0]).isTuple()) - { - vector_start_index = 1; - const auto * type_tuple = typeid_cast(arguments[0].get()); - auto tuple_size = type_tuple->getElements().size(); - if (tuple_size != (arguments.size() - 1)) - throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, - "Illegal argument {} for function {}, tuple size should be equal to number of UInt arguments", - arguments[0]->getName(), getName()); - for (size_t i = 0; i < tuple_size; i++) - { - if (!WhichDataType(type_tuple->getElement(i)).isNativeUInt()) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "Illegal type {} of argument in tuple for function {}, should be a native UInt", - type_tuple->getElement(i)->getName(), getName()); - } - } - - for (size_t i = vector_start_index; i < arguments.size(); i++) - { - const auto & arg = arguments[i]; - if (!WhichDataType(arg).isNativeUInt()) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "Illegal type {} of argument of function {}, should be a native UInt", - arg->getName(), getName()); - } - return std::make_shared(); - } -}; - -} diff --git a/src/Functions/hilbertDecode.cpp b/src/Functions/hilbertDecode.cpp new file mode 100644 index 00000000000..7bace81ba5c --- /dev/null +++ b/src/Functions/hilbertDecode.cpp @@ -0,0 +1,55 @@ +#include +#include + + +namespace DB +{ + +REGISTER_FUNCTION(HilbertDecode) +{ + factory.registerFunction(FunctionDocumentation{ + .description=R"( +Decodes Hilbert Curve code into the corresponding unsigned integer tuple + +The function has two modes of operation: +- Simple +- Expanded + +Simple: accepts a resulting tuple size as a first argument and the code as a second argument. +[example:simple] +Will decode into: `(8, 0)` +The resulting tuple size cannot be more than 2 + +Expanded: accepts a range mask (tuple) as a first argument and the code as a second argument. +Each number in mask configures the amount of bits that corresponding argument will be shifted right +[example:range_shrank] +Note: see hilbertEncode() docs on why range change might be beneficial. +Still limited to 2 numbers at most. + +Hilbert code for one argument is always the argument itself (as a tuple). +[example:identity] +Produces: `(1)` + +You can shrink one argument too: +[example:identity_shrank] +Produces: `(128)` + +The function accepts a column of codes as a second argument: +[example:from_table] + +The range tuple must be a constant: +[example:from_table_range] +)", + .examples{ + {"simple", "SELECT hilbertDecode(2, 64)", ""}, + {"range_shrank", "SELECT hilbertDecode((1,2), 1572864)", ""}, + {"identity", "SELECT hilbertDecode(1, 1)", ""}, + {"identity_shrank", "SELECT hilbertDecode(tuple(2), 512)", ""}, + {"from_table", "SELECT hilbertDecode(2, code) FROM table", ""}, + {"from_table_range", "SELECT hilbertDecode((1,2), code) FROM table", ""}, + }, + .categories {"Hilbert coding", "Hilbert Curve"} + }); +} + +} diff --git a/src/Functions/hilbertDecode.h b/src/Functions/hilbertDecode.h new file mode 100644 index 00000000000..783b26c174f --- /dev/null +++ b/src/Functions/hilbertDecode.h @@ -0,0 +1,204 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int ARGUMENT_OUT_OF_BOUND; +} + +namespace HilbertDetails +{ + +template +class HilbertDecodeLookupTable +{ +public: + constexpr static UInt8 LOOKUP_TABLE[0] = {}; +}; + +template <> +class HilbertDecodeLookupTable<1> +{ +public: + constexpr static UInt8 LOOKUP_TABLE[16] = { + 4, 1, 3, 10, + 0, 6, 7, 13, + 15, 9, 8, 2, + 11, 14, 12, 5 + }; +}; + +template <> +class HilbertDecodeLookupTable<3> +{ +public: + constexpr static UInt8 LOOKUP_TABLE[256] = { + 64, 1, 9, 136, 16, 88, 89, 209, 18, 90, 91, 211, 139, 202, 194, 67, 4, 76, 77, 197, 70, 7, + 15, 142, 86, 23, 31, 158, 221, 149, 148, 28, 36, 108, 109, 229, 102, 39, 47, 174, 118, 55, + 63, 190, 253, 181, 180, 60, 187, 250, 242, 115, 235, 163, 162, 42, 233, 161, 160, 40, 112, + 49, 57, 184, 0, 72, 73, 193, 66, 3, 11, 138, 82, 19, 27, 154, 217, 145, 144, 24, 96, 33, + 41, 168, 48, 120, 121, 241, 50, 122, 123, 243, 171, 234, 226, 99, 100, 37, 45, 172, 52, + 124, 125, 245, 54, 126, 127, 247, 175, 238, 230, 103, 223, 151, 150, 30, 157, 220, 212, 85, + 141, 204, 196, 69, 6, 78, 79, 199, 255, 183, 182, 62, 189, 252, 244, 117, 173, 236, 228, + 101, 38, 110, 111, 231, 159, 222, 214, 87, 207, 135, 134, 14, 205, 133, 132, 12, 84, 21, + 29, 156, 155, 218, 210, 83, 203, 131, 130, 10, 201, 129, 128, 8, 80, 17, 25, 152, 32, 104, + 105, 225, 98, 35, 43, 170, 114, 51, 59, 186, 249, 177, 176, 56, 191, 254, 246, 119, 239, + 167, 166, 46, 237, 165, 164, 44, 116, 53, 61, 188, 251, 179, 178, 58, 185, 248, 240, 113, + 169, 232, 224, 97, 34, 106, 107, 227, 219, 147, 146, 26, 153, 216, 208, 81, 137, 200, 192, + 65, 2, 74, 75, 195, 68, 5, 13, 140, 20, 92, 93, 213, 22, 94, 95, 215, 143, 206, 198, 71 + }; +}; + +} + + +template +class FunctionHilbertDecode2DWIthLookupTableImpl +{ + static_assert(bit_step <= 3, "bit_step should not be more than 3 to fit in UInt8"); +public: + static std::tuple decode(UInt64 hilbert_code) + { + UInt64 x = 0; + UInt64 y = 0; + const auto leading_zeros_count = getLeadingZeroBits(hilbert_code); + const auto used_bits = std::numeric_limits::digits - leading_zeros_count; + + auto [current_shift, state] = getInitialShiftAndState(used_bits); + + while (current_shift >= 0) + { + const UInt8 hilbert_bits = (hilbert_code >> current_shift) & HILBERT_MASK; + const auto [x_bits, y_bits] = getCodeAndUpdateState(hilbert_bits, state); + x |= (x_bits << (current_shift >> 1)); + y |= (y_bits << (current_shift >> 1)); + current_shift -= getHilbertShift(bit_step); + } + + return {x, y}; + } + +private: + // for bit_step = 3 + // LOOKUP_TABLE[SSHHHHHH] = SSXXXYYY + // where SS - 2 bits for state, XXX - 3 bits of x, YYY - 3 bits of y + // State is rotation of curve on every step, left/up/right/down - therefore 2 bits + static std::pair getCodeAndUpdateState(UInt8 hilbert_bits, UInt8& state) + { + const UInt8 table_index = state | hilbert_bits; + const auto table_code = HilbertDetails::HilbertDecodeLookupTable::LOOKUP_TABLE[table_index]; + state = table_code & STATE_MASK; + const UInt64 x_bits = (table_code & X_MASK) >> bit_step; + const UInt64 y_bits = table_code & Y_MASK; + return {x_bits, y_bits}; + } + + // hilbert code is double size of input values + static constexpr UInt8 getHilbertShift(UInt8 shift) + { + return shift << 1; + } + + static std::pair getInitialShiftAndState(UInt8 used_bits) + { + const UInt8 hilbert_shift = getHilbertShift(bit_step); + UInt8 iterations = used_bits / hilbert_shift; + Int8 initial_shift = iterations * hilbert_shift; + if (initial_shift < used_bits) + { + ++iterations; + } + else + { + initial_shift -= hilbert_shift; + } + UInt8 state = iterations % 2 == 0 ? 0b01 << hilbert_shift : 0; + return {initial_shift, state}; + } + + constexpr static UInt8 STEP_MASK = (1 << bit_step) - 1; + constexpr static UInt8 HILBERT_MASK = (1 << getHilbertShift(bit_step)) - 1; + constexpr static UInt8 STATE_MASK = 0b11 << getHilbertShift(bit_step); + constexpr static UInt8 Y_MASK = STEP_MASK; + constexpr static UInt8 X_MASK = STEP_MASK << bit_step; +}; + + +class FunctionHilbertDecode : public FunctionSpaceFillingCurveDecode<2, 0, 32> +{ +public: + static constexpr auto name = "hilbertDecode"; + static FunctionPtr create(ContextPtr) + { + return std::make_shared(); + } + + String getName() const override { return name; } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override + { + size_t nd; + const auto * col_const = typeid_cast(arguments[0].column.get()); + const auto * mask = typeid_cast(col_const->getDataColumnPtr().get()); + if (mask) + nd = mask->tupleSize(); + else + nd = col_const->getUInt(0); + auto non_const_arguments = arguments; + non_const_arguments[1].column = non_const_arguments[1].column->convertToFullColumnIfConst(); + const ColumnPtr & col_code = non_const_arguments[1].column; + Columns tuple_columns(nd); + + const auto shrink = [mask](const UInt64 value, const UInt8 column_id) { + if (mask) + return value >> mask->getColumn(column_id).getUInt(0); + return value; + }; + + auto col0 = ColumnUInt64::create(); + auto & vec0 = col0->getData(); + vec0.resize(input_rows_count); + + if (nd == 1) + { + for (size_t i = 0; i < input_rows_count; i++) + { + vec0[i] = shrink(col_code->getUInt(i), 0); + } + tuple_columns[0] = std::move(col0); + return ColumnTuple::create(tuple_columns); + } + + auto col1 = ColumnUInt64::create(); + auto & vec1 = col1->getData(); + vec1.resize(input_rows_count); + + if (nd == 2) + { + for (size_t i = 0; i < input_rows_count; i++) + { + const auto res = FunctionHilbertDecode2DWIthLookupTableImpl<3>::decode(col_code->getUInt(i)); + vec0[i] = shrink(std::get<0>(res), 0); + vec1[i] = shrink(std::get<1>(res), 1); + } + tuple_columns[0] = std::move(col0); + return ColumnTuple::create(tuple_columns); + } + + return ColumnTuple::create(tuple_columns); + } +}; + +} diff --git a/src/Functions/hilbertEncode.cpp b/src/Functions/hilbertEncode.cpp index 0bad6f36b30..e98628a5a44 100644 --- a/src/Functions/hilbertEncode.cpp +++ b/src/Functions/hilbertEncode.cpp @@ -8,7 +8,7 @@ namespace DB REGISTER_FUNCTION(HilbertEncode) { factory.registerFunction(FunctionDocumentation{ - .description=R"( + .description=R"( Calculates code for Hilbert Curve for a list of unsigned integers The function has two modes of operation: @@ -44,15 +44,15 @@ Please note that you can fit only so much bits of information into Morton code a Two arguments will have a range of maximum 2^32 (64/2) each All overflow will be clamped to zero )", - .examples{ - {"simple", "SELECT hilbertEncode(1, 2, 3)", ""}, - {"range_expanded", "SELECT hilbertEncode((1,6), 1024, 16)", ""}, - {"identity", "SELECT hilbertEncode(1)", ""}, - {"identity_expanded", "SELECT hilbertEncode(tuple(2), 128)", ""}, - {"from_table", "SELECT hilbertEncode(n1, n2) FROM table", ""}, - {"from_table_range", "SELECT hilbertEncode((1,2), n1, n2) FROM table", ""}, - }, - .categories {"Hilbert coding", "Hilbert Curve"} + .examples{ + {"simple", "SELECT hilbertEncode(1, 2, 3)", ""}, + {"range_expanded", "SELECT hilbertEncode((1,6), 1024, 16)", ""}, + {"identity", "SELECT hilbertEncode(1)", ""}, + {"identity_expanded", "SELECT hilbertEncode(tuple(2), 128)", ""}, + {"from_table", "SELECT hilbertEncode(n1, n2) FROM table", ""}, + {"from_table_range", "SELECT hilbertEncode((1,2), n1, n2) FROM table", ""}, + }, + .categories {"Hilbert coding", "Hilbert Curve"} }); } diff --git a/src/Functions/hilbertEncode.h b/src/Functions/hilbertEncode.h index 28ad1e72666..7dc7ec8fdf2 100644 --- a/src/Functions/hilbertEncode.h +++ b/src/Functions/hilbertEncode.h @@ -3,7 +3,7 @@ #include #include #include -#include +#include #include #include #include @@ -22,14 +22,14 @@ namespace HilbertDetails { template -class HilbertLookupTable +class HilbertEncodeLookupTable { public: constexpr static UInt8 LOOKUP_TABLE[0] = {}; }; template <> -class HilbertLookupTable<1> +class HilbertEncodeLookupTable<1> { public: constexpr static UInt8 LOOKUP_TABLE[16] = { @@ -41,7 +41,7 @@ public: }; template <> -class HilbertLookupTable<3> +class HilbertEncodeLookupTable<3> { public: constexpr static UInt8 LOOKUP_TABLE[256] = { @@ -64,9 +64,10 @@ public: } -template +template class FunctionHilbertEncode2DWIthLookupTableImpl { + static_assert(bit_step <= 3, "bit_step should not be more than 3 to fit in UInt8"); public: static UInt64 encode(UInt64 x, UInt64 y) { @@ -89,13 +90,14 @@ public: } private: + // for bit_step = 3 // LOOKUP_TABLE[SSXXXYYY] = SSHHHHHH // where SS - 2 bits for state, XXX - 3 bits of x, YYY - 3 bits of y // State is rotation of curve on every step, left/up/right/down - therefore 2 bits static UInt64 getCodeAndUpdateState(UInt8 x_bits, UInt8 y_bits, UInt8& state) { const UInt8 table_index = state | (x_bits << bit_step) | y_bits; - const auto table_code = HilbertDetails::HilbertLookupTable::LOOKUP_TABLE[table_index]; + const auto table_code = HilbertDetails::HilbertEncodeLookupTable::LOOKUP_TABLE[table_index]; state = table_code & STATE_MASK; return table_code & HILBERT_MASK; } @@ -173,7 +175,7 @@ public: const auto expand = [mask](const UInt64 value, const UInt8 column_id) { - if z(mask) + if (mask) return value << mask->getColumn(column_id).getUInt(0); return value; }; diff --git a/src/Functions/mortonDecode.cpp b/src/Functions/mortonDecode.cpp index f65f38fb097..7da1d1084eb 100644 --- a/src/Functions/mortonDecode.cpp +++ b/src/Functions/mortonDecode.cpp @@ -1,10 +1,11 @@ -#include -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include +#include +#include #include #include @@ -186,7 +187,7 @@ constexpr auto MortonND_5D_Dec = mortonnd::MortonNDLutDecoder<5, 12, 8>(); constexpr auto MortonND_6D_Dec = mortonnd::MortonNDLutDecoder<6, 10, 8>(); constexpr auto MortonND_7D_Dec = mortonnd::MortonNDLutDecoder<7, 9, 8>(); constexpr auto MortonND_8D_Dec = mortonnd::MortonNDLutDecoder<8, 8, 8>(); -class FunctionMortonDecode : public IFunction +class FunctionMortonDecode : public FunctionSpaceFillingCurveDecode<8, 1, 8> { public: static constexpr auto name = "mortonDecode"; @@ -200,68 +201,6 @@ public: return name; } - size_t getNumberOfArguments() const override - { - return 2; - } - - bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; } - - ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {0}; } - - DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override - { - UInt64 tuple_size = 0; - const auto * col_const = typeid_cast(arguments[0].column.get()); - if (!col_const) - throw Exception(ErrorCodes::ILLEGAL_COLUMN, - "Illegal column type {} of function {}, should be a constant (UInt or Tuple)", - arguments[0].type->getName(), getName()); - if (!WhichDataType(arguments[1].type).isNativeUInt()) - throw Exception(ErrorCodes::ILLEGAL_COLUMN, - "Illegal column type {} of function {}, should be a native UInt", - arguments[1].type->getName(), getName()); - const auto * mask = typeid_cast(col_const->getDataColumnPtr().get()); - if (mask) - { - tuple_size = mask->tupleSize(); - } - else if (WhichDataType(arguments[0].type).isNativeUInt()) - { - tuple_size = col_const->getUInt(0); - } - else - throw Exception(ErrorCodes::ILLEGAL_COLUMN, - "Illegal column type {} of function {}, should be UInt or Tuple", - arguments[0].type->getName(), getName()); - if (tuple_size > 8 || tuple_size < 1) - throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, - "Illegal first argument for function {}, should be a number in range 1-8 or a Tuple of such size", - getName()); - if (mask) - { - const auto * type_tuple = typeid_cast(arguments[0].type.get()); - for (size_t i = 0; i < tuple_size; i++) - { - if (!WhichDataType(type_tuple->getElement(i)).isNativeUInt()) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "Illegal type {} of argument in tuple for function {}, should be a native UInt", - type_tuple->getElement(i)->getName(), getName()); - auto ratio = mask->getColumn(i).getUInt(0); - if (ratio > 8 || ratio < 1) - throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, - "Illegal argument {} in tuple for function {}, should be a number in range 1-8", - ratio, getName()); - } - } - DataTypes types(tuple_size); - for (size_t i = 0; i < tuple_size; i++) - { - types[i] = std::make_shared(); - } - return std::make_shared(types); - } - static UInt64 shrink(UInt64 ratio, UInt64 value) { switch (ratio) // NOLINT(bugprone-switch-missing-default-case) diff --git a/src/Functions/mortonEncode.cpp b/src/Functions/mortonEncode.cpp index 63cabe5b77f..5ae5fd41b28 100644 --- a/src/Functions/mortonEncode.cpp +++ b/src/Functions/mortonEncode.cpp @@ -3,7 +3,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/Functions/tests/gtest_hilbert_curve.cpp b/src/Functions/tests/gtest_hilbert_curve.cpp new file mode 100644 index 00000000000..108ab6a6ccf --- /dev/null +++ b/src/Functions/tests/gtest_hilbert_curve.cpp @@ -0,0 +1,29 @@ +#include +#include +#include + + +TEST(HilbertLookupTable, EncodeBit1And3Consistnecy) +{ + const size_t bound = 1000; + for (size_t x = 0; x < bound; ++x) + { + for (size_t y = 0; y < bound; ++y) + { + auto hilbert1bit = DB::FunctionHilbertEncode2DWIthLookupTableImpl<1>::encode(x, y); + auto hilbert3bit = DB::FunctionHilbertEncode2DWIthLookupTableImpl<3>::encode(x, y); + ASSERT_EQ(hilbert1bit, hilbert3bit); + } + } +} + +TEST(HilbertLookupTable, DecodeBit1And3Consistnecy) +{ + const size_t bound = 1000 * 1000; + for (size_t hilbert_code = 0; hilbert_code < bound; ++hilbert_code) + { + auto res1 = DB::FunctionHilbertDecode2DWIthLookupTableImpl<1>::decode(hilbert_code); + auto res3 = DB::FunctionHilbertDecode2DWIthLookupTableImpl<3>::decode(hilbert_code); + ASSERT_EQ(res1, res3); + } +} diff --git a/src/Functions/tests/gtest_hilbert_encode.cpp b/src/Functions/tests/gtest_hilbert_encode.cpp deleted file mode 100644 index 43e72258355..00000000000 --- a/src/Functions/tests/gtest_hilbert_encode.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include -#include -#include - - -TEST(HilbertLookupTable, bitStep1And3Consistnecy) -{ - const size_t bound = 1000; - for (size_t x = 0; x < bound; ++x) - { - for (size_t y = 0; y < bound; ++y) - { - auto hilbert1bit = DB::FunctionHilbertEncode2DWIthLookupTableImpl<1>::encode(x, y); - auto hilbert3bit = DB::FunctionHilbertEncode2DWIthLookupTableImpl<3>::encode(x, y); - ASSERT_EQ(hilbert1bit, hilbert3bit); - } - } -} From 5fc6020540c4766ad57befe198f828e590f99403 Mon Sep 17 00:00:00 2001 From: Artem Mustafin Date: Fri, 23 Feb 2024 21:35:49 +0000 Subject: [PATCH 10/14] style --- src/Functions/hilbertDecode.h | 9 ++------- src/Functions/mortonDecode.cpp | 7 ------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/src/Functions/hilbertDecode.h b/src/Functions/hilbertDecode.h index 783b26c174f..326c5d7bdaf 100644 --- a/src/Functions/hilbertDecode.h +++ b/src/Functions/hilbertDecode.h @@ -12,12 +12,6 @@ namespace DB { -namespace ErrorCodes -{ - extern const int ILLEGAL_TYPE_OF_ARGUMENT; - extern const int ARGUMENT_OUT_OF_BOUND; -} - namespace HilbertDetails { @@ -161,7 +155,8 @@ public: const ColumnPtr & col_code = non_const_arguments[1].column; Columns tuple_columns(nd); - const auto shrink = [mask](const UInt64 value, const UInt8 column_id) { + const auto shrink = [mask](const UInt64 value, const UInt8 column_id) + { if (mask) return value >> mask->getColumn(column_id).getUInt(0); return value; diff --git a/src/Functions/mortonDecode.cpp b/src/Functions/mortonDecode.cpp index 7da1d1084eb..2b7b7b4f2e7 100644 --- a/src/Functions/mortonDecode.cpp +++ b/src/Functions/mortonDecode.cpp @@ -16,13 +16,6 @@ namespace DB { -namespace ErrorCodes -{ - extern const int ILLEGAL_TYPE_OF_ARGUMENT; - extern const int ILLEGAL_COLUMN; - extern const int ARGUMENT_OUT_OF_BOUND; -} - // NOLINTBEGIN(bugprone-switch-missing-default-case) #define EXTRACT_VECTOR(INDEX) \ From 695ea5f0294d29ff85fdeba9f034446f5cb20dbe Mon Sep 17 00:00:00 2001 From: Yarik Briukhovetskyi <114298166+yariks5s@users.noreply.github.com> Date: Tue, 5 Mar 2024 13:53:14 +0100 Subject: [PATCH 11/14] reload ci From cf489bd907f5e74d2dd357621ab20a9e1d092ba6 Mon Sep 17 00:00:00 2001 From: Yarik Briukhovetskyi <114298166+yariks5s@users.noreply.github.com> Date: Tue, 19 Mar 2024 19:39:33 +0100 Subject: [PATCH 12/14] "of function" -> "for function" --- src/Functions/FunctionSpaceFillingCurve.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Functions/FunctionSpaceFillingCurve.h b/src/Functions/FunctionSpaceFillingCurve.h index 37c298e9e54..9ce8fa6584e 100644 --- a/src/Functions/FunctionSpaceFillingCurve.h +++ b/src/Functions/FunctionSpaceFillingCurve.h @@ -65,7 +65,7 @@ public: const auto & arg = arguments[i]; if (!WhichDataType(arg).isNativeUInt()) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "Illegal type {} of argument of function {}, should be a native UInt", + "Illegal type {} of argument for function {}, should be a native UInt", arg->getName(), getName()); } return std::make_shared(); @@ -91,11 +91,11 @@ public: const auto * col_const = typeid_cast(arguments[0].column.get()); if (!col_const) throw Exception(ErrorCodes::ILLEGAL_COLUMN, - "Illegal column type {} of function {}, should be a constant (UInt or Tuple)", + "Illegal column type {} for function {}, should be a constant (UInt or Tuple)", arguments[0].type->getName(), getName()); if (!WhichDataType(arguments[1].type).isNativeUInt()) throw Exception(ErrorCodes::ILLEGAL_COLUMN, - "Illegal column type {} of function {}, should be a native UInt", + "Illegal column type {} for function {}, should be a native UInt", arguments[1].type->getName(), getName()); const auto * mask = typeid_cast(col_const->getDataColumnPtr().get()); if (mask) @@ -108,7 +108,7 @@ public: } else throw Exception(ErrorCodes::ILLEGAL_COLUMN, - "Illegal column type {} of function {}, should be UInt or Tuple", + "Illegal column type {} for function {}, should be UInt or Tuple", arguments[0].type->getName(), getName()); if (tuple_size > max_dimensions || tuple_size < 1) throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, From 6ab2e083864d4dc90978250610b803292f4bd660 Mon Sep 17 00:00:00 2001 From: Artem Mustafin Date: Tue, 23 Apr 2024 23:53:23 +0000 Subject: [PATCH 13/14] add bit_step=2 and some tests --- src/Functions/hilbertDecode.h | 61 +++++++++++------ src/Functions/hilbertEncode.h | 76 ++++++++++++++++----- src/Functions/tests/gtest_hilbert_curve.cpp | 56 ++++++++++++++- 3 files changed, 153 insertions(+), 40 deletions(-) diff --git a/src/Functions/hilbertDecode.h b/src/Functions/hilbertDecode.h index 326c5d7bdaf..4c46143399b 100644 --- a/src/Functions/hilbertDecode.h +++ b/src/Functions/hilbertDecode.h @@ -34,24 +34,43 @@ public: }; }; +template <> +class HilbertDecodeLookupTable<2> +{ +public: + constexpr static UInt8 LOOKUP_TABLE[64] = { + 0, 20, 21, 49, 18, 3, 7, 38, + 26, 11, 15, 46, 61, 41, 40, 12, + 16, 1, 5, 36, 8, 28, 29, 57, + 10, 30, 31, 59, 39, 54, 50, 19, + 47, 62, 58, 27, 55, 35, 34, 6, + 53, 33, 32, 4, 24, 9, 13, 44, + 63, 43, 42, 14, 45, 60, 56, 25, + 37, 52, 48, 17, 2, 22, 23, 51 + }; +}; + template <> class HilbertDecodeLookupTable<3> { public: constexpr static UInt8 LOOKUP_TABLE[256] = { - 64, 1, 9, 136, 16, 88, 89, 209, 18, 90, 91, 211, 139, 202, 194, 67, 4, 76, 77, 197, 70, 7, - 15, 142, 86, 23, 31, 158, 221, 149, 148, 28, 36, 108, 109, 229, 102, 39, 47, 174, 118, 55, - 63, 190, 253, 181, 180, 60, 187, 250, 242, 115, 235, 163, 162, 42, 233, 161, 160, 40, 112, - 49, 57, 184, 0, 72, 73, 193, 66, 3, 11, 138, 82, 19, 27, 154, 217, 145, 144, 24, 96, 33, - 41, 168, 48, 120, 121, 241, 50, 122, 123, 243, 171, 234, 226, 99, 100, 37, 45, 172, 52, - 124, 125, 245, 54, 126, 127, 247, 175, 238, 230, 103, 223, 151, 150, 30, 157, 220, 212, 85, - 141, 204, 196, 69, 6, 78, 79, 199, 255, 183, 182, 62, 189, 252, 244, 117, 173, 236, 228, - 101, 38, 110, 111, 231, 159, 222, 214, 87, 207, 135, 134, 14, 205, 133, 132, 12, 84, 21, - 29, 156, 155, 218, 210, 83, 203, 131, 130, 10, 201, 129, 128, 8, 80, 17, 25, 152, 32, 104, - 105, 225, 98, 35, 43, 170, 114, 51, 59, 186, 249, 177, 176, 56, 191, 254, 246, 119, 239, - 167, 166, 46, 237, 165, 164, 44, 116, 53, 61, 188, 251, 179, 178, 58, 185, 248, 240, 113, - 169, 232, 224, 97, 34, 106, 107, 227, 219, 147, 146, 26, 153, 216, 208, 81, 137, 200, 192, - 65, 2, 74, 75, 195, 68, 5, 13, 140, 20, 92, 93, 213, 22, 94, 95, 215, 143, 206, 198, 71 + 64, 1, 9, 136, 16, 88, 89, 209, 18, 90, 91, 211, 139, 202, 194, 67, + 4, 76, 77, 197, 70, 7, 15, 142, 86, 23, 31, 158, 221, 149, 148, 28, + 36, 108, 109, 229, 102, 39, 47, 174, 118, 55, 63, 190, 253, 181, 180, 60, + 187, 250, 242, 115, 235, 163, 162, 42, 233, 161, 160, 40, 112, 49, 57, 184, + 0, 72, 73, 193, 66, 3, 11, 138, 82, 19, 27, 154, 217, 145, 144, 24, + 96, 33, 41, 168, 48, 120, 121, 241, 50, 122, 123, 243, 171, 234, 226, 99, + 100, 37, 45, 172, 52, 124, 125, 245, 54, 126, 127, 247, 175, 238, 230, 103, + 223, 151, 150, 30, 157, 220, 212, 85, 141, 204, 196, 69, 6, 78, 79, 199, + 255, 183, 182, 62, 189, 252, 244, 117, 173, 236, 228, 101, 38, 110, 111, 231, + 159, 222, 214, 87, 207, 135, 134, 14, 205, 133, 132, 12, 84, 21, 29, 156, + 155, 218, 210, 83, 203, 131, 130, 10, 201, 129, 128, 8, 80, 17, 25, 152, + 32, 104, 105, 225, 98, 35, 43, 170, 114, 51, 59, 186, 249, 177, 176, 56, + 191, 254, 246, 119, 239, 167, 166, 46, 237, 165, 164, 44, 116, 53, 61, 188, + 251, 179, 178, 58, 185, 248, 240, 113, 169, 232, 224, 97, 34, 106, 107, 227, + 219, 147, 146, 26, 153, 216, 208, 81, 137, 200, 192, 65, 2, 74, 75, 195, + 68, 5, 13, 140, 20, 92, 93, 213, 22, 94, 95, 215, 143, 206, 198, 71 }; }; @@ -107,26 +126,28 @@ private: static std::pair getInitialShiftAndState(UInt8 used_bits) { - const UInt8 hilbert_shift = getHilbertShift(bit_step); - UInt8 iterations = used_bits / hilbert_shift; - Int8 initial_shift = iterations * hilbert_shift; + UInt8 iterations = used_bits / HILBERT_SHIFT; + Int8 initial_shift = iterations * HILBERT_SHIFT; if (initial_shift < used_bits) { ++iterations; } else { - initial_shift -= hilbert_shift; + initial_shift -= HILBERT_SHIFT; } - UInt8 state = iterations % 2 == 0 ? 0b01 << hilbert_shift : 0; + UInt8 state = iterations % 2 == 0 ? LEFT_STATE : DEFAULT_STATE; return {initial_shift, state}; } constexpr static UInt8 STEP_MASK = (1 << bit_step) - 1; - constexpr static UInt8 HILBERT_MASK = (1 << getHilbertShift(bit_step)) - 1; - constexpr static UInt8 STATE_MASK = 0b11 << getHilbertShift(bit_step); + constexpr static UInt8 HILBERT_SHIFT = getHilbertShift(bit_step); + constexpr static UInt8 HILBERT_MASK = (1 << HILBERT_SHIFT) - 1; + constexpr static UInt8 STATE_MASK = 0b11 << HILBERT_SHIFT; constexpr static UInt8 Y_MASK = STEP_MASK; constexpr static UInt8 X_MASK = STEP_MASK << bit_step; + constexpr static UInt8 LEFT_STATE = 0b01 << HILBERT_SHIFT; + constexpr static UInt8 DEFAULT_STATE = bit_step % 2 == 0 ? LEFT_STATE : 0; }; diff --git a/src/Functions/hilbertEncode.h b/src/Functions/hilbertEncode.h index 7dc7ec8fdf2..825065b34d3 100644 --- a/src/Functions/hilbertEncode.h +++ b/src/Functions/hilbertEncode.h @@ -7,6 +7,7 @@ #include #include #include +#include namespace DB @@ -40,24 +41,44 @@ public: }; }; +template <> +class HilbertEncodeLookupTable<2> +{ +public: + constexpr static UInt8 LOOKUP_TABLE[64] = { + 0, 51, 20, 5, 17, 18, 39, 6, + 46, 45, 24, 9, 15, 60, 43, 10, + 16, 1, 62, 31, 35, 2, 61, 44, + 4, 55, 8, 59, 21, 22, 25, 26, + 42, 41, 38, 37, 11, 56, 7, 52, + 28, 13, 50, 19, 47, 14, 49, 32, + 58, 27, 12, 63, 57, 40, 29, 30, + 54, 23, 34, 33, 53, 36, 3, 48 + }; +}; + + template <> class HilbertEncodeLookupTable<3> { public: constexpr static UInt8 LOOKUP_TABLE[256] = { - 64, 1, 206, 79, 16, 211, 84, 21, 131, 2, 205, 140, 81, 82, 151, 22, 4, 199, 8, 203, 158, - 157, 88, 25, 69, 70, 73, 74, 31, 220, 155, 26, 186, 185, 182, 181, 32, 227, 100, 37, 59, - 248, 55, 244, 97, 98, 167, 38, 124, 61, 242, 115, 174, 173, 104, 41, 191, 62, 241, 176, 47, - 236, 171, 42, 0, 195, 68, 5, 250, 123, 60, 255, 65, 66, 135, 6, 249, 184, 125, 126, 142, - 141, 72, 9, 246, 119, 178, 177, 15, 204, 139, 10, 245, 180, 51, 240, 80, 17, 222, 95, 96, - 33, 238, 111, 147, 18, 221, 156, 163, 34, 237, 172, 20, 215, 24, 219, 36, 231, 40, 235, 85, - 86, 89, 90, 101, 102, 105, 106, 170, 169, 166, 165, 154, 153, 150, 149, 43, 232, 39, 228, - 27, 216, 23, 212, 108, 45, 226, 99, 92, 29, 210, 83, 175, 46, 225, 160, 159, 30, 209, 144, - 48, 243, 116, 53, 202, 75, 12, 207, 113, 114, 183, 54, 201, 136, 77, 78, 190, 189, 120, 57, - 198, 71, 130, 129, 63, 252, 187, 58, 197, 132, 3, 192, 234, 107, 44, 239, 112, 49, 254, - 127, 233, 168, 109, 110, 179, 50, 253, 188, 230, 103, 162, 161, 52, 247, 56, 251, 229, 164, - 35, 224, 117, 118, 121, 122, 218, 91, 28, 223, 138, 137, 134, 133, 217, 152, 93, 94, 11, - 200, 7, 196, 214, 87, 146, 145, 76, 13, 194, 67, 213, 148, 19, 208, 143, 14, 193, 128, + 64, 1, 206, 79, 16, 211, 84, 21, 131, 2, 205, 140, 81, 82, 151, 22, 4, + 199, 8, 203, 158, 157, 88, 25, 69, 70, 73, 74, 31, 220, 155, 26, 186, + 185, 182, 181, 32, 227, 100, 37, 59, 248, 55, 244, 97, 98, 167, 38, 124, + 61, 242, 115, 174, 173, 104, 41, 191, 62, 241, 176, 47, 236, 171, 42, 0, + 195, 68, 5, 250, 123, 60, 255, 65, 66, 135, 6, 249, 184, 125, 126, 142, + 141, 72, 9, 246, 119, 178, 177, 15, 204, 139, 10, 245, 180, 51, 240, 80, + 17, 222, 95, 96, 33, 238, 111, 147, 18, 221, 156, 163, 34, 237, 172, 20, + 215, 24, 219, 36, 231, 40, 235, 85, 86, 89, 90, 101, 102, 105, 106, 170, + 169, 166, 165, 154, 153, 150, 149, 43, 232, 39, 228, 27, 216, 23, 212, 108, + 45, 226, 99, 92, 29, 210, 83, 175, 46, 225, 160, 159, 30, 209, 144, 48, + 243, 116, 53, 202, 75, 12, 207, 113, 114, 183, 54, 201, 136, 77, 78, 190, + 189, 120, 57, 198, 71, 130, 129, 63, 252, 187, 58, 197, 132, 3, 192, 234, + 107, 44, 239, 112, 49, 254, 127, 233, 168, 109, 110, 179, 50, 253, 188, 230, + 103, 162, 161, 52, 247, 56, 251, 229, 164, 35, 224, 117, 118, 121, 122, 218, + 91, 28, 223, 138, 137, 134, 133, 217, 152, 93, 94, 11, 200, 7, 196, 214, + 87, 146, 145, 76, 13, 194, 67, 213, 148, 19, 208, 143, 14, 193, 128, }; }; @@ -70,23 +91,39 @@ class FunctionHilbertEncode2DWIthLookupTableImpl static_assert(bit_step <= 3, "bit_step should not be more than 3 to fit in UInt8"); public: static UInt64 encode(UInt64 x, UInt64 y) + { + return encodeImpl(x, y, std::nullopt).hilbert_code; + } + + + struct EncodeResult { UInt64 hilbert_code = 0; + UInt64 state = 0; + }; + + static EncodeResult encodeImpl(UInt64 x, UInt64 y, std::optional start_state) + { + EncodeResult encode_result; const auto leading_zeros_count = getLeadingZeroBits(x | y); const auto used_bits = std::numeric_limits::digits - leading_zeros_count; auto [current_shift, state] = getInitialShiftAndState(used_bits); + if (start_state.has_value()) { + state = *start_state; + } while (current_shift >= 0) { const UInt8 x_bits = (x >> current_shift) & STEP_MASK; const UInt8 y_bits = (y >> current_shift) & STEP_MASK; const auto hilbert_bits = getCodeAndUpdateState(x_bits, y_bits, state); - hilbert_code |= (hilbert_bits << getHilbertShift(current_shift)); + encode_result.hilbert_code |= (hilbert_bits << getHilbertShift(current_shift)); current_shift -= bit_step; } - return hilbert_code; + encode_result.state = state; + return encode_result; } private: @@ -120,13 +157,16 @@ private: { initial_shift -= bit_step; } - UInt8 state = iterations % 2 == 0 ? 0b01 << getHilbertShift(bit_step) : 0; + UInt8 state = iterations % 2 == 0 ? LEFT_STATE : DEFAULT_STATE; return {initial_shift, state}; } constexpr static UInt8 STEP_MASK = (1 << bit_step) - 1; - constexpr static UInt8 HILBERT_MASK = (1 << getHilbertShift(bit_step)) - 1; - constexpr static UInt8 STATE_MASK = 0b11 << getHilbertShift(bit_step); + constexpr static UInt8 HILBERT_SHIFT = getHilbertShift(bit_step); + constexpr static UInt8 HILBERT_MASK = (1 << HILBERT_SHIFT) - 1; + constexpr static UInt8 STATE_MASK = 0b11 << HILBERT_SHIFT; + constexpr static UInt8 LEFT_STATE = 0b01 << HILBERT_SHIFT; + constexpr static UInt8 DEFAULT_STATE = bit_step % 2 == 0 ? LEFT_STATE : 0; }; diff --git a/src/Functions/tests/gtest_hilbert_curve.cpp b/src/Functions/tests/gtest_hilbert_curve.cpp index 108ab6a6ccf..716a8663c9a 100644 --- a/src/Functions/tests/gtest_hilbert_curve.cpp +++ b/src/Functions/tests/gtest_hilbert_curve.cpp @@ -1,9 +1,10 @@ #include #include #include +#include "base/types.h" -TEST(HilbertLookupTable, EncodeBit1And3Consistnecy) +TEST(HilbertLookupTable, EncodeBit1And3Consistency) { const size_t bound = 1000; for (size_t x = 0; x < bound; ++x) @@ -17,7 +18,21 @@ TEST(HilbertLookupTable, EncodeBit1And3Consistnecy) } } -TEST(HilbertLookupTable, DecodeBit1And3Consistnecy) +TEST(HilbertLookupTable, EncodeBit2And3Consistency) +{ + const size_t bound = 1000; + for (size_t x = 0; x < bound; ++x) + { + for (size_t y = 0; y < bound; ++y) + { + auto hilbert2bit = DB::FunctionHilbertEncode2DWIthLookupTableImpl<2>::encode(x, y); + auto hilbert3bit = DB::FunctionHilbertEncode2DWIthLookupTableImpl<3>::encode(x, y); + ASSERT_EQ(hilbert3bit, hilbert2bit); + } + } +} + +TEST(HilbertLookupTable, DecodeBit1And3Consistency) { const size_t bound = 1000 * 1000; for (size_t hilbert_code = 0; hilbert_code < bound; ++hilbert_code) @@ -27,3 +42,40 @@ TEST(HilbertLookupTable, DecodeBit1And3Consistnecy) ASSERT_EQ(res1, res3); } } + +TEST(HilbertLookupTable, DecodeBit2And3Consistency) +{ + const size_t bound = 1000 * 1000; + for (size_t hilbert_code = 0; hilbert_code < bound; ++hilbert_code) + { + auto res2 = DB::FunctionHilbertDecode2DWIthLookupTableImpl<2>::decode(hilbert_code); + auto res3 = DB::FunctionHilbertDecode2DWIthLookupTableImpl<3>::decode(hilbert_code); + ASSERT_EQ(res2, res3); + } +} + +TEST(HilbertLookupTable, DecodeAndEncodeAreInverseOperations) +{ + const size_t bound = 1000; + for (size_t x = 0; x < bound; ++x) + { + for (size_t y = 0; y < bound; ++y) + { + auto hilbert_code = DB::FunctionHilbertEncode2DWIthLookupTableImpl<3>::encode(x, y); + auto [x_new, y_new] = DB::FunctionHilbertDecode2DWIthLookupTableImpl<3>::decode(hilbert_code); + ASSERT_EQ(x_new, x); + ASSERT_EQ(y_new, y); + } + } +} + +TEST(HilbertLookupTable, EncodeAndDecodeAreInverseOperations) +{ + const size_t bound = 1000 * 1000; + for (size_t hilbert_code = 0; hilbert_code < bound; ++hilbert_code) + { + auto [x, y] = DB::FunctionHilbertDecode2DWIthLookupTableImpl<3>::decode(hilbert_code); + auto hilbert_new = DB::FunctionHilbertEncode2DWIthLookupTableImpl<3>::encode(x, y); + ASSERT_EQ(hilbert_new, hilbert_code); + } +} From 3d1074cc4003863abe2839e957c75383d67e2836 Mon Sep 17 00:00:00 2001 From: Yarik Briukhovetskyi <114298166+yariks5s@users.noreply.github.com> Date: Wed, 24 Apr 2024 14:26:29 +0200 Subject: [PATCH 14/14] fix style --- src/Functions/hilbertEncode.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Functions/hilbertEncode.h b/src/Functions/hilbertEncode.h index 825065b34d3..2eabf666d49 100644 --- a/src/Functions/hilbertEncode.h +++ b/src/Functions/hilbertEncode.h @@ -109,7 +109,8 @@ public: const auto used_bits = std::numeric_limits::digits - leading_zeros_count; auto [current_shift, state] = getInitialShiftAndState(used_bits); - if (start_state.has_value()) { + if (start_state.has_value()) + { state = *start_state; }