Forbid bit functions for floats

This will fix the following UB report [1]:

  "../src/Functions/bitAnd.cpp:24:61: runtime error: nan is outside the
range of representable values of type 'long' Received signal -3 Received
signal Unknown signal (-3)"

  [1]: https://clickhouse-test-reports.s3.yandex.net/19824/89c4055202b9d08459f90ee5791d4e3017b82fbf/fuzzer_ubsan/report.html#fail1
This commit is contained in:
Azat Khuzhin 2021-01-30 07:13:49 +03:00
parent 9d48e3ebd7
commit 7da4083237
12 changed files with 68 additions and 21 deletions

View File

@ -504,7 +504,7 @@ private:
using namespace traits_;
using namespace impl_;
template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true>
template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true, bool valid_on_float_arguments = true>
class FunctionBinaryArithmetic : public IFunction
{
static constexpr const bool is_plus = IsOperation<Op>::plus;
@ -542,8 +542,35 @@ class FunctionBinaryArithmetic : public IFunction
>(type, std::forward<F>(f));
}
template <typename F>
static bool castTypeNoFloats(const IDataType * type, F && f)
{
return castTypeToEither<
DataTypeUInt8,
DataTypeUInt16,
DataTypeUInt32,
DataTypeUInt64,
DataTypeUInt256,
DataTypeInt8,
DataTypeInt16,
DataTypeInt32,
DataTypeInt64,
DataTypeInt128,
DataTypeInt256,
DataTypeDate,
DataTypeDateTime,
DataTypeDecimal<Decimal32>,
DataTypeDecimal<Decimal64>,
DataTypeDecimal<Decimal128>,
DataTypeDecimal<Decimal256>,
DataTypeFixedString
>(type, std::forward<F>(f));
}
template <typename F>
static bool castBothTypes(const IDataType * left, const IDataType * right, F && f)
{
if constexpr (valid_on_float_arguments)
{
return castType(left, [&](const auto & left_)
{
@ -553,6 +580,17 @@ class FunctionBinaryArithmetic : public IFunction
});
});
}
else
{
return castTypeNoFloats(left, [&](const auto & left_)
{
return castTypeNoFloats(right, [&](const auto & right_)
{
return f(left_, right_);
});
});
}
}
static FunctionOverloadResolverPtr
getFunctionForIntervalArithmetic(const DataTypePtr & type0, const DataTypePtr & type1, const Context & context)
@ -1319,11 +1357,11 @@ public:
};
template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true>
class FunctionBinaryArithmeticWithConstants : public FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments>
template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true, bool valid_on_float_arguments = true>
class FunctionBinaryArithmeticWithConstants : public FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments>
{
public:
using Base = FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments>;
using Base = FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments>;
using Monotonicity = typename Base::Monotonicity;
static FunctionPtr create(
@ -1488,7 +1526,7 @@ private:
DataTypePtr return_type;
};
template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true>
template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true, bool valid_on_float_arguments = true>
class BinaryArithmeticOverloadResolver : public IFunctionOverloadResolverImpl
{
public:
@ -1512,14 +1550,14 @@ public:
|| (arguments[1].column && isColumnConst(*arguments[1].column))))
{
return std::make_unique<DefaultFunction>(
FunctionBinaryArithmeticWithConstants<Op, Name, valid_on_default_arguments>::create(
FunctionBinaryArithmeticWithConstants<Op, Name, valid_on_default_arguments, valid_on_float_arguments>::create(
arguments[0], arguments[1], return_type, context),
ext::map<DataTypes>(arguments, [](const auto & elem) { return elem.type; }),
return_type);
}
return std::make_unique<DefaultFunction>(
FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments>::create(context),
FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments>::create(context),
ext::map<DataTypes>(arguments, [](const auto & elem) { return elem.type; }),
return_type);
}
@ -1530,7 +1568,7 @@ public:
throw Exception(
"Number of arguments for function " + getName() + " doesn't match: passed " + toString(arguments.size()) + ", should be 2",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
return FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments>::getReturnTypeImplStatic(arguments, context);
return FunctionBinaryArithmetic<Op, Name, valid_on_default_arguments, valid_on_float_arguments>::getReturnTypeImplStatic(arguments, context);
}
private:

View File

@ -37,7 +37,7 @@ struct BitAndImpl
};
struct NameBitAnd { static constexpr auto name = "bitAnd"; };
using FunctionBitAnd = BinaryArithmeticOverloadResolver<BitAndImpl, NameBitAnd, true>;
using FunctionBitAnd = BinaryArithmeticOverloadResolver<BitAndImpl, NameBitAnd, true, false>;
}

