Some fixes around any/all

This commit is contained in:
kssenii 2021-08-14 20:04:21 +03:00
parent 18ab53488f
commit f125fb3fef
8 changed files with 101 additions and 82 deletions

View File

@ -50,6 +50,8 @@ private:
T value;
public:
static constexpr bool is_nullable = false;
bool has() const
{
return has_value;
@ -470,6 +472,8 @@ private:
char small_data[MAX_SMALL_STRING_SIZE]; /// Including the terminating zero.
public:
static constexpr bool is_nullable = false;
bool has() const
{
return size >= 0;
@ -693,6 +697,8 @@ private:
Field value;
public:
static constexpr bool is_nullable = false;
bool has() const
{
return !value.isNull();
@ -979,6 +985,8 @@ struct AggregateFunctionAnyLastData : Data
template <typename Data>
struct AggregateFunctionSingleValueOrNullData : Data
{
static constexpr bool is_nullable = true;
using Self = AggregateFunctionSingleValueOrNullData;
bool first_value = true;
@ -1136,7 +1144,9 @@ public:
DataTypePtr getReturnType() const override
{
auto result_type = this->argument_types.at(0);
return Data::name() == "singleValueOrNull" ? makeNullable(result_type) : result_type;
if constexpr (Data::is_nullable)
return makeNullable(result_type);
return result_type;
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override

View File

@ -5,6 +5,10 @@
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
/*
* Note: there is a difference between intersect and except behaviour.

View File

@ -8,10 +8,6 @@
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
ASTPtr ASTSelectWithUnionQuery::clone() const
{

View File

@ -198,92 +198,84 @@ enum class SubqueryFunctionType
ALL
};
static bool modifyAST(const String & operator_name, ASTPtr function, SubqueryFunctionType type)
static bool modifyAST(ASTPtr ast, SubqueryFunctionType type)
{
// = ANY --> IN, != ALL --> NOT IN
if ((type == SubqueryFunctionType::ANY && operator_name == "equals")
|| (type == SubqueryFunctionType::ALL && operator_name == "notEquals"))
/* Rewrite in AST:
* = ANY --> IN
* != ALL --> NOT IN
* = ALL --> IN (SELECT singleValueOrNull(*) FROM subquery)
* != ANY --> NOT IN (SELECT singleValueOrNull(*) FROM subquery)
**/
auto * function = assert_cast<ASTFunction *>(ast.get());
String operator_name = function->name;
auto function_equals = operator_name == "equals";
auto function_not_equals = operator_name == "notEquals";
String aggregate_function_name;
if (function_equals || function_not_equals)
{
assert_cast<ASTFunction *>(function.get())->name = "in";
if (operator_name == "notEquals")
function->name = "notIn";
else
function->name = "in";
if ((type == SubqueryFunctionType::ANY && function_equals)
|| (type == SubqueryFunctionType::ALL && function_not_equals))
{
auto function_not = std::make_shared<ASTFunction>();
auto exp_list_not = std::make_shared<ASTExpressionList>();
exp_list_not->children.push_back(function);
function_not->name = "not";
function_not->children.push_back(exp_list_not);
function_not->arguments = exp_list_not;
function = function_not;
}
return true;
}
// subquery --> (SELECT aggregate_function(*) FROM subquery)
auto aggregate_function = std::make_shared<ASTFunction>();
auto aggregate_function_exp_list = std::make_shared<ASTExpressionList>();
aggregate_function_exp_list ->children.push_back(std::make_shared<ASTAsterisk>());
aggregate_function->arguments = aggregate_function_exp_list;
aggregate_function->children.push_back(aggregate_function_exp_list);
aggregate_function_name = "singleValueOrNull";
}
else if (operator_name == "greaterOrEquals" || operator_name == "greater")
{
aggregate_function_name = (type == SubqueryFunctionType::ANY ? "min" : "max");
}
else if (operator_name == "lessOrEquals" || operator_name == "less")
{
aggregate_function_name = (type == SubqueryFunctionType::ANY ? "max" : "min");
}
else
return false;
/// subquery --> (SELECT aggregate_function(*) FROM subquery)
auto aggregate_function = makeASTFunction(aggregate_function_name, std::make_shared<ASTAsterisk>());
auto subquery_node = function->children[0]->children[1];
ASTPtr subquery_node = function->children[0]->children[1];
auto select_query = std::make_shared<ASTSelectQuery>();
auto tables_in_select = std::make_shared<ASTTablesInSelectQuery>();
auto tables_in_select_element = std::make_shared<ASTTablesInSelectQueryElement>();
auto table_expression = std::make_shared<ASTTableExpression>();
table_expression->subquery = subquery_node;
table_expression->children.push_back(subquery_node);
tables_in_select_element->table_expression = table_expression;
tables_in_select_element->children.push_back(table_expression);
tables_in_select->children.push_back(tables_in_select_element);
table_expression->subquery = std::move(subquery_node);
table_expression->children.push_back(table_expression->subquery);
auto tables_in_select_element = std::make_shared<ASTTablesInSelectQueryElement>();
tables_in_select_element->table_expression = std::move(table_expression);
tables_in_select_element->children.push_back(tables_in_select_element->table_expression);
auto tables_in_select = std::make_shared<ASTTablesInSelectQuery>();
tables_in_select->children.push_back(std::move(tables_in_select_element));
auto select_exp_list = std::make_shared<ASTExpressionList>();
select_exp_list->children.push_back(aggregate_function);
auto select_query = std::make_shared<ASTSelectQuery>();
select_query->children.push_back(select_exp_list);
select_query->children.push_back(tables_in_select);
select_query->setExpression(ASTSelectQuery::Expression::SELECT, std::move(select_exp_list));
select_query->setExpression(ASTSelectQuery::Expression::TABLES, std::move(tables_in_select));
select_query->setExpression(ASTSelectQuery::Expression::SELECT, select_exp_list);
select_query->setExpression(ASTSelectQuery::Expression::TABLES, tables_in_select);
auto select_with_union_query = std::make_shared<ASTSelectWithUnionQuery>();
auto list_of_selects = std::make_shared<ASTExpressionList>();
list_of_selects->children.push_back(select_query);
select_with_union_query->list_of_selects = list_of_selects;
select_with_union_query->list_of_selects = std::make_shared<ASTExpressionList>();
select_with_union_query->list_of_selects->children.push_back(std::move(select_query));
select_with_union_query->children.push_back(select_with_union_query->list_of_selects);
auto new_subquery = std::make_shared<ASTSubquery>();
new_subquery->children.push_back(select_with_union_query);
function->children[0]->children.pop_back();
function->children[0]->children.push_back(new_subquery);
ast->children[0]->children.back() = std::move(new_subquery);
if (operator_name == "greaterOrEquals" || operator_name == "greater")
{
aggregate_function->name = type == SubqueryFunctionType::ANY ? "min" : "max";
return true;
}
if (operator_name == "lessOrEquals" || operator_name == "less")
{
aggregate_function->name = type == SubqueryFunctionType::ANY ? "max" : "min";
return true;
}
// = ALL --> IN (SELECT singleValueOrNull(*) FROM subquery)
// != ANY --> NOT IN (SELECT singleValueOrNull(*) FROM subquery)
if (operator_name == "equals" || operator_name == "notEquals")
{
aggregate_function->name = "singleValueOrNull";
assert_cast<ASTFunction *>(function.get())->name = "in";
if (operator_name == "notEquals")
{
auto function_not = std::make_shared<ASTFunction>();
auto exp_list_not = std::make_shared<ASTExpressionList>();
exp_list_not->children.push_back(function);
function_not->name = "not";
function_not->children.push_back(exp_list_not);
function_not->arguments = exp_list_not;
function = function_not;
}
return true;
}
return false;
}
bool ParserComparisonExpression::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
@ -346,7 +338,7 @@ bool ParserComparisonExpression::parseImpl(Pos & pos, ASTPtr & node, Expected &
exp_list->children.push_back(node);
exp_list->children.push_back(elem);
if (subquery_function_type != SubqueryFunctionType::NONE && !modifyAST(function->name, function, subquery_function_type))
if (subquery_function_type != SubqueryFunctionType::NONE && !modifyAST(function, subquery_function_type))
return false;
pos.increaseDepth();

View File

@ -12,7 +12,7 @@ using Operator = ASTSelectIntersectExceptQuery::Operator;
public:
/// max_threads is used to limit the number of threads for result pipeline.
IntersectOrExceptStep(DataStreams input_streams_, Operator operators_, size_t max_threads_ = 0);
IntersectOrExceptStep(DataStreams input_streams_, Operator operator_, size_t max_threads_ = 0);
String getName() const override { return "IntersectOrExcept"; }

View File

@ -14,7 +14,7 @@ class IntersectOrExceptTransform : public IProcessor
using Operator = ASTSelectIntersectExceptQuery::Operator;
public:
IntersectOrExceptTransform(const Block & header_, Operator operators);
IntersectOrExceptTransform(const Block & header_, Operator operator_);
String getName() const override { return "IntersectOrExcept"; }

View File

@ -3,17 +3,29 @@ select 1 == any (select number from numbers(10));
1
select 1 == any (select number from numbers(2, 10));
0
select 1 != all (select 1 from numbers(10));
0
select 1 != all (select number from numbers(10));
0
select 1 == all (select 1 from numbers(10));
1
select 1 == all (select number from numbers(10));
0
select 1 != any (select 1 from numbers(10));
0
select 1 != any (select number from numbers(10));
1
select number as a from numbers(10) where a == any (select number from numbers(3, 3));
3
4
5
-- TODO: Incorrect:
select 1 != any (select 1 from numbers(10));
select number as a from numbers(10) where a != any (select 5 from numbers(3, 3));
0
1
select 1 != all (select 1 from numbers(10));
1
select number as a from numbers(10) where a != any (select number from numbers(3, 3));
2
3
4
6
7
8
9

View File

@ -1,12 +1,17 @@
-- { echo }
select 1 == any (select number from numbers(10));
select 1 == any (select number from numbers(2, 10));
select 1 != all (select 1 from numbers(10));
select 1 != all (select number from numbers(10));
select 1 == all (select 1 from numbers(10));
select 1 == all (select number from numbers(10));
select number as a from numbers(10) where a == any (select number from numbers(3, 3));
-- TODO: Incorrect:
select 1 != any (select 1 from numbers(10));
select 1 != all (select 1 from numbers(10));
select number as a from numbers(10) where a != any (select number from numbers(3, 3));
select 1 != any (select number from numbers(10));
select number as a from numbers(10) where a == any (select number from numbers(3, 3));
select number as a from numbers(10) where a != any (select 5 from numbers(3, 3));