Fix UBsan report in lcm()/gcd()

Changelog:
- Check against INT64_MIN/INT64_MAX
- Check against std::numeric_limits<>::min/max
- Move common code into separate header - GCDLCMImpl.h
- Forbid floats

UBsan report [1].

  [1]: https://clickhouse-test-reports.s3.yandex.net/19466/cb30a02540a0f223df6668c5f88ff84aa666ff54/fuzzer_ubsan/report.html#fail1
This commit is contained in:
Azat Khuzhin 2021-01-24 21:28:12 +03:00
parent f4a4d33c2d
commit 27a5794795
11 changed files with 149 additions and 90 deletions

View File

@ -0,0 +1,67 @@
#pragma once
#include <DataTypes/NumberTraits.h>
#include <Common/Exception.h>
#include <numeric>
#include <limits>
#include <type_traits>
#if !defined(ARCADIA_BUILD)
# include "config_core.h"
#endif
namespace DB
{
namespace ErrorCodes
{
extern const int NOT_IMPLEMENTED;
extern const int DECIMAL_OVERFLOW;
}
template <class T>
inline constexpr bool is_gcd_lcm_implemeted = !(is_big_int_v<T> || std::is_floating_point_v<T>);
template <typename A, typename B, typename Impl, typename Name>
struct GCDLCMImpl
{
using ResultType = typename NumberTraits::ResultOfAdditionMultiplication<A, B>::Type;
static const constexpr bool allow_fixed_string = false;
template <typename Result = ResultType>
static inline std::enable_if_t<!is_gcd_lcm_implemeted<Result>, Result>
apply(A, B)
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not implemented for big integers and floats", Name::name);
}
template <typename Result = ResultType>
static inline std::enable_if_t<is_gcd_lcm_implemeted<Result>, Result>
apply(A a, B b)
{
throwIfDivisionLeadsToFPE(typename NumberTraits::ToInteger<A>::Type(a), typename NumberTraits::ToInteger<B>::Type(b));
throwIfDivisionLeadsToFPE(typename NumberTraits::ToInteger<B>::Type(b), typename NumberTraits::ToInteger<A>::Type(a));
using Int = typename NumberTraits::ToInteger<Result>::Type;
if constexpr (is_signed_v<Result>)
{
/// gcd() internally uses std::abs()
Int a_s = static_cast<Int>(a);
Int b_s = static_cast<Int>(b);
Int min = std::numeric_limits<Int>::min();
Int max = std::numeric_limits<Int>::max();
if (unlikely((a_s == min || a_s == max) || (b_s == min || b_s == max)))
throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Intermediate result overflow (signed a = {}, signed b = {}, min = {}, max = {})", a_s, b_s, min, max);
}
return Impl::applyImpl(a, b);
}
#if USE_EMBEDDED_COMPILER
static constexpr bool compilable = false; /// exceptions (and a non-trivial algorithm)
#endif
};
}

View File

@ -1,45 +1,28 @@
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionBinaryArithmetic.h>
#include <numeric>
#include <Functions/GCDLCMImpl.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NOT_IMPLEMENTED;
}
namespace
{
struct NameGCD { static constexpr auto name = "gcd"; };
template <typename A, typename B>
struct GCDImpl
struct GCDImpl : public GCDLCMImpl<A, B, GCDImpl<A, B>, NameGCD>
{
using ResultType = typename NumberTraits::ResultOfAdditionMultiplication<A, B>::Type;
static const constexpr bool allow_fixed_string = false;
using ResultType = typename GCDLCMImpl<A, B, GCDImpl, NameGCD>::ResultType;
template <typename Result = ResultType>
static inline Result apply([[maybe_unused]] A a, [[maybe_unused]] B b)
static ResultType applyImpl(A a, B b)
{
if constexpr (is_big_int_v<A> || is_big_int_v<B> || is_big_int_v<Result>)
throw Exception("GCD is not implemented for big integers", ErrorCodes::NOT_IMPLEMENTED);
else
{
throwIfDivisionLeadsToFPE(typename NumberTraits::ToInteger<A>::Type(a), typename NumberTraits::ToInteger<B>::Type(b));
throwIfDivisionLeadsToFPE(typename NumberTraits::ToInteger<B>::Type(b), typename NumberTraits::ToInteger<A>::Type(a));
return std::gcd(
typename NumberTraits::ToInteger<Result>::Type(a),
typename NumberTraits::ToInteger<Result>::Type(b));
}
using Int = typename NumberTraits::ToInteger<ResultType>::Type;
return std::gcd(Int(a), Int(b));
}
#if USE_EMBEDDED_COMPILER
static constexpr bool compilable = false; /// exceptions (and a non-trivial algorithm)
#endif
};
struct NameGCD { static constexpr auto name = "gcd"; };
using FunctionGCD = BinaryArithmeticOverloadResolver<GCDImpl, NameGCD, false>;
}

View File