View File

@ -36,7 +36,7 @@ struct BitOrImpl
};
struct NameBitOr { static constexpr auto name = "bitOr"; };
using FunctionBitOr = BinaryArithmeticOverloadResolver<BitOrImpl, NameBitOr, true>;
using FunctionBitOr = BinaryArithmeticOverloadResolver<BitOrImpl, NameBitOr, true, false>;
}

View File

@ -43,7 +43,7 @@ struct BitRotateLeftImpl
};
struct NameBitRotateLeft { static constexpr auto name = "bitRotateLeft"; };
using FunctionBitRotateLeft = BinaryArithmeticOverloadResolver<BitRotateLeftImpl, NameBitRotateLeft>;
using FunctionBitRotateLeft = BinaryArithmeticOverloadResolver<BitRotateLeftImpl, NameBitRotateLeft, true, false>;
}

View File

@ -42,7 +42,7 @@ struct BitRotateRightImpl
};
struct NameBitRotateRight { static constexpr auto name = "bitRotateRight"; };
using FunctionBitRotateRight = BinaryArithmeticOverloadResolver<BitRotateRightImpl, NameBitRotateRight>;
using FunctionBitRotateRight = BinaryArithmeticOverloadResolver<BitRotateRightImpl, NameBitRotateRight, true, false>;
}

View File

@ -42,7 +42,7 @@ struct BitShiftLeftImpl
};
struct NameBitShiftLeft { static constexpr auto name = "bitShiftLeft"; };
using FunctionBitShiftLeft = BinaryArithmeticOverloadResolver<BitShiftLeftImpl, NameBitShiftLeft>;
using FunctionBitShiftLeft = BinaryArithmeticOverloadResolver<BitShiftLeftImpl, NameBitShiftLeft, true, false>;
}

View File

@ -42,7 +42,7 @@ struct BitShiftRightImpl
};
struct NameBitShiftRight { static constexpr auto name = "bitShiftRight"; };
using FunctionBitShiftRight = BinaryArithmeticOverloadResolver<BitShiftRightImpl, NameBitShiftRight>;
using FunctionBitShiftRight = BinaryArithmeticOverloadResolver<BitShiftRightImpl, NameBitShiftRight, true, false>;
}

View File

@ -34,7 +34,7 @@ struct BitTestImpl
};
struct NameBitTest { static constexpr auto name = "bitTest"; };
using FunctionBitTest = BinaryArithmeticOverloadResolver<BitTestImpl, NameBitTest>;
using FunctionBitTest = BinaryArithmeticOverloadResolver<BitTestImpl, NameBitTest, true, false>;
}

View File

@ -36,7 +36,7 @@ struct BitXorImpl
};
struct NameBitXor { static constexpr auto name = "bitXor"; };
using FunctionBitXor = BinaryArithmeticOverloadResolver<BitXorImpl, NameBitXor, true>;
using FunctionBitXor = BinaryArithmeticOverloadResolver<BitXorImpl, NameBitXor, true, false>;
}

View File

@ -3,5 +3,5 @@ SET max_bytes_before_external_group_by = 200000000;
SET max_memory_usage = 1500000000;
SET max_threads = 12;
SELECT bitAnd(number, pow(2, 20) - 1) as k, argMaxIf(k, number % 2 = 0 ? number : Null, number > 42), uniq(number) AS u FROM numbers(1000000) GROUP BY k format Null;
SELECT bitAnd(number, toUInt64(pow(2, 20) - 1)) as k, argMaxIf(k, number % 2 = 0 ? number : Null, number > 42), uniq(number) AS u FROM numbers(1000000) GROUP BY k format Null;

View File

@ -0,0 +1,9 @@
SELECT bitAnd(0, inf); -- { serverError 43 }
SELECT bitXor(0, inf); -- { serverError 43 }
SELECT bitOr(0, inf); -- { serverError 43 }
SELECT bitTest(inf, 0); -- { serverError 43 }
SELECT bitTest(0, inf); -- { serverError 43 }
SELECT bitRotateLeft(inf, 0); -- { serverError 43 }
SELECT bitRotateRight(inf, 0); -- { serverError 43 }
SELECT bitShiftLeft(inf, 0); -- { serverError 43 }
SELECT bitShiftRight(inf, 0); -- { serverError 43 }