Rewrite arrayProduct implementation

This commit is contained in:
hexiaoting 2021-04-06 17:41:54 +08:00
parent 8a7599c4b1
commit 38bd455be9
5 changed files with 79 additions and 175 deletions

View File

@ -22,7 +22,8 @@ enum class AggregateOperation
min,
max,
sum,
average
average,
product
};
/**
@ -54,6 +55,12 @@ struct ArrayAggregateResultImpl<ArrayElement, AggregateOperation::average>
using Result = Float64;
};
template <typename ArrayElement>
struct ArrayAggregateResultImpl<ArrayElement, AggregateOperation::product>
{
using Result = Float64;
};
template <typename ArrayElement>
struct ArrayAggregateResultImpl<ArrayElement, AggregateOperation::sum>
{
@ -86,7 +93,7 @@ struct ArrayAggregateImpl
using Types = std::decay_t<decltype(types)>;
using DataType = typename Types::LeftType;
if constexpr (aggregate_operation == AggregateOperation::average)
if constexpr (aggregate_operation == AggregateOperation::average || aggregate_operation == AggregateOperation::product)
{
result = std::make_shared<DataTypeFloat64>();
@ -128,13 +135,13 @@ struct ArrayAggregateImpl
using ColVecType = std::conditional_t<IsDecimalNumber<Element>, ColumnDecimal<Element>, ColumnVector<Element>>;
using ColVecResult = std::conditional_t<IsDecimalNumber<Result>, ColumnDecimal<Result>, ColumnVector<Result>>;
/// For average of array we return Float64 as result, but we want to keep precision
/// For average and product of array we return Float64 as result, but we want to keep precision
/// so we convert to Float64 as last step, but intermediate sum is represented as result of sum operation
static constexpr bool is_average_operation = aggregate_operation == AggregateOperation::average;
static constexpr bool is_average_or_product_operation = aggregate_operation == AggregateOperation::average ||
aggregate_operation == AggregateOperation::product;
using SummAggregationType = ArrayAggregateResult<Element, AggregateOperation::sum>;
using AggregationType = std::conditional_t<is_average_operation, SummAggregationType, Result>;
using AggregationType = std::conditional_t<is_average_or_product_operation, SummAggregationType, Result>;
const ColVecType * column = checkAndGetColumn<ColVecType>(&*mapped);
@ -185,6 +192,22 @@ struct ArrayAggregateImpl
res[i] = x;
}
}
else if constexpr (aggregate_operation == AggregateOperation::product)
{
size_t array_size = offsets[i] - pos;
AggregationType product = 1;
for (i = 0; i < array_size; ++i)
product = product * x;
if constexpr (IsDecimalNumber<Element>)
{
res[i] = DecimalUtils::convertTo<Result>(product, data.getScale() * array_size);
}
else
{
res[i] = product;
}
}
pos = offsets[i];
}
@ -242,6 +265,10 @@ struct ArrayAggregateImpl
s = element;
}
}
else if constexpr (aggregate_operation == AggregateOperation::product)
{
s *= element;
}
++count;
}
@ -258,6 +285,10 @@ struct ArrayAggregateImpl
res[i] = static_cast<Result>(s) / count;
}
}
else if constexpr (aggregate_operation == AggregateOperation::product && IsDecimalNumber<Element>)
{
res[i] = DecimalUtils::convertTo<Result>(s, data.getScale() * count);
}
else
{
res[i] = s;
@ -307,12 +338,16 @@ using FunctionArraySum = FunctionArrayMapped<ArrayAggregateImpl<AggregateOperati
struct NameArrayAverage { static constexpr auto name = "arrayAvg"; };
using FunctionArrayAverage = FunctionArrayMapped<ArrayAggregateImpl<AggregateOperation::average>, NameArrayAverage>;
struct NameArrayProduct { static constexpr auto name = "arrayProduct"; };
using FunctionArrayProduct = FunctionArrayMapped<ArrayAggregateImpl<AggregateOperation::product>, NameArrayProduct>;
void registerFunctionArrayAggregation(FunctionFactory & factory)
{
factory.registerFunction<FunctionArrayMin>();
factory.registerFunction<FunctionArrayMax>();
factory.registerFunction<FunctionArraySum>();
factory.registerFunction<FunctionArrayAverage>();
factory.registerFunction<FunctionArrayProduct>();
}
}

View File

@ -1,163 +0,0 @@
#include <Columns/ColumnTuple.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnDecimal.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <IO/WriteHelpers.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int SIZES_OF_ARRAYS_DOESNT_MATCH;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ILLEGAL_COLUMN;
}
/// arrayProduct([1,2,3])=6
class FunctionArrayProduct : public IFunction
{
public:
static constexpr auto name = "arrayProduct";
static FunctionPtr create(const Context &) { return std::make_shared<FunctionArrayProduct>(); }
String getName() const override
{
return name;
}
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 1; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override;
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const override;
private:
template <typename T>
static ColumnPtr executeProductDecimal(const IColumn & col_from, const IColumn::Offsets & offsets);
template <typename T>
static ColumnPtr executeProduct(const IColumn & src_data);
};
DataTypePtr FunctionArrayProduct::getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const
{
if (arguments.size() != 1)
throw Exception(
"Function " + getName() + " needs one argument; passed " + toString(arguments.size()) + ".",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(arguments[0].type.get());
if (!array_type)
throw Exception(
"Argument 0 of function " + getName() + " must be array. Found " + arguments[0].type->getName() + " instead.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const DataTypePtr & nest_type = array_type->getNestedType();
const WhichDataType type(nest_type);
if (!(type.isInt() || type.isUInt() || type.isFloat() || type.isDecimal()))
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"First argument for function '{}' must be an array of integer/float/decimal type, got '{}' instead",
getName(),
nest_type->getName());
if (type.isDecimal())
{
UInt32 scale = getDecimalScale(*nest_type);
return std::make_shared<DataTypeDecimal<Decimal128>>(DecimalUtils::max_precision<Decimal128>, scale);
}
return nest_type;
}
template <typename T>
ColumnPtr FunctionArrayProduct::executeProduct(const IColumn & col_from)
{
const ColumnVector<T> * col_from_data = checkAndGetColumn<ColumnVector<T>>(&col_from);
if (!col_from_data)
return nullptr;
size_t size = col_from_data->size();
const PaddedPODArray<T> & values = col_from_data->getData();
auto col_res = ColumnVector<T>::create();
typename ColumnVector<T>::Container & res = col_res->getData();
if (size > 0)
{
res.resize(1);
res[0] = values[0];
for (size_t i = 1; i < size; i++)
res[0] *= values[i];
}
return col_res;
}
template <typename T>
ColumnPtr FunctionArrayProduct::executeProductDecimal(const IColumn & col_from, const IColumn::Offsets & offsets)
{
const ColumnDecimal<T> * col_from_data = checkAndGetColumn<ColumnDecimal<T>>(&col_from);
if (!col_from_data)
return nullptr;
size_t size = col_from_data->size();
auto & values = col_from_data->getData();
UInt32 scale = (offsets.size() == 0) ? col_from_data->getScale() : (col_from_data->getScale() * size);
auto col_res = ColumnDecimal<Decimal128>::create(offsets.size(), scale);
typename ColumnDecimal<Decimal128>::Container & res = col_res->getData();
if (size > 0)
{
res.resize(1);
res[0] = values[0];
for (size_t i = 1; i < size; i++)
res[0] *= values[i];
}
return col_res;
}
ColumnPtr FunctionArrayProduct::executeImpl(
const ColumnsWithTypeAndName & arguments, const DataTypePtr & /*result_type*/, size_t /*input_rows_count*/) const
{
const ColumnArray * column_array = checkAndGetColumn<ColumnArray>(arguments[0].column.get());
if (!column_array)
throw Exception("Argument 0 of function " + getName() + " must be array", ErrorCodes::ILLEGAL_COLUMN);
const IColumn & src_data = column_array->getData();
const IColumn::Offsets & offsets = column_array->getOffsets();
ColumnPtr res;
if (!((res = executeProduct<UInt8>(src_data))
|| (res = executeProduct<UInt16>(src_data))
|| (res = executeProduct<UInt32>(src_data))
|| (res = executeProduct<UInt64>(src_data))
|| (res = executeProduct<Int8>(src_data))
|| (res = executeProduct<Int16>(src_data))
|| (res = executeProduct<Int32>(src_data))
|| (res = executeProduct<Int64>(src_data))
|| (res = executeProduct<Float32>(src_data))
|| (res = executeProduct<Float64>(src_data))
|| (res = executeProductDecimal<Decimal32>(src_data, offsets))
|| (res = executeProductDecimal<Decimal64>(src_data, offsets))
|| (res = executeProductDecimal<Decimal128>(src_data, offsets))))
throw Exception(
"Illegal column " + arguments[0].column->getName() + " of first argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN);
return res;
}
void registerFunctionArrayProduct(FunctionFactory & factory)
{
factory.registerFunction<FunctionArrayProduct>();
}
}

