mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 23:21:59 +00:00
rewrite
This commit is contained in:
parent
fdc6abaaa1
commit
35125c1b33
@ -79,8 +79,9 @@ struct InvalidType;
|
||||
template <template <typename> class Op, typename Name, bool is_injective>
|
||||
class FunctionUnaryArithmetic : public IFunction
|
||||
{
|
||||
static constexpr bool allow_decimal = IsUnaryOperation<Op>::negate || IsUnaryOperation<Op>::abs;
|
||||
static constexpr bool allow_decimal = IsUnaryOperation<Op>::negate || IsUnaryOperation<Op>::abs || IsUnaryOperation<Op>::sign;
|
||||
static constexpr bool allow_fixed_string = Op<UInt8>::allow_fixed_string;
|
||||
static constexpr bool is_sign_function = IsUnaryOperation<Op>::sign;
|
||||
|
||||
template <typename F>
|
||||
static bool castType(const IDataType * type, F && f)
|
||||
@ -137,7 +138,7 @@ public:
|
||||
{
|
||||
using T0 = typename DataType::FieldType;
|
||||
|
||||
if constexpr (IsDataTypeDecimal<DataType>)
|
||||
if constexpr (IsDataTypeDecimal<DataType> && !is_sign_function)
|
||||
{
|
||||
if constexpr (!allow_decimal)
|
||||
return false;
|
||||
@ -182,6 +183,17 @@ public:
|
||||
if constexpr (allow_decimal)
|
||||
{
|
||||
if (auto col = checkAndGetColumn<ColumnDecimal<T0>>(arguments[0].column.get()))
|
||||
{
|
||||
if constexpr (is_sign_function)
|
||||
{
|
||||
auto col_res = ColumnVector<typename Op<T0>::ResultType>::create();
|
||||
auto & vec_res = col_res->getData();
|
||||
vec_res.resize(col->getData().size());
|
||||
UnaryOperationImpl<T0, Op<T0>>::vector(col->getData(), vec_res);
|
||||
result_column = std::move(col_res);
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto col_res = ColumnDecimal<typename Op<T0>::ResultType>::create(0, type.getScale());
|
||||
auto & vec_res = col_res->getData();
|
||||
@ -192,6 +204,7 @@ public:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
using T0 = typename DataType::FieldType;
|
||||
|
@ -20,6 +20,9 @@ template <typename, typename> struct NotEqualsOp;
|
||||
template <typename, typename> struct LessOrEqualsOp;
|
||||
template <typename, typename> struct GreaterOrEqualsOp;
|
||||
|
||||
template <typename>
|
||||
struct SignImpl;
|
||||
|
||||
template <template <typename, typename> typename Op1, template <typename, typename> typename Op2>
|
||||
struct IsSameOperation
|
||||
{
|
||||
@ -31,6 +34,7 @@ struct IsUnaryOperation
|
||||
{
|
||||
static constexpr bool abs = std::is_same_v<Op<Int8>, AbsImpl<Int8>>;
|
||||
static constexpr bool negate = std::is_same_v<Op<Int8>, NegateImpl<Int8>>;
|
||||
static constexpr bool sign = std::is_same_v<Op<Int8>, SignImpl<Int8>>;
|
||||
};
|
||||
|
||||
template <template <typename, typename> typename Op>
|
||||
|
@ -1,100 +1,47 @@
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
#include <Functions/IFunction.h>
|
||||
#include <Functions/FunctionUnaryArithmetic.h>
|
||||
#include <DataTypes/NumberTraits.h>
|
||||
#include <Common/FieldVisitors.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
template <typename A>
|
||||
struct SignImpl
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
using ResultType = Int8;
|
||||
static const constexpr bool allow_fixed_string = false;
|
||||
|
||||
static inline NO_SANITIZE_UNDEFINED ResultType apply(A a)
|
||||
{
|
||||
if constexpr (IsDecimalNumber<A> || std::is_floating_point_v<A>)
|
||||
return a < A(0) ? -1 : a == A(0) ? 0 : 1;
|
||||
else if constexpr (is_signed_v<A>)
|
||||
return a < 0 ? -1 : a == 0 ? 0 : 1;
|
||||
else if constexpr (is_unsigned_v<A>)
|
||||
return a == 0 ? 0 : 1;
|
||||
}
|
||||
|
||||
class FunctionSign : public IFunction
|
||||
{
|
||||
private:
|
||||
const Context & context;
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
static constexpr bool compilable = false;
|
||||
#endif
|
||||
};
|
||||
|
||||
public:
|
||||
struct NameSign
|
||||
{
|
||||
static constexpr auto name = "sign";
|
||||
};
|
||||
using FunctionSign = FunctionUnaryArithmetic<SignImpl, NameSign, false>;
|
||||
|
||||
explicit FunctionSign(const Context & context_) : context(context_) { }
|
||||
static FunctionPtr create(const Context & context) { return std::make_shared<FunctionSign>(context); }
|
||||
|
||||
String getName() const override { return name; }
|
||||
|
||||
size_t getNumberOfArguments() const override { return 1; }
|
||||
|
||||
bool useDefaultImplementationForConstants() const override { return true; }
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
|
||||
template <>
|
||||
struct FunctionUnaryArithmeticMonotonicity<NameSign>
|
||||
{
|
||||
if (!isNumber(arguments[0].type))
|
||||
throw Exception(
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Argument of function {} should be a number, got {}",
|
||||
getName(),
|
||||
arguments[0].type->getName());
|
||||
|
||||
ColumnWithTypeAndName compare_elem{arguments[0].column, arguments[0].type, {}};
|
||||
|
||||
auto greater = FunctionFactory::instance().get("greater", context);
|
||||
auto greater_compare = greater->build(ColumnsWithTypeAndName{compare_elem, compare_elem});
|
||||
|
||||
if (isUnsignedInteger(arguments[0].type.get()))
|
||||
{
|
||||
return greater_compare->getResultType();
|
||||
}
|
||||
|
||||
auto compare_type = greater_compare->getResultType();
|
||||
|
||||
ColumnWithTypeAndName minus_elem = {compare_type, {}};
|
||||
|
||||
auto minus = FunctionFactory::instance().get("minus", context);
|
||||
auto elem_minus = minus->build(ColumnsWithTypeAndName{minus_elem, minus_elem});
|
||||
|
||||
return elem_minus->getResultType();
|
||||
}
|
||||
|
||||
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
|
||||
{
|
||||
auto zero_column = arguments[0].type->createColumnConstWithDefaultValue(input_rows_count);
|
||||
|
||||
ColumnWithTypeAndName left{arguments[0].column, arguments[0].type, {}};
|
||||
ColumnWithTypeAndName right{zero_column, arguments[0].type, {}};
|
||||
|
||||
auto func_arg = ColumnsWithTypeAndName{left, right};
|
||||
|
||||
auto greater = FunctionFactory::instance().get("greater", context);
|
||||
auto greater_compare = greater->build(func_arg);
|
||||
|
||||
/// Unsigned number: sign(n) = greater(n, 0)
|
||||
if (isUnsignedInteger(arguments[0].type.get()))
|
||||
{
|
||||
return greater_compare->execute(func_arg, greater_compare->getResultType(), input_rows_count);
|
||||
}
|
||||
|
||||
/// Signed number: sign(n) = minus(greater(n, 0), less(n, 0))
|
||||
auto less = FunctionFactory::instance().get("less", context);
|
||||
auto less_compare = less->build(func_arg);
|
||||
|
||||
ColumnsWithTypeAndName columns(2);
|
||||
|
||||
columns[0].type = greater_compare->getResultType();
|
||||
columns[0].column = greater_compare->execute(func_arg, greater_compare->getResultType(), input_rows_count);
|
||||
columns[1].type = less_compare->getResultType();
|
||||
columns[1].column = less_compare->execute(func_arg, less_compare->getResultType(), input_rows_count);
|
||||
|
||||
auto minus = FunctionFactory::instance().get("minus", context);
|
||||
auto elem_minus = minus->build(columns);
|
||||
|
||||
return elem_minus->execute(columns, elem_minus->getResultType(), input_rows_count);
|
||||
}
|
||||
static bool has() { return true; }
|
||||
static IFunction::Monotonicity get(const Field &, const Field &) { return {true, true, false}; }
|
||||
};
|
||||
|
||||
void registerFunctionSign(FunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction<FunctionSign>();
|
||||
factory.registerFunction<FunctionSign>(FunctionFactory::CaseInsensitive);
|
||||
}
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user