This commit is contained in:
feng lv 2021-12-23 03:55:40 +00:00
parent d0c5a887a3
commit dc6f7858f8
4 changed files with 209 additions and 64 deletions

View File

@ -1,17 +1,21 @@
#pragma once
#include <base/arithmeticOverflow.h>
#include <Core/Block.h>
#include <Core/AccurateComparison.h>
#include <Core/callOnTypeIndex.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnConst.h>
#include <Core/AccurateComparison.h>
#include <Core/Block.h>
#include <Core/DecimalFloatComparison.h>
#include <Core/callOnTypeIndex.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionHelpers.h> /// TODO Core should not depend on Functions
#include <Functions/IsOperation.h>
#include <base/arithmeticOverflow.h>
#include <type_traits>
namespace DB
{
@ -52,9 +56,14 @@ struct DecCompareInt
using TypeB = Type;
};
///
template <typename A, typename B, template <typename, typename> typename Operation, bool _check_overflow = true,
bool _actual = is_decimal<A> || is_decimal<B>>
template <
typename A,
typename B,
template <typename, typename>
typename Operation,
bool _check_overflow = true,
bool _actual = is_decimal<A> || is_decimal<B>,
bool _has_float = std::is_floating_point_v<A> || std::is_floating_point_v<B>>
class DecimalComparison
{
public:
@ -220,6 +229,92 @@ private:
template <bool scale_left, bool scale_right>
static NO_INLINE UInt8 apply(A a, B b, CompareInt scale [[maybe_unused]])
{
/// Decimal compares with Float
if constexpr (_has_float)
{
if constexpr (IsOperation<Operation>::equals)
{
if constexpr (std::is_floating_point_v<A> && is_decimal<B>)
{
CompareInt decimal_value = b.value;
return DecimalFloatComparison::equals(a, decimal_value, scale);
}
if constexpr (std::is_floating_point_v<B> && is_decimal<A>)
{
CompareInt decimal_value = a.value;
return DecimalFloatComparison::equals(b, decimal_value, scale);
}
}
if constexpr (IsOperation<Operation>::not_equals)
{
if constexpr (std::is_floating_point_v<A> && is_decimal<B>)
{
CompareInt decimal_value = b.value;
return DecimalFloatComparison::notEquals(a, decimal_value, scale);
}
if constexpr (std::is_floating_point_v<B> && is_decimal<A>)
{
CompareInt decimal_value = a.value;
return DecimalFloatComparison::notEquals(b, decimal_value, scale);
}
}
if constexpr (IsOperation<Operation>::less)
{
if constexpr (std::is_floating_point_v<A> && is_decimal<B>)
{
CompareInt decimal_value = b.value;
return DecimalFloatComparison::less(a, decimal_value, scale);
}
if constexpr (std::is_floating_point_v<B> && is_decimal<A>)
{
CompareInt decimal_value = a.value;
return DecimalFloatComparison::greater(b, decimal_value, scale);
}
}
if constexpr (IsOperation<Operation>::less_or_equals)
{
if constexpr (std::is_floating_point_v<A> && is_decimal<B>)
{
CompareInt decimal_value = b.value;
return DecimalFloatComparison::lessOrEquals(a, decimal_value, scale);
}
if constexpr (std::is_floating_point_v<B> && is_decimal<A>)
{
CompareInt decimal_value = a.value;
return DecimalFloatComparison::greaterOrEquals(b, decimal_value, scale);
}
}
if constexpr (IsOperation<Operation>::greater)
{
if constexpr (std::is_floating_point_v<A> && is_decimal<B>)
{
CompareInt decimal_value = b.value;
return DecimalFloatComparison::greater(a, decimal_value, scale);
}
if constexpr (std::is_floating_point_v<B> && is_decimal<A>)
{
CompareInt decimal_value = a.value;
return DecimalFloatComparison::less(b, decimal_value, scale);
}
}
if constexpr (IsOperation<Operation>::greater_or_equals)
{
if constexpr (std::is_floating_point_v<A> && is_decimal<B>)
{
CompareInt decimal_value = b.value;
return DecimalFloatComparison::greaterOrEquals(a, decimal_value, scale);
}
if constexpr (std::is_floating_point_v<B> && is_decimal<A>)
{
CompareInt decimal_value = a.value;
return DecimalFloatComparison::lessOrEquals(b, decimal_value, scale);
}
}
}
/// Decimal compares with Int
else
{
CompareInt x;
if constexpr (is_decimal<A>)
@ -264,6 +359,7 @@ private:
return Op::apply(x, y);
}
}
template <bool scale_left, bool scale_right>
static void NO_INLINE vectorVector(const ArrayA & a, const ArrayB & b, PaddedPODArray<UInt8> & c,

View File

@ -0,0 +1,58 @@
#pragma once
#include <base/DecomposedFloat.h>
namespace DB
{
struct DecimalFloatComparison
{
template <typename Float, typename Int>
static int compare(Float a, Int b, Int scale)
{
/// TODO need to implement comparison
if (a)
return -1;
if (b)
return 0;
if (scale)
return 1;
return 0;
}
template <typename Float, typename Int>
static bool equals(Float a, Int b, Int scale)
{
return compare(a, b, scale) == 0;
}
template <typename Float, typename Int>
static bool notEquals(Float a, Int b, Int scale)
{
return compare(a, b, scale) != 0;
}
template <typename Float, typename Int>
static bool less(Float a, Int b, Int scale)
{
return compare(a, b, scale) < 0;
}
template <typename Float, typename Int>
static bool greater(Float a, Int b, Int scale)
{
return compare(a, b, scale) > 0;
}
template <typename Float, typename Int>
static bool lessOrEquals(Float a, Int b, Int scale)
{
return compare(a, b, scale) <= 0;
}
template <typename Float, typename Int>
static bool greaterOrEquals(Float a, Int b, Int scale)
{
return compare(a, b, scale) >= 0;
}
};
}

View File

@ -687,7 +687,7 @@ private:
return (res = DecimalComparison<LeftDataType, RightDataType, Op, false>::apply(col_left, col_right)) != nullptr;
};
if (!callOnBasicTypes<true, false, true, true>(left_number, right_number, call))
if (!callOnBasicTypes<true, true, true, true>(left_number, right_number, call))
throw Exception("Wrong call for " + getName() + " with " + col_left.type->getName() + " and " + col_right.type->getName(),
ErrorCodes::LOGICAL_ERROR);
@ -1175,9 +1175,6 @@ public:
const bool left_is_num = col_left_untyped->isNumeric();
const bool right_is_num = col_right_untyped->isNumeric();
const bool left_is_float = which_left.isFloat();
const bool right_is_float = which_right.isFloat();
const bool left_is_string = which_left.isStringOrFixedString();
const bool right_is_string = which_right.isStringOrFixedString();
@ -1240,16 +1237,6 @@ public:
throw Exception(
"No operation " + getName() + " between " + left_type->getName() + " and " + right_type->getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (left_is_float)
{
ColumnPtr left_converted = castColumn(col_with_type_and_name_left, right_type);
return executeDecimal({left_converted, right_type, "left"}, col_with_type_and_name_right);
}
if (right_is_float)
{
ColumnPtr right_converted = castColumn(col_with_type_and_name_right, left_type);
return executeDecimal(col_with_type_and_name_left, {right_converted, left_type, "right"});
}
return executeDecimal(col_with_type_and_name_left, col_with_type_and_name_right);
}

View File

@ -17,7 +17,9 @@ template <typename, typename> struct GreatestBaseImpl;
template <typename, typename> struct ModuloImpl;
template <typename, typename> struct EqualsOp;
template <typename, typename> struct NotEqualsOp;
template <typename, typename> struct LessOp;
template <typename, typename> struct LessOrEqualsOp;
template <typename, typename> struct GreaterOp;
template <typename, typename> struct GreaterOrEqualsOp;
template <typename>
@ -42,7 +44,9 @@ struct IsOperation
{
static constexpr bool equals = IsSameOperation<Op, EqualsOp>::value;
static constexpr bool not_equals = IsSameOperation<Op, NotEqualsOp>::value;
static constexpr bool less = IsSameOperation<Op, LessOp>::value;
static constexpr bool less_or_equals = IsSameOperation<Op, LessOrEqualsOp>::value;
static constexpr bool greater = IsSameOperation<Op, GreaterOp>::value;
static constexpr bool greater_or_equals = IsSameOperation<Op, GreaterOrEqualsOp>::value;
static constexpr bool plus = IsSameOperation<Op, PlusImpl>::value;