Binary operator monotonicity

This commit is contained in:
Amos Bird 2020-09-05 22:12:47 +08:00
parent c2f762e20a
commit 34b9547ce1
No known key found for this signature in database
GPG Key ID: 80D430DCBECFEDB4
24 changed files with 247 additions and 31 deletions

View File

@ -28,6 +28,7 @@
#include "FunctionFactory.h"
#include <Common/typeid_cast.h>
#include <Common/assert_cast.h>
#include <ext/map.h>
#if !defined(ARCADIA_BUILD)
# include <Common/config.h>
@ -51,6 +52,7 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
extern const int DECIMAL_OVERFLOW;
extern const int CANNOT_ADD_DIFFERENT_AGGREGATE_STATES;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
@ -602,7 +604,8 @@ class FunctionBinaryArithmetic : public IFunction
return castType(left, [&](const auto & left_) { return castType(right, [&](const auto & right_) { return f(left_, right_); }); });
}
FunctionOverloadResolverPtr getFunctionForIntervalArithmetic(const DataTypePtr & type0, const DataTypePtr & type1) const
static FunctionOverloadResolverPtr
getFunctionForIntervalArithmetic(const DataTypePtr & type0, const DataTypePtr & type1, const Context & context)
{
bool first_is_date_or_datetime = isDateOrDateTime(type0);
bool second_is_date_or_datetime = isDateOrDateTime(type1);
@ -632,7 +635,7 @@ class FunctionBinaryArithmetic : public IFunction
}
if (second_is_date_or_datetime && is_minus)
throw Exception("Wrong order of arguments for function " + getName() + ": argument of type Interval cannot be first.",
throw Exception("Wrong order of arguments for function " + String(name) + ": argument of type Interval cannot be first.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
std::string function_name;
@ -651,7 +654,7 @@ class FunctionBinaryArithmetic : public IFunction
return FunctionFactory::instance().get(function_name, context);
}
bool isAggregateMultiply(const DataTypePtr & type0, const DataTypePtr & type1) const
static bool isAggregateMultiply(const DataTypePtr & type0, const DataTypePtr & type1)
{
if constexpr (!is_multiply)
return false;
@ -663,7 +666,7 @@ class FunctionBinaryArithmetic : public IFunction
|| (which0.isNativeUInt() && which1.isAggregateFunction());
}
bool isAggregateAddition(const DataTypePtr & type0, const DataTypePtr & type1) const
static bool isAggregateAddition(const DataTypePtr & type0, const DataTypePtr & type1)
{
if constexpr (!is_plus)
return false;
@ -812,6 +815,11 @@ public:
size_t getNumberOfArguments() const override { return 2; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
return getReturnTypeImplStatic(arguments, context);
}
static DataTypePtr getReturnTypeImplStatic(const DataTypes & arguments, const Context & context)
{
/// Special case when multiply aggregate function state
if (isAggregateMultiply(arguments[0], arguments[1]))
@ -832,7 +840,7 @@ public:
}
/// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
if (auto function_builder = getFunctionForIntervalArithmetic(arguments[0], arguments[1]))
if (auto function_builder = getFunctionForIntervalArithmetic(arguments[0], arguments[1], context))
{
ColumnsWithTypeAndName new_arguments(2);
@ -903,7 +911,7 @@ public:
return false;
});
if (!valid)
throw Exception("Illegal types " + arguments[0]->getName() + " and " + arguments[1]->getName() + " of arguments of function " + getName(),
throw Exception("Illegal types " + arguments[0]->getName() + " and " + arguments[1]->getName() + " of arguments of function " + String(name),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return type_res;
}
@ -1110,7 +1118,8 @@ public:
}
/// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
if (auto function_builder = getFunctionForIntervalArithmetic(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type))
if (auto function_builder
= getFunctionForIntervalArithmetic(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type, context))
{
executeDateTimeIntervalPlusMinus(block, arguments, result, input_rows_count, function_builder);
return;
@ -1200,4 +1209,167 @@ public:
bool canBeExecutedOnDefaultArguments() const override { return valid_on_default_arguments; }
};
template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true>
class FunctionBinaryArithmeticWithConstants : public FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments>
{
public:
using Base = FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments>;
using Monotonicity = typename Base::Monotonicity;
static FunctionPtr create(const ColumnWithTypeAndName & left_, const ColumnWithTypeAndName & right_, const Context & context)
{
return std::make_shared<FunctionBinaryArithmeticWithConstants>(left_, right_, context);
}
FunctionBinaryArithmeticWithConstants(
const ColumnWithTypeAndName & left_, const ColumnWithTypeAndName & right_, const Context & context_)
: Base(context_), left(left_), right(right_)
{
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) const override
{
if (left.column && isColumnConst(*left.column) && arguments.size() == 1)
{
Block block_with_constant
= {{left.column->cloneResized(input_rows_count), left.type, left.name},
block.getByPosition(arguments[0]),
block.getByPosition(result)};
Base::executeImpl(block_with_constant, {0, 1}, 2, input_rows_count);
block.getByPosition(result) = block_with_constant.getByPosition(2);
}
else if (right.column && isColumnConst(*right.column) && arguments.size() == 1)
{
Block block_with_constant
= {block.getByPosition(arguments[0]),
{right.column->cloneResized(input_rows_count), right.type, right.name},
block.getByPosition(result)};
Base::executeImpl(block_with_constant, {0, 1}, 2, input_rows_count);
block.getByPosition(result) = block_with_constant.getByPosition(2);
}
else
Base::executeImpl(block, arguments, result, input_rows_count);
}
bool hasInformationAboutMonotonicity() const override
{
std::string_view name_ = Name::name;
if (name_ == "minus" || name_ == "plus" || name_ == "multiply" || name_ == "divide" || name_ == "intDiv")
{
return true;
}
return false;
}
Monotonicity getMonotonicityForRange(const IDataType &, const Field & left_point, const Field & right_point) const override
{
std::string_view name_ = Name::name;
if (name_ == "minus" || name_ == "plus")
{
return {true, true, true};
}
if (name_ == "multiply" || name_ == "divide" || name_ == "intDiv")
{
if (!left.column)
{
bool positive = true;
if (WhichDataType(right.type).isInt())
{
positive = right.column->getInt(0) >= 0;
}
if (WhichDataType(left.type).isUInt())
return {true, positive, true};
else if (WhichDataType(left.type).isInt())
{
if (left_point.get<Int64>() == right_point.get<Int64>())
return {true, positive, true};
if (left_point.get<Int64>() >= 0)
return {true, positive, false};
else if (right_point.get<Int64>() <= 0)
return {true, !positive, false};
else
return {false, true, false};
}
}
if (!right.column)
{
bool positive = true;
if (WhichDataType(left.type).isInt())
{
positive = right.column->getInt(0) >= 0;
}
if (WhichDataType(left.type).isUInt())
return {true, !positive, true};
else if (WhichDataType(left.type).isInt())
{
if (left_point.get<Int64>() == right_point.get<Int64>())
return {true, !positive, true};
if (left_point.get<Int64>() >= 0)
return {true, !positive, false};
else if (right_point.get<Int64>() <= 0)
return {true, positive, false};
else
return {false, true, false};
}
}
return {true, true, true}; // both arguments are constants
}
return {false, true, false};
}
private:
ColumnWithTypeAndName left;
ColumnWithTypeAndName right;
};
template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true>
class BinaryArithmeticOverloadResolver : public IFunctionOverloadResolverImpl
{
public:
static constexpr auto name = Name::name;
static FunctionOverloadResolverImplPtr create(const Context & context)
{
return std::make_unique<BinaryArithmeticOverloadResolver>(context);
}
explicit BinaryArithmeticOverloadResolver(const Context & context_) : context(context_) {}
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 2; }
bool isVariadic() const override { return false; }
FunctionBaseImplPtr build(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type) const override
{
/// More efficient specialization for two numeric arguments.
if (arguments.size() == 2
&& ((arguments[0].column && isColumnConst(*arguments[0].column))
|| (arguments[1].column && isColumnConst(*arguments[1].column))))
{
return std::make_unique<DefaultFunction>(
FunctionBinaryArithmeticWithConstants<Op, Name, valid_on_default_arguments>::create(arguments[0], arguments[1], context),
ext::map<DataTypes>(arguments, [](const auto & elem) { return elem.type; }),
return_type);
}
return std::make_unique<DefaultFunction>(
FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments>::create(context),
ext::map<DataTypes>(arguments, [](const auto & elem) { return elem.type; }),
return_type);
}
DataTypePtr getReturnType(const DataTypes & arguments) const override
{
if (arguments.size() != 2)
throw Exception(
"Number of arguments for function " + getName() + " doesn't match: passed " + toString(arguments.size()) + ", should be 2",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
return FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments>::getReturnTypeImplStatic(arguments, context);
}
private:
const Context & context;
};
}