@ -1,10 +1,6 @@
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionBinaryArithmetic.h>
#include <numeric>
#include <limits>
#include <type_traits>
#include <Functions/GCDLCMImpl.h>
namespace
{
@ -27,33 +23,21 @@ constexpr T abs(T value) noexcept
namespace DB
{
namespace ErrorCodes
{
extern const int NOT_IMPLEMENTED;
}
namespace
{
struct NameLCM { static constexpr auto name = "lcm"; };
template <typename A, typename B>
struct LCMImpl
struct LCMImpl : public GCDLCMImpl<A, B, LCMImpl<A, B>, NameLCM>
{
using ResultType = typename NumberTraits::ResultOfAdditionMultiplication<A, B>::Type;
static const constexpr bool allow_fixed_string = false;
using ResultType = typename GCDLCMImpl<A, B, LCMImpl<A, B>, NameLCM>::ResultType;
template <typename Result = ResultType>
static inline std::enable_if_t<is_big_int_v<A> || is_big_int_v<B> || is_big_int_v<Result>, Result>
apply([[maybe_unused]] A a, [[maybe_unused]] B b)
static ResultType applyImpl(A a, B b)
{
throw Exception("LCM is not implemented for big integers", ErrorCodes::NOT_IMPLEMENTED);
}
template <typename Result = ResultType>
static inline std::enable_if_t<!is_big_int_v<A> && !is_big_int_v<B> && !is_big_int_v<Result>, Result>
apply([[maybe_unused]] A a, [[maybe_unused]] B b)
{
throwIfDivisionLeadsToFPE(typename NumberTraits::ToInteger<A>::Type(a), typename NumberTraits::ToInteger<B>::Type(b));
throwIfDivisionLeadsToFPE(typename NumberTraits::ToInteger<B>::Type(b), typename NumberTraits::ToInteger<A>::Type(a));
using Int = typename NumberTraits::ToInteger<ResultType>::Type;
using Unsigned = make_unsigned_t<Int>;
/** It's tempting to use std::lcm function.
* But it has undefined behaviour on overflow.
@ -62,22 +46,14 @@ struct LCMImpl
* (example: throw an exception or overflow in implementation specific way).
*/
using Int = typename NumberTraits::ToInteger<Result>::Type;
using Unsigned = make_unsigned_t<Int>;
Unsigned val1 = abs<Int>(a) / std::gcd(Int(a), Int(b));
Unsigned val2 = abs<Int>(b);
/// Overflow in implementation specific way.
return Result(val1 * val2);
return ResultType(val1 * val2);
}
#if USE_EMBEDDED_COMPILER
static constexpr bool compilable = false; /// exceptions (and a non-trivial algorithm)
#endif
};
struct NameLCM { static constexpr auto name = "lcm"; };
using FunctionLCM = BinaryArithmeticOverloadResolver<LCMImpl, NameLCM, false>;
}

View File

@ -21,17 +21,3 @@
4611686011984936962
4611686011984936962
2147483648
256
11
64
1
2
1
1
5120
121
256
1
4
735
64770

View File

@ -23,19 +23,19 @@ select lcm(255, 254);
select lcm(2147483647, 2147483646);
select lcm(4611686011984936962, 2147483647);
select lcm(-2147483648, 1);
-- test gcd float will cast to int
select gcd(1280.1, 1024.1);
select gcd(11.1, 121.1);
select gcd(-256.1, 64.1);
select gcd(1.1, 1.1);
select gcd(4.1, 2.1);
select gcd(15.1, 49.1);
select gcd(255.1, 254.1);
-- test lcm float cast to int
select lcm(1280.1, 1024.1);
select lcm(11.1, 121.1);
select lcm(-256.1, 64.1);
select lcm(1.1, 1.1);
select lcm(4.1, 2.1);
select lcm(15.1, 49.1);
select lcm(255.1, 254.1);
-- test gcd float
select gcd(1280.1, 1024.1); -- { serverError 48 }
select gcd(11.1, 121.1); -- { serverError 48 }
select gcd(-256.1, 64.1); -- { serverError 48 }
select gcd(1.1, 1.1); -- { serverError 48 }
select gcd(4.1, 2.1); -- { serverError 48 }
select gcd(15.1, 49.1); -- { serverError 48 }
select gcd(255.1, 254.1); -- { serverError 48 }
-- test lcm float
select lcm(1280.1, 1024.1); -- { serverError 48 }
select lcm(11.1, 121.1); -- { serverError 48 }
select lcm(-256.1, 64.1); -- { serverError 48 }
select lcm(1.1, 1.1); -- { serverError 48 }
select lcm(4.1, 2.1); -- { serverError 48 }
select lcm(15.1, 49.1); -- { serverError 48 }
select lcm(255.1, 254.1); -- { serverError 48 }

View File

@ -5,4 +5,3 @@
0
0
0
0

View File

@ -6,5 +6,5 @@ SELECT lcm(-15, -10);
-- Implementation specific result on overflow:
SELECT ignore(lcm(256, 9223372036854775807));
SELECT ignore(lcm(256, -9223372036854775807));
SELECT ignore(lcm(-256, 9223372036854775807));
SELECT ignore(lcm(-256, 9223372036854775807)); -- { serverError 407 }
SELECT ignore(lcm(-256, -9223372036854775807));

View File

@ -0,0 +1,13 @@
-- { echo }
SELECT gcd(9223372036854775807, -9223372036854775808); -- { serverError 407 }
SELECT gcd(9223372036854775808, -9223372036854775807); -- { serverError 407 }
SELECT gcd(-9223372036854775808, 9223372036854775807); -- { serverError 407 }
SELECT gcd(-9223372036854775807, 9223372036854775808); -- { serverError 407 }
SELECT gcd(9223372036854775808, -1); -- { serverError 407 }
SELECT lcm(-170141183460469231731687303715884105728, -170141183460469231731687303715884105728); -- { serverError 48 }
SELECT lcm(toInt128(-170141183460469231731687303715884105728), toInt128(-170141183460469231731687303715884105728)); -- { serverError 407 }
SELECT lcm(toInt128(-170141183460469231731687303715884105720), toInt128(-170141183460469231731687303715884105720)); -- { serverError 407 }
SELECT lcm(toInt128('-170141183460469231731687303715884105720'), toInt128('-170141183460469231731687303715884105720'));
170141183460469231731687303715884105720
SELECT gcd(-9223372036854775806, -9223372036854775806);
9223372036854775806

View File

@ -0,0 +1,11 @@
-- { echo }
SELECT gcd(9223372036854775807, -9223372036854775808); -- { serverError 407 }
SELECT gcd(9223372036854775808, -9223372036854775807); -- { serverError 407 }
SELECT gcd(-9223372036854775808, 9223372036854775807); -- { serverError 407 }
SELECT gcd(-9223372036854775807, 9223372036854775808); -- { serverError 407 }
SELECT gcd(9223372036854775808, -1); -- { serverError 407 }
SELECT lcm(-170141183460469231731687303715884105728, -170141183460469231731687303715884105728); -- { serverError 48 }
SELECT lcm(toInt128(-170141183460469231731687303715884105728), toInt128(-170141183460469231731687303715884105728)); -- { serverError 407 }
SELECT lcm(toInt128(-170141183460469231731687303715884105720), toInt128(-170141183460469231731687303715884105720)); -- { serverError 407 }
SELECT lcm(toInt128('-170141183460469231731687303715884105720'), toInt128('-170141183460469231731687303715884105720'));
SELECT gcd(-9223372036854775806, -9223372036854775806);

View File

@ -0,0 +1,13 @@
-- { echo }
SELECT lcm(9223372036854775807, -9223372036854775808); -- { serverError 407 }
SELECT lcm(9223372036854775808, -9223372036854775807); -- { serverError 407 }
SELECT lcm(-9223372036854775808, 9223372036854775807); -- { serverError 407 }
SELECT lcm(-9223372036854775807, 9223372036854775808); -- { serverError 407 }
SELECT lcm(9223372036854775808, -1); -- { serverError 407 }
SELECT lcm(-170141183460469231731687303715884105728, -170141183460469231731687303715884105728); -- { serverError 48 }
SELECT lcm(toInt128(-170141183460469231731687303715884105728), toInt128(-170141183460469231731687303715884105728)); -- { serverError 407 }
SELECT lcm(toInt128(-170141183460469231731687303715884105720), toInt128(-170141183460469231731687303715884105720)); -- { serverError 407 }
SELECT lcm(toInt128('-170141183460469231731687303715884105720'), toInt128('-170141183460469231731687303715884105720'));
170141183460469231731687303715884105720
SELECT lcm(-9223372036854775806, -9223372036854775806);
9223372036854775806

View File

@ -0,0 +1,11 @@
-- { echo }
SELECT lcm(9223372036854775807, -9223372036854775808); -- { serverError 407 }
SELECT lcm(9223372036854775808, -9223372036854775807); -- { serverError 407 }
SELECT lcm(-9223372036854775808, 9223372036854775807); -- { serverError 407 }
SELECT lcm(-9223372036854775807, 9223372036854775808); -- { serverError 407 }
SELECT lcm(9223372036854775808, -1); -- { serverError 407 }
SELECT lcm(-170141183460469231731687303715884105728, -170141183460469231731687303715884105728); -- { serverError 48 }
SELECT lcm(toInt128(-170141183460469231731687303715884105728), toInt128(-170141183460469231731687303715884105728)); -- { serverError 407 }
SELECT lcm(toInt128(-170141183460469231731687303715884105720), toInt128(-170141183460469231731687303715884105720)); -- { serverError 407 }
SELECT lcm(toInt128('-170141183460469231731687303715884105720'), toInt128('-170141183460469231731687303715884105720'));
SELECT lcm(-9223372036854775806, -9223372036854775806);