View File

@ -37,7 +37,6 @@ void registerFunctionArrayAUC(FunctionFactory &);
void registerFunctionArrayReduceInRanges(FunctionFactory &);
void registerFunctionMapOp(FunctionFactory &);
void registerFunctionMapPopulateSeries(FunctionFactory &);
void registerFunctionArrayProduct(FunctionFactory &);
void registerFunctionsArray(FunctionFactory & factory)
{
@ -76,7 +75,6 @@ void registerFunctionsArray(FunctionFactory & factory)
registerFunctionArrayAUC(factory);
registerFunctionMapOp(factory);
registerFunctionMapPopulateSeries(factory);
registerFunctionArrayProduct(factory);
}
}

View File

@ -1,2 +1,16 @@
24
27 Float64
Array product 720
28.799999999999997 Float64
3.5 Float64
20 Float64
720
0
6
15504.12105639681
0
6
Types of aggregation result array product
Float64 Float64 Float64 Float64
Float64 Float64 Float64 Float64
Float64 Float64 Float64
Float64 Float64
Float64 Float64 Float64

View File

@ -1,2 +1,22 @@
select arrayProduct(array(1,2,3,4));
select arrayProduct(array(1.0,2.0,3.0,4.5)) as k , toTypeName(k);
SELECT 'Array product ', (arrayProduct(array(1,2,3,4,5,6)));
select arrayProduct(array(1.0,2.0,3.0,4.8)) as k , toTypeName(k);
select arrayProduct(array(1,3.5)) as k , toTypeName(k);
SELECT arrayProduct([toDecimal32(2, 8), toDecimal32(10, 8)]) as a , toTypeName(a);
DROP TABLE IF EXISTS test_aggregation;
CREATE TABLE test_aggregation (x Array(Int)) ENGINE=TinyLog;
INSERT INTO test_aggregation VALUES ([1,2,3,4,5,6]), ([]), ([1,2,3]);
SELECT arrayProduct(x) FROM test_aggregation;
DROP TABLE test_aggregation;
CREATE TABLE test_aggregation (x Array(Decimal64(8))) ENGINE=TinyLog;
INSERT INTO test_aggregation VALUES ([1,2,3,4,5,6]), ([]), ([1,2,3]);
SELECT arrayProduct(x) FROM test_aggregation;
DROP TABLE test_aggregation;
SELECT 'Types of aggregation result array product';
SELECT toTypeName(arrayProduct([toInt8(0)])), toTypeName(arrayProduct([toInt16(0)])), toTypeName(arrayProduct([toInt32(0)])), toTypeName(arrayProduct([toInt64(0)]));
SELECT toTypeName(arrayProduct([toUInt8(0)])), toTypeName(arrayProduct([toUInt16(0)])), toTypeName(arrayProduct([toUInt32(0)])), toTypeName(arrayProduct([toUInt64(0)]));
SELECT toTypeName(arrayProduct([toInt128(0)])), toTypeName(arrayProduct([toInt256(0)])), toTypeName(arrayProduct([toUInt256(0)]));
SELECT toTypeName(arrayProduct([toFloat32(0)])), toTypeName(arrayProduct([toFloat64(0)]));
SELECT toTypeName(arrayProduct([toDecimal32(0, 8)])), toTypeName(arrayProduct([toDecimal64(0, 8)])), toTypeName(arrayProduct([toDecimal128(0, 8)]));