View File

@ -37,7 +37,7 @@ struct BitAndImpl
};
struct NameBitAnd { static constexpr auto name = "bitAnd"; };
using FunctionBitAnd = FunctionBinaryArithmetic<BitAndImpl, NameBitAnd, true>;
using FunctionBitAnd = BinaryArithmeticOverloadResolver<BitAndImpl, NameBitAnd, true>;
}

View File

@ -42,7 +42,7 @@ struct BitBoolMaskAndImpl
};
struct NameBitBoolMaskAnd { static constexpr auto name = "__bitBoolMaskAnd"; };
using FunctionBitBoolMaskAnd = FunctionBinaryArithmetic<BitBoolMaskAndImpl, NameBitBoolMaskAnd>;
using FunctionBitBoolMaskAnd = BinaryArithmeticOverloadResolver<BitBoolMaskAndImpl, NameBitBoolMaskAnd>;
}

View File

@ -42,7 +42,7 @@ struct BitBoolMaskOrImpl
};
struct NameBitBoolMaskOr { static constexpr auto name = "__bitBoolMaskOr"; };
using FunctionBitBoolMaskOr = FunctionBinaryArithmetic<BitBoolMaskOrImpl, NameBitBoolMaskOr>;
using FunctionBitBoolMaskOr = BinaryArithmeticOverloadResolver<BitBoolMaskOrImpl, NameBitBoolMaskOr>;
}

