Add some optimizations and fixes

This commit is contained in:
Pavel Kruglov 2021-05-14 17:07:24 +03:00
parent fd56210652
commit cdbe4951f4
10 changed files with 52 additions and 51 deletions

View File

@ -69,6 +69,7 @@ INSTANTIATE(Decimal128)
INSTANTIATE(Decimal256)
INSTANTIATE(DateTime64)
INSTANTIATE(char *)
INSTANTIATE(UUID)
#undef INSTANTIATE
@ -251,14 +252,18 @@ void executeColumnIfNeeded(ColumnWithTypeAndName & column)
column = typeid_cast<const ColumnFunction *>(column_function)->reduce();
}
bool checkArgumentsForColumnFunction(const ColumnsWithTypeAndName & arguments)
int checkShirtCircuitArguments(const ColumnsWithTypeAndName & arguments)
{
for (const auto & arg : arguments)
int last_short_circuit_argument_index = -1;
for (size_t i = 0; i != arguments.size(); ++i)
{
if (checkAndGetColumn<ColumnFunction>(*arg.column))
return true;
const auto * column_func = checkAndGetColumn<ColumnFunction>(*arguments[i].column);
if (column_func && column_func->isShortCircuitArgument())
last_short_circuit_argument_index = i;
}
return false;
return last_short_circuit_argument_index;
}
}

View File

@ -23,6 +23,6 @@ void maskedExecute(ColumnWithTypeAndName & column, const PaddedPODArray<UInt8> &
void executeColumnIfNeeded(ColumnWithTypeAndName & column);
bool checkArgumentsForColumnFunction(const ColumnsWithTypeAndName & arguments);
int checkShirtCircuitArguments(const ColumnsWithTypeAndName & arguments);
}

View File

@ -100,6 +100,7 @@ public:
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 0; }
bool isVariadic() const override { return true; }
bool isSuitableForShortCircuitArgumentsExecution() const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override

View File

@ -514,22 +514,20 @@ void FunctionAnyArityLogical<Impl, Name>::executeShortCircuitArguments(ColumnsWi
if (Name::name != NameAnd::name && Name::name != NameOr::name)
throw Exception("Function " + getName() + " doesn't support short circuit execution", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!checkArgumentsForColumnFunction(arguments))
int last_short_circuit_argument_index = checkShirtCircuitArguments(arguments);
if (last_short_circuit_argument_index < 0)
return;
bool reverse = Name::name != NameAnd::name;
UInt8 null_value = Name::name == NameAnd::name ? 1 : 0;
UInt8 value_for_mask_expanding = Name::name == NameAnd::name ? 0 : 1;
executeColumnIfNeeded(arguments[0]);
IColumn::Filter mask;
getMaskFromColumn(arguments[0].column, mask, reverse, nullptr, null_value);
for (size_t i = 1; i < arguments.size(); ++i)
for (int i = 1; i <= last_short_circuit_argument_index; ++i)
{
if (isColumnFunction(*arguments[i].column))
maskedExecute(arguments[i], mask, false, &value_for_mask_expanding);
getMaskFromColumn(arguments[i].column, mask, reverse, nullptr, null_value);
getMaskFromColumn(arguments[i - 1].column, mask, reverse, nullptr, null_value);
maskedExecute(arguments[i], mask, false, &value_for_mask_expanding);
}
}

View File

