Merge pull request #12039 from ClickHouse/fix-nullable-tuple-compare

Fix nullable tuple compare
This commit is contained in:
alexey-milovidov 2020-06-30 01:38:46 +03:00 committed by GitHub
commit c1d2d2d7f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 248 additions and 17 deletions

View File

@ -12,6 +12,7 @@
#include <Columns/ColumnArray.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeDateTime64.h>
#include <DataTypes/DataTypeDate.h>
@ -931,6 +932,8 @@ private:
if (0 == tuple_size)
throw Exception("Comparison of zero-sized tuples is not implemented.", ErrorCodes::NOT_IMPLEMENTED);
ColumnsWithTypeAndName convolution_types(tuple_size);
Block tmp_block;
for (size_t i = 0; i < tuple_size; ++i)
{
@ -938,9 +941,10 @@ private:
tmp_block.insert(y[i]);
auto impl = func_compare->build({x[i], y[i]});
convolution_types[i].type = impl->getReturnType();
/// Comparison of the elements.
tmp_block.insert({ nullptr, std::make_shared<DataTypeUInt8>(), "" });
tmp_block.insert({ nullptr, impl->getReturnType(), "" });
impl->execute(tmp_block, {i * 3, i * 3 + 1}, i * 3 + 2, input_rows_count);
}
@ -952,14 +956,13 @@ private:
}
/// Logical convolution.
tmp_block.insert({ nullptr, std::make_shared<DataTypeUInt8>(), "" });
ColumnNumbers convolution_args(tuple_size);
for (size_t i = 0; i < tuple_size; ++i)
convolution_args[i] = i * 3 + 2;
ColumnsWithTypeAndName convolution_types(convolution_args.size(), { nullptr, std::make_shared<DataTypeUInt8>(), "" });
auto impl = func_convolution->build(convolution_types);
tmp_block.insert({ nullptr, impl->getReturnType(), "" });
impl->execute(tmp_block, convolution_args, tuple_size * 3, input_rows_count);
block.getByPosition(result).column = tmp_block.getByPosition(tuple_size * 3).column;
@ -978,49 +981,71 @@ private:
size_t tuple_size,
size_t input_rows_count)
{
ColumnsWithTypeAndName bin_args = {{ nullptr, std::make_shared<DataTypeUInt8>(), "" },
{ nullptr, std::make_shared<DataTypeUInt8>(), "" }};
auto func_and_adaptor = func_and->build(bin_args);
auto func_or_adaptor = func_or->build(bin_args);
Block tmp_block;
/// Pairwise comparison of the inequality of all elements; on the equality of all elements except the last.
/// (x[i], y[i], x[i] < y[i], x[i] == y[i])
for (size_t i = 0; i < tuple_size; ++i)
{
tmp_block.insert(x[i]);
tmp_block.insert(y[i]);
tmp_block.insert({ nullptr, std::make_shared<DataTypeUInt8>(), "" });
tmp_block.insert(ColumnWithTypeAndName()); // pos == i * 4 + 2
if (i + 1 != tuple_size)
{
auto impl_head = func_compare_head->build({x[i], y[i]});
tmp_block.getByPosition(i * 4 + 2).type = impl_head->getReturnType();
impl_head->execute(tmp_block, {i * 4, i * 4 + 1}, i * 4 + 2, input_rows_count);
tmp_block.insert({ nullptr, std::make_shared<DataTypeUInt8>(), "" });
tmp_block.insert(ColumnWithTypeAndName()); // i * 4 + 3
auto impl_equals = func_equals->build({x[i], y[i]});
tmp_block.getByPosition(i * 4 + 3).type = impl_equals->getReturnType();
impl_equals->execute(tmp_block, {i * 4, i * 4 + 1}, i * 4 + 3, input_rows_count);
}
else
{
auto impl_tail = func_compare_tail->build({x[i], y[i]});
tmp_block.getByPosition(i * 4 + 2).type = impl_tail->getReturnType();
impl_tail->execute(tmp_block, {i * 4, i * 4 + 1}, i * 4 + 2, input_rows_count);
}
}
/// Combination. Complex code - make a drawing. It can be replaced by a recursive comparison of tuples.
/// Last column contains intermediate result.
/// Code is generally equivalent to:
/// res = `x < y`[tuple_size - 1];
/// for (int i = tuple_size - 2; i >= 0; --i)
/// res = (res && `x == y`[i]) || `x < y`[i];
size_t i = tuple_size - 1;
while (i > 0)
{
tmp_block.insert({ nullptr, std::make_shared<DataTypeUInt8>(), "" });
func_and_adaptor->execute(tmp_block, {tmp_block.columns() - 2, (i - 1) * 4 + 3}, tmp_block.columns() - 1, input_rows_count);
tmp_block.insert({ nullptr, std::make_shared<DataTypeUInt8>(), "" });
func_or_adaptor->execute(tmp_block, {tmp_block.columns() - 2, (i - 1) * 4 + 2}, tmp_block.columns() - 1, input_rows_count);
--i;
size_t and_lhs_pos = tmp_block.columns() - 1; // res
size_t and_rhs_pos = i * 4 + 3; // `x == y`[i]
tmp_block.insert(ColumnWithTypeAndName());
ColumnsWithTypeAndName and_args = {{ nullptr, tmp_block.getByPosition(and_lhs_pos).type, "" },
{ nullptr, tmp_block.getByPosition(and_rhs_pos).type, "" }};
auto func_and_adaptor = func_and->build(and_args);
tmp_block.getByPosition(tmp_block.columns() - 1).type = func_and_adaptor->getReturnType();
func_and_adaptor->execute(tmp_block, {and_lhs_pos, and_rhs_pos}, tmp_block.columns() - 1, input_rows_count);
size_t or_lhs_pos = tmp_block.columns() - 1; // (res && `x == y`[i])
size_t or_rhs_pos = i * 4 + 2; // `x < y`[i]
tmp_block.insert(ColumnWithTypeAndName());
ColumnsWithTypeAndName or_args = {{ nullptr, tmp_block.getByPosition(or_lhs_pos).type, "" },
{ nullptr, tmp_block.getByPosition(or_rhs_pos).type, "" }};
auto func_or_adaptor = func_or->build(or_args);
tmp_block.getByPosition(tmp_block.columns() - 1).type = func_or_adaptor->getReturnType();
func_or_adaptor->execute(tmp_block, {or_lhs_pos, or_rhs_pos}, tmp_block.columns() - 1, input_rows_count);
}
block.getByPosition(result).column = tmp_block.getByPosition(tmp_block.columns() - 1).column;
@ -1109,13 +1134,20 @@ public:
auto adaptor = FunctionOverloadResolverAdaptor(std::make_unique<DefaultOverloadResolver>(
FunctionComparison<Op, Name>::create(context)));
bool has_nullable = false;
size_t size = left_tuple->getElements().size();
for (size_t i = 0; i < size; ++i)
{
ColumnsWithTypeAndName args = {{nullptr, left_tuple->getElements()[i], ""},
{nullptr, right_tuple->getElements()[i], ""}};
adaptor.build(args);
has_nullable = has_nullable || adaptor.build(args)->getReturnType()->isNullable();
}
/// If any element comparison is nullable, return type will also be nullable.
/// We useDefaultImplementationForNulls, but it doesn't work for tuples.
if (has_nullable)
return std::make_shared<DataTypeNullable>(std::make_shared<DataTypeUInt8>());
}
return std::make_shared<DataTypeUInt8>();
@ -1135,7 +1167,7 @@ public:
/// NOTE: Nullable types are special case.
/// (BTW, this function use default implementation for Nullable, so Nullable types cannot be here. Check just in case.)
/// NOTE: We consider NaN comparison to be implementation specific (and in our implementation NaNs are sometimes equal sometimes not).
if (left_type->equals(*right_type) && !left_type->isNullable() && col_left_untyped == col_right_untyped)
if (left_type->equals(*right_type) && !left_type->isNullable() && !isTuple(left_type) && col_left_untyped == col_right_untyped)
{
/// Always true: =, <=, >=
if constexpr (std::is_same_v<Op<int, int>, EqualsOp<int, int>>

View File

@ -0,0 +1,92 @@
single argument
1
0
1
0
1
0
- 1
1
1
1
0
0
0
0
0
0
1
1
1
- 2
1
1
1
0
0
0
0
0
1
1
1
1
- 3
1
1
1
1
1
1
- 4
\N
\N
\N
\N
\N
\N
two arguments
1
1
1
1
1
1
- 1
0
0
0
0
0
0
- 2
1
1
1
1
1
1
- 3
\N
\N
\N
\N
\N
1
\N
\N
0
many arguments
1
1
0
0
1
0
1
\N
\N
\N
\N
\N
\N

View File

@ -0,0 +1,107 @@
select 'single argument';
select tuple(number) = tuple(number) from numbers(1);
select tuple(number) = tuple(number + 1) from numbers(1);
select tuple(toNullable(number)) = tuple(number) from numbers(1);
select tuple(toNullable(number)) = tuple(number + 1) from numbers(1);
select tuple(toNullable(number)) = tuple(toNullable(number)) from numbers(1);
select tuple(toNullable(number)) = tuple(toNullable(number + 1)) from numbers(1);
select '- 1';
select tuple(toNullable(number)) < tuple(number + 1) from numbers(1);
select tuple(number) < tuple(toNullable(number + 1)) from numbers(1);
select tuple(toNullable(number)) < tuple(toNullable(number + 1)) from numbers(1);
select tuple(toNullable(number)) > tuple(number + 1) from numbers(1);
select tuple(number) > tuple(toNullable(number + 1)) from numbers(1);
select tuple(toNullable(number)) > tuple(toNullable(number + 1)) from numbers(1);
select tuple(toNullable(number + 1)) < tuple(number) from numbers(1);
select tuple(number + 1) < tuple(toNullable(number)) from numbers(1);
select tuple(toNullable(number + 1)) < tuple(toNullable(number + 1)) from numbers(1);
select tuple(toNullable(number + 1)) > tuple(number) from numbers(1);
select tuple(number + 1) > tuple(toNullable(number)) from numbers(1);
select tuple(toNullable(number + 1)) > tuple(toNullable(number)) from numbers(1);
select '- 2';
select tuple(toNullable(number)) <= tuple(number + 1) from numbers(1);
select tuple(number) <= tuple(toNullable(number + 1)) from numbers(1);
select tuple(toNullable(number)) <= tuple(toNullable(number + 1)) from numbers(1);
select tuple(toNullable(number)) >= tuple(number + 1) from numbers(1);
select tuple(number) > tuple(toNullable(number + 1)) from numbers(1);
select tuple(toNullable(number)) >= tuple(toNullable(number + 1)) from numbers(1);
select tuple(toNullable(number + 1)) <= tuple(number) from numbers(1);
select tuple(number + 1) <= tuple(toNullable(number)) from numbers(1);
select tuple(toNullable(number + 1)) <= tuple(toNullable(number + 1)) from numbers(1);
select tuple(toNullable(number + 1)) >= tuple(number) from numbers(1);
select tuple(number + 1) >= tuple(toNullable(number)) from numbers(1);
select tuple(toNullable(number + 1)) >= tuple(toNullable(number)) from numbers(1);
select '- 3';
select tuple(toNullable(number)) <= tuple(number) from numbers(1);
select tuple(number) <= tuple(toNullable(number)) from numbers(1);
select tuple(toNullable(number)) <= tuple(toNullable(number)) from numbers(1);
select tuple(toNullable(number)) >= tuple(number) from numbers(1);
select tuple(number) >= tuple(toNullable(number)) from numbers(1);
select tuple(toNullable(number)) >= tuple(toNullable(number)) from numbers(1);
select '- 4';
select tuple(number) = tuple(materialize(toUInt64OrNull(''))) from numbers(1);
select tuple(materialize(toUInt64OrNull(''))) = tuple(materialize(toUInt64OrNull(''))) from numbers(1);
select tuple(number) <= tuple(materialize(toUInt64OrNull(''))) from numbers(1);
select tuple(materialize(toUInt64OrNull(''))) <= tuple(materialize(toUInt64OrNull(''))) from numbers(1);
select tuple(number) >= tuple(materialize(toUInt64OrNull(''))) from numbers(1);
select tuple(materialize(toUInt64OrNull(''))) >= tuple(materialize(toUInt64OrNull(''))) from numbers(1);
select 'two arguments';
select tuple(toNullable(number), number) = tuple(number, number) from numbers(1);
select tuple(toNullable(number), toNullable(number)) = tuple(number, number) from numbers(1);
select tuple(toNullable(number), toNullable(number)) = tuple(toNullable(number), number) from numbers(1);
select tuple(toNullable(number), toNullable(number)) = tuple(toNullable(number), toNullable(number)) from numbers(1);
select tuple(number, toNullable(number)) = tuple(toNullable(number), toNullable(number)) from numbers(1);
select tuple(number, toNullable(number)) = tuple(toNullable(number), number) from numbers(1);
select '- 1';
select tuple(toNullable(number), number) < tuple(number, number) from numbers(1);
select tuple(toNullable(number), toNullable(number)) < tuple(number, number) from numbers(1);
select tuple(toNullable(number), toNullable(number)) < tuple(toNullable(number), number) from numbers(1);
select tuple(toNullable(number), toNullable(number)) < tuple(toNullable(number), toNullable(number)) from numbers(1);
select tuple(number, toNullable(number)) < tuple(toNullable(number), toNullable(number)) from numbers(1);
select tuple(number, toNullable(number)) < tuple(toNullable(number), number) from numbers(1);
select '- 2';
select tuple(toNullable(number), number) < tuple(number, number + 1) from numbers(1);
select tuple(toNullable(number), toNullable(number)) < tuple(number, number + 1) from numbers(1);
select tuple(toNullable(number), toNullable(number)) < tuple(toNullable(number + 1), number) from numbers(1);
select tuple(toNullable(number), toNullable(number)) < tuple(toNullable(number + 1), toNullable(number)) from numbers(1);
select tuple(number, toNullable(number)) < tuple(toNullable(number), toNullable(number + 1)) from numbers(1);
select tuple(number, toNullable(number)) < tuple(toNullable(number), number + 1) from numbers(1);
select '- 3';
select tuple(materialize(toUInt64OrNull('')), number) = tuple(number, number) from numbers(1);
select tuple(materialize(toUInt64OrNull('')), number) = tuple(number, toUInt64OrNull('')) from numbers(1);
select tuple(materialize(toUInt64OrNull('')), toUInt64OrNull('')) = tuple(toUInt64OrNull(''), toUInt64OrNull('')) from numbers(1);
select tuple(number, materialize(toUInt64OrNull(''))) < tuple(number, number) from numbers(1);
select tuple(number, materialize(toUInt64OrNull(''))) <= tuple(number, number) from numbers(1);
select tuple(number, materialize(toUInt64OrNull(''))) < tuple(number + 1, number) from numbers(1);
select tuple(number, materialize(toUInt64OrNull(''))) > tuple(number, number) from numbers(1);
select tuple(number, materialize(toUInt64OrNull(''))) >= tuple(number, number) from numbers(1);
select tuple(number, materialize(toUInt64OrNull(''))) > tuple(number + 1, number) from numbers(1);
select 'many arguments';
select tuple(toNullable(number), number, number) = tuple(number, number, number) from numbers(1);
select tuple(toNullable(number), materialize('a'), number) = tuple(number, materialize('a'), number) from numbers(1);
select tuple(toNullable(number), materialize('a'), number) = tuple(number, materialize('a'), number + 1) from numbers(1);
select tuple(toNullable(number), number, number) < tuple(number, number, number) from numbers(1);
select tuple(toNullable(number), number, number) <= tuple(number, number, number) from numbers(1);
select tuple(toNullable(number), materialize('a'), number) < tuple(number, materialize('a'), number) from numbers(1);
select tuple(toNullable(number), materialize('a'), number) < tuple(number, materialize('a'), number + 1) from numbers(1);
select tuple(toNullable(number), number, materialize(toUInt64OrNull(''))) = tuple(number, number, number) from numbers(1);
select tuple(toNullable(number), materialize('a'), materialize(toUInt64OrNull(''))) = tuple(number, materialize('a'), number) from numbers(1);
select tuple(toNullable(number), materialize('a'), materialize(toUInt64OrNull(''))) = tuple(number, materialize('a'), number + 1) from numbers(1);
select tuple(toNullable(number), number, materialize(toUInt64OrNull(''))) <= tuple(number, number, number) from numbers(1);
select tuple(toNullable(number), materialize('a'), materialize(toUInt64OrNull(''))) <= tuple(number, materialize('a'), number) from numbers(1);
select tuple(toNullable(number), materialize('a'), materialize(toUInt64OrNull(''))) <= tuple(number, materialize('a'), number + 1) from numbers(1);