more refactoring of FunctionToSubcolumnsPass

This commit is contained in:
Anton Popov 2024-02-07 18:19:15 +00:00
parent 46f6867896
commit 361b5a2077

View File

@ -26,31 +26,40 @@ namespace DB
namespace namespace
{ {
void optimizeFunctionLength(QueryTreeNodePtr & node, FunctionNode &, ColumnNode & column_node, ContextPtr) struct ColumnContext
{
NameAndTypePair column;
QueryTreeNodePtr column_source;
ContextPtr context;
};
using NodeToSubcolumnTransformer = std::function<void(QueryTreeNodePtr &, FunctionNode &, ColumnContext &)>;
void optimizeFunctionLength(QueryTreeNodePtr & node, FunctionNode &, ColumnContext & ctx)
{ {
/// Replace `length(argument)` with `argument.size0` /// Replace `length(argument)` with `argument.size0`
/// `argument` may be Array or Map. /// `argument` may be Array or Map.
NameAndTypePair column{column_node.getColumnName() + ".size0", std::make_shared<DataTypeUInt64>()}; NameAndTypePair column{ctx.column.name + ".size0", std::make_shared<DataTypeUInt64>()};
node = std::make_shared<ColumnNode>(column, column_node.getColumnSource()); node = std::make_shared<ColumnNode>(column, ctx.column_source);
} }
template <bool positive> template <bool positive>
void optimizeFunctionEmpty(QueryTreeNodePtr &, FunctionNode & function_node, ColumnNode & column_node, ContextPtr context) void optimizeFunctionEmpty(QueryTreeNodePtr &, FunctionNode & function_node, ColumnContext & ctx)
{ {
/// Replace `empty(argument)` with `equals(argument.size0, 0)` if positive /// Replace `empty(argument)` with `equals(argument.size0, 0)` if positive
/// Replace `notEmpty(argument)` with `notEquals(argument.size0, 0)` if not positive /// Replace `notEmpty(argument)` with `notEquals(argument.size0, 0)` if not positive
/// `argument` may be Array or Map. /// `argument` may be Array or Map.
NameAndTypePair column{column_node.getColumnName() + ".size0", std::make_shared<DataTypeUInt64>()}; NameAndTypePair column{ctx.column.name + ".size0", std::make_shared<DataTypeUInt64>()};
auto & function_arguments_nodes = function_node.getArguments().getNodes(); auto & function_arguments_nodes = function_node.getArguments().getNodes();
function_arguments_nodes.clear(); function_arguments_nodes.clear();
function_arguments_nodes.push_back(std::make_shared<ColumnNode>(column, column_node.getColumnSource())); function_arguments_nodes.push_back(std::make_shared<ColumnNode>(column, ctx.column_source));
function_arguments_nodes.push_back(std::make_shared<ConstantNode>(static_cast<UInt64>(0))); function_arguments_nodes.push_back(std::make_shared<ConstantNode>(static_cast<UInt64>(0)));
auto function_name = positive ? "equals" : "notEquals"; auto function_name = positive ? "equals" : "notEquals";
resolveOrdinaryFunctionNodeByName(function_node, function_name, std::move(context)); resolveOrdinaryFunctionNodeByName(function_node, function_name, std::move(ctx.context));
} }
String getSubcolumnNameForElement(const Field & value, const DataTypeTuple & data_type_tuple) String getSubcolumnNameForElement(const Field & value, const DataTypeTuple & data_type_tuple)
@ -73,7 +82,7 @@ String getSubcolumnNameForElement(const Field & value, const DataTypeVariant &)
} }
template <typename DataType> template <typename DataType>
void optimizeTupleOrVariantElement(QueryTreeNodePtr & node, FunctionNode & function_node, ColumnNode & column_node, ContextPtr) void optimizeTupleOrVariantElement(QueryTreeNodePtr & node, FunctionNode & function_node, ColumnContext & ctx)
{ {
/// Replace `tupleElement(tuple_argument, string_literal)`, `tupleElement(tuple_argument, integer_literal)` with `tuple_argument.column_name`. /// Replace `tupleElement(tuple_argument, string_literal)`, `tupleElement(tuple_argument, integer_literal)` with `tuple_argument.column_name`.
/// Replace `variantElement(variant_argument, string_literal)` with `variant_argument.column_name`. /// Replace `variantElement(variant_argument, string_literal)` with `variant_argument.column_name`.
@ -86,19 +95,16 @@ void optimizeTupleOrVariantElement(QueryTreeNodePtr & node, FunctionNode & funct
if (!second_argument_constant_node) if (!second_argument_constant_node)
return; return;
auto column_type = column_node.getColumnType(); const auto & data_type_concrete = assert_cast<const DataType &>(*ctx.column.type);
const auto & data_type_concrete = assert_cast<const DataType &>(*column_type);
auto subcolumn_name = getSubcolumnNameForElement(second_argument_constant_node->getValue(), data_type_concrete); auto subcolumn_name = getSubcolumnNameForElement(second_argument_constant_node->getValue(), data_type_concrete);
if (subcolumn_name.empty()) if (subcolumn_name.empty())
return; return;
NameAndTypePair column{column_node.getColumnName() + "." + subcolumn_name, function_node.getResultType()}; NameAndTypePair column{ctx.column.name + "." + subcolumn_name, function_node.getResultType()};
node = std::make_shared<ColumnNode>(column, column_node.getColumnSource()); node = std::make_shared<ColumnNode>(column, ctx.column_source);
} }
using NodeToSubcolumnTransformer = std::function<void(QueryTreeNodePtr &, FunctionNode &, ColumnNode &, ContextPtr)>;
std::map<std::pair<TypeIndex, String>, NodeToSubcolumnTransformer> node_transformers = std::map<std::pair<TypeIndex, String>, NodeToSubcolumnTransformer> node_transformers =
{ {
{ {
@ -121,52 +127,51 @@ std::map<std::pair<TypeIndex, String>, NodeToSubcolumnTransformer> node_transfor
}, },
{ {
{TypeIndex::Map, "mapKeys"}, {TypeIndex::Map, "mapKeys"},
[](QueryTreeNodePtr & node, FunctionNode & function_node, ColumnNode & column_node, ContextPtr) [](QueryTreeNodePtr & node, FunctionNode & function_node, ColumnContext & ctx)
{ {
/// Replace `mapKeys(map_argument)` with `map_argument.keys` /// Replace `mapKeys(map_argument)` with `map_argument.keys`
NameAndTypePair column{column_node.getColumnName() + ".keys", function_node.getResultType()}; NameAndTypePair column{ctx.column.name + ".keys", function_node.getResultType()};
node = std::make_shared<ColumnNode>(column, column_node.getColumnSource()); node = std::make_shared<ColumnNode>(column, ctx.column_source);
}, },
}, },
{ {
{TypeIndex::Map, "mapValues"}, {TypeIndex::Map, "mapValues"},
[](QueryTreeNodePtr & node, FunctionNode & function_node, ColumnNode & column_node, ContextPtr) [](QueryTreeNodePtr & node, FunctionNode & function_node, ColumnContext & ctx)
{ {
/// Replace `mapValues(map_argument)` with `map_argument.values` /// Replace `mapValues(map_argument)` with `map_argument.values`
NameAndTypePair column{column_node.getColumnName() + ".values", function_node.getResultType()}; NameAndTypePair column{ctx.column.name + ".values", function_node.getResultType()};
node = std::make_shared<ColumnNode>(column, column_node.getColumnSource()); node = std::make_shared<ColumnNode>(column, ctx.column_source);
}, },
}, },
{ {
{TypeIndex::Map, "mapContains"}, {TypeIndex::Map, "mapContains"},
[](QueryTreeNodePtr &, FunctionNode & function_node, ColumnNode & column_node, ContextPtr context) [](QueryTreeNodePtr &, FunctionNode & function_node, ColumnContext & ctx)
{ {
/// Replace `mapContains(map_argument, argument)` with `has(map_argument.keys, argument)` /// Replace `mapContains(map_argument, argument)` with `has(map_argument.keys, argument)`
auto column_type = column_node.getColumnType(); const auto & data_type_map = assert_cast<const DataTypeMap &>(*ctx.column.type);
const auto & data_type_map = assert_cast<const DataTypeMap &>(*column_type);
NameAndTypePair column{column_node.getColumnName() + ".keys", std::make_shared<DataTypeArray>(data_type_map.getKeyType())}; NameAndTypePair column{ctx.column.name + ".keys", std::make_shared<DataTypeArray>(data_type_map.getKeyType())};
auto & function_arguments_nodes = function_node.getArguments().getNodes(); auto & function_arguments_nodes = function_node.getArguments().getNodes();
auto has_function_argument = std::make_shared<ColumnNode>(column, column_node.getColumnSource()); auto has_function_argument = std::make_shared<ColumnNode>(column, ctx.column_source);
function_arguments_nodes[0] = std::move(has_function_argument); function_arguments_nodes[0] = std::move(has_function_argument);
resolveOrdinaryFunctionNodeByName(function_node, "has", context); resolveOrdinaryFunctionNodeByName(function_node, "has", ctx.context);
}, },
}, },
{ {
{TypeIndex::Nullable, "count"}, {TypeIndex::Nullable, "count"},
[](QueryTreeNodePtr &, FunctionNode & function_node, ColumnNode & column_node, ContextPtr context) [](QueryTreeNodePtr &, FunctionNode & function_node, ColumnContext & ctx)
{ {
/// Replace `count(nullable_argument)` with `sum(not(nullable_argument.null))` /// Replace `count(nullable_argument)` with `sum(not(nullable_argument.null))`
NameAndTypePair column{column_node.getColumnName() + ".null", std::make_shared<DataTypeUInt8>()}; NameAndTypePair column{ctx.column.name + ".null", std::make_shared<DataTypeUInt8>()};
auto & function_arguments_nodes = function_node.getArguments().getNodes(); auto & function_arguments_nodes = function_node.getArguments().getNodes();
auto new_column_node = std::make_shared<ColumnNode>(column, column_node.getColumnSource()); auto new_column_node = std::make_shared<ColumnNode>(column, ctx.column_source);
auto function_node_not = std::make_shared<FunctionNode>("not"); auto function_node_not = std::make_shared<FunctionNode>("not");
function_node_not->getArguments().getNodes().push_back(std::move(new_column_node)); function_node_not->getArguments().getNodes().push_back(std::move(new_column_node));
resolveOrdinaryFunctionNodeByName(*function_node_not, "not", context); resolveOrdinaryFunctionNodeByName(*function_node_not, "not", ctx.context);
function_arguments_nodes = {std::move(function_node_not)}; function_arguments_nodes = {std::move(function_node_not)};
resolveAggregateFunctionNodeByName(function_node, "sum"); resolveAggregateFunctionNodeByName(function_node, "sum");
@ -174,23 +179,23 @@ std::map<std::pair<TypeIndex, String>, NodeToSubcolumnTransformer> node_transfor
}, },
{ {
{TypeIndex::Nullable, "isNull"}, {TypeIndex::Nullable, "isNull"},
[](QueryTreeNodePtr & node, FunctionNode &, ColumnNode & column_node, ContextPtr) [](QueryTreeNodePtr & node, FunctionNode &, ColumnContext & ctx)
{ {
/// Replace `isNull(nullable_argument)` with `nullable_argument.null` /// Replace `isNull(nullable_argument)` with `nullable_argument.null`
NameAndTypePair column{column_node.getColumnName() + ".null", std::make_shared<DataTypeUInt8>()}; NameAndTypePair column{ctx.column.name + ".null", std::make_shared<DataTypeUInt8>()};
node = std::make_shared<ColumnNode>(column, column_node.getColumnSource()); node = std::make_shared<ColumnNode>(column, ctx.column_source);
}, },
}, },
{ {
{TypeIndex::Nullable, "isNotNull"}, {TypeIndex::Nullable, "isNotNull"},
[](QueryTreeNodePtr &, FunctionNode & function_node, ColumnNode & column_node, ContextPtr context) [](QueryTreeNodePtr &, FunctionNode & function_node, ColumnContext & ctx)
{ {
/// Replace `isNotNull(nullable_argument)` with `not(nullable_argument.null)` /// Replace `isNotNull(nullable_argument)` with `not(nullable_argument.null)`
NameAndTypePair column{column_node.getColumnName() + ".null", std::make_shared<DataTypeUInt8>()}; NameAndTypePair column{ctx.column.name + ".null", std::make_shared<DataTypeUInt8>()};
auto & function_arguments_nodes = function_node.getArguments().getNodes(); auto & function_arguments_nodes = function_node.getArguments().getNodes();
function_arguments_nodes = {std::make_shared<ColumnNode>(column, column_node.getColumnSource())}; function_arguments_nodes = {std::make_shared<ColumnNode>(column, ctx.column_source)};
resolveOrdinaryFunctionNodeByName(function_node, "not", context); resolveOrdinaryFunctionNodeByName(function_node, "not", ctx.context);
}, },
}, },
{ {
@ -380,7 +385,10 @@ public:
auto transformer_it = node_transformers.find({column.type->getTypeId(), function_node->getFunctionName()}); auto transformer_it = node_transformers.find({column.type->getTypeId(), function_node->getFunctionName()});
if (transformer_it != node_transformers.end()) if (transformer_it != node_transformers.end())
transformer_it->second(node, *function_node, *first_argument_column_node, getContext()); {
ColumnContext ctx{std::move(column), first_argument_column_node->getColumnSource(), getContext()};
transformer_it->second(node, *function_node, ctx);
}
} }
}; };