View File

@ -36,7 +36,7 @@ struct BitOrImpl
};
struct NameBitOr { static constexpr auto name = "bitOr"; };
using FunctionBitOr = FunctionBinaryArithmetic<BitOrImpl, NameBitOr, true>;
using FunctionBitOr = BinaryArithmeticOverloadResolver<BitOrImpl, NameBitOr, true>;
}

View File

@ -43,7 +43,7 @@ struct BitRotateLeftImpl
};
struct NameBitRotateLeft { static constexpr auto name = "bitRotateLeft"; };
using FunctionBitRotateLeft = FunctionBinaryArithmetic<BitRotateLeftImpl, NameBitRotateLeft>;
using FunctionBitRotateLeft = BinaryArithmeticOverloadResolver<BitRotateLeftImpl, NameBitRotateLeft>;
}

View File

@ -42,7 +42,7 @@ struct BitRotateRightImpl
};
struct NameBitRotateRight { static constexpr auto name = "bitRotateRight"; };
using FunctionBitRotateRight = FunctionBinaryArithmetic<BitRotateRightImpl, NameBitRotateRight>;
using FunctionBitRotateRight = BinaryArithmeticOverloadResolver<BitRotateRightImpl, NameBitRotateRight>;
}

View File

@ -42,7 +42,7 @@ struct BitShiftLeftImpl
};
struct NameBitShiftLeft { static constexpr auto name = "bitShiftLeft"; };
using FunctionBitShiftLeft = FunctionBinaryArithmetic<BitShiftLeftImpl, NameBitShiftLeft>;
using FunctionBitShiftLeft = BinaryArithmeticOverloadResolver<BitShiftLeftImpl, NameBitShiftLeft>;
}

View File

@ -42,7 +42,7 @@ struct BitShiftRightImpl
};
struct NameBitShiftRight { static constexpr auto name = "bitShiftRight"; };
using FunctionBitShiftRight = FunctionBinaryArithmetic<BitShiftRightImpl, NameBitShiftRight>;
using FunctionBitShiftRight = BinaryArithmeticOverloadResolver<BitShiftRightImpl, NameBitShiftRight>;
}

View File

