mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-09-20 16:50:48 +00:00
Rewrite arrayProduct implementation
This commit is contained in:
parent
8a7599c4b1
commit
38bd455be9
@ -22,7 +22,8 @@ enum class AggregateOperation
|
||||
min,
|
||||
max,
|
||||
sum,
|
||||
average
|
||||
average,
|
||||
product
|
||||
};
|
||||
|
||||
/**
|
||||
@ -54,6 +55,12 @@ struct ArrayAggregateResultImpl<ArrayElement, AggregateOperation::average>
|
||||
using Result = Float64;
|
||||
};
|
||||
|
||||
template <typename ArrayElement>
|
||||
struct ArrayAggregateResultImpl<ArrayElement, AggregateOperation::product>
|
||||
{
|
||||
using Result = Float64;
|
||||
};
|
||||
|
||||
template <typename ArrayElement>
|
||||
struct ArrayAggregateResultImpl<ArrayElement, AggregateOperation::sum>
|
||||
{
|
||||
@ -86,7 +93,7 @@ struct ArrayAggregateImpl
|
||||
using Types = std::decay_t<decltype(types)>;
|
||||
using DataType = typename Types::LeftType;
|
||||
|
||||
if constexpr (aggregate_operation == AggregateOperation::average)
|
||||
if constexpr (aggregate_operation == AggregateOperation::average || aggregate_operation == AggregateOperation::product)
|
||||
{
|
||||
result = std::make_shared<DataTypeFloat64>();
|
||||
|
||||
@ -128,13 +135,13 @@ struct ArrayAggregateImpl
|
||||
using ColVecType = std::conditional_t<IsDecimalNumber<Element>, ColumnDecimal<Element>, ColumnVector<Element>>;
|
||||
using ColVecResult = std::conditional_t<IsDecimalNumber<Result>, ColumnDecimal<Result>, ColumnVector<Result>>;
|
||||
|
||||
/// For average of array we return Float64 as result, but we want to keep precision
|
||||
/// For average and product of array we return Float64 as result, but we want to keep precision
|
||||
/// so we convert to Float64 as last step, but intermediate sum is represented as result of sum operation
|
||||
static constexpr bool is_average_operation = aggregate_operation == AggregateOperation::average;
|
||||
static constexpr bool is_average_or_product_operation = aggregate_operation == AggregateOperation::average ||
|
||||
aggregate_operation == AggregateOperation::product;
|
||||
using SummAggregationType = ArrayAggregateResult<Element, AggregateOperation::sum>;
|
||||
|
||||
using AggregationType = std::conditional_t<is_average_operation, SummAggregationType, Result>;
|
||||
|
||||
using AggregationType = std::conditional_t<is_average_or_product_operation, SummAggregationType, Result>;
|
||||
|
||||
const ColVecType * column = checkAndGetColumn<ColVecType>(&*mapped);
|
||||
|
||||
@ -185,6 +192,22 @@ struct ArrayAggregateImpl
|
||||
res[i] = x;
|
||||
}
|
||||
}
|
||||
else if constexpr (aggregate_operation == AggregateOperation::product)
|
||||
{
|
||||
size_t array_size = offsets[i] - pos;
|
||||
AggregationType product = 1;
|
||||
for (i = 0; i < array_size; ++i)
|
||||
product = product * x;
|
||||
|
||||
if constexpr (IsDecimalNumber<Element>)
|
||||
{
|
||||
res[i] = DecimalUtils::convertTo<Result>(product, data.getScale() * array_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
res[i] = product;
|
||||
}
|
||||
}
|
||||
|
||||
pos = offsets[i];
|
||||
}
|
||||
@ -242,6 +265,10 @@ struct ArrayAggregateImpl
|
||||
s = element;
|
||||
}
|
||||
}
|
||||
else if constexpr (aggregate_operation == AggregateOperation::product)
|
||||
{
|
||||
s *= element;
|
||||
}
|
||||
|
||||
++count;
|
||||
}
|
||||
@ -258,6 +285,10 @@ struct ArrayAggregateImpl
|
||||
res[i] = static_cast<Result>(s) / count;
|
||||
}
|
||||
}
|
||||
else if constexpr (aggregate_operation == AggregateOperation::product && IsDecimalNumber<Element>)
|
||||
{
|
||||
res[i] = DecimalUtils::convertTo<Result>(s, data.getScale() * count);
|
||||
}
|
||||
else
|
||||
{
|
||||
res[i] = s;
|
||||
@ -307,12 +338,16 @@ using FunctionArraySum = FunctionArrayMapped<ArrayAggregateImpl<AggregateOperati
|
||||
struct NameArrayAverage { static constexpr auto name = "arrayAvg"; };
|
||||
using FunctionArrayAverage = FunctionArrayMapped<ArrayAggregateImpl<AggregateOperation::average>, NameArrayAverage>;
|
||||
|
||||
struct NameArrayProduct { static constexpr auto name = "arrayProduct"; };
|
||||
using FunctionArrayProduct = FunctionArrayMapped<ArrayAggregateImpl<AggregateOperation::product>, NameArrayProduct>;
|
||||
|
||||
void registerFunctionArrayAggregation(FunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction<FunctionArrayMin>();
|
||||
factory.registerFunction<FunctionArrayMax>();
|
||||
factory.registerFunction<FunctionArraySum>();
|
||||
factory.registerFunction<FunctionArrayAverage>();
|
||||
factory.registerFunction<FunctionArrayProduct>();
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,163 +0,0 @@
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnDecimal.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypesDecimal.h>
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int SIZES_OF_ARRAYS_DOESNT_MATCH;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int ILLEGAL_COLUMN;
|
||||
}
|
||||
|
||||
/// arrayProduct([1,2,3])=6
|
||||
class FunctionArrayProduct : public IFunction
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = "arrayProduct";
|
||||
static FunctionPtr create(const Context &) { return std::make_shared<FunctionArrayProduct>(); }
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return name;
|
||||
}
|
||||
|
||||
bool isVariadic() const override { return true; }
|
||||
size_t getNumberOfArguments() const override { return 1; }
|
||||
bool useDefaultImplementationForConstants() const override { return true; }
|
||||
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override;
|
||||
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
static ColumnPtr executeProductDecimal(const IColumn & col_from, const IColumn::Offsets & offsets);
|
||||
|
||||
template <typename T>
|
||||
static ColumnPtr executeProduct(const IColumn & src_data);
|
||||
};
|
||||
|
||||
DataTypePtr FunctionArrayProduct::getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const
|
||||
{
|
||||
if (arguments.size() != 1)
|
||||
throw Exception(
|
||||
"Function " + getName() + " needs one argument; passed " + toString(arguments.size()) + ".",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(arguments[0].type.get());
|
||||
if (!array_type)
|
||||
throw Exception(
|
||||
"Argument 0 of function " + getName() + " must be array. Found " + arguments[0].type->getName() + " instead.",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
const DataTypePtr & nest_type = array_type->getNestedType();
|
||||
const WhichDataType type(nest_type);
|
||||
if (!(type.isInt() || type.isUInt() || type.isFloat() || type.isDecimal()))
|
||||
throw Exception(
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"First argument for function '{}' must be an array of integer/float/decimal type, got '{}' instead",
|
||||
getName(),
|
||||
nest_type->getName());
|
||||
if (type.isDecimal())
|
||||
{
|
||||
UInt32 scale = getDecimalScale(*nest_type);
|
||||
return std::make_shared<DataTypeDecimal<Decimal128>>(DecimalUtils::max_precision<Decimal128>, scale);
|
||||
}
|
||||
|
||||
return nest_type;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ColumnPtr FunctionArrayProduct::executeProduct(const IColumn & col_from)
|
||||
{
|
||||
const ColumnVector<T> * col_from_data = checkAndGetColumn<ColumnVector<T>>(&col_from);
|
||||
if (!col_from_data)
|
||||
return nullptr;
|
||||
|
||||
size_t size = col_from_data->size();
|
||||
const PaddedPODArray<T> & values = col_from_data->getData();
|
||||
|
||||
auto col_res = ColumnVector<T>::create();
|
||||
typename ColumnVector<T>::Container & res = col_res->getData();
|
||||
if (size > 0)
|
||||
{
|
||||
res.resize(1);
|
||||
res[0] = values[0];
|
||||
for (size_t i = 1; i < size; i++)
|
||||
res[0] *= values[i];
|
||||
}
|
||||
return col_res;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ColumnPtr FunctionArrayProduct::executeProductDecimal(const IColumn & col_from, const IColumn::Offsets & offsets)
|
||||
{
|
||||
const ColumnDecimal<T> * col_from_data = checkAndGetColumn<ColumnDecimal<T>>(&col_from);
|
||||
if (!col_from_data)
|
||||
return nullptr;
|
||||
|
||||
size_t size = col_from_data->size();
|
||||
auto & values = col_from_data->getData();
|
||||
|
||||
UInt32 scale = (offsets.size() == 0) ? col_from_data->getScale() : (col_from_data->getScale() * size);
|
||||
auto col_res = ColumnDecimal<Decimal128>::create(offsets.size(), scale);
|
||||
typename ColumnDecimal<Decimal128>::Container & res = col_res->getData();
|
||||
|
||||
if (size > 0)
|
||||
{
|
||||
res.resize(1);
|
||||
res[0] = values[0];
|
||||
for (size_t i = 1; i < size; i++)
|
||||
res[0] *= values[i];
|
||||
}
|
||||
|
||||
return col_res;
|
||||
}
|
||||
|
||||
|
||||
ColumnPtr FunctionArrayProduct::executeImpl(
|
||||
const ColumnsWithTypeAndName & arguments, const DataTypePtr & /*result_type*/, size_t /*input_rows_count*/) const
|
||||
{
|
||||
const ColumnArray * column_array = checkAndGetColumn<ColumnArray>(arguments[0].column.get());
|
||||
|
||||
if (!column_array)
|
||||
throw Exception("Argument 0 of function " + getName() + " must be array", ErrorCodes::ILLEGAL_COLUMN);
|
||||
|
||||
const IColumn & src_data = column_array->getData();
|
||||
const IColumn::Offsets & offsets = column_array->getOffsets();
|
||||
|
||||
ColumnPtr res;
|
||||
if (!((res = executeProduct<UInt8>(src_data))
|
||||
|| (res = executeProduct<UInt16>(src_data))
|
||||
|| (res = executeProduct<UInt32>(src_data))
|
||||
|| (res = executeProduct<UInt64>(src_data))
|
||||
|| (res = executeProduct<Int8>(src_data))
|
||||
|| (res = executeProduct<Int16>(src_data))
|
||||
|| (res = executeProduct<Int32>(src_data))
|
||||
|| (res = executeProduct<Int64>(src_data))
|
||||
|| (res = executeProduct<Float32>(src_data))
|
||||
|| (res = executeProduct<Float64>(src_data))
|
||||
|| (res = executeProductDecimal<Decimal32>(src_data, offsets))
|
||||
|| (res = executeProductDecimal<Decimal64>(src_data, offsets))
|
||||
|| (res = executeProductDecimal<Decimal128>(src_data, offsets))))
|
||||
throw Exception(
|
||||
"Illegal column " + arguments[0].column->getName() + " of first argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
return res;
|
||||
}
|
||||
|
||||
void registerFunctionArrayProduct(FunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction<FunctionArrayProduct>();
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -37,7 +37,6 @@ void registerFunctionArrayAUC(FunctionFactory &);
|
||||
void registerFunctionArrayReduceInRanges(FunctionFactory &);
|
||||
void registerFunctionMapOp(FunctionFactory &);
|
||||
void registerFunctionMapPopulateSeries(FunctionFactory &);
|
||||
void registerFunctionArrayProduct(FunctionFactory &);
|
||||
|
||||
void registerFunctionsArray(FunctionFactory & factory)
|
||||
{
|
||||
@ -76,7 +75,6 @@ void registerFunctionsArray(FunctionFactory & factory)
|
||||
registerFunctionArrayAUC(factory);
|
||||
registerFunctionMapOp(factory);
|
||||
registerFunctionMapPopulateSeries(factory);
|
||||
registerFunctionArrayProduct(factory);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,2 +1,16 @@
|
||||
24
|
||||
27 Float64
|
||||
Array product 720
|
||||
28.799999999999997 Float64
|
||||
3.5 Float64
|
||||
20 Float64
|
||||
720
|
||||
0
|
||||
6
|
||||
15504.12105639681
|
||||
0
|
||||
6
|
||||
Types of aggregation result array product
|
||||
Float64 Float64 Float64 Float64
|
||||
Float64 Float64 Float64 Float64
|
||||
Float64 Float64 Float64
|
||||
Float64 Float64
|
||||
Float64 Float64 Float64
|
||||
|
@ -1,2 +1,22 @@
|
||||
select arrayProduct(array(1,2,3,4));
|
||||
select arrayProduct(array(1.0,2.0,3.0,4.5)) as k , toTypeName(k);
|
||||
SELECT 'Array product ', (arrayProduct(array(1,2,3,4,5,6)));
|
||||
select arrayProduct(array(1.0,2.0,3.0,4.8)) as k , toTypeName(k);
|
||||
select arrayProduct(array(1,3.5)) as k , toTypeName(k);
|
||||
SELECT arrayProduct([toDecimal32(2, 8), toDecimal32(10, 8)]) as a , toTypeName(a);
|
||||
|
||||
DROP TABLE IF EXISTS test_aggregation;
|
||||
CREATE TABLE test_aggregation (x Array(Int)) ENGINE=TinyLog;
|
||||
INSERT INTO test_aggregation VALUES ([1,2,3,4,5,6]), ([]), ([1,2,3]);
|
||||
SELECT arrayProduct(x) FROM test_aggregation;
|
||||
DROP TABLE test_aggregation;
|
||||
|
||||
CREATE TABLE test_aggregation (x Array(Decimal64(8))) ENGINE=TinyLog;
|
||||
INSERT INTO test_aggregation VALUES ([1,2,3,4,5,6]), ([]), ([1,2,3]);
|
||||
SELECT arrayProduct(x) FROM test_aggregation;
|
||||
DROP TABLE test_aggregation;
|
||||
|
||||
SELECT 'Types of aggregation result array product';
|
||||
SELECT toTypeName(arrayProduct([toInt8(0)])), toTypeName(arrayProduct([toInt16(0)])), toTypeName(arrayProduct([toInt32(0)])), toTypeName(arrayProduct([toInt64(0)]));
|
||||
SELECT toTypeName(arrayProduct([toUInt8(0)])), toTypeName(arrayProduct([toUInt16(0)])), toTypeName(arrayProduct([toUInt32(0)])), toTypeName(arrayProduct([toUInt64(0)]));
|
||||
SELECT toTypeName(arrayProduct([toInt128(0)])), toTypeName(arrayProduct([toInt256(0)])), toTypeName(arrayProduct([toUInt256(0)]));
|
||||
SELECT toTypeName(arrayProduct([toFloat32(0)])), toTypeName(arrayProduct([toFloat64(0)]));
|
||||
SELECT toTypeName(arrayProduct([toDecimal32(0, 8)])), toTypeName(arrayProduct([toDecimal64(0, 8)])), toTypeName(arrayProduct([toDecimal128(0, 8)]));
|
||||
|
Loading…
Reference in New Issue
Block a user