diff --git a/dbms/src/Common/ErrorCodes.cpp b/dbms/src/Common/ErrorCodes.cpp index 9794e32949a..c8141b96537 100644 --- a/dbms/src/Common/ErrorCodes.cpp +++ b/dbms/src/Common/ErrorCodes.cpp @@ -379,6 +379,7 @@ namespace ErrorCodes extern const int CANNOT_IOSETUP = 402; extern const int INVALID_JOIN_ON_EXPRESSION = 403; extern const int BAD_ODBC_CONNECTION_STRING = 404; + extern const int DECIMAL_OVERFLOW = 405; extern const int KEEPER_EXCEPTION = 999; extern const int POCO_EXCEPTION = 1000; diff --git a/dbms/src/Functions/FunctionsComparison.h b/dbms/src/Functions/FunctionsComparison.h index a225b7e97ea..6cf8c856fdf 100644 --- a/dbms/src/Functions/FunctionsComparison.h +++ b/dbms/src/Functions/FunctionsComparison.h @@ -36,6 +36,12 @@ namespace DB { +namespace ErrorCodes +{ + extern const int DECIMAL_OVERFLOW; +} + + /** Comparison functions: ==, !=, <, >, <=, >=. * The comparison functions always return 0 or 1 (UInt8). * @@ -232,13 +238,19 @@ public: { if constexpr (_actual) { + ColumnPtr c_res; Shift shift = getScales(col_left.type, col_right.type); - if (ColumnPtr c_res = apply(col_left.column, col_right.column, shift)) - { + if (shift.a == 1 && shift.b == 1) + c_res = apply(col_left.column, col_right.column, 1); + else if (shift.a == 1) + c_res = apply(col_left.column, col_right.column, shift.b); + else + c_res = apply(col_left.column, col_right.column, shift.a); + + if (c_res) block.getByPosition(result).column = std::move(c_res); - return true; - } + return true; } return false; } @@ -294,7 +306,8 @@ private: return shift; } - static ColumnPtr apply(const ColumnPtr & c0, const ColumnPtr & c1, const Shift & shift) + template + static ColumnPtr apply(const ColumnPtr & c0, const ColumnPtr & c1, CompareInt scale) { using ColVecA = ColumnVector; using ColVecB = ColumnVector; @@ -313,7 +326,7 @@ private: A a = c0_const->template getValue(); B b = c1_const->template getValue(); - UInt8 res = apply(a, b, shift); + UInt8 res = apply(a, b, scale); return DataTypeUInt8().createColumnConst(c0->size(), toField(res)); } @@ -325,34 +338,34 @@ private: const ColumnConst * c0_const = checkAndGetColumnConst(c0.get()); A a = c0_const->template getValue(); if (const ColVecB * c1_vec = checkAndGetColumn(c1.get())) - constant_vector(a, c1_vec->getData(), vec_res, shift); + constant_vector(a, c1_vec->getData(), vec_res, scale); else if (const ColVecB * c1_vec = checkAndGetColumn(c1.get())) - constant_vector(a, c1_vec->getData(), vec_res, shift); + constant_vector(a, c1_vec->getData(), vec_res, scale); } else if (c1_const) { const ColumnConst * c1_const = checkAndGetColumnConst(c1.get()); B b = c1_const->template getValue(); if (const ColVecA * c0_vec = checkAndGetColumn(c0.get())) - vector_constant(c0_vec->getData(), b, vec_res, shift); + vector_constant(c0_vec->getData(), b, vec_res, scale); else if (const ColVecA * c0_vec = checkAndGetColumn(c0.get())) - vector_constant(c0_vec->getData(), b, vec_res, shift); + vector_constant(c0_vec->getData(), b, vec_res, scale); } else { if (const ColVecA * c0_vec = checkAndGetColumn(c0.get())) { if (const ColVecB * c1_vec = checkAndGetColumn(c1.get())) - vector_vector(c0_vec->getData(), c1_vec->getData(), vec_res, shift); + vector_vector(c0_vec->getData(), c1_vec->getData(), vec_res, scale); else if (const ColVecB * c1_vec = checkAndGetColumn(c1.get())) - vector_vector(c0_vec->getData(), c1_vec->getData(), vec_res, shift); + vector_vector(c0_vec->getData(), c1_vec->getData(), vec_res, scale); } else if (const ColVecA * c0_vec = checkAndGetColumn(c0.get())) { if (const ColVecB * c1_vec = checkAndGetColumn(c1.get())) - vector_vector(c0_vec->getData(), c1_vec->getData(), vec_res, shift); + vector_vector(c0_vec->getData(), c1_vec->getData(), vec_res, scale); else if (const ColVecB * c1_vec = checkAndGetColumn(c1.get())) - vector_vector(c0_vec->getData(), c1_vec->getData(), vec_res, shift); + vector_vector(c0_vec->getData(), c1_vec->getData(), vec_res, scale); } } } @@ -360,14 +373,36 @@ private: return c_res; } - /// TODO: there's a special case then sizeof(A) or sizeof(B) > sizeof(CompareInt) - static NO_INLINE UInt8 apply(A a, B b, const Shift & shift) + template + static NO_INLINE UInt8 apply(A a, B b, CompareInt scale [[maybe_unused]]) { - return Op::apply(a * shift.a, b * shift.b); + CompareInt x = a; + CompareInt y = b; + bool overflow = false; + + if constexpr (sizeof(A) > sizeof(CompareInt)) + overflow |= (A(x) != a); + if constexpr (sizeof(B) > sizeof(CompareInt)) + overflow |= (B(y) != b); + if constexpr (std::is_unsigned_v) + overflow |= (x < 0); + if constexpr (std::is_unsigned_v) + overflow |= (y < 0); + + if constexpr (scale_left) + overflow |= __builtin_mul_overflow(x, scale, &x); + if constexpr (scale_right) + overflow |= __builtin_mul_overflow(y, scale, &y); + + if (overflow) + throw Exception("Can't compare", ErrorCodes::DECIMAL_OVERFLOW); + + return Op::apply(x, y); } + template static void NO_INLINE vector_vector(const PaddedPODArray & a, const PaddedPODArray & b, PaddedPODArray & c, - const Shift & shift) + CompareInt scale) { size_t size = a.size(); const A * a_pos = &a[0]; @@ -377,14 +412,15 @@ private: while (a_pos < a_end) { - *c_pos = apply(*a_pos, *b_pos, shift); + *c_pos = apply(*a_pos, *b_pos, scale); ++a_pos; ++b_pos; ++c_pos; } } - static void NO_INLINE vector_constant(const PaddedPODArray & a, B b, PaddedPODArray & c, const Shift & shift) + template + static void NO_INLINE vector_constant(const PaddedPODArray & a, B b, PaddedPODArray & c, CompareInt scale) { size_t size = a.size(); const A * a_pos = &a[0]; @@ -393,13 +429,14 @@ private: while (a_pos < a_end) { - *c_pos = apply(*a_pos, b, shift); + *c_pos = apply(*a_pos, b, scale); ++a_pos; ++c_pos; } } - static void NO_INLINE constant_vector(A a, const PaddedPODArray & b, PaddedPODArray & c, const Shift & shift) + template + static void NO_INLINE constant_vector(A a, const PaddedPODArray & b, PaddedPODArray & c, CompareInt scale) { size_t size = b.size(); const B * b_pos = &b[0]; @@ -408,7 +445,7 @@ private: while (b_pos < b_end) { - *c_pos = apply(a, *b_pos, shift); + *c_pos = apply(a, *b_pos, scale); ++b_pos; ++c_pos; } diff --git a/dbms/tests/queries/0_stateless/00700_decimal_compare.reference b/dbms/tests/queries/0_stateless/00700_decimal_compare.reference index 8aaaa12354b..2616a0afa92 100644 --- a/dbms/tests/queries/0_stateless/00700_decimal_compare.reference +++ b/dbms/tests/queries/0_stateless/00700_decimal_compare.reference @@ -25,3 +25,11 @@ 1 0 1 0 1 0 +2147483648 0 1 +9223372036854775808 0 1 +0 1 +0 1 +0 1 +0 1 +0 1 +0 1 diff --git a/dbms/tests/queries/0_stateless/00700_decimal_compare.sql b/dbms/tests/queries/0_stateless/00700_decimal_compare.sql index 8252f85784f..3edcc86dd77 100644 --- a/dbms/tests/queries/0_stateless/00700_decimal_compare.sql +++ b/dbms/tests/queries/0_stateless/00700_decimal_compare.sql @@ -47,4 +47,19 @@ SELECT greatest(a, 0), greatest(b, 0), greatest(g, 0) FROM test.decimal ORDER BY SELECT (a, d, g) = (b, e, h), (a, d, g) != (b, e, h) FROM test.decimal ORDER BY a; SELECT (a, d, g) = (c, f, i), (a, d, g) != (c, f, i) FROM test.decimal ORDER BY a; +SELECT toUInt32(2147483648) AS x, a == x FROM test.decimal WHERE a = 42; -- { serverError 405 } +SELECT toUInt64(2147483648) AS x, b == x, x == ((b - 42) + x) FROM test.decimal WHERE a = 42; +SELECT toUInt64(9223372036854775808) AS x, b == x FROM test.decimal WHERE a = 42; -- { serverError 405 } +SELECT toUInt64(9223372036854775808) AS x, c == x, x == ((c - 42) + x) FROM test.decimal WHERE a = 42; + +SELECT g = 10000, (g - g + 10000) == 10000 FROM test.decimal WHERE a = 42; +SELECT 10000 = g, 10000 = (g - g + 10000) FROM test.decimal WHERE a = 42; +SELECT g = 30000 FROM test.decimal WHERE a = 42; -- { serverError 405 } +SELECT 30000 = g FROM test.decimal WHERE a = 42; -- { serverError 405 } +SELECT h = 30000, (h - g + 30000) = 30000 FROM test.decimal WHERE a = 42; +SELECT 30000 = h, 30000 = (h - g + 30000) FROM test.decimal WHERE a = 42; +SELECT h = 10000000000 FROM test.decimal WHERE a = 42; -- { serverError 405 } +SELECT i = 10000000000, (i - g + 10000000000) = 10000000000 FROM test.decimal WHERE a = 42; +SELECT 10000000000 = i, 10000000000 = (i - g + 10000000000) FROM test.decimal WHERE a = 42; + DROP TABLE IF EXISTS test.decimal;