@ -34,7 +34,7 @@ struct BitTestImpl
};
struct NameBitTest { static constexpr auto name = "bitTest"; };
using FunctionBitTest = FunctionBinaryArithmetic<BitTestImpl, NameBitTest>;
using FunctionBitTest = BinaryArithmeticOverloadResolver<BitTestImpl, NameBitTest>;
}

View File

@ -36,7 +36,7 @@ struct BitXorImpl
};
struct NameBitXor { static constexpr auto name = "bitXor"; };
using FunctionBitXor = FunctionBinaryArithmetic<BitXorImpl, NameBitXor, true>;
using FunctionBitXor = BinaryArithmeticOverloadResolver<BitXorImpl, NameBitXor, true>;
}

View File

@ -37,7 +37,7 @@ struct DivideFloatingImpl
};
struct NameDivide { static constexpr auto name = "divide"; };
using FunctionDivide = FunctionBinaryArithmetic<DivideFloatingImpl, NameDivide>;
using FunctionDivide = BinaryArithmeticOverloadResolver<DivideFloatingImpl, NameDivide>;
void registerFunctionDivide(FunctionFactory & factory)
{

View File

@ -40,7 +40,7 @@ struct GCDImpl
};
struct NameGCD { static constexpr auto name = "gcd"; };
using FunctionGCD = FunctionBinaryArithmetic<GCDImpl, NameGCD, false>;
using FunctionGCD = BinaryArithmeticOverloadResolver<GCDImpl, NameGCD, false>;
}

View File

@ -110,7 +110,7 @@ template <> struct BinaryOperationImpl<Int32, Int64, DivideIntegralImpl<Int32, I
struct NameIntDiv { static constexpr auto name = "intDiv"; };
using FunctionIntDiv = FunctionBinaryArithmetic<DivideIntegralImpl, NameIntDiv, false>;
using FunctionIntDiv = BinaryArithmeticOverloadResolver<DivideIntegralImpl, NameIntDiv, false>;
void registerFunctionIntDiv(FunctionFactory & factory)
{

View File

@ -26,7 +26,7 @@ struct DivideIntegralOrZeroImpl
};
struct NameIntDivOrZero { static constexpr auto name = "intDivOrZero"; };
using FunctionIntDivOrZero = FunctionBinaryArithmetic<DivideIntegralOrZeroImpl, NameIntDivOrZero>;
using FunctionIntDivOrZero = BinaryArithmeticOverloadResolver<DivideIntegralOrZeroImpl, NameIntDivOrZero>;
void registerFunctionIntDivOrZero(FunctionFactory & factory)
{

View File

@ -78,7 +78,7 @@ struct LCMImpl
};
struct NameLCM { static constexpr auto name = "lcm"; };
using FunctionLCM = FunctionBinaryArithmetic<LCMImpl, NameLCM, false>;
using FunctionLCM = BinaryArithmeticOverloadResolver<LCMImpl, NameLCM, false>;
}

View File

@ -43,7 +43,7 @@ struct MinusImpl
};
struct NameMinus { static constexpr auto name = "minus"; };
using FunctionMinus = FunctionBinaryArithmetic<MinusImpl, NameMinus>;
using FunctionMinus = BinaryArithmeticOverloadResolver<MinusImpl, NameMinus>;
void registerFunctionMinus(FunctionFactory & factory)
{

View File

@ -101,7 +101,7 @@ template <> struct BinaryOperationImpl<Int32, Int64, ModuloImpl<Int32, Int64>> :
struct NameModulo { static constexpr auto name = "modulo"; };
using FunctionModulo = FunctionBinaryArithmetic<ModuloImpl, NameModulo, false>;
using FunctionModulo = BinaryArithmeticOverloadResolver<ModuloImpl, NameModulo, false>;
void registerFunctionModulo(FunctionFactory & factory)
{

View File

@ -36,7 +36,7 @@ struct ModuloOrZeroImpl
};
struct NameModuloOrZero { static constexpr auto name = "moduloOrZero"; };
using FunctionModuloOrZero = FunctionBinaryArithmetic<ModuloOrZeroImpl, NameModuloOrZero>;
using FunctionModuloOrZero = BinaryArithmeticOverloadResolver<ModuloOrZeroImpl, NameModuloOrZero>;
}

View File

