mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 23:21:59 +00:00
Merge pull request #23782 from ClickHouse/merging-array-product-function
Merging array product function
This commit is contained in:
commit
fa1e9de7f7
@ -15,6 +15,8 @@ namespace ErrorCodes
|
|||||||
{
|
{
|
||||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||||
extern const int ILLEGAL_COLUMN;
|
extern const int ILLEGAL_COLUMN;
|
||||||
|
extern const int DECIMAL_OVERFLOW;
|
||||||
|
extern const int ARGUMENT_OUT_OF_BOUND;
|
||||||
}
|
}
|
||||||
|
|
||||||
enum class AggregateOperation
|
enum class AggregateOperation
|
||||||
@ -22,7 +24,8 @@ enum class AggregateOperation
|
|||||||
min,
|
min,
|
||||||
max,
|
max,
|
||||||
sum,
|
sum,
|
||||||
average
|
average,
|
||||||
|
product
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -54,6 +57,12 @@ struct ArrayAggregateResultImpl<ArrayElement, AggregateOperation::average>
|
|||||||
using Result = Float64;
|
using Result = Float64;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename ArrayElement>
|
||||||
|
struct ArrayAggregateResultImpl<ArrayElement, AggregateOperation::product>
|
||||||
|
{
|
||||||
|
using Result = Float64;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename ArrayElement>
|
template <typename ArrayElement>
|
||||||
struct ArrayAggregateResultImpl<ArrayElement, AggregateOperation::sum>
|
struct ArrayAggregateResultImpl<ArrayElement, AggregateOperation::sum>
|
||||||
{
|
{
|
||||||
@ -86,7 +95,7 @@ struct ArrayAggregateImpl
|
|||||||
using Types = std::decay_t<decltype(types)>;
|
using Types = std::decay_t<decltype(types)>;
|
||||||
using DataType = typename Types::LeftType;
|
using DataType = typename Types::LeftType;
|
||||||
|
|
||||||
if constexpr (aggregate_operation == AggregateOperation::average)
|
if constexpr (aggregate_operation == AggregateOperation::average || aggregate_operation == AggregateOperation::product)
|
||||||
{
|
{
|
||||||
result = std::make_shared<DataTypeFloat64>();
|
result = std::make_shared<DataTypeFloat64>();
|
||||||
|
|
||||||
@ -124,17 +133,17 @@ struct ArrayAggregateImpl
|
|||||||
template <typename Element>
|
template <typename Element>
|
||||||
static NO_SANITIZE_UNDEFINED bool executeType(const ColumnPtr & mapped, const ColumnArray::Offsets & offsets, ColumnPtr & res_ptr)
|
static NO_SANITIZE_UNDEFINED bool executeType(const ColumnPtr & mapped, const ColumnArray::Offsets & offsets, ColumnPtr & res_ptr)
|
||||||
{
|
{
|
||||||
using Result = ArrayAggregateResult<Element, aggregate_operation>;
|
using ResultType = ArrayAggregateResult<Element, aggregate_operation>;
|
||||||
using ColVecType = std::conditional_t<IsDecimalNumber<Element>, ColumnDecimal<Element>, ColumnVector<Element>>;
|
using ColVecType = std::conditional_t<IsDecimalNumber<Element>, ColumnDecimal<Element>, ColumnVector<Element>>;
|
||||||
using ColVecResult = std::conditional_t<IsDecimalNumber<Result>, ColumnDecimal<Result>, ColumnVector<Result>>;
|
using ColVecResultType = std::conditional_t<IsDecimalNumber<ResultType>, ColumnDecimal<ResultType>, ColumnVector<ResultType>>;
|
||||||
|
|
||||||
/// For average of array we return Float64 as result, but we want to keep precision
|
/// For average and product of array we return Float64 as result, but we want to keep precision
|
||||||
/// so we convert to Float64 as last step, but intermediate sum is represented as result of sum operation
|
/// so we convert to Float64 as last step, but intermediate value is represented as result of sum operation
|
||||||
static constexpr bool is_average_operation = aggregate_operation == AggregateOperation::average;
|
static constexpr bool is_average_or_product_operation = aggregate_operation == AggregateOperation::average ||
|
||||||
|
aggregate_operation == AggregateOperation::product;
|
||||||
using SummAggregationType = ArrayAggregateResult<Element, AggregateOperation::sum>;
|
using SummAggregationType = ArrayAggregateResult<Element, AggregateOperation::sum>;
|
||||||
|
|
||||||
using AggregationType = std::conditional_t<is_average_operation, SummAggregationType, Result>;
|
using AggregationType = std::conditional_t<is_average_or_product_operation, SummAggregationType, ResultType>;
|
||||||
|
|
||||||
|
|
||||||
const ColVecType * column = checkAndGetColumn<ColVecType>(&*mapped);
|
const ColVecType * column = checkAndGetColumn<ColVecType>(&*mapped);
|
||||||
|
|
||||||
@ -147,18 +156,15 @@ struct ArrayAggregateImpl
|
|||||||
return false;
|
return false;
|
||||||
|
|
||||||
const AggregationType x = column_const->template getValue<Element>(); // NOLINT
|
const AggregationType x = column_const->template getValue<Element>(); // NOLINT
|
||||||
const typename ColVecType::Container & data
|
const auto & data = checkAndGetColumn<ColVecType>(&column_const->getDataColumn())->getData();
|
||||||
= checkAndGetColumn<ColVecType>(&column_const->getDataColumn())->getData();
|
|
||||||
|
|
||||||
typename ColVecResult::MutablePtr res_column;
|
typename ColVecResultType::MutablePtr res_column;
|
||||||
if constexpr (IsDecimalNumber<Element>)
|
if constexpr (IsDecimalNumber<Element>)
|
||||||
{
|
res_column = ColVecResultType::create(offsets.size(), data.getScale());
|
||||||
res_column = ColVecResult::create(offsets.size(), data.getScale());
|
|
||||||
}
|
|
||||||
else
|
else
|
||||||
res_column = ColVecResult::create(offsets.size());
|
res_column = ColVecResultType::create(offsets.size());
|
||||||
|
|
||||||
typename ColVecResult::Container & res = res_column->getData();
|
auto & res = res_column->getData();
|
||||||
|
|
||||||
size_t pos = 0;
|
size_t pos = 0;
|
||||||
for (size_t i = 0; i < offsets.size(); ++i)
|
for (size_t i = 0; i < offsets.size(); ++i)
|
||||||
@ -178,13 +184,45 @@ struct ArrayAggregateImpl
|
|||||||
{
|
{
|
||||||
if constexpr (IsDecimalNumber<Element>)
|
if constexpr (IsDecimalNumber<Element>)
|
||||||
{
|
{
|
||||||
res[i] = DecimalUtils::convertTo<Result>(x, data.getScale());
|
res[i] = DecimalUtils::convertTo<ResultType>(x, data.getScale());
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
res[i] = x;
|
res[i] = x;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
else if constexpr (aggregate_operation == AggregateOperation::product)
|
||||||
|
{
|
||||||
|
size_t array_size = offsets[i] - pos;
|
||||||
|
AggregationType product = x;
|
||||||
|
|
||||||
|
if constexpr (IsDecimalNumber<Element>)
|
||||||
|
{
|
||||||
|
using T = decltype(x.value);
|
||||||
|
T x_val = x.value;
|
||||||
|
|
||||||
|
for (size_t array_index = 1; array_index < array_size; ++array_index)
|
||||||
|
{
|
||||||
|
T product_val = product.value;
|
||||||
|
|
||||||
|
if (common::mulOverflow(x_val, product_val, product.value))
|
||||||
|
throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Decimal math overflow");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result_scale = data.getScale() * array_size;
|
||||||
|
if (unlikely(result_scale > DecimalUtils::max_precision<AggregationType>))
|
||||||
|
throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Scale {} is out of bounds", result_scale);
|
||||||
|
|
||||||
|
res[i] = DecimalUtils::convertTo<ResultType>(product, data.getScale() * array_size);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for (size_t array_index = 1; array_index < array_size; ++array_index)
|
||||||
|
product = product * x;
|
||||||
|
|
||||||
|
res[i] = product;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pos = offsets[i];
|
pos = offsets[i];
|
||||||
}
|
}
|
||||||
@ -193,30 +231,30 @@ struct ArrayAggregateImpl
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
const typename ColVecType::Container & data = column->getData();
|
const auto & data = column->getData();
|
||||||
|
|
||||||
typename ColVecResult::MutablePtr res_column;
|
typename ColVecResultType::MutablePtr res_column;
|
||||||
if constexpr (IsDecimalNumber<Element>)
|
if constexpr (IsDecimalNumber<Element>)
|
||||||
res_column = ColVecResult::create(offsets.size(), data.getScale());
|
res_column = ColVecResultType::create(offsets.size(), data.getScale());
|
||||||
else
|
else
|
||||||
res_column = ColVecResult::create(offsets.size());
|
res_column = ColVecResultType::create(offsets.size());
|
||||||
|
|
||||||
typename ColVecResult::Container & res = res_column->getData();
|
typename ColVecResultType::Container & res = res_column->getData();
|
||||||
|
|
||||||
size_t pos = 0;
|
size_t pos = 0;
|
||||||
for (size_t i = 0; i < offsets.size(); ++i)
|
for (size_t i = 0; i < offsets.size(); ++i)
|
||||||
{
|
{
|
||||||
AggregationType s = 0;
|
AggregationType aggregate_value = 0;
|
||||||
|
|
||||||
/// Array is empty
|
/// Array is empty
|
||||||
if (offsets[i] == pos)
|
if (offsets[i] == pos)
|
||||||
{
|
{
|
||||||
res[i] = s;
|
res[i] = aggregate_value;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t count = 1;
|
size_t count = 1;
|
||||||
s = data[pos]; // NOLINT
|
aggregate_value = data[pos]; // NOLINT
|
||||||
++pos;
|
++pos;
|
||||||
|
|
||||||
for (; pos < offsets[i]; ++pos)
|
for (; pos < offsets[i]; ++pos)
|
||||||
@ -226,20 +264,36 @@ struct ArrayAggregateImpl
|
|||||||
if constexpr (aggregate_operation == AggregateOperation::sum ||
|
if constexpr (aggregate_operation == AggregateOperation::sum ||
|
||||||
aggregate_operation == AggregateOperation::average)
|
aggregate_operation == AggregateOperation::average)
|
||||||
{
|
{
|
||||||
s += element;
|
aggregate_value += element;
|
||||||
}
|
}
|
||||||
else if constexpr (aggregate_operation == AggregateOperation::min)
|
else if constexpr (aggregate_operation == AggregateOperation::min)
|
||||||
{
|
{
|
||||||
if (element < s)
|
if (element < aggregate_value)
|
||||||
{
|
{
|
||||||
s = element;
|
aggregate_value = element;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if constexpr (aggregate_operation == AggregateOperation::max)
|
else if constexpr (aggregate_operation == AggregateOperation::max)
|
||||||
{
|
{
|
||||||
if (element > s)
|
if (element > aggregate_value)
|
||||||
{
|
{
|
||||||
s = element;
|
aggregate_value = element;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if constexpr (aggregate_operation == AggregateOperation::product)
|
||||||
|
{
|
||||||
|
if constexpr (IsDecimalNumber<Element>)
|
||||||
|
{
|
||||||
|
using AggregateValueDecimalUnderlyingValue = decltype(aggregate_value.value);
|
||||||
|
AggregateValueDecimalUnderlyingValue current_aggregate_value = aggregate_value.value;
|
||||||
|
AggregateValueDecimalUnderlyingValue element_value = static_cast<AggregateValueDecimalUnderlyingValue>(element.value);
|
||||||
|
|
||||||
|
if (common::mulOverflow(current_aggregate_value, element_value, aggregate_value.value))
|
||||||
|
throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Decimal math overflow");
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
aggregate_value *= element;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -250,17 +304,26 @@ struct ArrayAggregateImpl
|
|||||||
{
|
{
|
||||||
if constexpr (IsDecimalNumber<Element>)
|
if constexpr (IsDecimalNumber<Element>)
|
||||||
{
|
{
|
||||||
s = s / count;
|
aggregate_value = aggregate_value / count;
|
||||||
res[i] = DecimalUtils::convertTo<Result>(s, data.getScale());
|
res[i] = DecimalUtils::convertTo<ResultType>(aggregate_value, data.getScale());
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
res[i] = static_cast<Result>(s) / count;
|
res[i] = static_cast<ResultType>(aggregate_value) / count;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
else if constexpr (aggregate_operation == AggregateOperation::product && IsDecimalNumber<Element>)
|
||||||
|
{
|
||||||
|
auto result_scale = data.getScale() * count;
|
||||||
|
|
||||||
|
if (unlikely(result_scale > DecimalUtils::max_precision<AggregationType>))
|
||||||
|
throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Scale {} is out of bounds", result_scale);
|
||||||
|
|
||||||
|
res[i] = DecimalUtils::convertTo<ResultType>(aggregate_value, result_scale);
|
||||||
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
res[i] = s;
|
res[i] = aggregate_value;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -291,7 +354,7 @@ struct ArrayAggregateImpl
|
|||||||
executeType<Decimal128>(mapped, offsets, res))
|
executeType<Decimal128>(mapped, offsets, res))
|
||||||
return res;
|
return res;
|
||||||
else
|
else
|
||||||
throw Exception("Unexpected column for arraySum: " + mapped->getName(), ErrorCodes::ILLEGAL_COLUMN);
|
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Unexpected column for arraySum: {}" + mapped->getName());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -307,12 +370,16 @@ using FunctionArraySum = FunctionArrayMapped<ArrayAggregateImpl<AggregateOperati
|
|||||||
struct NameArrayAverage { static constexpr auto name = "arrayAvg"; };
|
struct NameArrayAverage { static constexpr auto name = "arrayAvg"; };
|
||||||
using FunctionArrayAverage = FunctionArrayMapped<ArrayAggregateImpl<AggregateOperation::average>, NameArrayAverage>;
|
using FunctionArrayAverage = FunctionArrayMapped<ArrayAggregateImpl<AggregateOperation::average>, NameArrayAverage>;
|
||||||
|
|
||||||
|
struct NameArrayProduct { static constexpr auto name = "arrayProduct"; };
|
||||||
|
using FunctionArrayProduct = FunctionArrayMapped<ArrayAggregateImpl<AggregateOperation::product>, NameArrayProduct>;
|
||||||
|
|
||||||
void registerFunctionArrayAggregation(FunctionFactory & factory)
|
void registerFunctionArrayAggregation(FunctionFactory & factory)
|
||||||
{
|
{
|
||||||
factory.registerFunction<FunctionArrayMin>();
|
factory.registerFunction<FunctionArrayMin>();
|
||||||
factory.registerFunction<FunctionArrayMax>();
|
factory.registerFunction<FunctionArrayMax>();
|
||||||
factory.registerFunction<FunctionArraySum>();
|
factory.registerFunction<FunctionArraySum>();
|
||||||
factory.registerFunction<FunctionArrayAverage>();
|
factory.registerFunction<FunctionArrayAverage>();
|
||||||
|
factory.registerFunction<FunctionArrayProduct>();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
18
tests/queries/0_stateless/01768_array_product.reference
Normal file
18
tests/queries/0_stateless/01768_array_product.reference
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
Array product with constant column
|
||||||
|
720 Float64
|
||||||
|
24 Float64
|
||||||
|
3.5 Float64
|
||||||
|
6 Float64
|
||||||
|
Array product with non constant column
|
||||||
|
24
|
||||||
|
0
|
||||||
|
6
|
||||||
|
24
|
||||||
|
0
|
||||||
|
6
|
||||||
|
Types of aggregation result array product
|
||||||
|
Float64 Float64 Float64 Float64
|
||||||
|
Float64 Float64 Float64 Float64
|
||||||
|
Float64 Float64 Float64
|
||||||
|
Float64 Float64
|
||||||
|
Float64 Float64 Float64
|
26
tests/queries/0_stateless/01768_array_product.sql
Normal file
26
tests/queries/0_stateless/01768_array_product.sql
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
SELECT 'Array product with constant column';
|
||||||
|
|
||||||
|
SELECT arrayProduct([1,2,3,4,5,6]) as a, toTypeName(a);
|
||||||
|
SELECT arrayProduct(array(1.0,2.0,3.0,4.0)) as a, toTypeName(a);
|
||||||
|
SELECT arrayProduct(array(1,3.5)) as a, toTypeName(a);
|
||||||
|
SELECT arrayProduct([toDecimal64(1,8), toDecimal64(2,8), toDecimal64(3,8)]) as a, toTypeName(a);
|
||||||
|
|
||||||
|
SELECT 'Array product with non constant column';
|
||||||
|
|
||||||
|
DROP TABLE IF EXISTS test_aggregation;
|
||||||
|
CREATE TABLE test_aggregation (x Array(Int)) ENGINE=TinyLog;
|
||||||
|
INSERT INTO test_aggregation VALUES ([1,2,3,4]), ([]), ([1,2,3]);
|
||||||
|
SELECT arrayProduct(x) FROM test_aggregation;
|
||||||
|
DROP TABLE test_aggregation;
|
||||||
|
|
||||||
|
CREATE TABLE test_aggregation (x Array(Decimal64(8))) ENGINE=TinyLog;
|
||||||
|
INSERT INTO test_aggregation VALUES ([1,2,3,4]), ([]), ([1,2,3]);
|
||||||
|
SELECT arrayProduct(x) FROM test_aggregation;
|
||||||
|
DROP TABLE test_aggregation;
|
||||||
|
|
||||||
|
SELECT 'Types of aggregation result array product';
|
||||||
|
SELECT toTypeName(arrayProduct([toInt8(0)])), toTypeName(arrayProduct([toInt16(0)])), toTypeName(arrayProduct([toInt32(0)])), toTypeName(arrayProduct([toInt64(0)]));
|
||||||
|
SELECT toTypeName(arrayProduct([toUInt8(0)])), toTypeName(arrayProduct([toUInt16(0)])), toTypeName(arrayProduct([toUInt32(0)])), toTypeName(arrayProduct([toUInt64(0)]));
|
||||||
|
SELECT toTypeName(arrayProduct([toInt128(0)])), toTypeName(arrayProduct([toInt256(0)])), toTypeName(arrayProduct([toUInt256(0)]));
|
||||||
|
SELECT toTypeName(arrayProduct([toFloat32(0)])), toTypeName(arrayProduct([toFloat64(0)]));
|
||||||
|
SELECT toTypeName(arrayProduct([toDecimal32(0, 8)])), toTypeName(arrayProduct([toDecimal64(0, 8)])), toTypeName(arrayProduct([toDecimal128(0, 8)]));
|
Loading…
Reference in New Issue
Block a user