From 2523053b27746517ab2e2093249a8d944f322783 Mon Sep 17 00:00:00 2001 From: avogar Date: Mon, 4 Oct 2021 18:22:06 +0300 Subject: [PATCH] Remove code duplication, fix bug, add more tests --- src/Functions/FunctionBinaryArithmetic.h | 97 ++++++++----------- .../02015_division_by_nullable.reference | 84 ++++++++++------ .../02015_division_by_nullable.sql | 44 ++++++++- tests/queries/0_stateless/tmp | 4 + 4 files changed, 142 insertions(+), 87 deletions(-) create mode 100644 tests/queries/0_stateless/tmp diff --git a/src/Functions/FunctionBinaryArithmetic.h b/src/Functions/FunctionBinaryArithmetic.h index 7d63b4974c8..2d87e6a6819 100644 --- a/src/Functions/FunctionBinaryArithmetic.h +++ b/src/Functions/FunctionBinaryArithmetic.h @@ -451,62 +451,14 @@ public: } else if constexpr (is_division && is_decimal_b) { - if (right_nullmap) + processWithRightNullmapImpl(a, b, c, size, right_nullmap, [&scale_a](const auto & left, const auto & right) { - if constexpr (op_case == OpCase::RightConstant) - { - if ((*right_nullmap)[0]) - return; - - for (size_t i = 0; i < size; ++i) - c[i] = applyScaledDiv( - undec(a[i]), undec(b), scale_a); - } - else - { - for (size_t i = 0; i < size; ++i) - { - if (!(*right_nullmap)[i]) - c[i] = applyScaledDiv( - unwrap(a, i), undec(b[i]), scale_a); - } - } - } - else - for (size_t i = 0; i < size; ++i) - c[i] = applyScaledDiv( - unwrap(a, i), unwrap(b, i), scale_a); + return applyScaledDiv(left, right, scale_a); + }); return; } - if (right_nullmap) - { - if constexpr (op_case == OpCase::RightConstant) - { - if ((*right_nullmap)[0]) - return; - - for (size_t i = 0; i < size; ++i) - c[i] = apply(undec(a[i]), undec(b)); - } - else - { - for (size_t i = 0; i < size; ++i) - { - if (!(*right_nullmap)[i]) - c[i] = apply(unwrap(a, i), undec(b[i])); - } - } - - for (size_t i = 0; i < size; ++i) - { - if (!(*right_nullmap)[i]) - c[i] = apply(unwrap(a, i), unwrap(b, i)); - } - } - else - for (size_t i = 0; i < size; ++i) - c[i] = apply(unwrap(a, i), unwrap(b, i)); + processWithRightNullmapImpl(a, b, c, size, right_nullmap, [](const auto & left, const auto & right){ return apply(left, right); }); } template @@ -527,6 +479,33 @@ public: } private: + template + static inline void processWithRightNullmapImpl(const auto & a, const auto & b, ResultContainerType & c, size_t size, const NullMap * right_nullmap, ApplyFunc apply_func) + { + if (right_nullmap) + { + if constexpr (op_case == OpCase::RightConstant) + { + if ((*right_nullmap)[0]) + return; + + for (size_t i = 0; i < size; ++i) + c[i] = apply_func(undec(a[i]), undec(b)); + } + else + { + for (size_t i = 0; i < size; ++i) + { + if (!(*right_nullmap)[i]) + c[i] = apply_func(unwrap(a, i), undec(b[i])); + } + } + } + else + for (size_t i = 0; i < size; ++i) + c[i] = apply_func(unwrap(a, i), unwrap(b, i)); + } + static constexpr bool is_plus_minus = IsOperation::plus || IsOperation::minus; static constexpr bool is_multiply = IsOperation::multiply; @@ -1573,26 +1552,28 @@ public: } /// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval. - if (auto function_builder - = getFunctionForIntervalArithmetic(arguments[0].type, arguments[1].type, context)) + if (auto function_builder = getFunctionForIntervalArithmetic(arguments[0].type, arguments[1].type, context)) { return executeDateTimeIntervalPlusMinus(arguments, result_type, input_rows_count, function_builder); } /// Special case when the function is plus, minus or multiply, both arguments are tuples. - if (auto function_builder - = getFunctionForTupleArithmetic(arguments[0].type, arguments[1].type, context)) + if (auto function_builder = getFunctionForTupleArithmetic(arguments[0].type, arguments[1].type, context)) { return function_builder->build(arguments)->execute(arguments, result_type, input_rows_count); } /// Special case when the function is multiply or divide, one of arguments is Tuple and another is Number. - if (auto function_builder - = getFunctionForTupleAndNumberArithmetic(arguments[0].type, arguments[1].type, context)) + if (auto function_builder = getFunctionForTupleAndNumberArithmetic(arguments[0].type, arguments[1].type, context)) { return executeTupleNumberOperator(arguments, result_type, input_rows_count, function_builder); } + return executeImpl2(arguments, result_type, input_rows_count); + } + + ColumnPtr executeImpl2(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count, const NullMap * right_nullmap = nullptr) const + { const auto & left_argument = arguments[0]; const auto & right_argument = arguments[1]; diff --git a/tests/queries/0_stateless/02015_division_by_nullable.reference b/tests/queries/0_stateless/02015_division_by_nullable.reference index 92a8ceb6dfc..d85e2a48b71 100644 --- a/tests/queries/0_stateless/02015_division_by_nullable.reference +++ b/tests/queries/0_stateless/02015_division_by_nullable.reference @@ -2,51 +2,79 @@ \N \N \N -\N -\N -\N -\N -\N -\N -\N -\N -\N -\N -\N -\N -\N -\N -\N -\N -\N -\N -\N -\N -\N -\N -\N -\N 1 +1 +1 +1 +\N +\N +\N +\N +1 +1 +1 +1 +\N +\N +\N +\N +1 +1 +1 +1 +\N +\N +\N +\N +1 +1 +1 +1 +\N +\N +\N +\N 0 0 0 0 +\N +\N +\N +\N +1 +1 +1 +1 +\N +\N +\N +\N +0 +0 +0 0 1 +\N 0 0 +\N +0 1 -0 +\N 0 1 -0 +\N 0 1 -0 +\N 0 1 -0 +\N 0 1 +\N 0 +1 +\N 0 diff --git a/tests/queries/0_stateless/02015_division_by_nullable.sql b/tests/queries/0_stateless/02015_division_by_nullable.sql index a3cf27f81c2..16a01061070 100644 --- a/tests/queries/0_stateless/02015_division_by_nullable.sql +++ b/tests/queries/0_stateless/02015_division_by_nullable.sql @@ -4,44 +4,86 @@ SELECT 1 / CAST(materialize(NULL), 'Nullable(Decimal(7, 2))'); SELECT materialize(1) / CAST(materialize(NULL), 'Nullable(Decimal(7, 2))'); +SELECT 1 / CAST(1, 'Nullable(Decimal(7, 2))'); +SELECT materialize(1) / CAST(1, 'Nullable(Decimal(7, 2))'); +SELECT 1 / CAST(materialize(1), 'Nullable(Decimal(7, 2))'); +SELECT materialize(1) / CAST(materialize(1), 'Nullable(Decimal(7, 2))'); + + SELECT intDiv(1, CAST(NULL, 'Nullable(Decimal(7, 2))')); SELECT intDiv(materialize(1), CAST(NULL, 'Nullable(Decimal(7, 2))')); SELECT intDiv(1, CAST(materialize(NULL), 'Nullable(Decimal(7, 2))')); SELECT intDiv(materialize(1), CAST(materialize(NULL), 'Nullable(Decimal(7, 2))')); +SELECT intDiv(1, CAST(1, 'Nullable(Decimal(7, 2))')); +SELECT intDiv(materialize(1), CAST(1, 'Nullable(Decimal(7, 2))')); +SELECT intDiv(1, CAST(materialize(1), 'Nullable(Decimal(7, 2))')); +SELECT intDiv(materialize(1), CAST(materialize(1), 'Nullable(Decimal(7, 2))')); + + SELECT toDecimal32(1, 2) / CAST(NULL, 'Nullable(UInt32)'); SELECT materialize(toDecimal32(1, 2)) / CAST(NULL, 'Nullable(UInt32)'); SELECT toDecimal32(1, 2) / CAST(materialize(NULL), 'Nullable(UInt32)'); SELECT materialize(toDecimal32(1, 2)) / CAST(materialize(NULL), 'Nullable(UInt32)'); +SELECT toDecimal32(1, 2) / CAST(1, 'Nullable(UInt32)'); +SELECT materialize(toDecimal32(1, 2)) / CAST(1, 'Nullable(UInt32)'); +SELECT toDecimal32(1, 2) / CAST(materialize(1), 'Nullable(UInt32)'); +SELECT materialize(toDecimal32(1, 2)) / CAST(materialize(1), 'Nullable(UInt32)'); + + SELECT intDiv(1, CAST(NULL, 'Nullable(UInt32)')); SELECT intDiv(materialize(1), CAST(NULL, 'Nullable(UInt32)')); SELECT intDiv(1, CAST(materialize(NULL), 'Nullable(UInt32)')); SELECT intDiv(materialize(1), CAST(materialize(NULL), 'Nullable(UInt32)')); +SELECT intDiv(1, CAST(1, 'Nullable(UInt32)')); +SELECT intDiv(materialize(1), CAST(1, 'Nullable(UInt32)')); +SELECT intDiv(1, CAST(materialize(1), 'Nullable(UInt32)')); +SELECT intDiv(materialize(1), CAST(materialize(1), 'Nullable(UInt32)')); + + SELECT 1 % CAST(NULL, 'Nullable(UInt32)'); SELECT materialize(1) % CAST(NULL, 'Nullable(UInt32)'); SELECT 1 % CAST(materialize(NULL), 'Nullable(UInt32)'); SELECT materialize(1) % CAST(materialize(NULL), 'Nullable(UInt32)'); +SELECT 1 % CAST(1, 'Nullable(UInt32)'); +SELECT materialize(1) % CAST(1, 'Nullable(UInt32)'); +SELECT 1 % CAST(materialize(1), 'Nullable(UInt32)'); +SELECT materialize(1) % CAST(materialize(1), 'Nullable(UInt32)'); + + SELECT intDiv(1, CAST(NULL, 'Nullable(Float32)')); SELECT intDiv(materialize(1), CAST(NULL, 'Nullable(Float32)')); SELECT intDiv(1, CAST(materialize(NULL), 'Nullable(Float32)')); SELECT intDiv(materialize(1), CAST(materialize(NULL), 'Nullable(Float32)')); +SELECT intDiv(1, CAST(1, 'Nullable(Float32)')); +SELECT intDiv(materialize(1), CAST(1, 'Nullable(Float32)')); +SELECT intDiv(1, CAST(materialize(1), 'Nullable(Float32)')); +SELECT intDiv(materialize(1), CAST(materialize(1), 'Nullable(Float32)')); + + SELECT 1 % CAST(NULL, 'Nullable(Float32)'); SELECT materialize(1) % CAST(NULL, 'Nullable(Float32)'); SELECT 1 % CAST(materialize(NULL), 'Nullable(Float32)'); SELECT materialize(1) % CAST(materialize(NULL), 'Nullable(Float32)'); +SELECT 1 % CAST(1, 'Nullable(Float32)'); +SELECT materialize(1) % CAST(1, 'Nullable(Float32)'); +SELECT 1 % CAST(materialize(1), 'Nullable(Float32)'); +SELECT materialize(1) % CAST(materialize(1), 'Nullable(Float32)'); + + DROP TABLE IF EXISTS nullable_division; -CREATE TABLE nullable_division (x UInt32, y UInt32, a Decimal(7, 2), b Decimal(7, 2)) ENGINE=MergeTree() order by x; +CREATE TABLE nullable_division (x UInt32, y Nullable(UInt32), a Decimal(7, 2), b Nullable(Decimal(7, 2))) ENGINE=MergeTree() order by x; INSERT INTO nullable_division VALUES (1, 1, 1, 1), (1, NULL, 1, NULL), (1, 0, 1, 0); SELECT if(y = 0, 0, intDiv(x, y)) from nullable_division; diff --git a/tests/queries/0_stateless/tmp b/tests/queries/0_stateless/tmp new file mode 100644 index 00000000000..03e803e575b --- /dev/null +++ b/tests/queries/0_stateless/tmp @@ -0,0 +1,4 @@ +-1 1 -1000 1000 -10000000 1000000 -1000000000 1000000000 123.123 123123123.12312312 Some string fixed Some data 2000-01-06 2000-06-01 19:42:42 2000-04-01 11:21:33.123 +1 (2,(3,4)) (((5))) +1 [1,2,3] [[[1,2,3],[4,5,6]],[[7,8,9],[]],[]] +1 ((2,[[3,4],[5,6],[]]),[([[(7,8),(9,10)],[(11,12),(13,14)],[]],[([15,16,17]),([])])])