diff --git a/src/Analyzer/Passes/GroupingFunctionsResolvePass.cpp b/src/Analyzer/Passes/GroupingFunctionsResolvePass.cpp index fc122730b37..d7ccc53041b 100644 --- a/src/Analyzer/Passes/GroupingFunctionsResolvePass.cpp +++ b/src/Analyzer/Passes/GroupingFunctionsResolvePass.cpp @@ -4,6 +4,7 @@ #include #include +#include #include diff --git a/src/Analyzer/Resolve/QueryAnalyzer.cpp b/src/Analyzer/Resolve/QueryAnalyzer.cpp index 03ebd893c47..d118cb281ae 100644 --- a/src/Analyzer/Resolve/QueryAnalyzer.cpp +++ b/src/Analyzer/Resolve/QueryAnalyzer.cpp @@ -13,6 +13,7 @@ #include #include +#include #include #include #include diff --git a/src/Columns/ColumnDecimal.cpp b/src/Columns/ColumnDecimal.cpp index bb4433f8956..73366150e7d 100644 --- a/src/Columns/ColumnDecimal.cpp +++ b/src/Columns/ColumnDecimal.cpp @@ -8,6 +8,10 @@ #include #include +#include +#include + +#include #include #include @@ -30,6 +34,19 @@ namespace ErrorCodes extern const int NOT_IMPLEMENTED; } +template +const char * ColumnDecimal::getFamilyName() const +{ + return TypeName.data(); +} + +template +TypeIndex ColumnDecimal::getDataType() const +{ + return TypeToTypeIndex; +} + + template #if !defined(DEBUG_OR_SANITIZER_BUILD) int ColumnDecimal::compareAt(size_t n, size_t m, const IColumn & rhs_, int) const @@ -46,6 +63,12 @@ int ColumnDecimal::doCompareAt(size_t n, size_t m, const IColumn & rhs_, int) return decimalLess(b, a, other.scale, scale) ? 1 : (decimalLess(a, b, scale, other.scale) ? -1 : 0); } +template +Float64 ColumnDecimal::getFloat64(size_t n) const +{ + return DecimalUtils::convertTo(data[n], scale); +} + template const char * ColumnDecimal::deserializeAndInsertFromArena(const char * pos) { diff --git a/src/Columns/ColumnDecimal.h b/src/Columns/ColumnDecimal.h index 6f8360a54dd..690549e4a56 100644 --- a/src/Columns/ColumnDecimal.h +++ b/src/Columns/ColumnDecimal.h @@ -1,14 +1,9 @@ #pragma once -#include -#include -#include -#include -#include -#include #include #include #include +#include namespace DB @@ -39,8 +34,8 @@ private: {} public: - const char * getFamilyName() const override { return TypeName.data(); } - TypeIndex getDataType() const override { return TypeToTypeIndex; } + const char * getFamilyName() const override; + TypeIndex getDataType() const override; bool isNumeric() const override { return false; } bool canBeInsideNullable() const override { return true; } @@ -98,7 +93,7 @@ public: return StringRef(reinterpret_cast(&data[n]), sizeof(data[n])); } - Float64 getFloat64(size_t n) const final { return DecimalUtils::convertTo(data[n], scale); } + Float64 getFloat64(size_t n) const final; const char * deserializeAndInsertFromArena(const char * pos) override; const char * skipSerializedInArena(const char * pos) const override; diff --git a/src/Columns/ColumnFunction.cpp b/src/Columns/ColumnFunction.cpp index cc80d04444e..5e41e95fdc5 100644 --- a/src/Columns/ColumnFunction.cpp +++ b/src/Columns/ColumnFunction.cpp @@ -347,7 +347,7 @@ ColumnWithTypeAndName ColumnFunction::reduce() const if (is_function_compiled) ProfileEvents::increment(ProfileEvents::CompiledFunctionExecute); - res.column = function->execute(columns, res.type, elements_size); + res.column = function->execute(columns, res.type, elements_size, /* dry_run = */ false); if (res.column->getDataType() != res.type->getColumnType()) throw Exception( ErrorCodes::LOGICAL_ERROR, diff --git a/src/Columns/ColumnVector.cpp b/src/Columns/ColumnVector.cpp index 84fc6ebc61d..3c7727f37c4 100644 --- a/src/Columns/ColumnVector.cpp +++ b/src/Columns/ColumnVector.cpp @@ -32,6 +32,8 @@ # include #endif +#include "config.h" + #if USE_MULTITARGET_CODE # include #endif @@ -658,7 +660,7 @@ inline void doFilterAligned(const UInt8 *& filt_pos, const UInt8 *& filt_end_ali reinterpret_cast(&res_data[current_offset]), mask & KMASK); current_offset += std::popcount(mask & KMASK); /// prepare mask for next iter, if ELEMENTS_PER_VEC = 64, no next iter - if (ELEMENTS_PER_VEC < 64) + if constexpr (ELEMENTS_PER_VEC < 64) { mask >>= ELEMENTS_PER_VEC; } @@ -992,6 +994,151 @@ ColumnPtr ColumnVector::createWithOffsets(const IColumn::Offsets & offsets, c return res; } +DECLARE_DEFAULT_CODE( + template void vectorIndexImpl( + const Container & data, const PaddedPODArray & indexes, size_t limit, Container & res_data) + { + for (size_t i = 0; i < limit; ++i) + res_data[i] = data[indexes[i]]; + } +); + +DECLARE_AVX512VBMI_SPECIFIC_CODE( + template + void vectorIndexImpl(const Container & data, const PaddedPODArray & indexes, size_t limit, Container & res_data) + { + static constexpr UInt64 MASK64 = 0xffffffffffffffff; + const size_t limit64 = limit & ~63; + size_t pos = 0; + size_t data_size = data.size(); + + auto data_pos = reinterpret_cast(data.data()); + auto indexes_pos = reinterpret_cast(indexes.data()); + auto res_pos = reinterpret_cast(res_data.data()); + + if (limit == 0) + return; /// nothing to do, just return + + if (data_size <= 64) + { + /// one single mask load for table size <= 64 + __mmask64 last_mask = MASK64 >> (64 - data_size); + __m512i table1 = _mm512_maskz_loadu_epi8(last_mask, data_pos); + + /// 64 bytes table lookup using one single permutexvar_epi8 + while (pos < limit64) + { + __m512i vidx = _mm512_loadu_epi8(indexes_pos + pos); + __m512i out = _mm512_permutexvar_epi8(vidx, table1); + _mm512_storeu_epi8(res_pos + pos, out); + pos += 64; + } + /// tail handling + if (limit > limit64) + { + __mmask64 tail_mask = MASK64 >> (limit64 + 64 - limit); + __m512i vidx = _mm512_maskz_loadu_epi8(tail_mask, indexes_pos + pos); + __m512i out = _mm512_permutexvar_epi8(vidx, table1); + _mm512_mask_storeu_epi8(res_pos + pos, tail_mask, out); + } + } + else if (data_size <= 128) + { + /// table size (64, 128] requires 2 zmm load + __mmask64 last_mask = MASK64 >> (128 - data_size); + __m512i table1 = _mm512_loadu_epi8(data_pos); + __m512i table2 = _mm512_maskz_loadu_epi8(last_mask, data_pos + 64); + + /// 128 bytes table lookup using one single permute2xvar_epi8 + while (pos < limit64) + { + __m512i vidx = _mm512_loadu_epi8(indexes_pos + pos); + __m512i out = _mm512_permutex2var_epi8(table1, vidx, table2); + _mm512_storeu_epi8(res_pos + pos, out); + pos += 64; + } + if (limit > limit64) + { + __mmask64 tail_mask = MASK64 >> (limit64 + 64 - limit); + __m512i vidx = _mm512_maskz_loadu_epi8(tail_mask, indexes_pos + pos); + __m512i out = _mm512_permutex2var_epi8(table1, vidx, table2); + _mm512_mask_storeu_epi8(res_pos + pos, tail_mask, out); + } + } + else + { + if (data_size > 256) + { + /// byte index will not exceed 256 boundary. + data_size = 256; + } + + __m512i table1 = _mm512_loadu_epi8(data_pos); + __m512i table2 = _mm512_loadu_epi8(data_pos + 64); + __m512i table3, table4; + if (data_size <= 192) + { + /// only 3 tables need to load if size <= 192 + __mmask64 last_mask = MASK64 >> (192 - data_size); + table3 = _mm512_maskz_loadu_epi8(last_mask, data_pos + 128); + table4 = _mm512_setzero_si512(); + } + else + { + __mmask64 last_mask = MASK64 >> (256 - data_size); + table3 = _mm512_loadu_epi8(data_pos + 128); + table4 = _mm512_maskz_loadu_epi8(last_mask, data_pos + 192); + } + + /// 256 bytes table lookup can use: 2 permute2xvar_epi8 plus 1 blender with MSB + while (pos < limit64) + { + __m512i vidx = _mm512_loadu_epi8(indexes_pos + pos); + __m512i tmp1 = _mm512_permutex2var_epi8(table1, vidx, table2); + __m512i tmp2 = _mm512_permutex2var_epi8(table3, vidx, table4); + __mmask64 msb = _mm512_movepi8_mask(vidx); + __m512i out = _mm512_mask_blend_epi8(msb, tmp1, tmp2); + _mm512_storeu_epi8(res_pos + pos, out); + pos += 64; + } + if (limit > limit64) + { + __mmask64 tail_mask = MASK64 >> (limit64 + 64 - limit); + __m512i vidx = _mm512_maskz_loadu_epi8(tail_mask, indexes_pos + pos); + __m512i tmp1 = _mm512_permutex2var_epi8(table1, vidx, table2); + __m512i tmp2 = _mm512_permutex2var_epi8(table3, vidx, table4); + __mmask64 msb = _mm512_movepi8_mask(vidx); + __m512i out = _mm512_mask_blend_epi8(msb, tmp1, tmp2); + _mm512_mask_storeu_epi8(res_pos + pos, tail_mask, out); + } + } + } +); + +template +template +ColumnPtr ColumnVector::indexImpl(const PaddedPODArray & indexes, size_t limit) const +{ + chassert(limit <= indexes.size()); + + auto res = this->create(limit); + typename Self::Container & res_data = res->getData(); +#if USE_MULTITARGET_CODE + if constexpr (sizeof(T) == 1 && sizeof(Type) == 1) + { + /// VBMI optimization only applicable for (U)Int8 types + if (isArchSupported(TargetArch::AVX512VBMI)) + { + TargetSpecific::AVX512VBMI::vectorIndexImpl(data, indexes, limit, res_data); + return res; + } + } +#endif + TargetSpecific::Default::vectorIndexImpl(data, indexes, limit, res_data); + + return res; +} + /// Explicit template instantiations - to avoid code bloat in headers. template class ColumnVector; template class ColumnVector; @@ -1012,4 +1159,17 @@ template class ColumnVector; template class ColumnVector; template class ColumnVector; +INSTANTIATE_INDEX_TEMPLATE_IMPL(ColumnVector) +/// Used by ColumnVariant.cpp +template ColumnPtr ColumnVector::indexImpl(const PaddedPODArray & indexes, size_t limit) const; +template ColumnPtr ColumnVector::indexImpl(const PaddedPODArray & indexes, size_t limit) const; +template ColumnPtr ColumnVector::indexImpl(const PaddedPODArray & indexes, size_t limit) const; +template ColumnPtr ColumnVector::indexImpl(const PaddedPODArray & indexes, size_t limit) const; +template ColumnPtr ColumnVector::indexImpl(const PaddedPODArray & indexes, size_t limit) const; +template ColumnPtr ColumnVector::indexImpl(const PaddedPODArray & indexes, size_t limit) const; + +#if defined(OS_DARWIN) +template ColumnPtr ColumnVector::indexImpl(const PaddedPODArray & indexes, size_t limit) const; +template ColumnPtr ColumnVector::indexImpl(const PaddedPODArray & indexes, size_t limit) const; +#endif } diff --git a/src/Columns/ColumnVector.h b/src/Columns/ColumnVector.h index e8bb6ad6798..1387cca1ece 100644 --- a/src/Columns/ColumnVector.h +++ b/src/Columns/ColumnVector.h @@ -13,10 +13,6 @@ #include "config.h" -#if USE_MULTITARGET_CODE -# include -#endif - namespace DB { @@ -320,151 +316,6 @@ protected: Container data; }; -DECLARE_DEFAULT_CODE( -template -inline void vectorIndexImpl(const Container & data, const PaddedPODArray & indexes, size_t limit, Container & res_data) -{ - for (size_t i = 0; i < limit; ++i) - res_data[i] = data[indexes[i]]; -} -); - -DECLARE_AVX512VBMI_SPECIFIC_CODE( -template -inline void vectorIndexImpl(const Container & data, const PaddedPODArray & indexes, size_t limit, Container & res_data) -{ - static constexpr UInt64 MASK64 = 0xffffffffffffffff; - const size_t limit64 = limit & ~63; - size_t pos = 0; - size_t data_size = data.size(); - - auto data_pos = reinterpret_cast(data.data()); - auto indexes_pos = reinterpret_cast(indexes.data()); - auto res_pos = reinterpret_cast(res_data.data()); - - if (limit == 0) - return; /// nothing to do, just return - - if (data_size <= 64) - { - /// one single mask load for table size <= 64 - __mmask64 last_mask = MASK64 >> (64 - data_size); - __m512i table1 = _mm512_maskz_loadu_epi8(last_mask, data_pos); - - /// 64 bytes table lookup using one single permutexvar_epi8 - while (pos < limit64) - { - __m512i vidx = _mm512_loadu_epi8(indexes_pos + pos); - __m512i out = _mm512_permutexvar_epi8(vidx, table1); - _mm512_storeu_epi8(res_pos + pos, out); - pos += 64; - } - /// tail handling - if (limit > limit64) - { - __mmask64 tail_mask = MASK64 >> (limit64 + 64 - limit); - __m512i vidx = _mm512_maskz_loadu_epi8(tail_mask, indexes_pos + pos); - __m512i out = _mm512_permutexvar_epi8(vidx, table1); - _mm512_mask_storeu_epi8(res_pos + pos, tail_mask, out); - } - } - else if (data_size <= 128) - { - /// table size (64, 128] requires 2 zmm load - __mmask64 last_mask = MASK64 >> (128 - data_size); - __m512i table1 = _mm512_loadu_epi8(data_pos); - __m512i table2 = _mm512_maskz_loadu_epi8(last_mask, data_pos + 64); - - /// 128 bytes table lookup using one single permute2xvar_epi8 - while (pos < limit64) - { - __m512i vidx = _mm512_loadu_epi8(indexes_pos + pos); - __m512i out = _mm512_permutex2var_epi8(table1, vidx, table2); - _mm512_storeu_epi8(res_pos + pos, out); - pos += 64; - } - if (limit > limit64) - { - __mmask64 tail_mask = MASK64 >> (limit64 + 64 - limit); - __m512i vidx = _mm512_maskz_loadu_epi8(tail_mask, indexes_pos + pos); - __m512i out = _mm512_permutex2var_epi8(table1, vidx, table2); - _mm512_mask_storeu_epi8(res_pos + pos, tail_mask, out); - } - } - else - { - if (data_size > 256) - { - /// byte index will not exceed 256 boundary. - data_size = 256; - } - - __m512i table1 = _mm512_loadu_epi8(data_pos); - __m512i table2 = _mm512_loadu_epi8(data_pos + 64); - __m512i table3, table4; - if (data_size <= 192) - { - /// only 3 tables need to load if size <= 192 - __mmask64 last_mask = MASK64 >> (192 - data_size); - table3 = _mm512_maskz_loadu_epi8(last_mask, data_pos + 128); - table4 = _mm512_setzero_si512(); - } - else - { - __mmask64 last_mask = MASK64 >> (256 - data_size); - table3 = _mm512_loadu_epi8(data_pos + 128); - table4 = _mm512_maskz_loadu_epi8(last_mask, data_pos + 192); - } - - /// 256 bytes table lookup can use: 2 permute2xvar_epi8 plus 1 blender with MSB - while (pos < limit64) - { - __m512i vidx = _mm512_loadu_epi8(indexes_pos + pos); - __m512i tmp1 = _mm512_permutex2var_epi8(table1, vidx, table2); - __m512i tmp2 = _mm512_permutex2var_epi8(table3, vidx, table4); - __mmask64 msb = _mm512_movepi8_mask(vidx); - __m512i out = _mm512_mask_blend_epi8(msb, tmp1, tmp2); - _mm512_storeu_epi8(res_pos + pos, out); - pos += 64; - } - if (limit > limit64) - { - __mmask64 tail_mask = MASK64 >> (limit64 + 64 - limit); - __m512i vidx = _mm512_maskz_loadu_epi8(tail_mask, indexes_pos + pos); - __m512i tmp1 = _mm512_permutex2var_epi8(table1, vidx, table2); - __m512i tmp2 = _mm512_permutex2var_epi8(table3, vidx, table4); - __mmask64 msb = _mm512_movepi8_mask(vidx); - __m512i out = _mm512_mask_blend_epi8(msb, tmp1, tmp2); - _mm512_mask_storeu_epi8(res_pos + pos, tail_mask, out); - } - } -} -); - -template -template -ColumnPtr ColumnVector::indexImpl(const PaddedPODArray & indexes, size_t limit) const -{ - assert(limit <= indexes.size()); - - auto res = this->create(limit); - typename Self::Container & res_data = res->getData(); -#if USE_MULTITARGET_CODE - if constexpr (sizeof(T) == 1 && sizeof(Type) == 1) - { - /// VBMI optimization only applicable for (U)Int8 types - if (isArchSupported(TargetArch::AVX512VBMI)) - { - TargetSpecific::AVX512VBMI::vectorIndexImpl(data, indexes, limit, res_data); - return res; - } - } -#endif - TargetSpecific::Default::vectorIndexImpl(data, indexes, limit, res_data); - - return res; -} - template concept is_col_vector = std::is_same_v>; diff --git a/src/Columns/ColumnsCommon.h b/src/Columns/ColumnsCommon.h index 99f1d2da47e..f0d6cff2e35 100644 --- a/src/Columns/ColumnsCommon.h +++ b/src/Columns/ColumnsCommon.h @@ -142,4 +142,10 @@ ColumnPtr permuteImpl(const Column & column, const IColumn::Permutation & perm, template ColumnPtr Column::indexImpl(const PaddedPODArray & indexes, size_t limit) const; \ template ColumnPtr Column::indexImpl(const PaddedPODArray & indexes, size_t limit) const; \ template ColumnPtr Column::indexImpl(const PaddedPODArray & indexes, size_t limit) const; + +#define INSTANTIATE_INDEX_TEMPLATE_IMPL(ColumnTemplate) \ + template ColumnPtr ColumnTemplate::indexImpl(const PaddedPODArray & indexes, size_t limit) const; \ + template ColumnPtr ColumnTemplate::indexImpl(const PaddedPODArray & indexes, size_t limit) const; \ + template ColumnPtr ColumnTemplate::indexImpl(const PaddedPODArray & indexes, size_t limit) const; \ + template ColumnPtr ColumnTemplate::indexImpl(const PaddedPODArray & indexes, size_t limit) const; } diff --git a/src/Core/DecimalComparison.h b/src/Core/DecimalComparison.h index 77402adf164..4b6783265d0 100644 --- a/src/Core/DecimalComparison.h +++ b/src/Core/DecimalComparison.h @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -52,8 +53,8 @@ struct DecCompareInt using TypeB = Type; }; -template typename Operation, bool _check_overflow = true, - bool _actual = is_decimal || is_decimal> +template typename Operation> +requires is_decimal || is_decimal class DecimalComparison { public: @@ -65,20 +66,17 @@ public: using ArrayA = typename ColVecA::Container; using ArrayB = typename ColVecB::Container; - static ColumnPtr apply(const ColumnWithTypeAndName & col_left, const ColumnWithTypeAndName & col_right) + static ColumnPtr apply(const ColumnWithTypeAndName & col_left, const ColumnWithTypeAndName & col_right, bool check_overflow) { - if constexpr (_actual) - { - ColumnPtr c_res; - Shift shift = getScales(col_left.type, col_right.type); + ColumnPtr c_res; + Shift shift = getScales(col_left.type, col_right.type); - return applyWithScale(col_left.column, col_right.column, shift); - } - else - return nullptr; + if (check_overflow) + return applyWithScale(col_left.column, col_right.column, shift); + return applyWithScale(col_left.column, col_right.column, shift); } - static bool compare(A a, B b, UInt32 scale_a, UInt32 scale_b) + static bool compare(A a, B b, UInt32 scale_a, UInt32 scale_b, bool check_overflow) { static const UInt32 max_scale = DecimalUtils::max_precision; if (scale_a > max_scale || scale_b > max_scale) @@ -90,7 +88,9 @@ public: if (scale_a > scale_b) shift.b = static_cast(DecimalUtils::scaleMultiplier(scale_a - scale_b)); - return applyWithScale(a, b, shift); + if (check_overflow) + return applyWithScale(a, b, shift); + return applyWithScale(a, b, shift); } private: @@ -104,14 +104,14 @@ private: bool right() const { return b != 1; } }; - template + template static auto applyWithScale(T a, U b, const Shift & shift) { if (shift.left()) - return apply(a, b, shift.a); + return apply(a, b, shift.a); if (shift.right()) - return apply(a, b, shift.b); - return apply(a, b, 1); + return apply(a, b, shift.b); + return apply(a, b, 1); } template @@ -125,8 +125,8 @@ private: if (decimal0 && decimal1) { auto result_type = DecimalUtils::binaryOpResult(*decimal0, *decimal1); - shift.a = static_cast(result_type.scaleFactorFor(decimal0->getTrait(), false).value); - shift.b = static_cast(result_type.scaleFactorFor(decimal1->getTrait(), false).value); + shift.a = static_cast(result_type.scaleFactorFor(DecimalUtils::DataTypeDecimalTrait{decimal0->getPrecision(), decimal0->getScale()}, false).value); + shift.b = static_cast(result_type.scaleFactorFor(DecimalUtils::DataTypeDecimalTrait{decimal1->getPrecision(), decimal1->getScale()}, false).value); } else if (decimal0) shift.b = static_cast(decimal0->getScaleMultiplier().value); @@ -158,66 +158,63 @@ private: return shift; } - template + template static ColumnPtr apply(const ColumnPtr & c0, const ColumnPtr & c1, CompareInt scale) { auto c_res = ColumnUInt8::create(); - if constexpr (_actual) + bool c0_is_const = isColumnConst(*c0); + bool c1_is_const = isColumnConst(*c1); + + if (c0_is_const && c1_is_const) { - bool c0_is_const = isColumnConst(*c0); - bool c1_is_const = isColumnConst(*c1); + const ColumnConst & c0_const = checkAndGetColumnConst(*c0); + const ColumnConst & c1_const = checkAndGetColumnConst(*c1); - if (c0_is_const && c1_is_const) + A a = c0_const.template getValue(); + B b = c1_const.template getValue(); + UInt8 res = apply(a, b, scale); + return DataTypeUInt8().createColumnConst(c0->size(), toField(res)); + } + + ColumnUInt8::Container & vec_res = c_res->getData(); + vec_res.resize(c0->size()); + + if (c0_is_const) + { + const ColumnConst & c0_const = checkAndGetColumnConst(*c0); + A a = c0_const.template getValue(); + if (const ColVecB * c1_vec = checkAndGetColumn(c1.get())) + constantVector(a, c1_vec->getData(), vec_res, scale); + else + throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong column in Decimal comparison"); + } + else if (c1_is_const) + { + const ColumnConst & c1_const = checkAndGetColumnConst(*c1); + B b = c1_const.template getValue(); + if (const ColVecA * c0_vec = checkAndGetColumn(c0.get())) + vectorConstant(c0_vec->getData(), b, vec_res, scale); + else + throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong column in Decimal comparison"); + } + else + { + if (const ColVecA * c0_vec = checkAndGetColumn(c0.get())) { - const ColumnConst & c0_const = checkAndGetColumnConst(*c0); - const ColumnConst & c1_const = checkAndGetColumnConst(*c1); - - A a = c0_const.template getValue(); - B b = c1_const.template getValue(); - UInt8 res = apply(a, b, scale); - return DataTypeUInt8().createColumnConst(c0->size(), toField(res)); - } - - ColumnUInt8::Container & vec_res = c_res->getData(); - vec_res.resize(c0->size()); - - if (c0_is_const) - { - const ColumnConst & c0_const = checkAndGetColumnConst(*c0); - A a = c0_const.template getValue(); if (const ColVecB * c1_vec = checkAndGetColumn(c1.get())) - constantVector(a, c1_vec->getData(), vec_res, scale); - else - throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong column in Decimal comparison"); - } - else if (c1_is_const) - { - const ColumnConst & c1_const = checkAndGetColumnConst(*c1); - B b = c1_const.template getValue(); - if (const ColVecA * c0_vec = checkAndGetColumn(c0.get())) - vectorConstant(c0_vec->getData(), b, vec_res, scale); + vectorVector(c0_vec->getData(), c1_vec->getData(), vec_res, scale); else throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong column in Decimal comparison"); } else - { - if (const ColVecA * c0_vec = checkAndGetColumn(c0.get())) - { - if (const ColVecB * c1_vec = checkAndGetColumn(c1.get())) - vectorVector(c0_vec->getData(), c1_vec->getData(), vec_res, scale); - else - throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong column in Decimal comparison"); - } - else - throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong column in Decimal comparison"); - } + throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong column in Decimal comparison"); } return c_res; } - template + template static NO_INLINE UInt8 apply(A a, B b, CompareInt scale [[maybe_unused]]) { CompareInt x; @@ -232,7 +229,7 @@ private: else y = static_cast(b); - if constexpr (_check_overflow) + if constexpr (check_overflow) { bool overflow = false; @@ -264,9 +261,8 @@ private: return Op::apply(x, y); } - template - static void NO_INLINE vectorVector(const ArrayA & a, const ArrayB & b, PaddedPODArray & c, - CompareInt scale) + template + static void NO_INLINE vectorVector(const ArrayA & a, const ArrayB & b, PaddedPODArray & c, CompareInt scale) { size_t size = a.size(); const A * a_pos = a.data(); @@ -276,14 +272,14 @@ private: while (a_pos < a_end) { - *c_pos = apply(*a_pos, *b_pos, scale); + *c_pos = apply(*a_pos, *b_pos, scale); ++a_pos; ++b_pos; ++c_pos; } } - template + template static void NO_INLINE vectorConstant(const ArrayA & a, B b, PaddedPODArray & c, CompareInt scale) { size_t size = a.size(); @@ -293,13 +289,13 @@ private: while (a_pos < a_end) { - *c_pos = apply(*a_pos, b, scale); + *c_pos = apply(*a_pos, b, scale); ++a_pos; ++c_pos; } } - template + template static void NO_INLINE constantVector(A a, const ArrayB & b, PaddedPODArray & c, CompareInt scale) { size_t size = b.size(); @@ -309,7 +305,7 @@ private: while (b_pos < b_end) { - *c_pos = apply(a, *b_pos, scale); + *c_pos = apply(a, *b_pos, scale); ++b_pos; ++c_pos; } diff --git a/src/Core/Field.cpp b/src/Core/Field.cpp index e774a95e19f..90f30b52c0c 100644 --- a/src/Core/Field.cpp +++ b/src/Core/Field.cpp @@ -529,22 +529,25 @@ Field Field::restoreFromDump(std::string_view dump_) template bool decimalEqual(T x, T y, UInt32 x_scale, UInt32 y_scale) { + bool check_overflow = true; using Comparator = DecimalComparison; - return Comparator::compare(x, y, x_scale, y_scale); + return Comparator::compare(x, y, x_scale, y_scale, check_overflow); } template bool decimalLess(T x, T y, UInt32 x_scale, UInt32 y_scale) { + bool check_overflow = true; using Comparator = DecimalComparison; - return Comparator::compare(x, y, x_scale, y_scale); + return Comparator::compare(x, y, x_scale, y_scale, check_overflow); } template bool decimalLessOrEqual(T x, T y, UInt32 x_scale, UInt32 y_scale) { + bool check_overflow = true; using Comparator = DecimalComparison; - return Comparator::compare(x, y, x_scale, y_scale); + return Comparator::compare(x, y, x_scale, y_scale, check_overflow); } diff --git a/src/Core/Field.h b/src/Core/Field.h index c08d5c9eb42..5a6ee9cdf29 100644 --- a/src/Core/Field.h +++ b/src/Core/Field.h @@ -863,6 +863,9 @@ template <> struct Field::EnumToType { usi template <> struct Field::EnumToType { using Type = CustomType; }; template <> struct Field::EnumToType { using Type = UInt64; }; +/// Use it to prevent inclusion of magic_enum in headers, which is very expensive for the compiler +std::string_view fieldTypeToString(Field::Types::Which type); + constexpr bool isInt64OrUInt64FieldType(Field::Types::Which t) { return t == Field::Types::Int64 @@ -886,7 +889,7 @@ auto & Field::safeGet() if (target != which && !(which == Field::Types::Bool && (target == Field::Types::UInt64 || target == Field::Types::Int64)) && !(isInt64OrUInt64FieldType(which) && isInt64OrUInt64FieldType(target))) - throw Exception(ErrorCodes::BAD_GET, "Bad get: has {}, requested {}", getTypeName(), target); + throw Exception(ErrorCodes::BAD_GET, "Bad get: has {}, requested {}", getTypeName(), fieldTypeToString(target)); return get(); } @@ -1002,8 +1005,6 @@ void readQuoted(DecimalField & x, ReadBuffer & buf); void writeFieldText(const Field & x, WriteBuffer & buf); String toString(const Field & x); - -std::string_view fieldTypeToString(Field::Types::Which type); } template <> diff --git a/src/Core/callOnTypeIndex.h b/src/Core/callOnTypeIndex.h index 0c8f2201b0d..09fbc7f1f10 100644 --- a/src/Core/callOnTypeIndex.h +++ b/src/Core/callOnTypeIndex.h @@ -87,6 +87,77 @@ static bool callOnBasicType(TypeIndex number, F && f) return false; } + +template +static bool callOnBasicTypeSecondArg(TypeIndex number, F && f) +{ + if constexpr (_int) + { + switch (number) + { + case TypeIndex::UInt8: return f(TypePair()); + case TypeIndex::UInt16: return f(TypePair()); + case TypeIndex::UInt32: return f(TypePair()); + case TypeIndex::UInt64: return f(TypePair()); + case TypeIndex::UInt128: return f(TypePair()); + case TypeIndex::UInt256: return f(TypePair()); + + case TypeIndex::Int8: return f(TypePair()); + case TypeIndex::Int16: return f(TypePair()); + case TypeIndex::Int32: return f(TypePair()); + case TypeIndex::Int64: return f(TypePair()); + case TypeIndex::Int128: return f(TypePair()); + case TypeIndex::Int256: return f(TypePair()); + + case TypeIndex::Enum8: return f(TypePair()); + case TypeIndex::Enum16: return f(TypePair()); + + default: + break; + } + } + + if constexpr (_decimal) + { + switch (number) + { + case TypeIndex::Decimal32: return f(TypePair()); + case TypeIndex::Decimal64: return f(TypePair()); + case TypeIndex::Decimal128: return f(TypePair()); + case TypeIndex::Decimal256: return f(TypePair()); + default: + break; + } + } + + if constexpr (_float) + { + switch (number) + { + case TypeIndex::BFloat16: return f(TypePair()); + case TypeIndex::Float32: return f(TypePair()); + case TypeIndex::Float64: return f(TypePair()); + default: + break; + } + } + + if constexpr (_datetime) + { + switch (number) + { + case TypeIndex::Date: return f(TypePair()); + case TypeIndex::Date32: return f(TypePair()); + case TypeIndex::DateTime: return f(TypePair()); + case TypeIndex::DateTime64: return f(TypePair()); + default: + break; + } + } + + return false; +} + /// Unroll template using TypeIndex template static inline bool callOnBasicTypes(TypeIndex type_num1, TypeIndex type_num2, F && f) diff --git a/src/DataTypes/DataTypeDecimalBase.cpp b/src/DataTypes/DataTypeDecimalBase.cpp index 68bfba475d6..423ab2e4765 100644 --- a/src/DataTypes/DataTypeDecimalBase.cpp +++ b/src/DataTypes/DataTypeDecimalBase.cpp @@ -1,7 +1,8 @@ +#include +#include #include #include #include -#include namespace DB { @@ -14,6 +15,12 @@ namespace ErrorCodes { } +template +constexpr size_t DataTypeDecimalBase::maxPrecision() +{ + return DecimalUtils::max_precision; +} + bool decimalCheckComparisonOverflow(ContextPtr context) { return context->getSettingsRef()[Setting::decimal_check_overflow]; @@ -41,6 +48,18 @@ T DataTypeDecimalBase::getScaleMultiplier(UInt32 scale_) return DecimalUtils::scaleMultiplier(scale_); } +template +T DataTypeDecimalBase::wholePart(T x) const +{ + return DecimalUtils::getWholePart(x, scale); +} + +template +T DataTypeDecimalBase::fractionalPart(T x) const +{ + return DecimalUtils::getFractionalPart(x, scale); +} + /// Explicit template instantiations. template class DataTypeDecimalBase; diff --git a/src/DataTypes/DataTypeDecimalBase.h b/src/DataTypes/DataTypeDecimalBase.h index c1e1d27557f..beba3c42616 100644 --- a/src/DataTypes/DataTypeDecimalBase.h +++ b/src/DataTypes/DataTypeDecimalBase.h @@ -3,11 +3,10 @@ #include #include -#include -#include #include -#include +#include #include +#include #include @@ -64,7 +63,7 @@ public: static constexpr bool is_parametric = true; - static constexpr size_t maxPrecision() { return DecimalUtils::max_precision; } + static constexpr size_t maxPrecision(); DataTypeDecimalBase(UInt32 precision_, UInt32 scale_) : precision(precision_), @@ -104,15 +103,8 @@ public: UInt32 getScale() const { return scale; } T getScaleMultiplier() const { return getScaleMultiplier(scale); } - T wholePart(T x) const - { - return DecimalUtils::getWholePart(x, scale); - } - - T fractionalPart(T x) const - { - return DecimalUtils::getFractionalPart(x, scale); - } + T wholePart(T x) const; + T fractionalPart(T x) const; T maxWholeValue() const { return getScaleMultiplier(precision - scale) - T(1); } @@ -147,11 +139,6 @@ public: static T getScaleMultiplier(UInt32 scale); - DecimalUtils::DataTypeDecimalTrait getTrait() const - { - return {precision, scale}; - } - protected: const UInt32 precision; const UInt32 scale; @@ -167,50 +154,35 @@ inline const DataTypeDecimalBase * checkDecimalBase(const IDataType & data_ty return nullptr; } -template typename DecimalType> -inline auto decimalResultType(const DecimalType & tx, const DecimalType & ty) -{ - const auto result_trait = DecimalUtils::binaryOpResult(tx, ty); - return DecimalType(result_trait.precision, result_trait.scale); -} +template <> constexpr size_t DataTypeDecimalBase::maxPrecision() { return 9; }; +template <> constexpr size_t DataTypeDecimalBase::maxPrecision() { return 18; }; +template <> constexpr size_t DataTypeDecimalBase::maxPrecision() { return 18; }; +template <> constexpr size_t DataTypeDecimalBase::maxPrecision() { return 38; }; +template <> constexpr size_t DataTypeDecimalBase::maxPrecision() { return 76; }; -template typename DecimalType> -inline DecimalType decimalResultType(const DecimalType & tx, const DataTypeNumber & ty) -{ - const auto result_trait = DecimalUtils::binaryOpResult(tx, ty); - return DecimalType(result_trait.precision, result_trait.scale); -} - -template typename DecimalType> -inline DecimalType decimalResultType(const DataTypeNumber & tx, const DecimalType & ty) -{ - const auto result_trait = DecimalUtils::binaryOpResult(tx, ty); - return DecimalType(result_trait.precision, result_trait.scale); -} +extern template class DataTypeDecimalBase; +extern template class DataTypeDecimalBase; +extern template class DataTypeDecimalBase; +extern template class DataTypeDecimalBase; +extern template class DataTypeDecimalBase; template