@ -923,11 +923,12 @@ public:
void executeShortCircuitArguments(ColumnsWithTypeAndName & arguments) const override
{
if (!checkArgumentsForColumnFunction(arguments))
int last_short_circuit_argument_index = checkShirtCircuitArguments(arguments);
if (last_short_circuit_argument_index < 0)
return;
executeColumnIfNeeded(arguments[0]);
if (isColumnFunction(*arguments[1].column) || isColumnFunction(*arguments[2].column))
if (last_short_circuit_argument_index > 0)
{
IColumn::Filter mask;
getMaskFromColumn(arguments[0].column, mask);

View File

@ -111,7 +111,8 @@ public:
void executeShortCircuitArguments(ColumnsWithTypeAndName & arguments) const override
{
if (!checkArgumentsForColumnFunction(arguments))
int last_short_circuit_argument_index = checkShirtCircuitArguments(arguments);
if (last_short_circuit_argument_index < 0)
return;
executeColumnIfNeeded(arguments[0]);
@ -119,24 +120,19 @@ public:
IColumn::Filter mask_disjunctions = IColumn::Filter(arguments[0].column->size(), 0);
UInt8 default_value_for_mask_expanding = 0;
size_t i = 1;
while (i < arguments.size())
int i = 1;
while (i <= last_short_circuit_argument_index)
{
getMaskFromColumn(arguments[i - 1].column, current_mask);
disjunctionMasks(mask_disjunctions, current_mask);
if (isColumnFunction(*arguments[i].column))
maskedExecute(arguments[i], current_mask);
maskedExecute(arguments[i], current_mask);
++i;
if (i > last_short_circuit_argument_index)
break;
if (isColumnFunction(*arguments[i].column))
{
if (i < arguments.size() - 1)
maskedExecute(arguments[i], mask_disjunctions, true, &default_value_for_mask_expanding);
else
maskedExecute(arguments[i], mask_disjunctions, true);
}
disjunctionMasks(mask_disjunctions, current_mask);
UInt8 * default_value_ptr = i + 1 == int(arguments.size()) ? nullptr: &default_value_for_mask_expanding;
maskedExecute(arguments[i], mask_disjunctions, true, default_value_ptr);
++i;
}
}

View File

@ -67,6 +67,13 @@ public:
using NodeRawPtrs = std::vector<Node *>;
using NodeRawConstPtrs = std::vector<const Node *>;
enum class LazyExecution
{
DISABLED,
ENABLED,
FORCE_ENABLED,
};
struct Node
{
NodeRawConstPtrs children;
@ -90,9 +97,9 @@ public:
/// For COLUMN node and propagated constants.
ColumnPtr column;
void toTree(JSONBuilder::JSONMap & map) const;
LazyExecution lazy_execution = LazyExecution::DISABLED;
bool is_lazy_executed = false;
void toTree(JSONBuilder::JSONMap & map) const;
};
/// NOTE: std::list is an implementation detail.

View File

@ -6,6 +6,7 @@
#include <Interpreters/Context.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnFunction.h>
#include <Columns/MaskOperations.h>
#include <Common/typeid_cast.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypesNumber.h>
@ -69,38 +70,26 @@ ExpressionActionsPtr ExpressionActions::clone() const
return std::make_shared<ExpressionActions>(*this);
}
bool ExpressionActions::rewriteShortCircuitArguments(const ActionsDAG::NodeRawConstPtrs & children, const std::unordered_map<const ActionsDAG::Node *, bool> & need_outside, bool force_rewrite)
void ExpressionActions::rewriteShortCircuitArguments(const ActionsDAG::NodeRawConstPtrs & children, const std::unordered_map<const ActionsDAG::Node *, bool> & need_outside, bool force_rewrite)
{
bool have_rewritten_child = false;
for (const auto * child : children)
{
if (!need_outside.contains(child) || need_outside.at(child))
if (!need_outside.contains(child) || need_outside.at(child) || child->lazy_execution != ActionsDAG::LazyExecution::DISABLED)
continue;
if (child->is_lazy_executed)
{
have_rewritten_child = true;
continue;
}
switch (child->type)
{
case ActionsDAG::ActionType::FUNCTION:
if (rewriteShortCircuitArguments(child->children, need_outside, force_rewrite) || child->function_base->isSuitableForShortCircuitArgumentsExecution() || force_rewrite)
{
const_cast<ActionsDAG::Node *>(child)->is_lazy_executed = true;
have_rewritten_child = true;
}
rewriteShortCircuitArguments(child->children, need_outside, force_rewrite);
const_cast<ActionsDAG::Node *>(child)->lazy_execution = force_rewrite ? ActionsDAG::LazyExecution::FORCE_ENABLED : ActionsDAG::LazyExecution::ENABLED;
break;
case ActionsDAG::ActionType::ALIAS:
have_rewritten_child |= rewriteShortCircuitArguments(child->children, need_outside, force_rewrite);
rewriteShortCircuitArguments(child->children, need_outside, force_rewrite);
break;
default:
break;
}
}
return have_rewritten_child;
}
@ -426,7 +415,9 @@ static void executeAction(const ExpressionActions::Action & action, ExecutionCon
arguments[i] = columns[action.arguments[i].pos];
}
if (action.node->is_lazy_executed)
if (action.node->lazy_execution == ActionsDAG::LazyExecution::FORCE_ENABLED
|| (action.node->lazy_execution == ActionsDAG::LazyExecution::ENABLED
&& (action.node->function_base->isSuitableForShortCircuitArgumentsExecution() || checkShirtCircuitArguments(arguments) >= 0)))
res_column.column = ColumnFunction::create(num_rows, action.node->function_base, std::move(arguments), true);
else
{

View File

@ -133,7 +133,7 @@ private:
void checkLimits(const ColumnsWithTypeAndName & columns) const;
void linearizeActions();
bool rewriteShortCircuitArguments(
void rewriteShortCircuitArguments(
const ActionsDAG::NodeRawConstPtrs & children, const std::unordered_map<const ActionsDAG::Node *, bool> & need_outside, bool force_rewrite);
void rewriteArgumentsForShortCircuitFunctions(

View File

@ -166,6 +166,8 @@ public:
return dag.compile(builder, values);
}
bool isSuitableForShortCircuitArgumentsExecution() const override { return true; }
String getName() const override { return name; }
const DataTypes & getArgumentTypes() const override { return argument_types; }