Avoid division by zero when denominator is Nullable

This commit is contained in:
Pavel Kruglov 2021-08-30 15:11:05 +03:00 committed by avogar
parent 362e84a336
commit 82d17870f1
5 changed files with 316 additions and 58 deletions

View File

@ -24,6 +24,7 @@
#include <Columns/ColumnString.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Columns/ColumnNullable.h>
#include "Core/DecimalFunctions.h"
#include "IFunction.h"
#include "FunctionHelpers.h"
@ -51,7 +52,6 @@
#include <cassert>
namespace DB
{
@ -197,18 +197,41 @@ struct BinaryOperation
static const constexpr bool allow_string_integer = false;
template <OpCase op_case>
static void NO_INLINE process(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t size)
static void NO_INLINE process(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t size, const NullMap * right_nullmap = nullptr)
{
for (size_t i = 0; i < size; ++i)
if constexpr (op_case == OpCase::Vector)
c[i] = Op::template apply<ResultType>(a[i], b[i]);
else if constexpr (op_case == OpCase::LeftConstant)
c[i] = Op::template apply<ResultType>(*a, b[i]);
else
if constexpr (op_case == OpCase::RightConstant)
{
if (right_nullmap && (*right_nullmap)[0])
return;
for (size_t i = 0; i < size; ++i)
c[i] = Op::template apply<ResultType>(a[i], *b);
}
else
{
if (right_nullmap)
{
for (size_t i = 0; i < size; ++i)
if (!(*right_nullmap)[i])
apply<op_case>(a, b, c, i);
}
else
for (size_t i = 0; i < size; ++i)
apply<op_case>(a, b, c, i);
}
}
static ResultType process(A a, B b) { return Op::template apply<ResultType>(a, b); }
private:
template <OpCase op_case>
static inline void apply(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t i)
{
if constexpr (op_case == OpCase::Vector)
c[i] = Op::template apply<ResultType>(a[i], b[i]);
else
c[i] = Op::template apply<ResultType>(*a, b[i]);
}
};
template <typename B, typename Op>
@ -371,7 +394,7 @@ private:
public:
template <OpCase op_case, bool is_decimal_a, bool is_decimal_b>
static void NO_INLINE process(const auto & a, const auto & b, ResultContainerType & c,
NativeResultType scale_a, NativeResultType scale_b)
NativeResultType scale_a, NativeResultType scale_b, const NullMap * right_nullmap = nullptr)
{
if constexpr (op_case == OpCase::LeftConstant) static_assert(!is_decimal<decltype(a)>);
if constexpr (op_case == OpCase::RightConstant) static_assert(!is_decimal<decltype(b)>);
@ -428,18 +451,62 @@ public:
}
else if constexpr (is_division && is_decimal_b)
{
for (size_t i = 0; i < size; ++i)
c[i] = applyScaledDiv<is_decimal_a>(
unwrap<op_case, OpCase::LeftConstant>(a, i),
unwrap<op_case, OpCase::RightConstant>(b, i),
scale_a);
if (right_nullmap)
{
if constexpr (op_case == OpCase::RightConstant)
{
if ((*right_nullmap)[0])
return;
for (size_t i = 0; i < size; ++i)
c[i] = applyScaledDiv<is_decimal_a>(
undec(a[i]), undec(b), scale_a);
}
else
{
for (size_t i = 0; i < size; ++i)
{
if (!(*right_nullmap)[i])
c[i] = applyScaledDiv<is_decimal_a>(
unwrap<op_case, OpCase::LeftConstant>(a, i), undec(b[i]), scale_a);
}
}
}
else
for (size_t i = 0; i < size; ++i)
c[i] = applyScaledDiv<is_decimal_a>(
unwrap<op_case, OpCase::LeftConstant>(a, i), unwrap<op_case, OpCase::RightConstant>(b, i), scale_a);
return;
}
for (size_t i = 0; i < size; ++i)
c[i] = apply(
unwrap<op_case, OpCase::LeftConstant>(a, i),
unwrap<op_case, OpCase::RightConstant>(b, i));
if (right_nullmap)
{
if constexpr (op_case == OpCase::RightConstant)
{
if ((*right_nullmap)[0])
return;
for (size_t i = 0; i < size; ++i)
c[i] = apply(undec(a[i]), undec(b));
}
else
{
for (size_t i = 0; i < size; ++i)
{
if (!(*right_nullmap)[i])
c[i] = apply(unwrap<op_case, OpCase::LeftConstant>(a, i), undec(b[i]));
}
}
for (size_t i = 0; i < size; ++i)
{
if (!(*right_nullmap)[i])
c[i] = apply(unwrap<op_case, OpCase::LeftConstant>(a, i), unwrap<op_case, OpCase::RightConstant>(b, i));
}
}
else
for (size_t i = 0; i < size; ++i)
c[i] = apply(unwrap<op_case, OpCase::LeftConstant>(a, i), unwrap<op_case, OpCase::RightConstant>(b, i));
}
template <bool is_decimal_a, bool is_decimal_b, class A, class B>
@ -564,7 +631,7 @@ private:
using namespace traits_;
using namespace impl_;
template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true, bool valid_on_float_arguments = true>
template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true, bool valid_on_float_arguments = true, bool special_implementation_for_nulls = false>
class FunctionBinaryArithmetic : public IFunction
{
static constexpr const bool is_plus = IsOperation<Op>::plus;
@ -884,12 +951,12 @@ class FunctionBinaryArithmetic : public IFunction
}
template <OpCase op_case, bool left_decimal, bool right_decimal, typename OpImpl, typename OpImplCheck>
void helperInvokeEither(const auto& left, const auto& right, auto& vec_res, auto scale_a, auto scale_b) const
void helperInvokeEither(const auto& left, const auto& right, auto& vec_res, auto scale_a, auto scale_b, const NullMap * right_nullmap) const
{
if (check_decimal_overflow)
OpImplCheck::template process<op_case, left_decimal, right_decimal>(left, right, vec_res, scale_a, scale_b);
OpImplCheck::template process<op_case, left_decimal, right_decimal>(left, right, vec_res, scale_a, scale_b, right_nullmap);
else
OpImpl::template process<op_case, left_decimal, right_decimal>(left, right, vec_res, scale_a, scale_b);
OpImpl::template process<op_case, left_decimal, right_decimal>(left, right, vec_res, scale_a, scale_b, right_nullmap);
}
template <class LeftDataType, class RightDataType, class ResultDataType>
@ -897,7 +964,7 @@ class FunctionBinaryArithmetic : public IFunction
const auto & left, const auto & right,
const ColumnConst * const col_left_const, const ColumnConst * const col_right_const,
const auto * const col_left, const auto * const col_right,
size_t col_left_size) const
size_t col_left_size, const NullMap * right_nullmap) const
{
using T0 = typename LeftDataType::FieldType;
using T1 = typename RightDataType::FieldType;
@ -979,9 +1046,10 @@ class FunctionBinaryArithmetic : public IFunction
const NativeResultType const_a = helperGetOrConvert<T0, ResultDataType>(col_left_const, left);
const NativeResultType const_b = helperGetOrConvert<T1, ResultDataType>(col_right_const, right);
const ResultType res = check_decimal_overflow
? OpImplCheck::template process<left_is_decimal, right_is_decimal>(const_a, const_b, scale_a, scale_b)
: OpImpl::template process<left_is_decimal, right_is_decimal>(const_a, const_b, scale_a, scale_b);
ResultType res = {};
if (!right_nullmap || !(*right_nullmap)[0])
res = check_decimal_overflow ? OpImplCheck::template process<left_is_decimal, right_is_decimal>(const_a, const_b, scale_a, scale_b)
: OpImpl::template process<left_is_decimal, right_is_decimal>(const_a, const_b, scale_a, scale_b);
if constexpr (result_is_decimal)
return ResultDataType(type.getPrecision(), type.getScale()).createColumnConst(
@ -1001,21 +1069,21 @@ class FunctionBinaryArithmetic : public IFunction
if (col_left && col_right)
{
helperInvokeEither<OpCase::Vector, left_is_decimal, right_is_decimal, OpImpl, OpImplCheck>(
col_left->getData(), col_right->getData(), vec_res, scale_a, scale_b);
col_left->getData(), col_right->getData(), vec_res, scale_a, scale_b, right_nullmap);
}
else if (col_left_const && col_right)
{
const NativeResultType const_a = helperGetOrConvert<T0, ResultDataType>(col_left_const, left);
helperInvokeEither<OpCase::LeftConstant, left_is_decimal, right_is_decimal, OpImpl, OpImplCheck>(
const_a, col_right->getData(), vec_res, scale_a, scale_b);
const_a, col_right->getData(), vec_res, scale_a, scale_b, right_nullmap);
}
else if (col_left && col_right_const)
{
const NativeResultType const_b = helperGetOrConvert<T1, ResultDataType>(col_right_const, right);
helperInvokeEither<OpCase::RightConstant, left_is_decimal, right_is_decimal, OpImpl, OpImplCheck>(
col_left->getData(), const_b, vec_res, scale_a, scale_b);
col_left->getData(), const_b, vec_res, scale_a, scale_b, right_nullmap);
}
else
return nullptr;
@ -1036,6 +1104,11 @@ public:
size_t getNumberOfArguments() const override { return 2; }
bool useDefaultImplementationForNulls() const override
{
return !special_implementation_for_nulls;
}
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & arguments) const override
{
return ((IsOperation<Op>::div_int || IsOperation<Op>::modulo) && !arguments[1].is_const)
@ -1385,7 +1458,7 @@ public:
}
template <typename A, typename B>
ColumnPtr executeNumeric(const ColumnsWithTypeAndName & arguments, const A & left, const B & right) const
ColumnPtr executeNumeric(const ColumnsWithTypeAndName & arguments, const A & left, const B & right, const NullMap * right_nullmap) const
{
using LeftDataType = std::decay_t<decltype(left)>;
using RightDataType = std::decay_t<decltype(right)>;
@ -1420,7 +1493,8 @@ public:
left, right,
col_left_const, col_right_const,
col_left, col_right,
col_left_size);
col_left_size,
right_nullmap);
}
else // can't avoid else and another indentation level, otherwise the compiler would try to instantiate
// ColVecResult for Decimals which would lead to a compile error.
@ -1430,7 +1504,7 @@ public:
/// non-vector result
if (col_left_const && col_right_const)
{
const auto res = OpImpl::process(
const auto res = right_nullmap && (*right_nullmap)[0] ? ResultType() : OpImpl::process(
col_left_const->template getValue<T0>(),
col_right_const->template getValue<T1>());
@ -1448,7 +1522,8 @@ public:
col_left->getData().data(),
col_right->getData().data(),
vec_res.data(),
vec_res.size());
vec_res.size(),
right_nullmap);
}
else if (col_left_const && col_right)
{
@ -1458,7 +1533,8 @@ public:
&value,
col_right->getData().data(),
vec_res.data(),
vec_res.size());
vec_res.size(),
right_nullmap);
}
else if (col_left && col_right_const)
{
@ -1468,7 +1544,8 @@ public:
col_left->getData().data(),
&value,
vec_res.data(),
vec_res.size());
vec_res.size(),
right_nullmap);
}
else
return nullptr;
@ -1515,6 +1592,22 @@ public:
const auto & left_argument = arguments[0];
const auto & right_argument = arguments[1];
if (special_implementation_for_nulls && !right_nullmap)
{
assert(right_argument.type->isNullable());
bool is_const = checkColumnConst<ColumnNullable>(right_argument.column.get());
const ColumnNullable * nullable_column = is_const ? checkAndGetColumnConstData<ColumnNullable>(right_argument.column.get())
: checkAndGetColumn<ColumnNullable>(*right_argument.column);
assert(nullable_column);
const auto & null_bytemap = nullable_column->getNullMapData();
auto res = executeImpl2(createBlockWithNestedColumns(arguments), removeNullable(result_type), input_rows_count, &null_bytemap);
return wrapInNullable(res, arguments, result_type, input_rows_count);
}
const auto * const left_generic = left_argument.type.get();
const auto * const right_generic = right_argument.type.get();
ColumnPtr res;
@ -1548,7 +1641,7 @@ public:
return (res = executeStringInteger<ColumnString>(arguments, left, right)) != nullptr;
}
else
return (res = executeNumeric(arguments, left, right)) != nullptr;
return (res = executeNumeric(arguments, left, right, right_nullmap)) != nullptr;
});
if (!valid)
@ -1619,11 +1712,11 @@ public:
};
template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true, bool valid_on_float_arguments = true>
class FunctionBinaryArithmeticWithConstants : public FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments>
template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true, bool valid_on_float_arguments = true, bool special_implementation_for_nulls = false>
class FunctionBinaryArithmeticWithConstants : public FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments, special_implementation_for_nulls>
{
public:
using Base = FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments>;
using Base = FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments, special_implementation_for_nulls>;
using Monotonicity = typename Base::Monotonicity;
static FunctionPtr create(
@ -1822,22 +1915,35 @@ public:
FunctionBasePtr buildImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type) const override
{
bool special_implementation_for_nulls = arguments[1].type->isNullable()
&& (IsOperation<Op>::div_int || IsOperation<Op>::modulo
|| (IsOperation<Op>::div_floating
&& (isDecimalOrNullableDecimal(arguments[0].type) || isDecimalOrNullableDecimal(arguments[1].type))));
/// More efficient specialization for two numeric arguments.
if (arguments.size() == 2
&& ((arguments[0].column && isColumnConst(*arguments[0].column))
|| (arguments[1].column && isColumnConst(*arguments[1].column))))
{
auto func = special_implementation_for_nulls ? FunctionBinaryArithmeticWithConstants<Op, Name, valid_on_default_arguments, valid_on_float_arguments, true>::create(
arguments[0], arguments[1], return_type, context)
: FunctionBinaryArithmeticWithConstants<Op, Name, valid_on_default_arguments, valid_on_float_arguments, false>::create(
arguments[0], arguments[1], return_type, context);
return std::make_unique<FunctionToFunctionBaseAdaptor>(
FunctionBinaryArithmeticWithConstants<Op, Name, valid_on_default_arguments, valid_on_float_arguments>::create(
arguments[0], arguments[1], return_type, context),
func,
collections::map<DataTypes>(arguments, [](const auto & elem) { return elem.type; }),
return_type);
}
auto func = special_implementation_for_nulls
? FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments, true>::create(context)
: FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments, false>::create(context);
return std::make_unique<FunctionToFunctionBaseAdaptor>(
FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments>::create(context),
func,
collections::map<DataTypes>(arguments, [](const auto & elem) { return elem.type; }),
return_type);
}
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override

View File

@ -26,16 +26,27 @@ struct DivideIntegralByConstantImpl
static const constexpr bool allow_string_integer = false;
template <OpCase op_case>
static void NO_INLINE process(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t size)
static void NO_INLINE process(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t size, const NullMap * right_nullmap)
{
if constexpr (op_case == OpCase::Vector)
for (size_t i = 0; i < size; ++i)
c[i] = Op::template apply<ResultType>(a[i], b[i]);
else if constexpr (op_case == OpCase::LeftConstant)
for (size_t i = 0; i < size; ++i)
c[i] = Op::template apply<ResultType>(*a, b[i]);
else
if constexpr (op_case == OpCase::RightConstant)
{
if (right_nullmap && (*right_nullmap)[0])
return;
vectorConstant(a, *b, c, size);
}
else
{
if (right_nullmap)
{
for (size_t i = 0; i < size; ++i)
if (!(*right_nullmap)[i])
apply<op_case>(a, b, c, i);
}
else
for (size_t i = 0; i < size; ++i)
apply<op_case>(a, b, c, i);
}
}
static ResultType process(A a, B b) { return Op::template apply<ResultType>(a, b); }
@ -69,6 +80,16 @@ struct DivideIntegralByConstantImpl
divideImpl(a_pos, b, c_pos, size);
}
private:
template <OpCase op_case>
static inline void apply(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t i)
{
if constexpr (op_case == OpCase::Vector)
c[i] = Op::template apply<ResultType>(a[i], b[i]);
else
c[i] = Op::template apply<ResultType>(*a, b[i]);
}
};
/** Specializations are specified for dividing numbers of the type UInt64, UInt32, Int64, Int32 by the numbers of the same sign.

View File

@ -30,16 +30,26 @@ struct ModuloByConstantImpl
static const constexpr bool allow_string_integer = false;
template <OpCase op_case>
static void NO_INLINE process(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t size)
static void NO_INLINE process(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t size, const NullMap * right_nullmap)
{
if constexpr (op_case == OpCase::Vector)
for (size_t i = 0; i < size; ++i)
c[i] = Op::template apply<ResultType>(a[i], b[i]);
else if constexpr (op_case == OpCase::LeftConstant)
for (size_t i = 0; i < size; ++i)
c[i] = Op::template apply<ResultType>(*a, b[i]);
else
if constexpr (op_case == OpCase::RightConstant)
{
if (right_nullmap && (*right_nullmap)[0])
return;
vectorConstant(a, *b, c, size);
}
else
{
if (right_nullmap)
{
for (size_t i = 0; i < size; ++i)
if (!(*right_nullmap)[i])
apply<op_case>(a, b, c, i);
}
else
for (size_t i = 0; i < size; ++i)
apply<op_case>(a, b, c, i);
}
}
static ResultType process(A a, B b) { return Op::template apply<ResultType>(a, b); }
@ -95,6 +105,16 @@ struct ModuloByConstantImpl
dst[i] = src[i] & mask;
}
}
private:
template <OpCase op_case>
static inline void apply(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t i)
{
if constexpr (op_case == OpCase::Vector)
c[i] = Op::template apply<ResultType>(a[i], b[i]);
else
c[i] = Op::template apply<ResultType>(*a, b[i]);
}
};
template <typename A, typename B>

View File

@ -0,0 +1,52 @@
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
1
0
0
0
0
0
1
0
0
1
0
0
1
0
0
1
0
0
1
0
0
1
0
0

View File

@ -0,0 +1,59 @@
SELECT 1 / CAST(NULL, 'Nullable(Decimal(7, 2))');
SELECT materialize(1) / CAST(NULL, 'Nullable(Decimal(7, 2))');
SELECT 1 / CAST(materialize(NULL), 'Nullable(Decimal(7, 2))');
SELECT materialize(1) / CAST(materialize(NULL), 'Nullable(Decimal(7, 2))');
SELECT intDiv(1, CAST(NULL, 'Nullable(Decimal(7, 2))'));
SELECT intDiv(materialize(1), CAST(NULL, 'Nullable(Decimal(7, 2))'));
SELECT intDiv(1, CAST(materialize(NULL), 'Nullable(Decimal(7, 2))'));
SELECT intDiv(materialize(1), CAST(materialize(NULL), 'Nullable(Decimal(7, 2))'));
SELECT toDecimal32(1, 2) / CAST(NULL, 'Nullable(UInt32)');
SELECT materialize(toDecimal32(1, 2)) / CAST(NULL, 'Nullable(UInt32)');
SELECT toDecimal32(1, 2) / CAST(materialize(NULL), 'Nullable(UInt32)');
SELECT materialize(toDecimal32(1, 2)) / CAST(materialize(NULL), 'Nullable(UInt32)');
SELECT intDiv(1, CAST(NULL, 'Nullable(UInt32)'));
SELECT intDiv(materialize(1), CAST(NULL, 'Nullable(UInt32)'));
SELECT intDiv(1, CAST(materialize(NULL), 'Nullable(UInt32)'));
SELECT intDiv(materialize(1), CAST(materialize(NULL), 'Nullable(UInt32)'));
SELECT 1 % CAST(NULL, 'Nullable(UInt32)');
SELECT materialize(1) % CAST(NULL, 'Nullable(UInt32)');
SELECT 1 % CAST(materialize(NULL), 'Nullable(UInt32)');
SELECT materialize(1) % CAST(materialize(NULL), 'Nullable(UInt32)');
SELECT intDiv(1, CAST(NULL, 'Nullable(Float32)'));
SELECT intDiv(materialize(1), CAST(NULL, 'Nullable(Float32)'));
SELECT intDiv(1, CAST(materialize(NULL), 'Nullable(Float32)'));
SELECT intDiv(materialize(1), CAST(materialize(NULL), 'Nullable(Float32)'));
SELECT 1 % CAST(NULL, 'Nullable(Float32)');
SELECT materialize(1) % CAST(NULL, 'Nullable(Float32)');
SELECT 1 % CAST(materialize(NULL), 'Nullable(Float32)');
SELECT materialize(1) % CAST(materialize(NULL), 'Nullable(Float32)');
DROP TABLE IF EXISTS nullable_division;
CREATE TABLE nullable_division (x UInt32, y UInt32, a Decimal(7, 2), b Decimal(7, 2)) ENGINE=MergeTree() order by x;
INSERT INTO nullable_division VALUES (1, 1, 1, 1), (1, NULL, 1, NULL), (1, 0, 1, 0);
SELECT if(y = 0, 0, intDiv(x, y)) from nullable_division;
SELECT if(y = 0, 0, x % y) from nullable_division;
SELECT if(y = 0, 0, intDiv(a, y)) from nullable_division;
SELECT if(y = 0, 0, a / y) from nullable_division;
SELECT if(b = 0, 0, intDiv(a, b)) from nullable_division;
SELECT if(b = 0, 0, a / b) from nullable_division;
SELECT if(b = 0, 0, intDiv(x, b)) from nullable_division;
SELECT if(b = 0, 0, x / b) from nullable_division;
DROP TABLE nullable_division;