@ -43,7 +43,7 @@ struct MultiplyImpl
};
struct NameMultiply { static constexpr auto name = "multiply"; };
using FunctionMultiply = FunctionBinaryArithmetic<MultiplyImpl, NameMultiply>;
using FunctionMultiply = BinaryArithmeticOverloadResolver<MultiplyImpl, NameMultiply>;
void registerFunctionMultiply(FunctionFactory & factory)
{

View File

@ -45,7 +45,7 @@ struct PlusImpl
};
struct NamePlus { static constexpr auto name = "plus"; };
using FunctionPlus = FunctionBinaryArithmetic<PlusImpl, NamePlus>;
using FunctionPlus = BinaryArithmeticOverloadResolver<PlusImpl, NamePlus>;
void registerFunctionPlus(FunctionFactory & factory)
{

View File

@ -1,6 +1,7 @@
#include <Storages/MergeTree/KeyCondition.h>
#include <Storages/MergeTree/BoolMask.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/FieldToDataType.h>
#include <Interpreters/TreeRewriter.h>
#include <Interpreters/ExpressionAnalyzer.h>
#include <Interpreters/ExpressionActions.h>
@ -711,8 +712,26 @@ bool KeyCondition::isKeyPossiblyWrappedByMonotonicFunctions(
for (auto it = chain_not_tested_for_monotonicity.rbegin(); it != chain_not_tested_for_monotonicity.rend(); ++it)
{
const auto & args = (*it)->arguments->children;
auto func_builder = FunctionFactory::instance().tryGet((*it)->name, context);
ColumnsWithTypeAndName arguments{{ nullptr, key_column_type, "" }};
ColumnsWithTypeAndName arguments;
if (args.size() == 2)
{
if (const auto * arg_left = args[0]->as<ASTLiteral>())
{
auto left_arg_type = applyVisitor(FieldToDataType(), arg_left->value);
arguments.push_back({ left_arg_type->createColumnConst(0, arg_left->value), left_arg_type, "" });
arguments.push_back({ nullptr, key_column_type, "" });
}
else if (const auto * arg_right = args[1]->as<ASTLiteral>())
{
arguments.push_back({ nullptr, key_column_type, "" });
auto right_arg_type = applyVisitor(FieldToDataType(), arg_right->value);
arguments.push_back({ right_arg_type->createColumnConst(0, arg_right->value), right_arg_type, "" });
}
}
else
arguments.push_back({ nullptr, key_column_type, "" });
auto func = func_builder->build(arguments);
if (!func || !func->hasInformationAboutMonotonicity())
@ -750,12 +769,27 @@ bool KeyCondition::isKeyPossiblyWrappedByMonotonicFunctionsImpl(
if (const auto * func = node->as<ASTFunction>())
{
const auto & args = func->arguments->children;
if (args.size() != 1)
if (args.size() > 2)
return false;
out_functions_chain.push_back(func);
return isKeyPossiblyWrappedByMonotonicFunctionsImpl(args[0], out_key_column_num, out_key_column_type, out_functions_chain);
bool ret = false;
if (args.size() == 2)
{
if (args[0]->as<ASTLiteral>())
{
ret = isKeyPossiblyWrappedByMonotonicFunctionsImpl(args[1], out_key_column_num, out_key_column_type, out_functions_chain);
}
else if (args[1]->as<ASTLiteral>())
{
ret = isKeyPossiblyWrappedByMonotonicFunctionsImpl(args[0], out_key_column_num, out_key_column_type, out_functions_chain);
}
}
else
{
ret = isKeyPossiblyWrappedByMonotonicFunctionsImpl(args[0], out_key_column_num, out_key_column_type, out_functions_chain);
}
return ret;
}
return false;

View File

@ -0,0 +1,10 @@
DROP TABLE IF EXISTS binary_op_mono;
CREATE TABLE binary_op_mono(i int, j int) ENGINE MergeTree PARTITION BY toDate(i / 1000) ORDER BY j;
INSERT INTO binary_op_mono VALUES (toUnixTimestamp('2020-09-01 00:00:00') * 1000, 1), (toUnixTimestamp('2020-09-01 00:00:00') * 1000, 2);
SET max_rows_to_read = 1;
SELECT * FROM binary_op_mono WHERE toDate(i / 1000) = '2020-09-02';
DROP TABLE IF EXISTS binary_op_mono;