2023-12-11 12:02:25 +00:00
|
|
|
#include <Interpreters/RewriteSumFunctionWithSumAndCountVisitor.h>
|
|
|
|
#include <Interpreters/IdentifierSemantic.h>
|
|
|
|
#include <Parsers/ASTLiteral.h>
|
|
|
|
#include <Parsers/ASTFunction.h>
|
|
|
|
#include <Parsers/ASTIdentifier.h>
|
|
|
|
|
|
|
|
namespace DB
|
|
|
|
{
|
|
|
|
|
|
|
|
void RewriteSumFunctionWithSumAndCountMatcher::visit(ASTPtr & ast, const Data & data)
|
|
|
|
{
|
|
|
|
if (auto * func = ast->as<ASTFunction>())
|
|
|
|
visit(*func, ast, data);
|
|
|
|
}
|
|
|
|
|
2024-01-19 17:43:06 +00:00
|
|
|
/** Rewrites `sum(column +/- literal)` into two individual functions
|
2023-12-11 12:02:25 +00:00
|
|
|
* `sum(column)` and `literal * count(column)`.
|
2024-01-18 10:31:02 +00:00
|
|
|
* sum(column + literal) -> sum(column) + literal * count(column)
|
|
|
|
* sum(literal + column) -> literal * count(column) + sum(column)
|
|
|
|
* sum(column - literal) -> sum(column) - literal * count(column)
|
|
|
|
* sum(literal - column) -> literal * count(column) - sum(column)
|
2023-12-11 12:02:25 +00:00
|
|
|
*/
|
|
|
|
void RewriteSumFunctionWithSumAndCountMatcher::visit(const ASTFunction & function, ASTPtr & ast, const Data & data)
|
|
|
|
{
|
2024-01-16 14:40:12 +00:00
|
|
|
static const std::unordered_set<String> function_supported = {
|
2023-12-11 12:02:25 +00:00
|
|
|
"plus",
|
|
|
|
"minus"
|
|
|
|
};
|
|
|
|
|
|
|
|
if (!function.arguments || Poco::toLower(function.name) != "sum" || function.arguments->children.size() != 1)
|
|
|
|
return;
|
|
|
|
|
2024-01-16 14:40:12 +00:00
|
|
|
const auto * func_plus_minus = function.arguments->children[0]->as<ASTFunction>();
|
2023-12-11 12:02:25 +00:00
|
|
|
|
2024-01-16 14:40:12 +00:00
|
|
|
if (!func_plus_minus || !function_supported.contains(Poco::toLower(func_plus_minus->name)) || func_plus_minus->arguments->children.size() != 2)
|
2023-12-11 12:02:25 +00:00
|
|
|
return;
|
|
|
|
|
2024-01-19 17:43:06 +00:00
|
|
|
size_t column_id;
|
|
|
|
if (func_plus_minus->arguments->children[0]->as<ASTIdentifier>() && func_plus_minus->arguments->children[1]->as<ASTLiteral>())
|
|
|
|
column_id = 0;
|
|
|
|
else if (func_plus_minus->arguments->children[0]->as<ASTLiteral>() && func_plus_minus->arguments->children[1]->as<ASTIdentifier>())
|
|
|
|
column_id = 1;
|
|
|
|
else
|
2023-12-11 12:02:25 +00:00
|
|
|
return;
|
|
|
|
|
|
|
|
size_t literal_id = 1 - column_id;
|
2024-01-16 14:40:12 +00:00
|
|
|
const auto * literal = func_plus_minus->arguments->children[literal_id]->as<ASTLiteral>();
|
2023-12-21 15:29:29 +00:00
|
|
|
if (!literal)
|
|
|
|
return;
|
|
|
|
|
2024-01-19 17:43:06 +00:00
|
|
|
///all the types listed are numbers and supported by 'plus' and 'minus'.
|
2023-12-21 15:29:29 +00:00
|
|
|
Field::Types::Which literal_type = literal->value.getType();
|
|
|
|
if (literal_type != Field::Types::UInt64 &&
|
|
|
|
literal_type != Field::Types::Int64 &&
|
|
|
|
literal_type != Field::Types::UInt128 &&
|
|
|
|
literal_type != Field::Types::Int128 &&
|
|
|
|
literal_type != Field::Types::UInt256 &&
|
|
|
|
literal_type != Field::Types::Int256 &&
|
|
|
|
literal_type != Field::Types::Float64 &&
|
|
|
|
literal_type != Field::Types::Decimal32 &&
|
|
|
|
literal_type != Field::Types::Decimal64 &&
|
|
|
|
literal_type != Field::Types::Decimal128 &&
|
|
|
|
literal_type != Field::Types::Decimal256)
|
|
|
|
return;
|
2023-12-11 12:02:25 +00:00
|
|
|
|
2024-01-16 14:40:12 +00:00
|
|
|
const auto * column = func_plus_minus->arguments->children[column_id]->as<ASTIdentifier>();
|
2023-12-21 15:29:29 +00:00
|
|
|
if (!column)
|
2023-12-11 12:02:25 +00:00
|
|
|
return;
|
|
|
|
|
|
|
|
auto pos = IdentifierSemantic::getMembership(*column);
|
|
|
|
if (!pos)
|
|
|
|
pos = IdentifierSemantic::chooseTableColumnMatch(*column, data.tables, true);
|
|
|
|
if (!pos)
|
|
|
|
return;
|
|
|
|
|
|
|
|
if (*pos >= data.tables.size())
|
|
|
|
return;
|
|
|
|
|
|
|
|
auto column_type_name = data.tables[*pos].columns.tryGetByName(column->shortName());
|
|
|
|
if (!column_type_name)
|
|
|
|
return;
|
|
|
|
|
|
|
|
const auto column_type = column_type_name->type;
|
|
|
|
if (!column_type || !isNumber(*column_type))
|
|
|
|
return;
|
|
|
|
|
|
|
|
const String & column_name = column_type_name->name;
|
|
|
|
|
2023-12-28 17:14:36 +00:00
|
|
|
if (column_id == 0)
|
|
|
|
{
|
2024-01-16 14:40:12 +00:00
|
|
|
const auto new_ast = makeASTFunction(func_plus_minus->name,
|
2023-12-28 17:14:36 +00:00
|
|
|
makeASTFunction("sum",
|
|
|
|
std::make_shared<ASTIdentifier>(column_name)
|
|
|
|
),
|
|
|
|
makeASTFunction("multiply",
|
|
|
|
std::make_shared<ASTLiteral>(* literal),
|
|
|
|
makeASTFunction("count", std::make_shared<ASTIdentifier>(column_name))
|
|
|
|
)
|
|
|
|
);
|
|
|
|
if (!new_ast)
|
|
|
|
return;
|
|
|
|
else
|
|
|
|
ast = new_ast;
|
|
|
|
}
|
|
|
|
else if (column_id == 1)
|
|
|
|
{
|
2024-01-16 14:40:12 +00:00
|
|
|
const auto new_ast = makeASTFunction(func_plus_minus->name,
|
2023-12-28 17:14:36 +00:00
|
|
|
makeASTFunction("multiply",
|
|
|
|
std::make_shared<ASTLiteral>(* literal),
|
|
|
|
makeASTFunction("count", std::make_shared<ASTIdentifier>(column_name))
|
|
|
|
),
|
|
|
|
makeASTFunction("sum",
|
|
|
|
std::make_shared<ASTIdentifier>(column_name)
|
|
|
|
)
|
|
|
|
);
|
|
|
|
if (!new_ast)
|
|
|
|
return;
|
|
|
|
else
|
|
|
|
ast = new_ast;
|
|
|
|
}
|
2023-12-11 12:02:25 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
}
|