From 82d17870f15df704821d6d13bd69e8ff69eff9af Mon Sep 17 00:00:00 2001 From: Pavel Kruglov Date: Mon, 30 Aug 2021 15:11:05 +0300 Subject: [PATCH] Avoid division by zero when denominator is Nullable --- src/Functions/FunctionBinaryArithmetic.h | 190 ++++++++++++++---- src/Functions/intDiv.cpp | 37 +++- src/Functions/modulo.cpp | 36 +++- .../02015_division_by_nullable.reference | 52 +++++ .../02015_division_by_nullable.sql | 59 ++++++ 5 files changed, 316 insertions(+), 58 deletions(-) create mode 100644 tests/queries/0_stateless/02015_division_by_nullable.reference create mode 100644 tests/queries/0_stateless/02015_division_by_nullable.sql diff --git a/src/Functions/FunctionBinaryArithmetic.h b/src/Functions/FunctionBinaryArithmetic.h index f3f1956c01c..1ee2ca029a4 100644 --- a/src/Functions/FunctionBinaryArithmetic.h +++ b/src/Functions/FunctionBinaryArithmetic.h @@ -24,6 +24,7 @@ #include #include #include +#include #include "Core/DecimalFunctions.h" #include "IFunction.h" #include "FunctionHelpers.h" @@ -51,7 +52,6 @@ #include - namespace DB { @@ -197,18 +197,41 @@ struct BinaryOperation static const constexpr bool allow_string_integer = false; template - static void NO_INLINE process(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t size) + static void NO_INLINE process(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t size, const NullMap * right_nullmap = nullptr) { - for (size_t i = 0; i < size; ++i) - if constexpr (op_case == OpCase::Vector) - c[i] = Op::template apply(a[i], b[i]); - else if constexpr (op_case == OpCase::LeftConstant) - c[i] = Op::template apply(*a, b[i]); - else + if constexpr (op_case == OpCase::RightConstant) + { + if (right_nullmap && (*right_nullmap)[0]) + return; + + for (size_t i = 0; i < size; ++i) c[i] = Op::template apply(a[i], *b); + } + else + { + if (right_nullmap) + { + for (size_t i = 0; i < size; ++i) + if (!(*right_nullmap)[i]) + apply(a, b, c, i); + } + else + for (size_t i = 0; i < size; ++i) + apply(a, b, c, i); + } } static ResultType process(A a, B b) { return Op::template apply(a, b); } + +private: + template + static inline void apply(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t i) + { + if constexpr (op_case == OpCase::Vector) + c[i] = Op::template apply(a[i], b[i]); + else + c[i] = Op::template apply(*a, b[i]); + } }; template @@ -371,7 +394,7 @@ private: public: template static void NO_INLINE process(const auto & a, const auto & b, ResultContainerType & c, - NativeResultType scale_a, NativeResultType scale_b) + NativeResultType scale_a, NativeResultType scale_b, const NullMap * right_nullmap = nullptr) { if constexpr (op_case == OpCase::LeftConstant) static_assert(!is_decimal); if constexpr (op_case == OpCase::RightConstant) static_assert(!is_decimal); @@ -428,18 +451,62 @@ public: } else if constexpr (is_division && is_decimal_b) { - for (size_t i = 0; i < size; ++i) - c[i] = applyScaledDiv( - unwrap(a, i), - unwrap(b, i), - scale_a); + if (right_nullmap) + { + if constexpr (op_case == OpCase::RightConstant) + { + if ((*right_nullmap)[0]) + return; + + for (size_t i = 0; i < size; ++i) + c[i] = applyScaledDiv( + undec(a[i]), undec(b), scale_a); + } + else + { + for (size_t i = 0; i < size; ++i) + { + if (!(*right_nullmap)[i]) + c[i] = applyScaledDiv( + unwrap(a, i), undec(b[i]), scale_a); + } + } + } + else + for (size_t i = 0; i < size; ++i) + c[i] = applyScaledDiv( + unwrap(a, i), unwrap(b, i), scale_a); return; } - for (size_t i = 0; i < size; ++i) - c[i] = apply( - unwrap(a, i), - unwrap(b, i)); + if (right_nullmap) + { + if constexpr (op_case == OpCase::RightConstant) + { + if ((*right_nullmap)[0]) + return; + + for (size_t i = 0; i < size; ++i) + c[i] = apply(undec(a[i]), undec(b)); + } + else + { + for (size_t i = 0; i < size; ++i) + { + if (!(*right_nullmap)[i]) + c[i] = apply(unwrap(a, i), undec(b[i])); + } + } + + for (size_t i = 0; i < size; ++i) + { + if (!(*right_nullmap)[i]) + c[i] = apply(unwrap(a, i), unwrap(b, i)); + } + } + else + for (size_t i = 0; i < size; ++i) + c[i] = apply(unwrap(a, i), unwrap(b, i)); } template @@ -564,7 +631,7 @@ private: using namespace traits_; using namespace impl_; -template