IN function result

Allowing function result to be used on the right side of IN statement.
All functions except aggregate ones, should work just Ok.
Fixed printing AST of operator IN.
This commit is contained in:
Vasily Nemkov 2019-05-19 22:22:51 +03:00
parent 62e18d89ea
commit ce70e4f238
6 changed files with 116 additions and 46 deletions

View File

@ -82,21 +82,7 @@ SetPtr makeExplicitSet(
if (prepared_sets.count(set_key))
return prepared_sets.at(set_key); /// Already prepared.
auto getTupleTypeFromAst = [&context](const ASTPtr & tuple_ast) -> DataTypePtr
{
const auto * func = tuple_ast->as<ASTFunction>();
if (func && func->name == "tuple" && !func->arguments->children.empty())
{
/// Won't parse all values of outer tuple.
auto element = func->arguments->children.at(0);
std::pair<Field, DataTypePtr> value_raw = evaluateConstantExpression(element, context);
return std::make_shared<DataTypeTuple>(DataTypes({value_raw.second}));
}
return evaluateConstantExpression(tuple_ast, context).second;
};
const DataTypePtr & right_arg_type = getTupleTypeFromAst(right_arg);
const auto right_arg_evaluated = evaluateConstantExpression(right_arg, context);
std::function<size_t(const DataTypePtr &)> getTupleDepth;
getTupleDepth = [&getTupleDepth](const DataTypePtr & type) -> size_t
@ -107,37 +93,77 @@ SetPtr makeExplicitSet(
return 0;
};
size_t left_tuple_depth = getTupleDepth(left_arg_type);
size_t right_tuple_depth = getTupleDepth(right_arg_type);
const auto& right_arg_type = right_arg_evaluated.second;
const auto& right_arg_value = right_arg_evaluated.first;
ASTPtr elements_ast = nullptr;
/// 1 in 1; (1, 2) in (1, 2); identity(tuple(tuple(tuple(1)))) in tuple(tuple(tuple(1))); etc.
if (left_tuple_depth == right_tuple_depth)
const size_t left_tuple_depth = getTupleDepth(left_arg_type);
const size_t right_tuple_depth = getTupleDepth(right_arg_type);
if (left_tuple_depth != right_tuple_depth && left_tuple_depth + 1 != right_tuple_depth)
{
ASTPtr exp_list = std::make_shared<ASTExpressionList>();
exp_list->children.push_back(right_arg);
elements_ast = exp_list;
}
/// 1 in (1, 2); (1, 2) in ((1, 2), (3, 4)); etc.
else if (left_tuple_depth + 1 == right_tuple_depth)
{
const auto * set_func = right_arg->as<ASTFunction>();
if (!set_func || set_func->name != "tuple")
throw Exception("Incorrect type of 2nd argument for function " + node->name
+ ". Must be subquery or set of elements with type " + left_arg_type->getName() + ".",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
elements_ast = set_func->arguments;
}
else
throw Exception("Invalid types for IN function: "
+ left_arg_type->getName() + " and " + right_arg_type->getName() + ".",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
Block block;
auto col = left_arg_type->createColumn();
switch (right_arg_type->getTypeId())
{
case TypeIndex::UInt8: [[fallthrough]];
case TypeIndex::UInt16: [[fallthrough]];
case TypeIndex::UInt32: [[fallthrough]];
case TypeIndex::UInt64: [[fallthrough]];
case TypeIndex::UInt128: [[fallthrough]];
case TypeIndex::Int8: [[fallthrough]];
case TypeIndex::Int16: [[fallthrough]];
case TypeIndex::Int32: [[fallthrough]];
case TypeIndex::Int64: [[fallthrough]];
case TypeIndex::Int128: [[fallthrough]];
case TypeIndex::Float32: [[fallthrough]];
case TypeIndex::Float64: [[fallthrough]];
case TypeIndex::Date: [[fallthrough]];
case TypeIndex::DateTime: [[fallthrough]];
case TypeIndex::String: [[fallthrough]];
case TypeIndex::FixedString: [[fallthrough]];
case TypeIndex::Enum8: [[fallthrough]];
case TypeIndex::Enum16: [[fallthrough]];
case TypeIndex::Decimal32: [[fallthrough]];
case TypeIndex::Decimal64: [[fallthrough]];
case TypeIndex::Decimal128: [[fallthrough]];
case TypeIndex::UUID:
{
col->insert(convertFieldToType(right_arg_value, *left_arg_type, right_arg_type.get()));
break;
}
// flatten compound values:
case TypeIndex::Array: [[fallthrough]];
case TypeIndex::Tuple: [[fallthrough]];
case TypeIndex::Set:
{
const Array & array = DB::get<const Array &>(right_arg_value);
if (array.size() == 0)
break;
for (size_t i = 0 ; i < array.size(); ++i)
{
col->insert(convertFieldToType(array[i], *left_arg_type, nullptr));
}
break;
}
default:
throw Exception("Unsupported value type at the right-side of IN:"
+ right_arg_type->getName() + ".",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
block.insert(ColumnWithTypeAndName{std::move(col),
left_arg_type,
"dummy_" + left_arg_type->getName()});
SetPtr set = std::make_shared<Set>(size_limits, create_ordered_set);
set->createFromAST(set_element_types, elements_ast, context);
set->setHeader(block);
set->insertFromBlock(block);
prepared_sets[set_key] = set;
return set;
}

View File

@ -187,8 +187,10 @@ void ASTFunction::formatImplWithoutAlias(const FormatSettings & settings, Format
const auto * second_arg_func = arguments->children[1]->as<ASTFunction>();
const auto * second_arg_literal = arguments->children[1]->as<ASTLiteral>();
bool extra_parents_around_in_rhs = (name == "in" || name == "notIn" || name == "globalIn" || name == "globalNotIn")
&& !(second_arg_func && second_arg_func->name == "tuple")
&& !(second_arg_literal && second_arg_literal->value.getType() == Field::Types::Tuple)
&& !second_arg_func
&& !(second_arg_literal
&& (second_arg_literal->value.getType() == Field::Types::Tuple
|| second_arg_literal->value.getType() == Field::Types::Array))
&& !arguments->children[1]->as<ASTSubquery>();
if (extra_parents_around_in_rhs)

View File

@ -15,6 +15,7 @@ class ASTLiteral : public ASTWithAlias
public:
Field value;
ASTLiteral(Field && value_) : value(value_) {}
ASTLiteral(const Field & value_) : value(value_) {}
/** Get the text that identifies this element. */

View File

@ -0,0 +1,7 @@
5
2
5
empty:
0
0
errors:

View File

@ -0,0 +1,34 @@
SET force_primary_key = 1;
DROP TABLE IF EXISTS samples;
CREATE TABLE samples (key UInt32, value UInt32) ENGINE = MergeTree() ORDER BY key PRIMARY KEY key;
INSERT INTO samples VALUES (1, 1)(2, 2)(3, 3)(4, 4)(5, 5);
-- all etries, verify that index is used
SELECT count() FROM samples WHERE key IN range(10);
-- some entries:
SELECT count() FROM samples WHERE key IN arraySlice(range(100), 5, 10);
-- different type
SELECT count() FROM samples WHERE toUInt64(key) IN range(100);
SELECT 'empty:';
-- should be empty
SELECT count() FROM samples WHERE key IN arraySlice(range(100), 10, 10);
-- not only ints:
SELECT 'a' IN splitByChar('c', 'abcdef');
SELECT 'errors:';
-- non-constant expressions in the right side of IN
SELECT count() FROM samples WHERE 1 IN range(samples.value); -- { serverError 47 }
SELECT count() FROM samples WHERE 1 IN range(rand()); -- { serverError 36 }
-- index is not used
SELECT count() FROM samples WHERE value IN range(3); -- { serverError 277 }
-- wrong type
SELECT 123 IN splitByChar('c', 'abcdef'); -- { serverError 53 }
DROP TABLE samples;

View File

@ -1,13 +1,13 @@
SELECT 1 IN (1)
SELECT 1 IN (1)
SELECT 1 IN (1, 2)
SELECT 1 IN (f(1))
SELECT 1 IN (f(1))
SELECT 1 IN f(1)
SELECT 1 IN f(1)
SELECT 1 IN (f(1), f(2))
SELECT 1 IN (f(1, 2))
SELECT 1 IN f(1, 2)
SELECT 1 IN (1 + 1)
SELECT 1 IN ('hello')
SELECT 1 IN (f('hello'))
SELECT 1 IN f('hello')
SELECT 1 IN ('hello', 'world')
SELECT 1 IN (f('hello', 'world'))
SELECT 1 IN f('hello', 'world')
SELECT 1 IN (SELECT 1)