diff --git a/dbms/src/Common/FieldVisitors.h b/dbms/src/Common/FieldVisitors.h index 727db50edd1..cb43f5fed7f 100644 --- a/dbms/src/Common/FieldVisitors.h +++ b/dbms/src/Common/FieldVisitors.h @@ -3,6 +3,7 @@ #include #include #include +#include class SipHash; @@ -135,22 +136,22 @@ class FieldVisitorConvertToNumber : public StaticVisitor public: T operator() (const Null &) const { - throw Exception("Cannot convert NULL to " + String(TypeName::get()), ErrorCodes::CANNOT_CONVERT_TYPE); + throw Exception("Cannot convert NULL to " + demangle(typeid(T).name()), ErrorCodes::CANNOT_CONVERT_TYPE); } T operator() (const String &) const { - throw Exception("Cannot convert String to " + String(TypeName::get()), ErrorCodes::CANNOT_CONVERT_TYPE); + throw Exception("Cannot convert String to " + demangle(typeid(T).name()), ErrorCodes::CANNOT_CONVERT_TYPE); } T operator() (const Array &) const { - throw Exception("Cannot convert Array to " + String(TypeName::get()), ErrorCodes::CANNOT_CONVERT_TYPE); + throw Exception("Cannot convert Array to " + demangle(typeid(T).name()), ErrorCodes::CANNOT_CONVERT_TYPE); } T operator() (const Tuple &) const { - throw Exception("Cannot convert Tuple to " + String(TypeName::get()), ErrorCodes::CANNOT_CONVERT_TYPE); + throw Exception("Cannot convert Tuple to " + demangle(typeid(T).name()), ErrorCodes::CANNOT_CONVERT_TYPE); } T operator() (const UInt64 & x) const { return x; } diff --git a/dbms/src/Functions/FunctionsLogical.h b/dbms/src/Functions/FunctionsLogical.h index eab4c370de1..067ae067a4a 100644 --- a/dbms/src/Functions/FunctionsLogical.h +++ b/dbms/src/Functions/FunctionsLogical.h @@ -3,6 +3,8 @@ #include #include #include +#include +#include #include #include #include @@ -19,10 +21,17 @@ namespace ErrorCodes extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; } - -/** Return an UInt8 containing 0 or 1. +/** Behaviour in presence of NULLs: + * + * Functions AND, XOR, NOT use default implementation for NULLs: + * - if one of arguments is Nullable, they return Nullable result where NULLs are returned when at least one argument was NULL. + * + * But function OR is different. + * It always return non-Nullable result and NULL are equivalent to 0 (false). + * For example, 1 OR NULL returns 1, not NULL. */ + struct AndImpl { static inline bool isSaturable() @@ -30,15 +39,17 @@ struct AndImpl return true; } - static inline bool isSaturatedValue(UInt8 a) + static inline bool isSaturatedValue(bool a) { return !a; } - static inline UInt8 apply(UInt8 a, UInt8 b) + static inline bool apply(bool a, bool b) { return a && b; } + + static inline bool specialImplementationForNulls() { return false; } }; struct OrImpl @@ -48,15 +59,17 @@ struct OrImpl return true; } - static inline bool isSaturatedValue(UInt8 a) + static inline bool isSaturatedValue(bool a) { return a; } - static inline UInt8 apply(UInt8 a, UInt8 b) + static inline bool apply(bool a, bool b) { return a || b; } + + static inline bool specialImplementationForNulls() { return true; } }; struct XorImpl @@ -66,15 +79,17 @@ struct XorImpl return false; } - static inline bool isSaturatedValue(UInt8) + static inline bool isSaturatedValue(bool) { return false; } - static inline UInt8 apply(UInt8 a, UInt8 b) + static inline bool apply(bool a, bool b) { - return (!a) != (!b); + return a != b; } + + static inline bool specialImplementationForNulls() { return false; } }; template @@ -170,21 +185,23 @@ private: bool has_res = false; for (int i = static_cast(in.size()) - 1; i >= 0; --i) { - if (in[i]->isColumnConst()) - { - UInt8 x = !!in[i]->getUInt(0); - if (has_res) - { - res = Impl::apply(res, x); - } - else - { - res = x; - has_res = true; - } + if (!in[i]->isColumnConst()) + continue; - in.erase(in.begin() + i); + Field value = (*in[i])[0]; + + UInt8 x = !value.isNull() && applyVisitor(FieldVisitorConvertToNumber(), value); + if (has_res) + { + res = Impl::apply(res, x); } + else + { + res = x; + has_res = true; + } + + in.erase(in.begin() + i); } return has_res; } @@ -195,7 +212,7 @@ private: auto col = checkAndGetColumn>(column); if (!col) return false; - const typename ColumnVector::Container & vec = col->getData(); + const auto & vec = col->getData(); size_t n = res.size(); for (size_t i = 0; i < n; ++i) res[i] = !!vec[i]; @@ -203,6 +220,25 @@ private: return true; } + template + bool convertNullableTypeToUInt8(const IColumn * column, UInt8Container & res) + { + auto col_nullable = checkAndGetColumn(column); + + auto col = checkAndGetColumn>(&col_nullable->getNestedColumn()); + if (!col) + return false; + + const auto & vec = col->getData(); + const auto & null_map = col_nullable->getNullMapData(); + + size_t n = res.size(); + for (size_t i = 0; i < n; ++i) + res[i] = !!vec[i] && !null_map[i]; + + return true; + } + void convertToUInt8(const IColumn * column, UInt8Container & res) { if (!convertTypeToUInt8(column, res) && @@ -213,7 +249,17 @@ private: !convertTypeToUInt8(column, res) && !convertTypeToUInt8(column, res) && !convertTypeToUInt8(column, res) && - !convertTypeToUInt8(column, res)) + !convertTypeToUInt8(column, res) && + !convertNullableTypeToUInt8(column, res) && + !convertNullableTypeToUInt8(column, res) && + !convertNullableTypeToUInt8(column, res) && + !convertNullableTypeToUInt8(column, res) && + !convertNullableTypeToUInt8(column, res) && + !convertNullableTypeToUInt8(column, res) && + !convertNullableTypeToUInt8(column, res) && + !convertNullableTypeToUInt8(column, res) && + !convertNullableTypeToUInt8(column, res) && + !convertNullableTypeToUInt8(column, res)) throw Exception("Unexpected type of column: " + column->getName(), ErrorCodes::ILLEGAL_COLUMN); } @@ -226,6 +272,8 @@ public: bool isVariadic() const override { return true; } size_t getNumberOfArguments() const override { return 0; } + bool useDefaultImplementationForNulls() const override { return !Impl::specialImplementationForNulls(); } + /// Get result types by argument types. If the function does not apply to these arguments, throw an exception. DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { @@ -235,7 +283,8 @@ public: ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); for (size_t i = 0; i < arguments.size(); ++i) - if (!arguments[i]->isNumber()) + if (!(arguments[i]->isNumber() + || (Impl::specialImplementationForNulls() && (arguments[i]->onlyNull() || removeNullable(arguments[i])->isNumber())))) throw Exception("Illegal type (" + arguments[i]->getName() + ") of " + toString(i + 1) + " argument of function " + getName(), diff --git a/dbms/tests/queries/0_stateless/00552_or_nullable.reference b/dbms/tests/queries/0_stateless/00552_or_nullable.reference new file mode 100644 index 00000000000..0dbd06680cc --- /dev/null +++ b/dbms/tests/queries/0_stateless/00552_or_nullable.reference @@ -0,0 +1,28 @@ +0 1 0 1 0 1 1 1 +\N 0 1 0 0 1 0 1 1 1 +1 1 1 1 1 1 1 1 1 1 +\N 0 1 0 0 1 0 1 1 1 +0 0 1 0 0 1 0 1 1 1 +\N 0 1 0 0 1 0 1 1 1 +2 1 1 1 1 1 1 1 1 1 +\N 0 1 0 0 1 0 1 1 1 +1 1 1 1 1 1 1 1 1 1 +\N 0 1 0 0 1 0 1 1 1 +0 0 1 0 0 1 0 1 1 1 +\N \N \N \N \N \N \N \N \N \N +1 0 1 1 0 1 0 1 1 0 +\N \N \N \N \N \N \N \N \N \N +0 0 0 0 0 0 0 0 0 0 +\N \N \N \N \N \N \N \N \N \N +2 0 2 2 0 2 0 2 2 0 +\N \N \N \N \N \N \N \N \N \N +1 0 1 1 0 1 0 1 1 0 +\N \N \N \N \N \N \N \N \N \N +0 0 0 0 0 0 0 0 0 0 +1 +0 +\N +1 +1 +\N +0 diff --git a/dbms/tests/queries/0_stateless/00552_or_nullable.sql b/dbms/tests/queries/0_stateless/00552_or_nullable.sql new file mode 100644 index 00000000000..1bc1fb3bbc0 --- /dev/null +++ b/dbms/tests/queries/0_stateless/00552_or_nullable.sql @@ -0,0 +1,51 @@ +SELECT + 0 OR NULL, + 1 OR NULL, + toNullable(0) OR NULL, + toNullable(1) OR NULL, + 0.0 OR NULL, + 0.1 OR NULL, + NULL OR 1 OR NULL, + 0 OR NULL OR 1 OR NULL; + +SELECT + x, + 0 OR x, + 1 OR x, + x OR x, + toNullable(0) OR x, + toNullable(1) OR x, + 0.0 OR x, + 0.1 OR x, + x OR 1 OR x, + 0 OR x OR 1 OR x +FROM (SELECT number % 2 ? number % 3 : NULL AS x FROM system.numbers LIMIT 10); + +SELECT + x, + 0 AND x, + 1 AND x, + x AND x, + toNullable(0) AND x, + toNullable(1) AND x, + 0.0 AND x, + 0.1 AND x, + x AND 1 AND x, + 0 AND x AND 1 AND x +FROM (SELECT number % 2 ? number % 3 : NULL AS x FROM system.numbers LIMIT 10); + +DROP TABLE IF EXISTS test.test; + +CREATE TABLE test.test +( + x Nullable(Int32) +) ENGINE = Log; + +INSERT INTO test.test VALUES(1), (0), (null); + +SELECT * FROM test.test; +SELECT x FROM test.test WHERE x != 0; +SELECT x FROM test.test WHERE x != 0 OR isNull(x); +SELECT x FROM test.test WHERE x != 1; + +DROP TABLE test.test;