Merge pull request #28352 from Avogar/div-null

Avoid division by zero when denominator is Nullable
This commit is contained in:
Anton Popov 2021-10-07 14:47:53 +03:00 committed by GitHub
commit 4bc14dedfb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 386 additions and 64 deletions

View File

@ -24,6 +24,7 @@
#include <Columns/ColumnString.h> #include <Columns/ColumnString.h>
#include <Columns/ColumnConst.h> #include <Columns/ColumnConst.h>
#include <Columns/ColumnAggregateFunction.h> #include <Columns/ColumnAggregateFunction.h>
#include <Columns/ColumnNullable.h>
#include "Core/DecimalFunctions.h" #include "Core/DecimalFunctions.h"
#include "IFunction.h" #include "IFunction.h"
#include "FunctionHelpers.h" #include "FunctionHelpers.h"
@ -51,7 +52,6 @@
#include <cassert> #include <cassert>
namespace DB namespace DB
{ {
@ -197,18 +197,43 @@ struct BinaryOperation
static const constexpr bool allow_string_integer = false; static const constexpr bool allow_string_integer = false;
template <OpCase op_case> template <OpCase op_case>
static void NO_INLINE process(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t size) static void NO_INLINE process(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t size, const NullMap * right_nullmap = nullptr)
{ {
for (size_t i = 0; i < size; ++i) if constexpr (op_case == OpCase::RightConstant)
if constexpr (op_case == OpCase::Vector) {
c[i] = Op::template apply<ResultType>(a[i], b[i]); if (right_nullmap && (*right_nullmap)[0])
else if constexpr (op_case == OpCase::LeftConstant) return;
c[i] = Op::template apply<ResultType>(*a, b[i]);
else for (size_t i = 0; i < size; ++i)
c[i] = Op::template apply<ResultType>(a[i], *b); c[i] = Op::template apply<ResultType>(a[i], *b);
}
else
{
if (right_nullmap)
{
for (size_t i = 0; i < size; ++i)
if ((*right_nullmap)[i])
c[i] = ResultType();
else
apply<op_case>(a, b, c, i);
}
else
for (size_t i = 0; i < size; ++i)
apply<op_case>(a, b, c, i);
}
} }
static ResultType process(A a, B b) { return Op::template apply<ResultType>(a, b); } static ResultType process(A a, B b) { return Op::template apply<ResultType>(a, b); }
private:
template <OpCase op_case>
static inline void apply(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t i)
{
if constexpr (op_case == OpCase::Vector)
c[i] = Op::template apply<ResultType>(a[i], b[i]);
else
c[i] = Op::template apply<ResultType>(*a, b[i]);
}
}; };
template <typename B, typename Op> template <typename B, typename Op>
@ -371,7 +396,7 @@ private:
public: public:
template <OpCase op_case, bool is_decimal_a, bool is_decimal_b> template <OpCase op_case, bool is_decimal_a, bool is_decimal_b>
static void NO_INLINE process(const auto & a, const auto & b, ResultContainerType & c, static void NO_INLINE process(const auto & a, const auto & b, ResultContainerType & c,
NativeResultType scale_a, NativeResultType scale_b) NativeResultType scale_a, NativeResultType scale_b, const NullMap * right_nullmap = nullptr)
{ {
if constexpr (op_case == OpCase::LeftConstant) static_assert(!is_decimal<decltype(a)>); if constexpr (op_case == OpCase::LeftConstant) static_assert(!is_decimal<decltype(a)>);
if constexpr (op_case == OpCase::RightConstant) static_assert(!is_decimal<decltype(b)>); if constexpr (op_case == OpCase::RightConstant) static_assert(!is_decimal<decltype(b)>);
@ -428,18 +453,14 @@ public:
} }
else if constexpr (is_division && is_decimal_b) else if constexpr (is_division && is_decimal_b)
{ {
for (size_t i = 0; i < size; ++i) processWithRightNullmapImpl<op_case>(a, b, c, size, right_nullmap, [&scale_a](const auto & left, const auto & right)
c[i] = applyScaledDiv<is_decimal_a>( {
unwrap<op_case, OpCase::LeftConstant>(a, i), return applyScaledDiv<is_decimal_a>(left, right, scale_a);
unwrap<op_case, OpCase::RightConstant>(b, i), });
scale_a);
return; return;
} }
for (size_t i = 0; i < size; ++i) processWithRightNullmapImpl<op_case>(a, b, c, size, right_nullmap, [](const auto & left, const auto & right){ return apply(left, right); });
c[i] = apply(
unwrap<op_case, OpCase::LeftConstant>(a, i),
unwrap<op_case, OpCase::RightConstant>(b, i));
} }
template <bool is_decimal_a, bool is_decimal_b, class A, class B> template <bool is_decimal_a, bool is_decimal_b, class A, class B>
@ -460,6 +481,35 @@ public:
} }
private: private:
template <OpCase op_case, typename ApplyFunc>
static inline void processWithRightNullmapImpl(const auto & a, const auto & b, ResultContainerType & c, size_t size, const NullMap * right_nullmap, ApplyFunc apply_func)
{
if (right_nullmap)
{
if constexpr (op_case == OpCase::RightConstant)
{
if ((*right_nullmap)[0])
return;
for (size_t i = 0; i < size; ++i)
c[i] = apply_func(undec(a[i]), undec(b));
}
else
{
for (size_t i = 0; i < size; ++i)
{
if ((*right_nullmap)[i])
c[i] = ResultType();
else
c[i] = apply_func(unwrap<op_case, OpCase::LeftConstant>(a, i), undec(b[i]));
}
}
}
else
for (size_t i = 0; i < size; ++i)
c[i] = apply_func(unwrap<op_case, OpCase::LeftConstant>(a, i), unwrap<op_case, OpCase::RightConstant>(b, i));
}
static constexpr bool is_plus_minus = IsOperation<Operation>::plus || static constexpr bool is_plus_minus = IsOperation<Operation>::plus ||
IsOperation<Operation>::minus; IsOperation<Operation>::minus;
static constexpr bool is_multiply = IsOperation<Operation>::multiply; static constexpr bool is_multiply = IsOperation<Operation>::multiply;
@ -564,7 +614,7 @@ private:
using namespace traits_; using namespace traits_;
using namespace impl_; using namespace impl_;
template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true, bool valid_on_float_arguments = true> template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true, bool valid_on_float_arguments = true, bool division_by_nullable = false>
class FunctionBinaryArithmetic : public IFunction class FunctionBinaryArithmetic : public IFunction
{ {
static constexpr const bool is_plus = IsOperation<Op>::plus; static constexpr const bool is_plus = IsOperation<Op>::plus;
@ -884,12 +934,12 @@ class FunctionBinaryArithmetic : public IFunction
} }
template <OpCase op_case, bool left_decimal, bool right_decimal, typename OpImpl, typename OpImplCheck> template <OpCase op_case, bool left_decimal, bool right_decimal, typename OpImpl, typename OpImplCheck>
void helperInvokeEither(const auto& left, const auto& right, auto& vec_res, auto scale_a, auto scale_b) const void helperInvokeEither(const auto& left, const auto& right, auto& vec_res, auto scale_a, auto scale_b, const NullMap * right_nullmap) const
{ {
if (check_decimal_overflow) if (check_decimal_overflow)
OpImplCheck::template process<op_case, left_decimal, right_decimal>(left, right, vec_res, scale_a, scale_b); OpImplCheck::template process<op_case, left_decimal, right_decimal>(left, right, vec_res, scale_a, scale_b, right_nullmap);
else else
OpImpl::template process<op_case, left_decimal, right_decimal>(left, right, vec_res, scale_a, scale_b); OpImpl::template process<op_case, left_decimal, right_decimal>(left, right, vec_res, scale_a, scale_b, right_nullmap);
} }
template <class LeftDataType, class RightDataType, class ResultDataType> template <class LeftDataType, class RightDataType, class ResultDataType>
@ -897,7 +947,7 @@ class FunctionBinaryArithmetic : public IFunction
const auto & left, const auto & right, const auto & left, const auto & right,
const ColumnConst * const col_left_const, const ColumnConst * const col_right_const, const ColumnConst * const col_left_const, const ColumnConst * const col_right_const,
const auto * const col_left, const auto * const col_right, const auto * const col_left, const auto * const col_right,
size_t col_left_size) const size_t col_left_size, const NullMap * right_nullmap) const
{ {
using T0 = typename LeftDataType::FieldType; using T0 = typename LeftDataType::FieldType;
using T1 = typename RightDataType::FieldType; using T1 = typename RightDataType::FieldType;
@ -979,9 +1029,10 @@ class FunctionBinaryArithmetic : public IFunction
const NativeResultType const_a = helperGetOrConvert<T0, ResultDataType>(col_left_const, left); const NativeResultType const_a = helperGetOrConvert<T0, ResultDataType>(col_left_const, left);
const NativeResultType const_b = helperGetOrConvert<T1, ResultDataType>(col_right_const, right); const NativeResultType const_b = helperGetOrConvert<T1, ResultDataType>(col_right_const, right);
const ResultType res = check_decimal_overflow ResultType res = {};
? OpImplCheck::template process<left_is_decimal, right_is_decimal>(const_a, const_b, scale_a, scale_b) if (!right_nullmap || !(*right_nullmap)[0])
: OpImpl::template process<left_is_decimal, right_is_decimal>(const_a, const_b, scale_a, scale_b); res = check_decimal_overflow ? OpImplCheck::template process<left_is_decimal, right_is_decimal>(const_a, const_b, scale_a, scale_b)
: OpImpl::template process<left_is_decimal, right_is_decimal>(const_a, const_b, scale_a, scale_b);
if constexpr (result_is_decimal) if constexpr (result_is_decimal)
return ResultDataType(type.getPrecision(), type.getScale()).createColumnConst( return ResultDataType(type.getPrecision(), type.getScale()).createColumnConst(
@ -1001,21 +1052,21 @@ class FunctionBinaryArithmetic : public IFunction
if (col_left && col_right) if (col_left && col_right)
{ {
helperInvokeEither<OpCase::Vector, left_is_decimal, right_is_decimal, OpImpl, OpImplCheck>( helperInvokeEither<OpCase::Vector, left_is_decimal, right_is_decimal, OpImpl, OpImplCheck>(
col_left->getData(), col_right->getData(), vec_res, scale_a, scale_b); col_left->getData(), col_right->getData(), vec_res, scale_a, scale_b, right_nullmap);
} }
else if (col_left_const && col_right) else if (col_left_const && col_right)
{ {
const NativeResultType const_a = helperGetOrConvert<T0, ResultDataType>(col_left_const, left); const NativeResultType const_a = helperGetOrConvert<T0, ResultDataType>(col_left_const, left);
helperInvokeEither<OpCase::LeftConstant, left_is_decimal, right_is_decimal, OpImpl, OpImplCheck>( helperInvokeEither<OpCase::LeftConstant, left_is_decimal, right_is_decimal, OpImpl, OpImplCheck>(
const_a, col_right->getData(), vec_res, scale_a, scale_b); const_a, col_right->getData(), vec_res, scale_a, scale_b, right_nullmap);
} }
else if (col_left && col_right_const) else if (col_left && col_right_const)
{ {
const NativeResultType const_b = helperGetOrConvert<T1, ResultDataType>(col_right_const, right); const NativeResultType const_b = helperGetOrConvert<T1, ResultDataType>(col_right_const, right);
helperInvokeEither<OpCase::RightConstant, left_is_decimal, right_is_decimal, OpImpl, OpImplCheck>( helperInvokeEither<OpCase::RightConstant, left_is_decimal, right_is_decimal, OpImpl, OpImplCheck>(
col_left->getData(), const_b, vec_res, scale_a, scale_b); col_left->getData(), const_b, vec_res, scale_a, scale_b, right_nullmap);
} }
else else
return nullptr; return nullptr;
@ -1036,6 +1087,14 @@ public:
size_t getNumberOfArguments() const override { return 2; } size_t getNumberOfArguments() const override { return 2; }
bool useDefaultImplementationForNulls() const override
{
/// We shouldn't use default implementation for nulls for the case when operation is divide,
/// intDiv or modulo and denominator is Nullable(Something), because it may cause division
/// by zero error (when value is Null we store default value 0 in nested column).
return !division_by_nullable;
}
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & arguments) const override bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & arguments) const override
{ {
return ((IsOperation<Op>::div_int || IsOperation<Op>::modulo) && !arguments[1].is_const) return ((IsOperation<Op>::div_int || IsOperation<Op>::modulo) && !arguments[1].is_const)
@ -1385,7 +1444,7 @@ public:
} }
template <typename A, typename B> template <typename A, typename B>
ColumnPtr executeNumeric(const ColumnsWithTypeAndName & arguments, const A & left, const B & right) const ColumnPtr executeNumeric(const ColumnsWithTypeAndName & arguments, const A & left, const B & right, const NullMap * right_nullmap) const
{ {
using LeftDataType = std::decay_t<decltype(left)>; using LeftDataType = std::decay_t<decltype(left)>;
using RightDataType = std::decay_t<decltype(right)>; using RightDataType = std::decay_t<decltype(right)>;
@ -1420,7 +1479,8 @@ public:
left, right, left, right,
col_left_const, col_right_const, col_left_const, col_right_const,
col_left, col_right, col_left, col_right,
col_left_size); col_left_size,
right_nullmap);
} }
else // can't avoid else and another indentation level, otherwise the compiler would try to instantiate else // can't avoid else and another indentation level, otherwise the compiler would try to instantiate
// ColVecResult for Decimals which would lead to a compile error. // ColVecResult for Decimals which would lead to a compile error.
@ -1430,7 +1490,7 @@ public:
/// non-vector result /// non-vector result
if (col_left_const && col_right_const) if (col_left_const && col_right_const)
{ {
const auto res = OpImpl::process( const auto res = right_nullmap && (*right_nullmap)[0] ? ResultType() : OpImpl::process(
col_left_const->template getValue<T0>(), col_left_const->template getValue<T0>(),
col_right_const->template getValue<T1>()); col_right_const->template getValue<T1>());
@ -1448,7 +1508,8 @@ public:
col_left->getData().data(), col_left->getData().data(),
col_right->getData().data(), col_right->getData().data(),
vec_res.data(), vec_res.data(),
vec_res.size()); vec_res.size(),
right_nullmap);
} }
else if (col_left_const && col_right) else if (col_left_const && col_right)
{ {
@ -1458,7 +1519,8 @@ public:
&value, &value,
col_right->getData().data(), col_right->getData().data(),
vec_res.data(), vec_res.data(),
vec_res.size()); vec_res.size(),
right_nullmap);
} }
else if (col_left && col_right_const) else if (col_left && col_right_const)
{ {
@ -1468,7 +1530,8 @@ public:
col_left->getData().data(), col_left->getData().data(),
&value, &value,
vec_res.data(), vec_res.data(),
vec_res.size()); vec_res.size(),
right_nullmap);
} }
else else
return nullptr; return nullptr;
@ -1493,28 +1556,46 @@ public:
} }
/// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval. /// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
if (auto function_builder if (auto function_builder = getFunctionForIntervalArithmetic(arguments[0].type, arguments[1].type, context))
= getFunctionForIntervalArithmetic(arguments[0].type, arguments[1].type, context))
{ {
return executeDateTimeIntervalPlusMinus(arguments, result_type, input_rows_count, function_builder); return executeDateTimeIntervalPlusMinus(arguments, result_type, input_rows_count, function_builder);
} }
/// Special case when the function is plus, minus or multiply, both arguments are tuples. /// Special case when the function is plus, minus or multiply, both arguments are tuples.
if (auto function_builder if (auto function_builder = getFunctionForTupleArithmetic(arguments[0].type, arguments[1].type, context))
= getFunctionForTupleArithmetic(arguments[0].type, arguments[1].type, context))
{ {
return function_builder->build(arguments)->execute(arguments, result_type, input_rows_count); return function_builder->build(arguments)->execute(arguments, result_type, input_rows_count);
} }
/// Special case when the function is multiply or divide, one of arguments is Tuple and another is Number. /// Special case when the function is multiply or divide, one of arguments is Tuple and another is Number.
if (auto function_builder if (auto function_builder = getFunctionForTupleAndNumberArithmetic(arguments[0].type, arguments[1].type, context))
= getFunctionForTupleAndNumberArithmetic(arguments[0].type, arguments[1].type, context))
{ {
return executeTupleNumberOperator(arguments, result_type, input_rows_count, function_builder); return executeTupleNumberOperator(arguments, result_type, input_rows_count, function_builder);
} }
return executeImpl2(arguments, result_type, input_rows_count);
}
ColumnPtr executeImpl2(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count, const NullMap * right_nullmap = nullptr) const
{
const auto & left_argument = arguments[0]; const auto & left_argument = arguments[0];
const auto & right_argument = arguments[1]; const auto & right_argument = arguments[1];
/// Process special case when operation is divide, intDiv or modulo and denominator
/// is Nullable(Something) to prevent division by zero error.
if (division_by_nullable && !right_nullmap)
{
assert(right_argument.type->isNullable());
bool is_const = checkColumnConst<ColumnNullable>(right_argument.column.get());
const ColumnNullable * nullable_column = is_const ? checkAndGetColumnConstData<ColumnNullable>(right_argument.column.get())
: checkAndGetColumn<ColumnNullable>(*right_argument.column);
const auto & null_bytemap = nullable_column->getNullMapData();
auto res = executeImpl2(createBlockWithNestedColumns(arguments), removeNullable(result_type), input_rows_count, &null_bytemap);
return wrapInNullable(res, arguments, result_type, input_rows_count);
}
const auto * const left_generic = left_argument.type.get(); const auto * const left_generic = left_argument.type.get();
const auto * const right_generic = right_argument.type.get(); const auto * const right_generic = right_argument.type.get();
ColumnPtr res; ColumnPtr res;
@ -1548,7 +1629,7 @@ public:
return (res = executeStringInteger<ColumnString>(arguments, left, right)) != nullptr; return (res = executeStringInteger<ColumnString>(arguments, left, right)) != nullptr;
} }
else else
return (res = executeNumeric(arguments, left, right)) != nullptr; return (res = executeNumeric(arguments, left, right, right_nullmap)) != nullptr;
}); });
if (!valid) if (!valid)
@ -1619,11 +1700,11 @@ public:
}; };
template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true, bool valid_on_float_arguments = true> template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true, bool valid_on_float_arguments = true, bool division_by_nullable = false>
class FunctionBinaryArithmeticWithConstants : public FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments> class FunctionBinaryArithmeticWithConstants : public FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments, division_by_nullable>
{ {
public: public:
using Base = FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments>; using Base = FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments, division_by_nullable>;
using Monotonicity = typename Base::Monotonicity; using Monotonicity = typename Base::Monotonicity;
static FunctionPtr create( static FunctionPtr create(
@ -1822,22 +1903,37 @@ public:
FunctionBasePtr buildImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type) const override FunctionBasePtr buildImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type) const override
{ {
/// Check the case when operation is divide, intDiv or modulo and denominator is Nullable(Something).
/// For divide operation we should check only Nullable(Decimal), because only this case can throw division by zero error.
bool division_by_nullable = !arguments[0].type->onlyNull() && !arguments[1].type->onlyNull() && arguments[1].type->isNullable()
&& (IsOperation<Op>::div_int || IsOperation<Op>::modulo
|| (IsOperation<Op>::div_floating
&& (isDecimalOrNullableDecimal(arguments[0].type) || isDecimalOrNullableDecimal(arguments[1].type))));
/// More efficient specialization for two numeric arguments. /// More efficient specialization for two numeric arguments.
if (arguments.size() == 2 if (arguments.size() == 2
&& ((arguments[0].column && isColumnConst(*arguments[0].column)) && ((arguments[0].column && isColumnConst(*arguments[0].column))
|| (arguments[1].column && isColumnConst(*arguments[1].column)))) || (arguments[1].column && isColumnConst(*arguments[1].column))))
{ {
auto function = division_by_nullable ? FunctionBinaryArithmeticWithConstants<Op, Name, valid_on_default_arguments, valid_on_float_arguments, true>::create(
arguments[0], arguments[1], return_type, context)
: FunctionBinaryArithmeticWithConstants<Op, Name, valid_on_default_arguments, valid_on_float_arguments, false>::create(
arguments[0], arguments[1], return_type, context);
return std::make_unique<FunctionToFunctionBaseAdaptor>( return std::make_unique<FunctionToFunctionBaseAdaptor>(
FunctionBinaryArithmeticWithConstants<Op, Name, valid_on_default_arguments, valid_on_float_arguments>::create( function,
arguments[0], arguments[1], return_type, context),
collections::map<DataTypes>(arguments, [](const auto & elem) { return elem.type; }), collections::map<DataTypes>(arguments, [](const auto & elem) { return elem.type; }),
return_type); return_type);
} }
auto function = division_by_nullable
? FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments, true>::create(context)
: FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments, false>::create(context);
return std::make_unique<FunctionToFunctionBaseAdaptor>( return std::make_unique<FunctionToFunctionBaseAdaptor>(
FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments>::create(context), function,
collections::map<DataTypes>(arguments, [](const auto & elem) { return elem.type; }), collections::map<DataTypes>(arguments, [](const auto & elem) { return elem.type; }),
return_type); return_type);
} }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override

View File

@ -26,16 +26,29 @@ struct DivideIntegralByConstantImpl
static const constexpr bool allow_string_integer = false; static const constexpr bool allow_string_integer = false;
template <OpCase op_case> template <OpCase op_case>
static void NO_INLINE process(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t size) static void NO_INLINE process(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t size, const NullMap * right_nullmap)
{ {
if constexpr (op_case == OpCase::Vector) if constexpr (op_case == OpCase::RightConstant)
for (size_t i = 0; i < size; ++i) {
c[i] = Op::template apply<ResultType>(a[i], b[i]); if (right_nullmap && (*right_nullmap)[0])
else if constexpr (op_case == OpCase::LeftConstant) return;
for (size_t i = 0; i < size; ++i)
c[i] = Op::template apply<ResultType>(*a, b[i]);
else
vectorConstant(a, *b, c, size); vectorConstant(a, *b, c, size);
}
else
{
if (right_nullmap)
{
for (size_t i = 0; i < size; ++i)
if ((*right_nullmap)[i])
c[i] = ResultType();
else
apply<op_case>(a, b, c, i);
}
else
for (size_t i = 0; i < size; ++i)
apply<op_case>(a, b, c, i);
}
} }
static ResultType process(A a, B b) { return Op::template apply<ResultType>(a, b); } static ResultType process(A a, B b) { return Op::template apply<ResultType>(a, b); }
@ -69,6 +82,16 @@ struct DivideIntegralByConstantImpl
divideImpl(a_pos, b, c_pos, size); divideImpl(a_pos, b, c_pos, size);
} }
private:
template <OpCase op_case>
static inline void apply(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t i)
{
if constexpr (op_case == OpCase::Vector)
c[i] = Op::template apply<ResultType>(a[i], b[i]);
else
c[i] = Op::template apply<ResultType>(*a, b[i]);
}
}; };
/** Specializations are specified for dividing numbers of the type UInt64, UInt32, Int64, Int32 by the numbers of the same sign. /** Specializations are specified for dividing numbers of the type UInt64, UInt32, Int64, Int32 by the numbers of the same sign.

View File

@ -30,16 +30,28 @@ struct ModuloByConstantImpl
static const constexpr bool allow_string_integer = false; static const constexpr bool allow_string_integer = false;
template <OpCase op_case> template <OpCase op_case>
static void NO_INLINE process(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t size) static void NO_INLINE process(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t size, const NullMap * right_nullmap)
{ {
if constexpr (op_case == OpCase::Vector) if constexpr (op_case == OpCase::RightConstant)
for (size_t i = 0; i < size; ++i) {
c[i] = Op::template apply<ResultType>(a[i], b[i]); if (right_nullmap && (*right_nullmap)[0])
else if constexpr (op_case == OpCase::LeftConstant) return;
for (size_t i = 0; i < size; ++i)
c[i] = Op::template apply<ResultType>(*a, b[i]);
else
vectorConstant(a, *b, c, size); vectorConstant(a, *b, c, size);
}
else
{
if (right_nullmap)
{
for (size_t i = 0; i < size; ++i)
if ((*right_nullmap)[i])
c[i] = ResultType();
else
apply<op_case>(a, b, c, i);
}
else
for (size_t i = 0; i < size; ++i)
apply<op_case>(a, b, c, i);
}
} }
static ResultType process(A a, B b) { return Op::template apply<ResultType>(a, b); } static ResultType process(A a, B b) { return Op::template apply<ResultType>(a, b); }
@ -95,6 +107,16 @@ struct ModuloByConstantImpl
dst[i] = src[i] & mask; dst[i] = src[i] & mask;
} }
} }
private:
template <OpCase op_case>
static inline void apply(const A * __restrict a, const B * __restrict b, ResultType * __restrict c, size_t i)
{
if constexpr (op_case == OpCase::Vector)
c[i] = Op::template apply<ResultType>(a[i], b[i]);
else
c[i] = Op::template apply<ResultType>(*a, b[i]);
}
}; };
template <typename A, typename B> template <typename A, typename B>

View File

@ -0,0 +1,80 @@
\N
\N
\N
\N
1
1
1
1
\N
\N
\N
\N
1
1
1
1
\N
\N
\N
\N
1
1
1
1
\N
\N
\N
\N
1
1
1
1
\N
\N
\N
\N
0
0
0
0
\N
\N
\N
\N
1
1
1
1
\N
\N
\N
\N
0
0
0
0
1
\N
0
0
\N
0
1
\N
0
1
\N
0
1
\N
0
1
\N
0
1
\N
0
1
\N
0

View File

@ -0,0 +1,101 @@
SELECT 1 / CAST(NULL, 'Nullable(Decimal(7, 2))');
SELECT materialize(1) / CAST(NULL, 'Nullable(Decimal(7, 2))');
SELECT 1 / CAST(materialize(NULL), 'Nullable(Decimal(7, 2))');
SELECT materialize(1) / CAST(materialize(NULL), 'Nullable(Decimal(7, 2))');
SELECT 1 / CAST(1, 'Nullable(Decimal(7, 2))');
SELECT materialize(1) / CAST(1, 'Nullable(Decimal(7, 2))');
SELECT 1 / CAST(materialize(1), 'Nullable(Decimal(7, 2))');
SELECT materialize(1) / CAST(materialize(1), 'Nullable(Decimal(7, 2))');
SELECT intDiv(1, CAST(NULL, 'Nullable(Decimal(7, 2))'));
SELECT intDiv(materialize(1), CAST(NULL, 'Nullable(Decimal(7, 2))'));
SELECT intDiv(1, CAST(materialize(NULL), 'Nullable(Decimal(7, 2))'));
SELECT intDiv(materialize(1), CAST(materialize(NULL), 'Nullable(Decimal(7, 2))'));
SELECT intDiv(1, CAST(1, 'Nullable(Decimal(7, 2))'));
SELECT intDiv(materialize(1), CAST(1, 'Nullable(Decimal(7, 2))'));
SELECT intDiv(1, CAST(materialize(1), 'Nullable(Decimal(7, 2))'));
SELECT intDiv(materialize(1), CAST(materialize(1), 'Nullable(Decimal(7, 2))'));
SELECT toDecimal32(1, 2) / CAST(NULL, 'Nullable(UInt32)');
SELECT materialize(toDecimal32(1, 2)) / CAST(NULL, 'Nullable(UInt32)');
SELECT toDecimal32(1, 2) / CAST(materialize(NULL), 'Nullable(UInt32)');
SELECT materialize(toDecimal32(1, 2)) / CAST(materialize(NULL), 'Nullable(UInt32)');
SELECT toDecimal32(1, 2) / CAST(1, 'Nullable(UInt32)');
SELECT materialize(toDecimal32(1, 2)) / CAST(1, 'Nullable(UInt32)');
SELECT toDecimal32(1, 2) / CAST(materialize(1), 'Nullable(UInt32)');
SELECT materialize(toDecimal32(1, 2)) / CAST(materialize(1), 'Nullable(UInt32)');
SELECT intDiv(1, CAST(NULL, 'Nullable(UInt32)'));
SELECT intDiv(materialize(1), CAST(NULL, 'Nullable(UInt32)'));
SELECT intDiv(1, CAST(materialize(NULL), 'Nullable(UInt32)'));
SELECT intDiv(materialize(1), CAST(materialize(NULL), 'Nullable(UInt32)'));
SELECT intDiv(1, CAST(1, 'Nullable(UInt32)'));
SELECT intDiv(materialize(1), CAST(1, 'Nullable(UInt32)'));
SELECT intDiv(1, CAST(materialize(1), 'Nullable(UInt32)'));
SELECT intDiv(materialize(1), CAST(materialize(1), 'Nullable(UInt32)'));
SELECT 1 % CAST(NULL, 'Nullable(UInt32)');
SELECT materialize(1) % CAST(NULL, 'Nullable(UInt32)');
SELECT 1 % CAST(materialize(NULL), 'Nullable(UInt32)');
SELECT materialize(1) % CAST(materialize(NULL), 'Nullable(UInt32)');
SELECT 1 % CAST(1, 'Nullable(UInt32)');
SELECT materialize(1) % CAST(1, 'Nullable(UInt32)');
SELECT 1 % CAST(materialize(1), 'Nullable(UInt32)');
SELECT materialize(1) % CAST(materialize(1), 'Nullable(UInt32)');
SELECT intDiv(1, CAST(NULL, 'Nullable(Float32)'));
SELECT intDiv(materialize(1), CAST(NULL, 'Nullable(Float32)'));
SELECT intDiv(1, CAST(materialize(NULL), 'Nullable(Float32)'));
SELECT intDiv(materialize(1), CAST(materialize(NULL), 'Nullable(Float32)'));
SELECT intDiv(1, CAST(1, 'Nullable(Float32)'));
SELECT intDiv(materialize(1), CAST(1, 'Nullable(Float32)'));
SELECT intDiv(1, CAST(materialize(1), 'Nullable(Float32)'));
SELECT intDiv(materialize(1), CAST(materialize(1), 'Nullable(Float32)'));
SELECT 1 % CAST(NULL, 'Nullable(Float32)');
SELECT materialize(1) % CAST(NULL, 'Nullable(Float32)');
SELECT 1 % CAST(materialize(NULL), 'Nullable(Float32)');
SELECT materialize(1) % CAST(materialize(NULL), 'Nullable(Float32)');
SELECT 1 % CAST(1, 'Nullable(Float32)');
SELECT materialize(1) % CAST(1, 'Nullable(Float32)');
SELECT 1 % CAST(materialize(1), 'Nullable(Float32)');
SELECT materialize(1) % CAST(materialize(1), 'Nullable(Float32)');
DROP TABLE IF EXISTS nullable_division;
CREATE TABLE nullable_division (x UInt32, y Nullable(UInt32), a Decimal(7, 2), b Nullable(Decimal(7, 2))) ENGINE=MergeTree() order by x;
INSERT INTO nullable_division VALUES (1, 1, 1, 1), (1, NULL, 1, NULL), (1, 0, 1, 0);
SELECT if(y = 0, 0, intDiv(x, y)) from nullable_division;
SELECT if(y = 0, 0, x % y) from nullable_division;
SELECT if(y = 0, 0, intDiv(a, y)) from nullable_division;
SELECT if(y = 0, 0, a / y) from nullable_division;
SELECT if(b = 0, 0, intDiv(a, b)) from nullable_division;
SELECT if(b = 0, 0, a / b) from nullable_division;
SELECT if(b = 0, 0, intDiv(x, b)) from nullable_division;
SELECT if(b = 0, 0, x / b) from nullable_division;
DROP TABLE nullable_division;