another attempt to fix UB in wide int lib

This commit is contained in:
myrrc 2020-11-10 19:35:14 +03:00
parent 9d3788f264
commit c80e05f18e

View File

@ -7,6 +7,7 @@
#include "throwError.h"
#include <cfloat>
#include <limits>
#include <cassert>
namespace wide
{
@ -229,35 +230,19 @@ struct integer<Bits, Signed>::_impl
constexpr static void wide_integer_from_bultin(integer<Bits, Signed> & self, double rhs) noexcept
{
constexpr uint64_t max_uint = std::numeric_limits<uint64_t>::max();
constexpr int64_t max_int = std::numeric_limits<int64_t>::max();
constexpr size_t max_sizet = std::numeric_limits<size_t>::max();
constexpr long double max_int_long_double = static_cast<long double>(max_int);
constexpr int64_t min_int = std::numeric_limits<int64_t>::min();
if ((rhs > 0 && rhs < max_uint) ||
(rhs < 0 && rhs > std::numeric_limits<int64_t>::min()))
constexpr long double max_int_long_double = static_cast<long double>(max_int);
constexpr long double min_int_long_double = static_cast<long double>(min_int);
if ((rhs > 0 && rhs < max_int) ||
(rhs < 0 && rhs > min_int))
{
self = to_Integral(rhs);
return;
}
long double r = rhs;
if (r < 0)
r = -r;
const long double div = r / max_int;
size_t count = max_sizet;
/// r / max_uint may not fit in size_t
if (div <= static_cast<long double>(max_sizet))
count = div;
self = count;
self *= max_uint;
long double to_diff = count;
to_diff *= max_uint;
/// There are values in int64 that have more than 53 significant bits (in terms of double
/// representation). Such values, being promoted to double, are rounded up or down. If they are rounded up,
/// the result may not fit in 64 bits.
@ -269,18 +254,32 @@ struct integer<Bits, Signed>::_impl
"On your system long double has less than 64 precision bits,"
"which may result in UB when initializing double from int64_t");
if (long double diff = r - to_diff; diff > max_int_long_double)
{
uint64_t diff_multiplier = max_uint;
/// Always >= 0
const long double rhs_long_double = (static_cast<long double>(rhs) < 0)
? -static_cast<long double>(rhs)
: rhs;
if (const long double multiplier = diff / max_int_long_double; multiplier < max_uint)
diff_multiplier = multiplier;
const long double rhs_max_int_count = rhs_long_double / max_int;
self += max_int_long_double * diff_multiplier;
self += static_cast<int64_t>(diff - max_int_long_double * diff_multiplier);
}
else
self += static_cast<int64_t>(diff);
// Won't fit only if long double can hold values >= 2^(64 * 3).
const uint64_t rhs_max_int_count_max_int_count = rhs_max_int_count / max_int;
long double rhs_max_int_count_acc = rhs_max_int_count;
self = 0;
for (uint64_t i = 0; i < rhs_max_int_count_max_int_count; ++i)
self += max_int;
self *= max_int;
const long double rhs_div_max_int = rhs_max_int_count * max_int;
const long double rhs_mod_max_int = rhs_long_double - rhs_div_max_int;
assert(rhs_mod_max_int < max_int_long_double);
assert(rhs_mod_max_int > min_int_long_double);
self += static_cast<int64_t>(rhs_mod_max_int);
if (rhs < 0)
self = -self;