diff --git a/src/Functions/GCDLCMImpl.h b/src/Functions/GCDLCMImpl.h new file mode 100644 index 00000000000..dffd91f8d6a --- /dev/null +++ b/src/Functions/GCDLCMImpl.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include +#include +#include + +#if !defined(ARCADIA_BUILD) +# include "config_core.h" +#endif + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int NOT_IMPLEMENTED; + extern const int DECIMAL_OVERFLOW; +} + +template +inline constexpr bool is_gcd_lcm_implemeted = !(is_big_int_v || std::is_floating_point_v); + +template +struct GCDLCMImpl +{ + using ResultType = typename NumberTraits::ResultOfAdditionMultiplication::Type; + static const constexpr bool allow_fixed_string = false; + + template + static inline std::enable_if_t, Result> + apply(A, B) + { + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not implemented for big integers and floats", Name::name); + } + + template + static inline std::enable_if_t, Result> + apply(A a, B b) + { + throwIfDivisionLeadsToFPE(typename NumberTraits::ToInteger::Type(a), typename NumberTraits::ToInteger::Type(b)); + throwIfDivisionLeadsToFPE(typename NumberTraits::ToInteger::Type(b), typename NumberTraits::ToInteger::Type(a)); + + using Int = typename NumberTraits::ToInteger::Type; + + if constexpr (is_signed_v) + { + /// gcd() internally uses std::abs() + Int a_s = static_cast(a); + Int b_s = static_cast(b); + Int min = std::numeric_limits::min(); + Int max = std::numeric_limits::max(); + if (unlikely((a_s == min || a_s == max) || (b_s == min || b_s == max))) + throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Intermediate result overflow (signed a = {}, signed b = {}, min = {}, max = {})", a_s, b_s, min, max); + } + + return Impl::applyImpl(a, b); + } + +#if USE_EMBEDDED_COMPILER + static constexpr bool compilable = false; /// exceptions (and a non-trivial algorithm) +#endif +}; + +} diff --git a/src/Functions/gcd.cpp b/src/Functions/gcd.cpp index 7c8a28c83f6..9cb53212c7f 100644 --- a/src/Functions/gcd.cpp +++ b/src/Functions/gcd.cpp @@ -1,45 +1,28 @@ #include #include -#include +#include namespace DB { -namespace ErrorCodes -{ - extern const int NOT_IMPLEMENTED; -} namespace { +struct NameGCD { static constexpr auto name = "gcd"; }; + template -struct GCDImpl +struct GCDImpl : public GCDLCMImpl, NameGCD> { - using ResultType = typename NumberTraits::ResultOfAdditionMultiplication::Type; - static const constexpr bool allow_fixed_string = false; + using ResultType = typename GCDLCMImpl::ResultType; - template - static inline Result apply([[maybe_unused]] A a, [[maybe_unused]] B b) + static ResultType applyImpl(A a, B b) { - if constexpr (is_big_int_v || is_big_int_v || is_big_int_v) - throw Exception("GCD is not implemented for big integers", ErrorCodes::NOT_IMPLEMENTED); - else - { - throwIfDivisionLeadsToFPE(typename NumberTraits::ToInteger::Type(a), typename NumberTraits::ToInteger::Type(b)); - throwIfDivisionLeadsToFPE(typename NumberTraits::ToInteger::Type(b), typename NumberTraits::ToInteger::Type(a)); - return std::gcd( - typename NumberTraits::ToInteger::Type(a), - typename NumberTraits::ToInteger::Type(b)); - } + using Int = typename NumberTraits::ToInteger::Type; + return std::gcd(Int(a), Int(b)); } - -#if USE_EMBEDDED_COMPILER - static constexpr bool compilable = false; /// exceptions (and a non-trivial algorithm) -#endif }; -struct NameGCD { static constexpr auto name = "gcd"; }; using FunctionGCD = BinaryArithmeticOverloadResolver; } diff --git a/src/Functions/lcm.cpp b/src/Functions/lcm.cpp index 81406861c52..5155a80e6cd 100644 --- a/src/Functions/lcm.cpp +++ b/src/Functions/lcm.cpp @@ -1,10 +1,6 @@ #include #include - -#include -#include -#include - +#include namespace { @@ -27,33 +23,21 @@ constexpr T abs(T value) noexcept namespace DB { -namespace ErrorCodes -{ - extern const int NOT_IMPLEMENTED; -} namespace { +struct NameLCM { static constexpr auto name = "lcm"; }; + template -struct LCMImpl +struct LCMImpl : public GCDLCMImpl, NameLCM> { - using ResultType = typename NumberTraits::ResultOfAdditionMultiplication::Type; - static const constexpr bool allow_fixed_string = false; + using ResultType = typename GCDLCMImpl, NameLCM>::ResultType; - template - static inline std::enable_if_t || is_big_int_v || is_big_int_v, Result> - apply([[maybe_unused]] A a, [[maybe_unused]] B b) + static ResultType applyImpl(A a, B b) { - throw Exception("LCM is not implemented for big integers", ErrorCodes::NOT_IMPLEMENTED); - } - - template - static inline std::enable_if_t && !is_big_int_v && !is_big_int_v, Result> - apply([[maybe_unused]] A a, [[maybe_unused]] B b) - { - throwIfDivisionLeadsToFPE(typename NumberTraits::ToInteger::Type(a), typename NumberTraits::ToInteger::Type(b)); - throwIfDivisionLeadsToFPE(typename NumberTraits::ToInteger::Type(b), typename NumberTraits::ToInteger::Type(a)); + using Int = typename NumberTraits::ToInteger::Type; + using Unsigned = make_unsigned_t; /** It's tempting to use std::lcm function. * But it has undefined behaviour on overflow. @@ -62,22 +46,14 @@ struct LCMImpl * (example: throw an exception or overflow in implementation specific way). */ - using Int = typename NumberTraits::ToInteger::Type; - using Unsigned = make_unsigned_t; - Unsigned val1 = abs(a) / std::gcd(Int(a), Int(b)); Unsigned val2 = abs(b); /// Overflow in implementation specific way. - return Result(val1 * val2); + return ResultType(val1 * val2); } - -#if USE_EMBEDDED_COMPILER - static constexpr bool compilable = false; /// exceptions (and a non-trivial algorithm) -#endif }; -struct NameLCM { static constexpr auto name = "lcm"; }; using FunctionLCM = BinaryArithmeticOverloadResolver; } diff --git a/tests/queries/0_stateless/00515_gcd_lcm.reference b/tests/queries/0_stateless/00515_gcd_lcm.reference index a24649ba97b..5289404dbf2 100644 --- a/tests/queries/0_stateless/00515_gcd_lcm.reference +++ b/tests/queries/0_stateless/00515_gcd_lcm.reference @@ -21,17 +21,3 @@ 4611686011984936962 4611686011984936962 2147483648 -256 -11 -64 -1 -2 -1 -1 -5120 -121 -256 -1 -4 -735 -64770 diff --git a/tests/queries/0_stateless/00515_gcd_lcm.sql b/tests/queries/0_stateless/00515_gcd_lcm.sql index 51da49e9c40..c3bf3275bb8 100644 --- a/tests/queries/0_stateless/00515_gcd_lcm.sql +++ b/tests/queries/0_stateless/00515_gcd_lcm.sql @@ -23,19 +23,19 @@ select lcm(255, 254); select lcm(2147483647, 2147483646); select lcm(4611686011984936962, 2147483647); select lcm(-2147483648, 1); --- test gcd float will cast to int -select gcd(1280.1, 1024.1); -select gcd(11.1, 121.1); -select gcd(-256.1, 64.1); -select gcd(1.1, 1.1); -select gcd(4.1, 2.1); -select gcd(15.1, 49.1); -select gcd(255.1, 254.1); --- test lcm float cast to int -select lcm(1280.1, 1024.1); -select lcm(11.1, 121.1); -select lcm(-256.1, 64.1); -select lcm(1.1, 1.1); -select lcm(4.1, 2.1); -select lcm(15.1, 49.1); -select lcm(255.1, 254.1); +-- test gcd float +select gcd(1280.1, 1024.1); -- { serverError 48 } +select gcd(11.1, 121.1); -- { serverError 48 } +select gcd(-256.1, 64.1); -- { serverError 48 } +select gcd(1.1, 1.1); -- { serverError 48 } +select gcd(4.1, 2.1); -- { serverError 48 } +select gcd(15.1, 49.1); -- { serverError 48 } +select gcd(255.1, 254.1); -- { serverError 48 } +-- test lcm float +select lcm(1280.1, 1024.1); -- { serverError 48 } +select lcm(11.1, 121.1); -- { serverError 48 } +select lcm(-256.1, 64.1); -- { serverError 48 } +select lcm(1.1, 1.1); -- { serverError 48 } +select lcm(4.1, 2.1); -- { serverError 48 } +select lcm(15.1, 49.1); -- { serverError 48 } +select lcm(255.1, 254.1); -- { serverError 48 } diff --git a/tests/queries/0_stateless/01435_lcm_overflow.reference b/tests/queries/0_stateless/01435_lcm_overflow.reference index eebd14705df..cb1cdf296a9 100644 --- a/tests/queries/0_stateless/01435_lcm_overflow.reference +++ b/tests/queries/0_stateless/01435_lcm_overflow.reference @@ -5,4 +5,3 @@ 0 0 0 -0 diff --git a/tests/queries/0_stateless/01435_lcm_overflow.sql b/tests/queries/0_stateless/01435_lcm_overflow.sql index f70200eb2d8..b069c0642bc 100644 --- a/tests/queries/0_stateless/01435_lcm_overflow.sql +++ b/tests/queries/0_stateless/01435_lcm_overflow.sql @@ -6,5 +6,5 @@ SELECT lcm(-15, -10); -- Implementation specific result on overflow: SELECT ignore(lcm(256, 9223372036854775807)); SELECT ignore(lcm(256, -9223372036854775807)); -SELECT ignore(lcm(-256, 9223372036854775807)); +SELECT ignore(lcm(-256, 9223372036854775807)); -- { serverError 407 } SELECT ignore(lcm(-256, -9223372036854775807)); diff --git a/tests/queries/0_stateless/01666_gcd_ubsan.reference b/tests/queries/0_stateless/01666_gcd_ubsan.reference new file mode 100644 index 00000000000..2500ef1deae --- /dev/null +++ b/tests/queries/0_stateless/01666_gcd_ubsan.reference @@ -0,0 +1,13 @@ +-- { echo } +SELECT gcd(9223372036854775807, -9223372036854775808); -- { serverError 407 } +SELECT gcd(9223372036854775808, -9223372036854775807); -- { serverError 407 } +SELECT gcd(-9223372036854775808, 9223372036854775807); -- { serverError 407 } +SELECT gcd(-9223372036854775807, 9223372036854775808); -- { serverError 407 } +SELECT gcd(9223372036854775808, -1); -- { serverError 407 } +SELECT lcm(-170141183460469231731687303715884105728, -170141183460469231731687303715884105728); -- { serverError 48 } +SELECT lcm(toInt128(-170141183460469231731687303715884105728), toInt128(-170141183460469231731687303715884105728)); -- { serverError 407 } +SELECT lcm(toInt128(-170141183460469231731687303715884105720), toInt128(-170141183460469231731687303715884105720)); -- { serverError 407 } +SELECT lcm(toInt128('-170141183460469231731687303715884105720'), toInt128('-170141183460469231731687303715884105720')); +170141183460469231731687303715884105720 +SELECT gcd(-9223372036854775806, -9223372036854775806); +9223372036854775806 diff --git a/tests/queries/0_stateless/01666_gcd_ubsan.sql b/tests/queries/0_stateless/01666_gcd_ubsan.sql new file mode 100644 index 00000000000..bde2b624cc0 --- /dev/null +++ b/tests/queries/0_stateless/01666_gcd_ubsan.sql @@ -0,0 +1,11 @@ +-- { echo } +SELECT gcd(9223372036854775807, -9223372036854775808); -- { serverError 407 } +SELECT gcd(9223372036854775808, -9223372036854775807); -- { serverError 407 } +SELECT gcd(-9223372036854775808, 9223372036854775807); -- { serverError 407 } +SELECT gcd(-9223372036854775807, 9223372036854775808); -- { serverError 407 } +SELECT gcd(9223372036854775808, -1); -- { serverError 407 } +SELECT lcm(-170141183460469231731687303715884105728, -170141183460469231731687303715884105728); -- { serverError 48 } +SELECT lcm(toInt128(-170141183460469231731687303715884105728), toInt128(-170141183460469231731687303715884105728)); -- { serverError 407 } +SELECT lcm(toInt128(-170141183460469231731687303715884105720), toInt128(-170141183460469231731687303715884105720)); -- { serverError 407 } +SELECT lcm(toInt128('-170141183460469231731687303715884105720'), toInt128('-170141183460469231731687303715884105720')); +SELECT gcd(-9223372036854775806, -9223372036854775806); diff --git a/tests/queries/0_stateless/01666_lcm_ubsan.reference b/tests/queries/0_stateless/01666_lcm_ubsan.reference new file mode 100644 index 00000000000..ed9a6aed42b --- /dev/null +++ b/tests/queries/0_stateless/01666_lcm_ubsan.reference @@ -0,0 +1,13 @@ +-- { echo } +SELECT lcm(9223372036854775807, -9223372036854775808); -- { serverError 407 } +SELECT lcm(9223372036854775808, -9223372036854775807); -- { serverError 407 } +SELECT lcm(-9223372036854775808, 9223372036854775807); -- { serverError 407 } +SELECT lcm(-9223372036854775807, 9223372036854775808); -- { serverError 407 } +SELECT lcm(9223372036854775808, -1); -- { serverError 407 } +SELECT lcm(-170141183460469231731687303715884105728, -170141183460469231731687303715884105728); -- { serverError 48 } +SELECT lcm(toInt128(-170141183460469231731687303715884105728), toInt128(-170141183460469231731687303715884105728)); -- { serverError 407 } +SELECT lcm(toInt128(-170141183460469231731687303715884105720), toInt128(-170141183460469231731687303715884105720)); -- { serverError 407 } +SELECT lcm(toInt128('-170141183460469231731687303715884105720'), toInt128('-170141183460469231731687303715884105720')); +170141183460469231731687303715884105720 +SELECT lcm(-9223372036854775806, -9223372036854775806); +9223372036854775806 diff --git a/tests/queries/0_stateless/01666_lcm_ubsan.sql b/tests/queries/0_stateless/01666_lcm_ubsan.sql new file mode 100644 index 00000000000..5cc3546e941 --- /dev/null +++ b/tests/queries/0_stateless/01666_lcm_ubsan.sql @@ -0,0 +1,11 @@ +-- { echo } +SELECT lcm(9223372036854775807, -9223372036854775808); -- { serverError 407 } +SELECT lcm(9223372036854775808, -9223372036854775807); -- { serverError 407 } +SELECT lcm(-9223372036854775808, 9223372036854775807); -- { serverError 407 } +SELECT lcm(-9223372036854775807, 9223372036854775808); -- { serverError 407 } +SELECT lcm(9223372036854775808, -1); -- { serverError 407 } +SELECT lcm(-170141183460469231731687303715884105728, -170141183460469231731687303715884105728); -- { serverError 48 } +SELECT lcm(toInt128(-170141183460469231731687303715884105728), toInt128(-170141183460469231731687303715884105728)); -- { serverError 407 } +SELECT lcm(toInt128(-170141183460469231731687303715884105720), toInt128(-170141183460469231731687303715884105720)); -- { serverError 407 } +SELECT lcm(toInt128('-170141183460469231731687303715884105720'), toInt128('-170141183460469231731687303715884105720')); +SELECT lcm(-9223372036854775806, -9223372036854775806);