This commit is contained in:
feng lv 2021-01-26 11:26:15 +00:00
parent fdc6abaaa1
commit 35125c1b33
3 changed files with 57 additions and 93 deletions

View File

@ -79,8 +79,9 @@ struct InvalidType;
template <template <typename> class Op, typename Name, bool is_injective>
class FunctionUnaryArithmetic : public IFunction
{
static constexpr bool allow_decimal = IsUnaryOperation<Op>::negate || IsUnaryOperation<Op>::abs;
static constexpr bool allow_decimal = IsUnaryOperation<Op>::negate || IsUnaryOperation<Op>::abs || IsUnaryOperation<Op>::sign;
static constexpr bool allow_fixed_string = Op<UInt8>::allow_fixed_string;
static constexpr bool is_sign_function = IsUnaryOperation<Op>::sign;
template <typename F>
static bool castType(const IDataType * type, F && f)
@ -137,7 +138,7 @@ public:
{
using T0 = typename DataType::FieldType;
if constexpr (IsDataTypeDecimal<DataType>)
if constexpr (IsDataTypeDecimal<DataType> && !is_sign_function)
{
if constexpr (!allow_decimal)
return false;
@ -183,12 +184,24 @@ public:
{
if (auto col = checkAndGetColumn<ColumnDecimal<T0>>(arguments[0].column.get()))
{
auto col_res = ColumnDecimal<typename Op<T0>::ResultType>::create(0, type.getScale());
auto & vec_res = col_res->getData();
vec_res.resize(col->getData().size());
UnaryOperationImpl<T0, Op<T0>>::vector(col->getData(), vec_res);
result_column = std::move(col_res);
return true;
if constexpr (is_sign_function)
{
auto col_res = ColumnVector<typename Op<T0>::ResultType>::create();
auto & vec_res = col_res->getData();
vec_res.resize(col->getData().size());
UnaryOperationImpl<T0, Op<T0>>::vector(col->getData(), vec_res);
result_column = std::move(col_res);
return true;
}
else
{
auto col_res = ColumnDecimal<typename Op<T0>::ResultType>::create(0, type.getScale());
auto & vec_res = col_res->getData();
vec_res.resize(col->getData().size());
UnaryOperationImpl<T0, Op<T0>>::vector(col->getData(), vec_res);
result_column = std::move(col_res);
return true;
}
}
}
}

View File

@ -20,6 +20,9 @@ template <typename, typename> struct NotEqualsOp;
template <typename, typename> struct LessOrEqualsOp;
template <typename, typename> struct GreaterOrEqualsOp;
template <typename>
struct SignImpl;
template <template <typename, typename> typename Op1, template <typename, typename> typename Op2>
struct IsSameOperation
{
@ -31,6 +34,7 @@ struct IsUnaryOperation
{
static constexpr bool abs = std::is_same_v<Op<Int8>, AbsImpl<Int8>>;
static constexpr bool negate = std::is_same_v<Op<Int8>, NegateImpl<Int8>>;
static constexpr bool sign = std::is_same_v<Op<Int8>, SignImpl<Int8>>;
};
template <template <typename, typename> typename Op>

View File

@ -1,100 +1,47 @@
#include <Columns/ColumnVector.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <Functions/FunctionUnaryArithmetic.h>
#include <DataTypes/NumberTraits.h>
#include <Common/FieldVisitors.h>
namespace DB
{
namespace ErrorCodes
template <typename A>
struct SignImpl
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
using ResultType = Int8;
static const constexpr bool allow_fixed_string = false;
class FunctionSign : public IFunction
static inline NO_SANITIZE_UNDEFINED ResultType apply(A a)
{
if constexpr (IsDecimalNumber<A> || std::is_floating_point_v<A>)
return a < A(0) ? -1 : a == A(0) ? 0 : 1;
else if constexpr (is_signed_v<A>)
return a < 0 ? -1 : a == 0 ? 0 : 1;
else if constexpr (is_unsigned_v<A>)
return a == 0 ? 0 : 1;
}
#if USE_EMBEDDED_COMPILER
static constexpr bool compilable = false;
#endif
};
struct NameSign
{
private:
const Context & context;
public:
static constexpr auto name = "sign";
};
using FunctionSign = FunctionUnaryArithmetic<SignImpl, NameSign, false>;
explicit FunctionSign(const Context & context_) : context(context_) { }
static FunctionPtr create(const Context & context) { return std::make_shared<FunctionSign>(context); }
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 1; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if (!isNumber(arguments[0].type))
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Argument of function {} should be a number, got {}",
getName(),
arguments[0].type->getName());
ColumnWithTypeAndName compare_elem{arguments[0].column, arguments[0].type, {}};
auto greater = FunctionFactory::instance().get("greater", context);
auto greater_compare = greater->build(ColumnsWithTypeAndName{compare_elem, compare_elem});
if (isUnsignedInteger(arguments[0].type.get()))
{
return greater_compare->getResultType();
}
auto compare_type = greater_compare->getResultType();
ColumnWithTypeAndName minus_elem = {compare_type, {}};
auto minus = FunctionFactory::instance().get("minus", context);
auto elem_minus = minus->build(ColumnsWithTypeAndName{minus_elem, minus_elem});
return elem_minus->getResultType();
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
auto zero_column = arguments[0].type->createColumnConstWithDefaultValue(input_rows_count);
ColumnWithTypeAndName left{arguments[0].column, arguments[0].type, {}};
ColumnWithTypeAndName right{zero_column, arguments[0].type, {}};
auto func_arg = ColumnsWithTypeAndName{left, right};
auto greater = FunctionFactory::instance().get("greater", context);
auto greater_compare = greater->build(func_arg);
/// Unsigned number: sign(n) = greater(n, 0)
if (isUnsignedInteger(arguments[0].type.get()))
{
return greater_compare->execute(func_arg, greater_compare->getResultType(), input_rows_count);
}
/// Signed number: sign(n) = minus(greater(n, 0), less(n, 0))
auto less = FunctionFactory::instance().get("less", context);
auto less_compare = less->build(func_arg);
ColumnsWithTypeAndName columns(2);
columns[0].type = greater_compare->getResultType();
columns[0].column = greater_compare->execute(func_arg, greater_compare->getResultType(), input_rows_count);
columns[1].type = less_compare->getResultType();
columns[1].column = less_compare->execute(func_arg, less_compare->getResultType(), input_rows_count);
auto minus = FunctionFactory::instance().get("minus", context);
auto elem_minus = minus->build(columns);
return elem_minus->execute(columns, elem_minus->getResultType(), input_rows_count);
}
template <>
struct FunctionUnaryArithmeticMonotonicity<NameSign>
{
static bool has() { return true; }
static IFunction::Monotonicity get(const Field &, const Field &) { return {true, true, false}; }
};
void registerFunctionSign(FunctionFactory & factory)
{
factory.registerFunction<FunctionSign>();
factory.registerFunction<FunctionSign>(FunctionFactory::CaseInsensitive);
}
}