Faster 256-bit multiplication (#15418)

This commit is contained in:
Artem Zuikov 2020-09-29 20:52:34 +03:00 committed by GitHub
parent 86cfc6f914
commit 4fd1db73a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 102 additions and 34 deletions

View File

@ -436,46 +436,94 @@ private:
}
template <typename T>
constexpr static auto multiply(const integer<Bits, Signed> & lhs, const T & rhs)
constexpr static integer<Bits, Signed>
multiply(const integer<Bits, Signed> & lhs, const T & rhs)
{
integer<Bits, Signed> res{};
#if 1
integer<Bits, Signed> lhs2 = plus(lhs, shift_left(lhs, 1));
integer<Bits, Signed> lhs3 = plus(lhs2, shift_left(lhs, 2));
#endif
for (unsigned i = 0; i < item_count; ++i)
if constexpr (Bits == 256 && sizeof(base_type) == 8)
{
base_type rhs_item = get_item(rhs, i);
unsigned pos = i * base_bits;
/// @sa https://github.com/abseil/abseil-cpp/blob/master/absl/numeric/int128.h
using HalfType = unsigned __int128;
while (rhs_item)
HalfType a01 = (HalfType(lhs.items[little(1)]) << 64) + lhs.items[little(0)];
HalfType a23 = (HalfType(lhs.items[little(3)]) << 64) + lhs.items[little(2)];
HalfType a0 = lhs.items[little(0)];
HalfType a1 = lhs.items[little(1)];
HalfType b01 = rhs;
uint64_t b0 = b01;
uint64_t b1 = 0;
HalfType b23 = 0;
if constexpr (sizeof(T) > 8)
b1 = b01 >> 64;
if constexpr (sizeof(T) > 16)
b23 = (HalfType(rhs.items[little(3)]) << 64) + rhs.items[little(2)];
HalfType r23 = a23 * b01 + a01 * b23 + a1 * b1;
HalfType r01 = a0 * b0;
HalfType r12 = (r01 >> 64) + (r23 << 64);
HalfType r12_x = a1 * b0;
integer<Bits, Signed> res;
res.items[little(0)] = r01;
res.items[little(3)] = r23 >> 64;
if constexpr (sizeof(T) > 8)
{
#if 1 /// optimization
if ((rhs_item & 0x7) == 0x7)
{
res = plus(res, shift_left(lhs3, pos));
rhs_item >>= 3;
pos += 3;
continue;
}
if ((rhs_item & 0x3) == 0x3)
{
res = plus(res, shift_left(lhs2, pos));
rhs_item >>= 2;
pos += 2;
continue;
}
#endif
if (rhs_item & 1)
res = plus(res, shift_left(lhs, pos));
rhs_item >>= 1;
++pos;
HalfType r12_y = a0 * b1;
r12_x += r12_y;
if (r12_x < r12_y)
++res.items[little(3)];
}
}
return res;
r12 += r12_x;
if (r12 < r12_x)
++res.items[little(3)];
res.items[little(1)] = r12;
res.items[little(2)] = r12 >> 64;
return res;
}
else
{
integer<Bits, Signed> res{};
#if 1
integer<Bits, Signed> lhs2 = plus(lhs, shift_left(lhs, 1));
integer<Bits, Signed> lhs3 = plus(lhs2, shift_left(lhs, 2));
#endif
for (unsigned i = 0; i < item_count; ++i)
{
base_type rhs_item = get_item(rhs, i);
unsigned pos = i * base_bits;
while (rhs_item)
{
#if 1 /// optimization
if ((rhs_item & 0x7) == 0x7)
{
res = plus(res, shift_left(lhs3, pos));
rhs_item >>= 3;
pos += 3;
continue;
}
if ((rhs_item & 0x3) == 0x3)
{
res = plus(res, shift_left(lhs2, pos));
rhs_item >>= 2;
pos += 2;
continue;
}
#endif
if (rhs_item & 1)
res = plus(res, shift_left(lhs, pos));
rhs_item >>= 1;
++pos;
}
}
return res;
}
}
public:

View File

@ -0,0 +1,15 @@
<test>
<settings>
<max_memory_usage>10G</max_memory_usage>
</settings>
<query>SELECT toInt128(number) + number FROM numbers_mt(1000000000) FORMAT Null</query>
<query>SELECT toInt128(number) - number FROM numbers_mt(1000000000) FORMAT Null</query>
<query>SELECT toInt128(number) * number FROM numbers_mt(1000000000) FORMAT Null</query>
<query>SELECT toInt128(number) / number FROM numbers_mt(1000000000) FORMAT Null</query>
<query>SELECT toInt256(number) + number FROM numbers_mt(100000000) FORMAT Null</query>
<query>SELECT toInt256(number) - number FROM numbers_mt(100000000) FORMAT Null</query>
<query>SELECT toInt256(number) * number FROM numbers_mt(100000000) FORMAT Null</query>
<query>SELECT intDiv(toInt256(number + 1), number + 1) FROM numbers_mt(100000000) FORMAT Null</query>
</test>

View File

@ -0,0 +1 @@
0

View File

@ -0,0 +1,4 @@
select count() from
(
select toInt128(number) * number x, toInt256(number) * number y from numbers_mt(100000000) where x != y
);