mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-12-03 13:02:00 +00:00
Removed const path for arrayAggregation
This commit is contained in:
parent
aa71b4a6df
commit
44b966af5a
@ -132,131 +132,47 @@ struct ArrayAggregateImpl
|
||||
template <typename Element>
|
||||
static NO_SANITIZE_UNDEFINED bool executeType(const ColumnPtr & mapped, const ColumnArray::Offsets & offsets, ColumnPtr & res_ptr)
|
||||
{
|
||||
using Result = ArrayAggregateResult<Element, aggregate_operation>;
|
||||
using ResultType = ArrayAggregateResult<Element, aggregate_operation>;
|
||||
using ColVecType = std::conditional_t<IsDecimalNumber<Element>, ColumnDecimal<Element>, ColumnVector<Element>>;
|
||||
using ColVecResult = std::conditional_t<IsDecimalNumber<Result>, ColumnDecimal<Result>, ColumnVector<Result>>;
|
||||
using ColVecResultType = std::conditional_t<IsDecimalNumber<ResultType>, ColumnDecimal<ResultType>, ColumnVector<ResultType>>;
|
||||
|
||||
/// For average and product of array we return Float64 as result, but we want to keep precision
|
||||
/// so we convert to Float64 as last step, but intermediate sum is represented as result of sum operation
|
||||
/// so we convert to Float64 as last step, but intermediate value is represented as result of sum operation
|
||||
static constexpr bool is_average_or_product_operation = aggregate_operation == AggregateOperation::average ||
|
||||
aggregate_operation == AggregateOperation::product;
|
||||
using SummAggregationType = ArrayAggregateResult<Element, AggregateOperation::sum>;
|
||||
|
||||
using AggregationType = std::conditional_t<is_average_or_product_operation, SummAggregationType, Result>;
|
||||
using AggregationType = std::conditional_t<is_average_or_product_operation, SummAggregationType, ResultType>;
|
||||
|
||||
const ColVecType * column = checkAndGetColumn<ColVecType>(&*mapped);
|
||||
|
||||
/// Constant case.
|
||||
if (!column)
|
||||
{
|
||||
const ColumnConst * column_const = checkAndGetColumnConst<ColVecType>(&*mapped);
|
||||
return false;
|
||||
|
||||
if (!column_const)
|
||||
return false;
|
||||
const auto & data = column->getData();
|
||||
|
||||
const AggregationType x = column_const->template getValue<Element>(); // NOLINT
|
||||
const typename ColVecType::Container & data
|
||||
= checkAndGetColumn<ColVecType>(&column_const->getDataColumn())->getData();
|
||||
|
||||
typename ColVecResult::MutablePtr res_column;
|
||||
if constexpr (IsDecimalNumber<Element>)
|
||||
{
|
||||
res_column = ColVecResult::create(offsets.size(), data.getScale());
|
||||
}
|
||||
else
|
||||
res_column = ColVecResult::create(offsets.size());
|
||||
|
||||
typename ColVecResult::Container & res = res_column->getData();
|
||||
|
||||
size_t pos = 0;
|
||||
for (size_t i = 0; i < offsets.size(); ++i)
|
||||
{
|
||||
if constexpr (aggregate_operation == AggregateOperation::sum)
|
||||
{
|
||||
size_t array_size = offsets[i] - pos;
|
||||
/// Just multiply the value by array size.
|
||||
res[i] = x * array_size;
|
||||
}
|
||||
else if constexpr (aggregate_operation == AggregateOperation::min ||
|
||||
aggregate_operation == AggregateOperation::max)
|
||||
{
|
||||
res[i] = x;
|
||||
}
|
||||
else if constexpr (aggregate_operation == AggregateOperation::average)
|
||||
{
|
||||
if constexpr (IsDecimalNumber<Element>)
|
||||
{
|
||||
res[i] = DecimalUtils::convertTo<Result>(x, data.getScale());
|
||||
}
|
||||
else
|
||||
{
|
||||
res[i] = x;
|
||||
}
|
||||
}
|
||||
else if constexpr (aggregate_operation == AggregateOperation::product)
|
||||
{
|
||||
size_t array_size = offsets[i] - pos;
|
||||
size_t num = array_size;
|
||||
AggregationType product = x;
|
||||
|
||||
if constexpr (IsDecimalNumber<Element>)
|
||||
{
|
||||
using T = decltype(x.value);
|
||||
T x_val = x.value;
|
||||
|
||||
for (i = 1; i < array_size; ++i)
|
||||
{
|
||||
T product_val = product.value;
|
||||
while (common::mulOverflow(x_val, product_val, product.value))
|
||||
{
|
||||
x_val = x_val / DecimalUtils::scaleMultiplier<T>(data.getScale());
|
||||
if (num == 1)
|
||||
throw Exception("arrayProduct for decimal type overflow", ErrorCodes::DECIMAL_OVERFLOW);
|
||||
--num;
|
||||
}
|
||||
}
|
||||
|
||||
res[i] = DecimalUtils::convertTo<Result>(product, data.getScale() * num);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (i = 1; i < array_size; ++i)
|
||||
product = product * x;
|
||||
res[i] = product;
|
||||
}
|
||||
}
|
||||
|
||||
pos = offsets[i];
|
||||
}
|
||||
|
||||
res_ptr = std::move(res_column);
|
||||
return true;
|
||||
}
|
||||
|
||||
const typename ColVecType::Container & data = column->getData();
|
||||
|
||||
typename ColVecResult::MutablePtr res_column;
|
||||
typename ColVecResultType::MutablePtr res_column;
|
||||
if constexpr (IsDecimalNumber<Element>)
|
||||
res_column = ColVecResult::create(offsets.size(), data.getScale());
|
||||
res_column = ColVecResultType::create(offsets.size(), data.getScale());
|
||||
else
|
||||
res_column = ColVecResult::create(offsets.size());
|
||||
res_column = ColVecResultType::create(offsets.size());
|
||||
|
||||
typename ColVecResult::Container & res = res_column->getData();
|
||||
typename ColVecResultType::Container & res = res_column->getData();
|
||||
|
||||
size_t pos = 0;
|
||||
for (size_t i = 0; i < offsets.size(); ++i)
|
||||
{
|
||||
AggregationType s = 0;
|
||||
AggregationType aggregate_value = 0;
|
||||
|
||||
/// Array is empty
|
||||
if (offsets[i] == pos)
|
||||
{
|
||||
res[i] = s;
|
||||
res[i] = aggregate_value;
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t count = 1;
|
||||
s = data[pos]; // NOLINT
|
||||
aggregate_value = data[pos]; // NOLINT
|
||||
++pos;
|
||||
|
||||
for (; pos < offsets[i]; ++pos)
|
||||
@ -266,40 +182,36 @@ struct ArrayAggregateImpl
|
||||
if constexpr (aggregate_operation == AggregateOperation::sum ||
|
||||
aggregate_operation == AggregateOperation::average)
|
||||
{
|
||||
s += element;
|
||||
aggregate_value += element;
|
||||
}
|
||||
else if constexpr (aggregate_operation == AggregateOperation::min)
|
||||
{
|
||||
if (element < s)
|
||||
if (element < aggregate_value)
|
||||
{
|
||||
s = element;
|
||||
aggregate_value = element;
|
||||
}
|
||||
}
|
||||
else if constexpr (aggregate_operation == AggregateOperation::max)
|
||||
{
|
||||
if (element > s)
|
||||
if (element > aggregate_value)
|
||||
{
|
||||
s = element;
|
||||
aggregate_value = element;
|
||||
}
|
||||
}
|
||||
else if constexpr (aggregate_operation == AggregateOperation::product)
|
||||
{
|
||||
if constexpr (IsDecimalNumber<Element>)
|
||||
{
|
||||
using T = decltype(s.value);
|
||||
T s_val = s.value;
|
||||
T element_val = element.value;
|
||||
while (common::mulOverflow(s_val, element_val, s.value))
|
||||
{
|
||||
s_val = s_val / DecimalUtils::scaleMultiplier<T>(data.getScale());
|
||||
if (count == 0)
|
||||
throw Exception("arrayProduct for decimal type overflow", ErrorCodes::DECIMAL_OVERFLOW);
|
||||
--count;
|
||||
}
|
||||
using AggregateValueDecimalUnderlyingValue = decltype(aggregate_value.value);
|
||||
AggregateValueDecimalUnderlyingValue current_aggregate_value = aggregate_value.value;
|
||||
AggregateValueDecimalUnderlyingValue element_value = static_cast<AggregateValueDecimalUnderlyingValue>(element.value);
|
||||
|
||||
if (common::mulOverflow(current_aggregate_value, element_value, aggregate_value.value))
|
||||
throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Decimal math overflow");
|
||||
}
|
||||
else
|
||||
{
|
||||
s *= element;
|
||||
aggregate_value *= element;
|
||||
}
|
||||
}
|
||||
|
||||
@ -310,21 +222,26 @@ struct ArrayAggregateImpl
|
||||
{
|
||||
if constexpr (IsDecimalNumber<Element>)
|
||||
{
|
||||
s = s / count;
|
||||
res[i] = DecimalUtils::convertTo<Result>(s, data.getScale());
|
||||
aggregate_value = aggregate_value / count;
|
||||
res[i] = DecimalUtils::convertTo<ResultType>(aggregate_value, data.getScale());
|
||||
}
|
||||
else
|
||||
{
|
||||
res[i] = static_cast<Result>(s) / count;
|
||||
res[i] = static_cast<ResultType>(aggregate_value) / count;
|
||||
}
|
||||
}
|
||||
else if constexpr (aggregate_operation == AggregateOperation::product && IsDecimalNumber<Element>)
|
||||
{
|
||||
res[i] = DecimalUtils::convertTo<Result>(s, data.getScale() * count);
|
||||
auto result_scale = data.getScale() * count;
|
||||
|
||||
if (unlikely(result_scale > DecimalUtils::max_precision<AggregationType>))
|
||||
throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Scale {} is out of bounds", result_scale);
|
||||
|
||||
res[i] = DecimalUtils::convertTo<ResultType>(aggregate_value, result_scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
res[i] = s;
|
||||
res[i] = aggregate_value;
|
||||
}
|
||||
}
|
||||
|
||||
@ -355,7 +272,7 @@ struct ArrayAggregateImpl
|
||||
executeType<Decimal128>(mapped, offsets, res))
|
||||
return res;
|
||||
else
|
||||
throw Exception("Unexpected column for arraySum: " + mapped->getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Unexpected column for arraySum: {}" + mapped->getName());
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1,11 +1,13 @@
|
||||
Array product 720
|
||||
28.799999999999997 Float64
|
||||
Array product with constant column
|
||||
720 Float64
|
||||
24 Float64
|
||||
3.5 Float64
|
||||
20 Float64
|
||||
720
|
||||
6 Float64
|
||||
Array product with non constant column
|
||||
24
|
||||
0
|
||||
6
|
||||
720
|
||||
24
|
||||
0
|
||||
6
|
||||
Types of aggregation result array product
|
||||
|
@ -1,16 +1,20 @@
|
||||
SELECT 'Array product ', (arrayProduct(array(1,2,3,4,5,6)));
|
||||
select arrayProduct(array(1.0,2.0,3.0,4.8)) as k , toTypeName(k);
|
||||
select arrayProduct(array(1,3.5)) as k , toTypeName(k);
|
||||
SELECT arrayProduct([toDecimal32(2, 8), toDecimal32(10, 8)]) as a , toTypeName(a);
|
||||
SELECT 'Array product with constant column';
|
||||
|
||||
SELECT arrayProduct([1,2,3,4,5,6]) as a, toTypeName(a);
|
||||
SELECT arrayProduct(array(1.0,2.0,3.0,4.0)) as a, toTypeName(a);
|
||||
SELECT arrayProduct(array(1,3.5)) as a, toTypeName(a);
|
||||
SELECT arrayProduct([toDecimal64(1,8), toDecimal64(2,8), toDecimal64(3,8)]) as a, toTypeName(a);
|
||||
|
||||
SELECT 'Array product with non constant column';
|
||||
|
||||
DROP TABLE IF EXISTS test_aggregation;
|
||||
CREATE TABLE test_aggregation (x Array(Int)) ENGINE=TinyLog;
|
||||
INSERT INTO test_aggregation VALUES ([1,2,3,4,5,6]), ([]), ([1,2,3]);
|
||||
INSERT INTO test_aggregation VALUES ([1,2,3,4]), ([]), ([1,2,3]);
|
||||
SELECT arrayProduct(x) FROM test_aggregation;
|
||||
DROP TABLE test_aggregation;
|
||||
|
||||
CREATE TABLE test_aggregation (x Array(Decimal64(8))) ENGINE=TinyLog;
|
||||
INSERT INTO test_aggregation VALUES ([1,2,3,4,5,6]), ([]), ([1,2,3]);
|
||||
INSERT INTO test_aggregation VALUES ([1,2,3,4]), ([]), ([1,2,3]);
|
||||
SELECT arrayProduct(x) FROM test_aggregation;
|
||||
DROP TABLE test_aggregation;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user