mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-22 15:42:02 +00:00
Merge pull request #2303 from pyos/cpp17-folds
Use C++17 fold expressions to simplify FunctionsArithmetic.h
This commit is contained in:
commit
1fc714f6b2
@ -64,9 +64,9 @@ struct BinaryOperationImplBase
|
||||
c[i] = Op::template apply<ResultType>(a, b[i]);
|
||||
}
|
||||
|
||||
static void constant_constant(A a, B b, ResultType & c)
|
||||
static ResultType constant_constant(A a, B b)
|
||||
{
|
||||
c = Op::template apply<ResultType>(a, b);
|
||||
return Op::template apply<ResultType>(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
@ -476,27 +476,13 @@ struct IntExp10Impl
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// this one is just for convenience
|
||||
template <bool B, typename T1, typename T2> using If = std::conditional_t<B, T1, T2>;
|
||||
/// these ones for better semantics
|
||||
template <typename T> using Then = T;
|
||||
template <typename T> using Else = T;
|
||||
|
||||
/// Used to indicate undefined operation
|
||||
struct InvalidType;
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeFromFieldType
|
||||
{
|
||||
using Type = DataTypeNumber<T>;
|
||||
};
|
||||
template <bool V, typename T> struct Case : std::bool_constant<V> { using type = T; };
|
||||
|
||||
template <>
|
||||
struct DataTypeFromFieldType<NumberTraits::Error>
|
||||
{
|
||||
using Type = InvalidType;
|
||||
};
|
||||
/// Switch<Case<C0, T0>, ...> -- select the first Ti for which Ci is true; InvalidType if none.
|
||||
template <typename... Ts> using Switch = typename std::disjunction<Ts..., Case<true, InvalidType>>::type;
|
||||
|
||||
template <typename DataType> constexpr bool IsIntegral = false;
|
||||
template <> constexpr bool IsIntegral<DataTypeUInt8> = true;
|
||||
@ -512,270 +498,74 @@ template <typename DataType> constexpr bool IsDateOrDateTime = false;
|
||||
template <> constexpr bool IsDateOrDateTime<DataTypeDate> = true;
|
||||
template <> constexpr bool IsDateOrDateTime<DataTypeDateTime> = true;
|
||||
|
||||
/** Returns appropriate result type for binary operator on dates (or datetimes):
|
||||
* Date + Integral -> Date
|
||||
* Integral + Date -> Date
|
||||
* Date - Date -> Int32
|
||||
* Date - Integral -> Date
|
||||
* least(Date, Date) -> Date
|
||||
* greatest(Date, Date) -> Date
|
||||
* All other operations are not defined and return InvalidType, operations on
|
||||
* distinct date types are also undefined (e.g. DataTypeDate - DataTypeDateTime)
|
||||
*/
|
||||
template <typename T> using DataTypeFromFieldType = std::conditional_t<std::is_same_v<T, NumberTraits::Error>, InvalidType, DataTypeNumber<T>>;
|
||||
|
||||
template <template <typename, typename> class Operation, typename LeftDataType, typename RightDataType>
|
||||
struct DateBinaryOperationTraits
|
||||
struct BinaryOperationTraits
|
||||
{
|
||||
using T0 = typename LeftDataType::FieldType;
|
||||
using T1 = typename RightDataType::FieldType;
|
||||
using Op = Operation<T0, T1>;
|
||||
|
||||
using ResultDataType =
|
||||
If<std::is_same_v<Op, PlusImpl<T0, T1>>,
|
||||
Then<
|
||||
If<IsDateOrDateTime<LeftDataType> && IsIntegral<RightDataType>,
|
||||
Then<LeftDataType>,
|
||||
Else<
|
||||
If<IsIntegral<LeftDataType> && IsDateOrDateTime<RightDataType>,
|
||||
Then<RightDataType>,
|
||||
Else<InvalidType>>>>>,
|
||||
Else<
|
||||
If<std::is_same_v<Op, MinusImpl<T0, T1>>,
|
||||
Then<
|
||||
If<IsDateOrDateTime<LeftDataType>,
|
||||
Then<
|
||||
If<std::is_same_v<LeftDataType, RightDataType>,
|
||||
Then<DataTypeInt32>,
|
||||
Else<
|
||||
If<IsIntegral<RightDataType>,
|
||||
Then<LeftDataType>,
|
||||
Else<InvalidType>>>>>,
|
||||
Else<InvalidType>>>,
|
||||
Else<
|
||||
If<std::is_same_v<T0, T1>
|
||||
&& (std::is_same_v<Op, LeastImpl<T0, T1>> || std::is_same_v<Op, GreatestImpl<T0, T1>>),
|
||||
Then<LeftDataType>,
|
||||
Else<InvalidType>>>>>>;
|
||||
/// Appropriate result type for binary operator on numeric types. "Date" can also mean
|
||||
/// DateTime, but if both operands are Dates, their type must be the same (e.g. Date - DateTime is invalid).
|
||||
using ResultDataType = Switch<
|
||||
/// number <op> number -> see corresponding impl
|
||||
Case<!IsDateOrDateTime<LeftDataType> && !IsDateOrDateTime<RightDataType>,
|
||||
DataTypeFromFieldType<typename Op::ResultType>>,
|
||||
/// Date + Integral -> Date
|
||||
/// Integral + Date -> Date
|
||||
Case<std::is_same_v<Op, PlusImpl<T0, T1>>, Switch<
|
||||
Case<IsIntegral<RightDataType>, LeftDataType>,
|
||||
Case<IsIntegral<LeftDataType>, RightDataType>>>,
|
||||
/// Date - Date -> Int32
|
||||
/// Date - Integral -> Date
|
||||
Case<std::is_same_v<Op, MinusImpl<T0, T1>>, Switch<
|
||||
Case<std::is_same_v<LeftDataType, RightDataType>, DataTypeInt32>,
|
||||
Case<IsDateOrDateTime<LeftDataType> && IsIntegral<RightDataType>, LeftDataType>>>,
|
||||
/// least(Date, Date) -> Date
|
||||
/// greatest(Date, Date) -> Date
|
||||
Case<std::is_same_v<LeftDataType, RightDataType> && (std::is_same_v<Op, LeastImpl<T0, T1>> || std::is_same_v<Op, GreatestImpl<T0, T1>>),
|
||||
LeftDataType>>;
|
||||
};
|
||||
|
||||
|
||||
/// Decides among date and numeric operations
|
||||
template <template <typename, typename> class Operation, typename LeftDataType, typename RightDataType>
|
||||
struct BinaryOperationTraits
|
||||
template <typename... Ts, typename F>
|
||||
static bool castTypeToEither(const IDataType * type, F && f)
|
||||
{
|
||||
using ResultDataType =
|
||||
If<IsDateOrDateTime<LeftDataType> || IsDateOrDateTime<RightDataType>,
|
||||
Then<
|
||||
typename DateBinaryOperationTraits<
|
||||
Operation, LeftDataType, RightDataType>::ResultDataType>,
|
||||
Else<
|
||||
typename DataTypeFromFieldType<
|
||||
typename Operation<
|
||||
typename LeftDataType::FieldType,
|
||||
typename RightDataType::FieldType>::ResultType>::Type>>;
|
||||
};
|
||||
/// XXX can't use && here because gcc-7 complains about parentheses around && within ||
|
||||
return ((typeid_cast<const Ts *>(type) ? f(*typeid_cast<const Ts *>(type)) : false) || ...);
|
||||
}
|
||||
|
||||
|
||||
template <template <typename, typename> class Op, typename Name>
|
||||
class FunctionBinaryArithmetic : public IFunction
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = Name::name;
|
||||
static FunctionPtr create(const Context & context) { return std::make_shared<FunctionBinaryArithmetic>(context); }
|
||||
|
||||
FunctionBinaryArithmetic(const Context & context) : context(context) {}
|
||||
|
||||
private:
|
||||
const Context & context;
|
||||
|
||||
template <typename ResultDataType>
|
||||
bool checkRightTypeImpl(DataTypePtr & type_res) const
|
||||
template <typename F>
|
||||
static bool castType(const IDataType * type, F && f)
|
||||
{
|
||||
/// Overload for InvalidType
|
||||
if constexpr (std::is_same_v<ResultDataType, InvalidType>)
|
||||
return false;
|
||||
else
|
||||
{
|
||||
type_res = std::make_shared<ResultDataType>();
|
||||
return true;
|
||||
}
|
||||
return castTypeToEither<
|
||||
DataTypeUInt8,
|
||||
DataTypeUInt16,
|
||||
DataTypeUInt32,
|
||||
DataTypeUInt64,
|
||||
DataTypeInt8,
|
||||
DataTypeInt16,
|
||||
DataTypeInt32,
|
||||
DataTypeInt64,
|
||||
DataTypeFloat32,
|
||||
DataTypeFloat64,
|
||||
DataTypeDate,
|
||||
DataTypeDateTime
|
||||
>(type, std::forward<F>(f));
|
||||
}
|
||||
|
||||
template <typename LeftDataType, typename RightDataType>
|
||||
bool checkRightType(const DataTypes & arguments, DataTypePtr & type_res) const
|
||||
template <typename F>
|
||||
static bool castBothTypes(const IDataType * left, const IDataType * right, F && f)
|
||||
{
|
||||
using ResultDataType = typename BinaryOperationTraits<Op, LeftDataType, RightDataType>::ResultDataType;
|
||||
|
||||
if (typeid_cast<const RightDataType *>(arguments[1].get()))
|
||||
return checkRightTypeImpl<ResultDataType>(type_res);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T0>
|
||||
bool checkLeftType(const DataTypes & arguments, DataTypePtr & type_res) const
|
||||
{
|
||||
if (typeid_cast<const T0 *>(arguments[0].get()))
|
||||
{
|
||||
if ( checkRightType<T0, DataTypeDate>(arguments, type_res)
|
||||
|| checkRightType<T0, DataTypeDateTime>(arguments, type_res)
|
||||
|| checkRightType<T0, DataTypeUInt8>(arguments, type_res)
|
||||
|| checkRightType<T0, DataTypeUInt16>(arguments, type_res)
|
||||
|| checkRightType<T0, DataTypeUInt32>(arguments, type_res)
|
||||
|| checkRightType<T0, DataTypeUInt64>(arguments, type_res)
|
||||
|| checkRightType<T0, DataTypeInt8>(arguments, type_res)
|
||||
|| checkRightType<T0, DataTypeInt16>(arguments, type_res)
|
||||
|| checkRightType<T0, DataTypeInt32>(arguments, type_res)
|
||||
|| checkRightType<T0, DataTypeInt64>(arguments, type_res)
|
||||
|| checkRightType<T0, DataTypeFloat32>(arguments, type_res)
|
||||
|| checkRightType<T0, DataTypeFloat64>(arguments, type_res))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Overload for date operations
|
||||
template <typename LeftDataType, typename RightDataType, typename ColumnType>
|
||||
bool executeRightType(Block & block, const ColumnNumbers & arguments, const size_t result, const ColumnType * col_left)
|
||||
{
|
||||
if (!typeid_cast<const RightDataType *>(block.getByPosition(arguments[1]).type.get()))
|
||||
return false;
|
||||
|
||||
using ResultDataType = typename BinaryOperationTraits<Op, LeftDataType, RightDataType>::ResultDataType;
|
||||
|
||||
return executeRightTypeDispatch<LeftDataType, RightDataType, ResultDataType>(
|
||||
block, arguments, result, col_left);
|
||||
}
|
||||
|
||||
/// Overload for InvalidType
|
||||
template <typename LeftDataType, typename RightDataType, typename ResultDataType, typename ColumnType>
|
||||
bool executeRightTypeDispatch(Block & block, const ColumnNumbers & arguments,
|
||||
[[maybe_unused]] const size_t result, [[maybe_unused]] const ColumnType * col_left)
|
||||
{
|
||||
if constexpr (std::is_same_v<ResultDataType, InvalidType>)
|
||||
throw Exception("Types " + String(TypeName<typename LeftDataType::FieldType>::get())
|
||||
+ " and " + String(TypeName<typename LeftDataType::FieldType>::get())
|
||||
+ " are incompatible for function " + getName() + " or not upscaleable to common type", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
else
|
||||
{
|
||||
using T0 = typename LeftDataType::FieldType;
|
||||
using T1 = typename RightDataType::FieldType;
|
||||
using ResultType = typename ResultDataType::FieldType;
|
||||
|
||||
return executeRightTypeImpl<T0, T1, ResultType>(block, arguments, result, col_left);
|
||||
}
|
||||
}
|
||||
|
||||
/// ColumnVector overload
|
||||
template <typename T0, typename T1, typename ResultType = typename Op<T0, T1>::ResultType>
|
||||
bool executeRightTypeImpl(Block & block, const ColumnNumbers & arguments, size_t result, const ColumnVector<T0> * col_left)
|
||||
{
|
||||
if (auto col_right = checkAndGetColumn<ColumnVector<T1>>(block.getByPosition(arguments[1]).column.get()))
|
||||
{
|
||||
auto col_res = ColumnVector<ResultType>::create();
|
||||
|
||||
auto & vec_res = col_res->getData();
|
||||
vec_res.resize(col_left->getData().size());
|
||||
BinaryOperationImpl<T0, T1, Op<T0, T1>, ResultType>::vector_vector(col_left->getData(), col_right->getData(), vec_res);
|
||||
|
||||
block.getByPosition(result).column = std::move(col_res);
|
||||
return true;
|
||||
}
|
||||
else if (auto col_right = checkAndGetColumnConst<ColumnVector<T1>>(block.getByPosition(arguments[1]).column.get()))
|
||||
{
|
||||
auto col_res = ColumnVector<ResultType>::create();
|
||||
|
||||
auto & vec_res = col_res->getData();
|
||||
vec_res.resize(col_left->getData().size());
|
||||
BinaryOperationImpl<T0, T1, Op<T0, T1>, ResultType>::vector_constant(col_left->getData(), col_right->template getValue<T1>(), vec_res);
|
||||
|
||||
block.getByPosition(result).column = std::move(col_res);
|
||||
return true;
|
||||
}
|
||||
|
||||
throw Exception("Logical error: unexpected type of column", ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
/// ColumnConst overload
|
||||
template <typename T0, typename T1, typename ResultType = typename Op<T0, T1>::ResultType>
|
||||
bool executeRightTypeImpl(Block & block, const ColumnNumbers & arguments, size_t result, const ColumnConst * col_left)
|
||||
{
|
||||
if (auto col_right = checkAndGetColumn<ColumnVector<T1>>(block.getByPosition(arguments[1]).column.get()))
|
||||
{
|
||||
auto col_res = ColumnVector<ResultType>::create();
|
||||
|
||||
auto & vec_res = col_res->getData();
|
||||
vec_res.resize(col_left->size());
|
||||
BinaryOperationImpl<T0, T1, Op<T0, T1>, ResultType>::constant_vector(col_left->template getValue<T0>(), col_right->getData(), vec_res);
|
||||
|
||||
block.getByPosition(result).column = std::move(col_res);
|
||||
return true;
|
||||
}
|
||||
else if (auto col_right = checkAndGetColumnConst<ColumnVector<T1>>(block.getByPosition(arguments[1]).column.get()))
|
||||
{
|
||||
ResultType res = 0;
|
||||
BinaryOperationImpl<T0, T1, Op<T0, T1>, ResultType>::constant_constant(col_left->template getValue<T0>(), col_right->template getValue<T1>(), res);
|
||||
block.getByPosition(result).column = DataTypeNumber<ResultType>().createColumnConst(col_left->size(), toField(res));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename LeftDataType>
|
||||
bool executeLeftType(Block & block, const ColumnNumbers & arguments, const size_t result)
|
||||
{
|
||||
if (!typeid_cast<const LeftDataType *>(block.getByPosition(arguments[0]).type.get()))
|
||||
return false;
|
||||
|
||||
return executeLeftTypeImpl<LeftDataType>(block, arguments, result);
|
||||
}
|
||||
|
||||
template <typename LeftDataType>
|
||||
bool executeLeftTypeImpl(Block & block, const ColumnNumbers & arguments, const size_t result)
|
||||
{
|
||||
if (auto col_left = checkAndGetColumn<ColumnVector<typename LeftDataType::FieldType>>(block.getByPosition(arguments[0]).column.get()))
|
||||
{
|
||||
if ( executeRightType<LeftDataType, DataTypeDate>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeDateTime>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeUInt8>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeUInt16>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeUInt32>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeUInt64>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeInt8>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeInt16>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeInt32>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeInt64>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeFloat32>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeFloat64>(block, arguments, result, col_left))
|
||||
return true;
|
||||
else
|
||||
throw Exception("Illegal column " + block.getByPosition(arguments[1]).column->getName()
|
||||
+ " of second argument of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_COLUMN);
|
||||
}
|
||||
else if (auto col_left = checkAndGetColumnConst<ColumnVector<typename LeftDataType::FieldType>>(block.getByPosition(arguments[0]).column.get()))
|
||||
{
|
||||
if ( executeRightType<LeftDataType, DataTypeDate>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeDateTime>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeUInt8>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeUInt16>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeUInt32>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeUInt64>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeInt8>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeInt16>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeInt32>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeInt64>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeFloat32>(block, arguments, result, col_left)
|
||||
|| executeRightType<LeftDataType, DataTypeFloat64>(block, arguments, result, col_left))
|
||||
return true;
|
||||
else
|
||||
throw Exception("Illegal column " + block.getByPosition(arguments[1]).column->getName()
|
||||
+ " of second argument of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_COLUMN);
|
||||
}
|
||||
|
||||
return false;
|
||||
return castType(left, [&](const auto & left) { return castType(right, [&](const auto & right) { return f(left, right); }); });
|
||||
}
|
||||
|
||||
FunctionBuilderPtr getFunctionForIntervalArithmetic(const DataTypePtr & type0, const DataTypePtr & type1) const
|
||||
@ -820,6 +610,11 @@ private:
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr auto name = Name::name;
|
||||
static FunctionPtr create(const Context & context) { return std::make_shared<FunctionBinaryArithmetic>(context); }
|
||||
|
||||
FunctionBinaryArithmetic(const Context & context) : context(context) {}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return name;
|
||||
@ -849,22 +644,21 @@ public:
|
||||
}
|
||||
|
||||
DataTypePtr type_res;
|
||||
|
||||
if (!( checkLeftType<DataTypeDate>(arguments, type_res)
|
||||
|| checkLeftType<DataTypeDateTime>(arguments, type_res)
|
||||
|| checkLeftType<DataTypeUInt8>(arguments, type_res)
|
||||
|| checkLeftType<DataTypeUInt16>(arguments, type_res)
|
||||
|| checkLeftType<DataTypeUInt32>(arguments, type_res)
|
||||
|| checkLeftType<DataTypeUInt64>(arguments, type_res)
|
||||
|| checkLeftType<DataTypeInt8>(arguments, type_res)
|
||||
|| checkLeftType<DataTypeInt16>(arguments, type_res)
|
||||
|| checkLeftType<DataTypeInt32>(arguments, type_res)
|
||||
|| checkLeftType<DataTypeInt64>(arguments, type_res)
|
||||
|| checkLeftType<DataTypeFloat32>(arguments, type_res)
|
||||
|| checkLeftType<DataTypeFloat64>(arguments, type_res)))
|
||||
bool valid = castBothTypes(arguments[0].get(), arguments[1].get(), [&](const auto & left, const auto & right)
|
||||
{
|
||||
using LeftDataType = std::decay_t<decltype(left)>;
|
||||
using RightDataType = std::decay_t<decltype(right)>;
|
||||
using ResultDataType = typename BinaryOperationTraits<Op, LeftDataType, RightDataType>::ResultDataType;
|
||||
if constexpr (!std::is_same_v<ResultDataType, InvalidType>)
|
||||
{
|
||||
type_res = std::make_shared<ResultDataType>();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
if (!valid)
|
||||
throw Exception("Illegal types " + arguments[0]->getName() + " and " + arguments[1]->getName() + " of arguments of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
return type_res;
|
||||
}
|
||||
|
||||
@ -893,21 +687,63 @@ public:
|
||||
return;
|
||||
}
|
||||
|
||||
if (!( executeLeftType<DataTypeDate>(block, arguments, result)
|
||||
|| executeLeftType<DataTypeDateTime>(block, arguments, result)
|
||||
|| executeLeftType<DataTypeUInt8>(block, arguments, result)
|
||||
|| executeLeftType<DataTypeUInt16>(block, arguments, result)
|
||||
|| executeLeftType<DataTypeUInt32>(block, arguments, result)
|
||||
|| executeLeftType<DataTypeUInt64>(block, arguments, result)
|
||||
|| executeLeftType<DataTypeInt8>(block, arguments, result)
|
||||
|| executeLeftType<DataTypeInt16>(block, arguments, result)
|
||||
|| executeLeftType<DataTypeInt32>(block, arguments, result)
|
||||
|| executeLeftType<DataTypeInt64>(block, arguments, result)
|
||||
|| executeLeftType<DataTypeFloat32>(block, arguments, result)
|
||||
|| executeLeftType<DataTypeFloat64>(block, arguments, result)))
|
||||
throw Exception("Illegal column " + block.getByPosition(arguments[0]).column->getName()
|
||||
+ " of first argument of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_COLUMN);
|
||||
auto * left = block.getByPosition(arguments[0]).type.get();
|
||||
auto * right = block.getByPosition(arguments[1]).type.get();
|
||||
bool valid = castBothTypes(left, right, [&](const auto & left, const auto & right)
|
||||
{
|
||||
using LeftDataType = std::decay_t<decltype(left)>;
|
||||
using RightDataType = std::decay_t<decltype(right)>;
|
||||
using ResultDataType = typename BinaryOperationTraits<Op, LeftDataType, RightDataType>::ResultDataType;
|
||||
if constexpr (!std::is_same_v<ResultDataType, InvalidType>)
|
||||
{
|
||||
using T0 = typename LeftDataType::FieldType;
|
||||
using T1 = typename RightDataType::FieldType;
|
||||
using ResultType = typename ResultDataType::FieldType;
|
||||
using OpImpl = BinaryOperationImpl<T0, T1, Op<T0, T1>, ResultType>;
|
||||
|
||||
auto col_left_raw = block.getByPosition(arguments[0]).column.get();
|
||||
auto col_right_raw = block.getByPosition(arguments[1]).column.get();
|
||||
if (auto col_left = checkAndGetColumnConst<ColumnVector<T0>>(col_left_raw))
|
||||
{
|
||||
if (auto col_right = checkAndGetColumnConst<ColumnVector<T1>>(col_right_raw))
|
||||
{
|
||||
/// the only case with a non-vector result
|
||||
auto res = OpImpl::constant_constant(col_left->template getValue<T0>(), col_right->template getValue<T1>());
|
||||
block.getByPosition(result).column = ResultDataType().createColumnConst(col_left->size(), toField(res));
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
auto col_res = ColumnVector<ResultType>::create();
|
||||
auto & vec_res = col_res->getData();
|
||||
vec_res.resize(block.rows());
|
||||
if (auto col_left = checkAndGetColumnConst<ColumnVector<T0>>(col_left_raw))
|
||||
{
|
||||
if (auto col_right = checkAndGetColumn<ColumnVector<T1>>(col_right_raw))
|
||||
OpImpl::constant_vector(col_left->template getValue<T0>(), col_right->getData(), vec_res);
|
||||
else
|
||||
return false;
|
||||
}
|
||||
else if (auto col_left = checkAndGetColumn<ColumnVector<T0>>(col_left_raw))
|
||||
{
|
||||
if (auto col_right = checkAndGetColumn<ColumnVector<T1>>(col_right_raw))
|
||||
OpImpl::vector_vector(col_left->getData(), col_right->getData(), vec_res);
|
||||
else if (auto col_right = checkAndGetColumnConst<ColumnVector<T1>>(col_right_raw))
|
||||
OpImpl::vector_constant(col_left->getData(), col_right->template getValue<T1>(), vec_res);
|
||||
else
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
block.getByPosition(result).column = std::move(col_res);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
if (!valid)
|
||||
throw Exception(getName() + "'s arguments do not match the expected data types", ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
};
|
||||
|
||||
@ -919,43 +755,27 @@ struct FunctionUnaryArithmeticMonotonicity;
|
||||
template <template <typename> class Op, typename Name, bool is_injective>
|
||||
class FunctionUnaryArithmetic : public IFunction
|
||||
{
|
||||
template <typename F>
|
||||
static bool castType(const IDataType * type, F && f)
|
||||
{
|
||||
return castTypeToEither<
|
||||
DataTypeUInt8,
|
||||
DataTypeUInt16,
|
||||
DataTypeUInt32,
|
||||
DataTypeUInt64,
|
||||
DataTypeInt8,
|
||||
DataTypeInt16,
|
||||
DataTypeInt32,
|
||||
DataTypeInt64,
|
||||
DataTypeFloat32,
|
||||
DataTypeFloat64
|
||||
>(type, std::forward<F>(f));
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr auto name = Name::name;
|
||||
static FunctionPtr create(const Context &) { return std::make_shared<FunctionUnaryArithmetic>(); }
|
||||
|
||||
private:
|
||||
template <typename T0>
|
||||
bool checkType(const DataTypes & arguments, DataTypePtr & result) const
|
||||
{
|
||||
if (typeid_cast<const T0 *>(arguments[0].get()))
|
||||
{
|
||||
result = std::make_shared<DataTypeNumber<typename Op<typename T0::FieldType>::ResultType>>();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T0>
|
||||
bool executeType(Block & block, const ColumnNumbers & arguments, size_t result)
|
||||
{
|
||||
if (const ColumnVector<T0> * col = checkAndGetColumn<ColumnVector<T0>>(block.getByPosition(arguments[0]).column.get()))
|
||||
{
|
||||
using ResultType = typename Op<T0>::ResultType;
|
||||
|
||||
auto col_res = ColumnVector<ResultType>::create();
|
||||
|
||||
typename ColumnVector<ResultType>::Container & vec_res = col_res->getData();
|
||||
vec_res.resize(col->getData().size());
|
||||
UnaryOperationImpl<T0, Op<T0>>::vector(col->getData(), vec_res);
|
||||
|
||||
block.getByPosition(result).column = std::move(col_res);
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
public:
|
||||
String getName() const override
|
||||
{
|
||||
return name;
|
||||
@ -969,38 +789,36 @@ public:
|
||||
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
|
||||
{
|
||||
DataTypePtr result;
|
||||
|
||||
if (!( checkType<DataTypeUInt8>(arguments, result)
|
||||
|| checkType<DataTypeUInt16>(arguments, result)
|
||||
|| checkType<DataTypeUInt32>(arguments, result)
|
||||
|| checkType<DataTypeUInt64>(arguments, result)
|
||||
|| checkType<DataTypeInt8>(arguments, result)
|
||||
|| checkType<DataTypeInt16>(arguments, result)
|
||||
|| checkType<DataTypeInt32>(arguments, result)
|
||||
|| checkType<DataTypeInt64>(arguments, result)
|
||||
|| checkType<DataTypeFloat32>(arguments, result)
|
||||
|| checkType<DataTypeFloat64>(arguments, result)))
|
||||
bool valid = castType(arguments[0].get(), [&](const auto & type)
|
||||
{
|
||||
using T0 = typename std::decay_t<decltype(type)>::FieldType;
|
||||
result = std::make_shared<DataTypeNumber<typename Op<T0>::ResultType>>();
|
||||
return true;
|
||||
});
|
||||
if (!valid)
|
||||
throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override
|
||||
{
|
||||
if (!( executeType<UInt8>(block, arguments, result)
|
||||
|| executeType<UInt16>(block, arguments, result)
|
||||
|| executeType<UInt32>(block, arguments, result)
|
||||
|| executeType<UInt64>(block, arguments, result)
|
||||
|| executeType<Int8>(block, arguments, result)
|
||||
|| executeType<Int16>(block, arguments, result)
|
||||
|| executeType<Int32>(block, arguments, result)
|
||||
|| executeType<Int64>(block, arguments, result)
|
||||
|| executeType<Float32>(block, arguments, result)
|
||||
|| executeType<Float64>(block, arguments, result)))
|
||||
throw Exception("Illegal column " + block.getByPosition(arguments[0]).column->getName()
|
||||
+ " of argument of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_COLUMN);
|
||||
bool valid = castType(block.getByPosition(arguments[0]).type.get(), [&](const auto & type)
|
||||
{
|
||||
using T0 = typename std::decay_t<decltype(type)>::FieldType;
|
||||
if (auto col = checkAndGetColumn<ColumnVector<T0>>(block.getByPosition(arguments[0]).column.get()))
|
||||
{
|
||||
auto col_res = ColumnVector<typename Op<T0>::ResultType>::create();
|
||||
auto & vec_res = col_res->getData();
|
||||
vec_res.resize(col->getData().size());
|
||||
UnaryOperationImpl<T0, Op<T0>>::vector(col->getData(), vec_res);
|
||||
block.getByPosition(result).column = std::move(col_res);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
if (!valid)
|
||||
throw Exception(getName() + "'s argument does not match the expected data type", ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
bool hasInformationAboutMonotonicity() const override
|
||||
|
Loading…
Reference in New Issue
Block a user