decimal compare overflow

This commit is contained in:
chertus 2018-08-08 16:57:16 +03:00
parent 16ad0caf37
commit 5f93ab73fa
4 changed files with 84 additions and 23 deletions

View File

@ -379,6 +379,7 @@ namespace ErrorCodes
extern const int CANNOT_IOSETUP = 402;
extern const int INVALID_JOIN_ON_EXPRESSION = 403;
extern const int BAD_ODBC_CONNECTION_STRING = 404;
extern const int DECIMAL_OVERFLOW = 405;
extern const int KEEPER_EXCEPTION = 999;
extern const int POCO_EXCEPTION = 1000;

View File

@ -36,6 +36,12 @@
namespace DB
{
namespace ErrorCodes
{
extern const int DECIMAL_OVERFLOW;
}
/** Comparison functions: ==, !=, <, >, <=, >=.
* The comparison functions always return 0 or 1 (UInt8).
*
@ -232,13 +238,19 @@ public:
{
if constexpr (_actual)
{
ColumnPtr c_res;
Shift shift = getScales<A, B>(col_left.type, col_right.type);
if (ColumnPtr c_res = apply(col_left.column, col_right.column, shift))
{
if (shift.a == 1 && shift.b == 1)
c_res = apply<false, false>(col_left.column, col_right.column, 1);
else if (shift.a == 1)
c_res = apply<false, true>(col_left.column, col_right.column, shift.b);
else
c_res = apply<true, false>(col_left.column, col_right.column, shift.a);
if (c_res)
block.getByPosition(result).column = std::move(c_res);
return true;
}
return true;
}
return false;
}
@ -294,7 +306,8 @@ private:
return shift;
}
static ColumnPtr apply(const ColumnPtr & c0, const ColumnPtr & c1, const Shift & shift)
template <bool scale_left, bool scale_right>
static ColumnPtr apply(const ColumnPtr & c0, const ColumnPtr & c1, CompareInt scale)
{
using ColVecA = ColumnVector<A>;
using ColVecB = ColumnVector<B>;
@ -313,7 +326,7 @@ private:
A a = c0_const->template getValue<A>();
B b = c1_const->template getValue<B>();
UInt8 res = apply(a, b, shift);
UInt8 res = apply<scale_left, scale_right>(a, b, scale);
return DataTypeUInt8().createColumnConst(c0->size(), toField(res));
}
@ -325,34 +338,34 @@ private:
const ColumnConst * c0_const = checkAndGetColumnConst<ColVecA>(c0.get());
A a = c0_const->template getValue<A>();
if (const ColVecB * c1_vec = checkAndGetColumn<ColVecB>(c1.get()))
constant_vector(a, c1_vec->getData(), vec_res, shift);
constant_vector<scale_left, scale_right>(a, c1_vec->getData(), vec_res, scale);
else if (const ColVecB * c1_vec = checkAndGetColumn<ColVecB>(c1.get()))
constant_vector(a, c1_vec->getData(), vec_res, shift);
constant_vector<scale_left, scale_right>(a, c1_vec->getData(), vec_res, scale);
}
else if (c1_const)
{
const ColumnConst * c1_const = checkAndGetColumnConst<ColVecB>(c1.get());
B b = c1_const->template getValue<B>();
if (const ColVecA * c0_vec = checkAndGetColumn<ColVecA>(c0.get()))
vector_constant(c0_vec->getData(), b, vec_res, shift);
vector_constant<scale_left, scale_right>(c0_vec->getData(), b, vec_res, scale);
else if (const ColVecA * c0_vec = checkAndGetColumn<ColVecA>(c0.get()))
vector_constant(c0_vec->getData(), b, vec_res, shift);
vector_constant<scale_left, scale_right>(c0_vec->getData(), b, vec_res, scale);
}
else
{
if (const ColVecA * c0_vec = checkAndGetColumn<ColVecA>(c0.get()))
{
if (const ColVecB * c1_vec = checkAndGetColumn<ColVecB>(c1.get()))
vector_vector(c0_vec->getData(), c1_vec->getData(), vec_res, shift);
vector_vector<scale_left, scale_right>(c0_vec->getData(), c1_vec->getData(), vec_res, scale);
else if (const ColVecB * c1_vec = checkAndGetColumn<ColVecB>(c1.get()))
vector_vector(c0_vec->getData(), c1_vec->getData(), vec_res, shift);
vector_vector<scale_left, scale_right>(c0_vec->getData(), c1_vec->getData(), vec_res, scale);
}
else if (const ColVecA * c0_vec = checkAndGetColumn<ColVecA>(c0.get()))
{
if (const ColVecB * c1_vec = checkAndGetColumn<ColVecB>(c1.get()))
vector_vector(c0_vec->getData(), c1_vec->getData(), vec_res, shift);
vector_vector<scale_left, scale_right>(c0_vec->getData(), c1_vec->getData(), vec_res, scale);
else if (const ColVecB * c1_vec = checkAndGetColumn<ColVecB>(c1.get()))
vector_vector(c0_vec->getData(), c1_vec->getData(), vec_res, shift);
vector_vector<scale_left, scale_right>(c0_vec->getData(), c1_vec->getData(), vec_res, scale);
}
}
}
@ -360,14 +373,36 @@ private:
return c_res;
}
/// TODO: there's a special case then sizeof(A) or sizeof(B) > sizeof(CompareInt)
static NO_INLINE UInt8 apply(A a, B b, const Shift & shift)
template <bool scale_left, bool scale_right>
static NO_INLINE UInt8 apply(A a, B b, CompareInt scale [[maybe_unused]])
{
return Op::apply(a * shift.a, b * shift.b);
CompareInt x = a;
CompareInt y = b;
bool overflow = false;
if constexpr (sizeof(A) > sizeof(CompareInt))
overflow |= (A(x) != a);
if constexpr (sizeof(B) > sizeof(CompareInt))
overflow |= (B(y) != b);
if constexpr (std::is_unsigned_v<A>)
overflow |= (x < 0);
if constexpr (std::is_unsigned_v<B>)
overflow |= (y < 0);
if constexpr (scale_left)
overflow |= __builtin_mul_overflow(x, scale, &x);
if constexpr (scale_right)
overflow |= __builtin_mul_overflow(y, scale, &y);
if (overflow)
throw Exception("Can't compare", ErrorCodes::DECIMAL_OVERFLOW);
return Op::apply(x, y);
}
template <bool scale_left, bool scale_right>
static void NO_INLINE vector_vector(const PaddedPODArray<A> & a, const PaddedPODArray<B> & b, PaddedPODArray<UInt8> & c,
const Shift & shift)
CompareInt scale)
{
size_t size = a.size();
const A * a_pos = &a[0];
@ -377,14 +412,15 @@ private:
while (a_pos < a_end)
{
*c_pos = apply(*a_pos, *b_pos, shift);
*c_pos = apply<scale_left, scale_right>(*a_pos, *b_pos, scale);
++a_pos;
++b_pos;
++c_pos;
}
}
static void NO_INLINE vector_constant(const PaddedPODArray<A> & a, B b, PaddedPODArray<UInt8> & c, const Shift & shift)
template <bool scale_left, bool scale_right>
static void NO_INLINE vector_constant(const PaddedPODArray<A> & a, B b, PaddedPODArray<UInt8> & c, CompareInt scale)
{
size_t size = a.size();
const A * a_pos = &a[0];
@ -393,13 +429,14 @@ private:
while (a_pos < a_end)
{
*c_pos = apply(*a_pos, b, shift);
*c_pos = apply<scale_left, scale_right>(*a_pos, b, scale);
++a_pos;
++c_pos;
}
}
static void NO_INLINE constant_vector(A a, const PaddedPODArray<B> & b, PaddedPODArray<UInt8> & c, const Shift & shift)
template <bool scale_left, bool scale_right>
static void NO_INLINE constant_vector(A a, const PaddedPODArray<B> & b, PaddedPODArray<UInt8> & c, CompareInt scale)
{
size_t size = b.size();
const B * b_pos = &b[0];
@ -408,7 +445,7 @@ private:
while (b_pos < b_end)
{
*c_pos = apply(a, *b_pos, shift);
*c_pos = apply<scale_left, scale_right>(a, *b_pos, scale);
++b_pos;
++c_pos;
}

View File

@ -25,3 +25,11 @@
1 0
1 0
1 0
2147483648 0 1
9223372036854775808 0 1
0 1
0 1
0 1
0 1
0 1
0 1

View File

@ -47,4 +47,19 @@ SELECT greatest(a, 0), greatest(b, 0), greatest(g, 0) FROM test.decimal ORDER BY
SELECT (a, d, g) = (b, e, h), (a, d, g) != (b, e, h) FROM test.decimal ORDER BY a;
SELECT (a, d, g) = (c, f, i), (a, d, g) != (c, f, i) FROM test.decimal ORDER BY a;
SELECT toUInt32(2147483648) AS x, a == x FROM test.decimal WHERE a = 42; -- { serverError 405 }
SELECT toUInt64(2147483648) AS x, b == x, x == ((b - 42) + x) FROM test.decimal WHERE a = 42;
SELECT toUInt64(9223372036854775808) AS x, b == x FROM test.decimal WHERE a = 42; -- { serverError 405 }
SELECT toUInt64(9223372036854775808) AS x, c == x, x == ((c - 42) + x) FROM test.decimal WHERE a = 42;
SELECT g = 10000, (g - g + 10000) == 10000 FROM test.decimal WHERE a = 42;
SELECT 10000 = g, 10000 = (g - g + 10000) FROM test.decimal WHERE a = 42;
SELECT g = 30000 FROM test.decimal WHERE a = 42; -- { serverError 405 }
SELECT 30000 = g FROM test.decimal WHERE a = 42; -- { serverError 405 }
SELECT h = 30000, (h - g + 30000) = 30000 FROM test.decimal WHERE a = 42;
SELECT 30000 = h, 30000 = (h - g + 30000) FROM test.decimal WHERE a = 42;
SELECT h = 10000000000 FROM test.decimal WHERE a = 42; -- { serverError 405 }
SELECT i = 10000000000, (i - g + 10000000000) = 10000000000 FROM test.decimal WHERE a = 42;
SELECT 10000000000 = i, 10000000000 = (i - g + 10000000000) FROM test.decimal WHERE a = 42;
DROP TABLE IF EXISTS test.decimal;