diff --git a/src/Functions/array/arrayAggregation.cpp b/src/Functions/array/arrayAggregation.cpp index c63215819a6..c3f3abfaa77 100644 --- a/src/Functions/array/arrayAggregation.cpp +++ b/src/Functions/array/arrayAggregation.cpp @@ -132,131 +132,47 @@ struct ArrayAggregateImpl template static NO_SANITIZE_UNDEFINED bool executeType(const ColumnPtr & mapped, const ColumnArray::Offsets & offsets, ColumnPtr & res_ptr) { - using Result = ArrayAggregateResult; + using ResultType = ArrayAggregateResult; using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; - using ColVecResult = std::conditional_t, ColumnDecimal, ColumnVector>; + using ColVecResultType = std::conditional_t, ColumnDecimal, ColumnVector>; /// For average and product of array we return Float64 as result, but we want to keep precision - /// so we convert to Float64 as last step, but intermediate sum is represented as result of sum operation + /// so we convert to Float64 as last step, but intermediate value is represented as result of sum operation static constexpr bool is_average_or_product_operation = aggregate_operation == AggregateOperation::average || aggregate_operation == AggregateOperation::product; using SummAggregationType = ArrayAggregateResult; - using AggregationType = std::conditional_t; + using AggregationType = std::conditional_t; const ColVecType * column = checkAndGetColumn(&*mapped); - /// Constant case. if (!column) - { - const ColumnConst * column_const = checkAndGetColumnConst(&*mapped); + return false; - if (!column_const) - return false; + const auto & data = column->getData(); - const AggregationType x = column_const->template getValue(); // NOLINT - const typename ColVecType::Container & data - = checkAndGetColumn(&column_const->getDataColumn())->getData(); - - typename ColVecResult::MutablePtr res_column; - if constexpr (IsDecimalNumber) - { - res_column = ColVecResult::create(offsets.size(), data.getScale()); - } - else - res_column = ColVecResult::create(offsets.size()); - - typename ColVecResult::Container & res = res_column->getData(); - - size_t pos = 0; - for (size_t i = 0; i < offsets.size(); ++i) - { - if constexpr (aggregate_operation == AggregateOperation::sum) - { - size_t array_size = offsets[i] - pos; - /// Just multiply the value by array size. - res[i] = x * array_size; - } - else if constexpr (aggregate_operation == AggregateOperation::min || - aggregate_operation == AggregateOperation::max) - { - res[i] = x; - } - else if constexpr (aggregate_operation == AggregateOperation::average) - { - if constexpr (IsDecimalNumber) - { - res[i] = DecimalUtils::convertTo(x, data.getScale()); - } - else - { - res[i] = x; - } - } - else if constexpr (aggregate_operation == AggregateOperation::product) - { - size_t array_size = offsets[i] - pos; - size_t num = array_size; - AggregationType product = x; - - if constexpr (IsDecimalNumber) - { - using T = decltype(x.value); - T x_val = x.value; - - for (i = 1; i < array_size; ++i) - { - T product_val = product.value; - while (common::mulOverflow(x_val, product_val, product.value)) - { - x_val = x_val / DecimalUtils::scaleMultiplier(data.getScale()); - if (num == 1) - throw Exception("arrayProduct for decimal type overflow", ErrorCodes::DECIMAL_OVERFLOW); - --num; - } - } - - res[i] = DecimalUtils::convertTo(product, data.getScale() * num); - } - else - { - for (i = 1; i < array_size; ++i) - product = product * x; - res[i] = product; - } - } - - pos = offsets[i]; - } - - res_ptr = std::move(res_column); - return true; - } - - const typename ColVecType::Container & data = column->getData(); - - typename ColVecResult::MutablePtr res_column; + typename ColVecResultType::MutablePtr res_column; if constexpr (IsDecimalNumber) - res_column = ColVecResult::create(offsets.size(), data.getScale()); + res_column = ColVecResultType::create(offsets.size(), data.getScale()); else - res_column = ColVecResult::create(offsets.size()); + res_column = ColVecResultType::create(offsets.size()); - typename ColVecResult::Container & res = res_column->getData(); + typename ColVecResultType::Container & res = res_column->getData(); size_t pos = 0; for (size_t i = 0; i < offsets.size(); ++i) { - AggregationType s = 0; + AggregationType aggregate_value = 0; /// Array is empty if (offsets[i] == pos) { - res[i] = s; + res[i] = aggregate_value; continue; } size_t count = 1; - s = data[pos]; // NOLINT + aggregate_value = data[pos]; // NOLINT ++pos; for (; pos < offsets[i]; ++pos) @@ -266,40 +182,36 @@ struct ArrayAggregateImpl if constexpr (aggregate_operation == AggregateOperation::sum || aggregate_operation == AggregateOperation::average) { - s += element; + aggregate_value += element; } else if constexpr (aggregate_operation == AggregateOperation::min) { - if (element < s) + if (element < aggregate_value) { - s = element; + aggregate_value = element; } } else if constexpr (aggregate_operation == AggregateOperation::max) { - if (element > s) + if (element > aggregate_value) { - s = element; + aggregate_value = element; } } else if constexpr (aggregate_operation == AggregateOperation::product) { if constexpr (IsDecimalNumber) { - using T = decltype(s.value); - T s_val = s.value; - T element_val = element.value; - while (common::mulOverflow(s_val, element_val, s.value)) - { - s_val = s_val / DecimalUtils::scaleMultiplier(data.getScale()); - if (count == 0) - throw Exception("arrayProduct for decimal type overflow", ErrorCodes::DECIMAL_OVERFLOW); - --count; - } + using AggregateValueDecimalUnderlyingValue = decltype(aggregate_value.value); + AggregateValueDecimalUnderlyingValue current_aggregate_value = aggregate_value.value; + AggregateValueDecimalUnderlyingValue element_value = static_cast(element.value); + + if (common::mulOverflow(current_aggregate_value, element_value, aggregate_value.value)) + throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Decimal math overflow"); } else { - s *= element; + aggregate_value *= element; } } @@ -310,21 +222,26 @@ struct ArrayAggregateImpl { if constexpr (IsDecimalNumber) { - s = s / count; - res[i] = DecimalUtils::convertTo(s, data.getScale()); + aggregate_value = aggregate_value / count; + res[i] = DecimalUtils::convertTo(aggregate_value, data.getScale()); } else { - res[i] = static_cast(s) / count; + res[i] = static_cast(aggregate_value) / count; } } else if constexpr (aggregate_operation == AggregateOperation::product && IsDecimalNumber) { - res[i] = DecimalUtils::convertTo(s, data.getScale() * count); + auto result_scale = data.getScale() * count; + + if (unlikely(result_scale > DecimalUtils::max_precision)) + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Scale {} is out of bounds", result_scale); + + res[i] = DecimalUtils::convertTo(aggregate_value, result_scale); } else { - res[i] = s; + res[i] = aggregate_value; } } @@ -355,7 +272,7 @@ struct ArrayAggregateImpl executeType(mapped, offsets, res)) return res; else - throw Exception("Unexpected column for arraySum: " + mapped->getName(), ErrorCodes::ILLEGAL_COLUMN); + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Unexpected column for arraySum: {}" + mapped->getName()); } }; diff --git a/tests/queries/0_stateless/01768_array_product.reference b/tests/queries/0_stateless/01768_array_product.reference index df66e2afc77..af2508e1aff 100644 --- a/tests/queries/0_stateless/01768_array_product.reference +++ b/tests/queries/0_stateless/01768_array_product.reference @@ -1,11 +1,13 @@ -Array product 720 -28.799999999999997 Float64 +Array product with constant column +720 Float64 +24 Float64 3.5 Float64 -20 Float64 -720 +6 Float64 +Array product with non constant column +24 0 6 -720 +24 0 6 Types of aggregation result array product diff --git a/tests/queries/0_stateless/01768_array_product.sql b/tests/queries/0_stateless/01768_array_product.sql index 787e89b2def..75056888ef2 100644 --- a/tests/queries/0_stateless/01768_array_product.sql +++ b/tests/queries/0_stateless/01768_array_product.sql @@ -1,16 +1,20 @@ -SELECT 'Array product ', (arrayProduct(array(1,2,3,4,5,6))); -select arrayProduct(array(1.0,2.0,3.0,4.8)) as k , toTypeName(k); -select arrayProduct(array(1,3.5)) as k , toTypeName(k); -SELECT arrayProduct([toDecimal32(2, 8), toDecimal32(10, 8)]) as a , toTypeName(a); +SELECT 'Array product with constant column'; + +SELECT arrayProduct([1,2,3,4,5,6]) as a, toTypeName(a); +SELECT arrayProduct(array(1.0,2.0,3.0,4.0)) as a, toTypeName(a); +SELECT arrayProduct(array(1,3.5)) as a, toTypeName(a); +SELECT arrayProduct([toDecimal64(1,8), toDecimal64(2,8), toDecimal64(3,8)]) as a, toTypeName(a); + +SELECT 'Array product with non constant column'; DROP TABLE IF EXISTS test_aggregation; CREATE TABLE test_aggregation (x Array(Int)) ENGINE=TinyLog; -INSERT INTO test_aggregation VALUES ([1,2,3,4,5,6]), ([]), ([1,2,3]); +INSERT INTO test_aggregation VALUES ([1,2,3,4]), ([]), ([1,2,3]); SELECT arrayProduct(x) FROM test_aggregation; DROP TABLE test_aggregation; CREATE TABLE test_aggregation (x Array(Decimal64(8))) ENGINE=TinyLog; -INSERT INTO test_aggregation VALUES ([1,2,3,4,5,6]), ([]), ([1,2,3]); +INSERT INTO test_aggregation VALUES ([1,2,3,4]), ([]), ([1,2,3]); SELECT arrayProduct(x) FROM test_aggregation; DROP TABLE test_aggregation;