Use tuples in arrayReduceInRanges

This commit is contained in:
hcz 2020-03-11 17:10:39 +08:00
parent 294f4af165
commit b634228947
4 changed files with 82 additions and 62 deletions

View File

@ -2,8 +2,10 @@
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeTuple.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnTuple.h>
#include <Columns/ColumnAggregateFunction.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionState.h>
@ -64,9 +66,9 @@ DataTypePtr FunctionArrayReduceInRanges::getReturnTypeImpl(const ColumnsWithType
/// The first argument is a constant string with the name of the aggregate function
/// (possibly with parameters in parentheses, for example: "quantile(0.99)").
if (arguments.size() < 4)
if (arguments.size() < 3)
throw Exception("Number of arguments for function " + getName() + " doesn't match: passed "
+ toString(arguments.size()) + ", should be at least 4.",
+ toString(arguments.size()) + ", should be at least 3.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
const ColumnConst * aggregate_function_name_column = checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
@ -74,25 +76,30 @@ DataTypePtr FunctionArrayReduceInRanges::getReturnTypeImpl(const ColumnsWithType
throw Exception("First argument for function " + getName() + " must be constant string: name of aggregate function.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const DataTypeArray * indices_type = checkAndGetDataType<DataTypeArray>(arguments[1].type.get());
if (!indices_type || !WhichDataType(*indices_type->getNestedType()).isNativeUInt())
throw Exception("Second argument for function " + getName() + " must be array of ints.",
const DataTypeArray * ranges_type_array = checkAndGetDataType<DataTypeArray>(arguments[1].type.get());
if (!ranges_type_array)
throw Exception("Second argument for function " + getName() + " must be an array of ranges.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const DataTypeTuple * ranges_type_tuple = checkAndGetDataType<DataTypeTuple>(ranges_type_array->getNestedType().get());
if (!ranges_type_tuple || ranges_type_tuple->getElements().size() != 2)
throw Exception("Each array element in the second argument for function " + getName() + " must be a tuple (index, length).",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!isNativeInteger(ranges_type_tuple->getElements()[0]))
throw Exception("First tuple member in the second argument for function " + getName() + " must be ints or uints.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!WhichDataType(ranges_type_tuple->getElements()[1]).isNativeUInt())
throw Exception("Second tuple member in the second argument for function " + getName() + " must be uints.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const DataTypeArray * lengths_type = checkAndGetDataType<DataTypeArray>(arguments[1].type.get());
if (!lengths_type || !WhichDataType(*lengths_type->getNestedType()).isNativeUInt())
throw Exception("Third argument for function " + getName() + " must be array of ints.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
DataTypes argument_types(arguments.size() - 3);
for (size_t i = 3, size = arguments.size(); i < size; ++i)
DataTypes argument_types(arguments.size() - 2);
for (size_t i = 2, size = arguments.size(); i < size; ++i)
{
const DataTypeArray * arg = checkAndGetDataType<DataTypeArray>(arguments[i].type.get());
if (!arg)
throw Exception("Argument " + toString(i) + " for function " + getName() + " must be an array but it has type "
+ arguments[i].type->getName() + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
argument_types[i - 3] = arg->getNestedType();
argument_types[i - 2] = arg->getNestedType();
}
if (!aggregate_function)
@ -123,59 +130,40 @@ void FunctionArrayReduceInRanges::executeImpl(Block & block, const ColumnNumbers
/// Aggregate functions do not support constant columns. Therefore, we materialize them.
std::vector<ColumnPtr> materialized_columns;
/// Handling indices
/// Handling ranges
const IColumn * indices_col = block.getByPosition(arguments[1]).column.get();
const IColumn * indices_data = nullptr;
const ColumnArray::Offsets * indices_offsets = nullptr;
if (const ColumnArray * arr = checkAndGetColumn<ColumnArray>(indices_col))
const IColumn * ranges_col_array = block.getByPosition(arguments[1]).column.get();
const IColumn * ranges_col_tuple = nullptr;
const ColumnArray::Offsets * ranges_offsets = nullptr;
if (const ColumnArray * arr = checkAndGetColumn<ColumnArray>(ranges_col_array))
{
indices_data = &arr->getData();
indices_offsets = &arr->getOffsets();
ranges_col_tuple = &arr->getData();
ranges_offsets = &arr->getOffsets();
}
else if (const ColumnConst * const_arr = checkAndGetColumnConst<ColumnArray>(indices_col))
else if (const ColumnConst * const_arr = checkAndGetColumnConst<ColumnArray>(ranges_col_array))
{
materialized_columns.emplace_back(const_arr->convertToFullColumn());
const auto & materialized_arr = typeid_cast<const ColumnArray &>(*materialized_columns.back());
indices_data = &materialized_arr.getData();
indices_offsets = &materialized_arr.getOffsets();
ranges_col_tuple = &materialized_arr.getData();
ranges_offsets = &materialized_arr.getOffsets();
}
else
throw Exception("Illegal column " + ranges_col_array->getName() + " as argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN);
/// Handling lengths
const IColumn * lengths_col = block.getByPosition(arguments[2]).column.get();
const IColumn * lengths_data = nullptr;
const ColumnArray::Offsets * lengths_offsets = nullptr;
if (const ColumnArray * arr = checkAndGetColumn<ColumnArray>(lengths_col))
{
lengths_data = &arr->getData();
lengths_offsets = &arr->getOffsets();
}
else if (const ColumnConst * const_arr = checkAndGetColumnConst<ColumnArray>(lengths_col))
{
materialized_columns.emplace_back(const_arr->convertToFullColumn());
const auto & materialized_arr = typeid_cast<const ColumnArray &>(*materialized_columns.back());
lengths_data = &materialized_arr.getData();
lengths_offsets = &materialized_arr.getOffsets();
}
if (*indices_offsets != *lengths_offsets)
throw Exception("Lengths of `indices` and `lengths` passed to " + getName() + " must be equal.",
ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH);
const IColumn & indices_col = static_cast<const ColumnTuple *>(ranges_col_tuple)->getColumn(0);
const IColumn & lengths_col = static_cast<const ColumnTuple *>(ranges_col_tuple)->getColumn(1);
/// Handling arguments
/// The code is mostly copied from `arrayReduce`. Maybe create a utility header?
const size_t num_arguments_columns = arguments.size() - 3;
const size_t num_arguments_columns = arguments.size() - 2;
std::vector<const IColumn *> aggregate_arguments_vec(num_arguments_columns);
const ColumnArray::Offsets * offsets = nullptr;
for (size_t i = 0; i < num_arguments_columns; ++i)
{
const IColumn * col = block.getByPosition(arguments[i + 3]).column.get();
const IColumn * col = block.getByPosition(arguments[i + 2]).column.get();
const ColumnArray::Offsets * offsets_i = nullptr;
if (const ColumnArray * arr = checkAndGetColumn<ColumnArray>(col))
@ -207,7 +195,7 @@ void FunctionArrayReduceInRanges::executeImpl(Block & block, const ColumnNumbers
ColumnArray * result_arr = static_cast<ColumnArray *>(result_holder.get());
IColumn & result_data = result_arr->getData();
result_arr->getOffsets().insert(indices_offsets->begin(), indices_offsets->end());
result_arr->getOffsets().insert(ranges_offsets->begin(), ranges_offsets->end());
/// AggregateFunction's states should be inserted into column using specific way
auto res_col_aggregate_function = typeid_cast<ColumnAggregateFunction *>(&result_data);
@ -228,7 +216,7 @@ void FunctionArrayReduceInRanges::executeImpl(Block & block, const ColumnNumbers
begin = end;
end = (*offsets)[i];
ranges_begin = ranges_end;
ranges_end = (*indices_offsets)[i];
ranges_end = (*ranges_offsets)[i];
/// We will allocate pre-aggregation places for each `minimun_place << level` rows.
/// The value of `level` starts from 0, and it will never exceed the number of bits in a `size_t`.
@ -305,8 +293,37 @@ void FunctionArrayReduceInRanges::executeImpl(Block & block, const ColumnNumbers
for (size_t j = ranges_begin; j < ranges_end; ++j)
{
size_t local_begin = std::max(indices_data->getUInt(j) - 1, size_t(0));
size_t local_end = std::min(local_begin + lengths_data->getUInt(j), end - begin);
size_t local_begin = 0;
size_t local_end = 0;
{
Int64 index = indices_col.getInt(j);
UInt64 length = lengths_col.getUInt(j);
/// Keep the same as in arraySlice
if (index > 0)
{
local_begin = index - 1;
if (local_begin + length < end - begin)
local_end = local_begin + length;
else
local_end = end - begin;
}
else if (index < 0)
{
if (end - begin + index > 0)
local_begin = end - begin + index;
else
local_begin = 0;
if (local_begin + length < end - begin)
local_end = local_begin + length;
else
local_end = end - begin;
}
}
size_t place_begin = (local_begin + minimum_step - 1) / minimum_step;
size_t place_end = local_end / minimum_step;

View File

@ -9,8 +9,8 @@
<query>SELECT arrayReduce('count', range(100000000))</query>
<query>SELECT arrayReduce('sum', range(100000000))</query>
<query>SELECT arrayReduceInRanges('count', [1], [100000000], range(100000000))</query>
<query>SELECT arrayReduceInRanges('sum', [1], [100000000], range(100000000))</query>
<query>SELECT arrayReduceInRanges('count', range(1000000), range(1000000), range(100000000))[123456]</query>
<query>SELECT arrayReduceInRanges('sum', range(1000000), range(1000000), range(100000000))[123456]</query>
<query>SELECT arrayReduceInRanges('count', [(1, 100000000)], range(100000000))</query>
<query>SELECT arrayReduceInRanges('sum', [(1, 100000000)], range(100000000))</query>
<query>SELECT arrayReduceInRanges('count', arrayZip(range(1000000), range(1000000)), range(100000000))[123456]</query>
<query>SELECT arrayReduceInRanges('sum', arrayZip(range(1000000), range(1000000)), range(100000000))[123456]</query>
</test>

View File

@ -1,5 +1,5 @@
[['a','b','c'],['b','c','d'],['c','d','e']]
[0,0,0,0,0,0,0,100,200,300,400,0,0,300,500,700,400,0,0,600,900,700,400,0]
[0,0,0,0,0,0,0,100,300,0,200,400,0,300,700,0,500,400,0,600,700,0,900,400]
1
1
1

View File

@ -1,23 +1,26 @@
SELECT
arrayReduceInRanges(
'groupArray',
[1, 2, 3],
[3, 3, 3],
[(1, 3), (2, 3), (3, 3)],
['a', 'b', 'c', 'd', 'e']
);
SELECT
arrayReduceInRanges(
'sum',
[0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5],
[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3],
[
(-6, 0), (-4, 0), (-2, 0), (0, 0), (2, 0), (4, 0),
(-6, 1), (-4, 1), (-2, 1), (0, 1), (2, 1), (4, 1),
(-6, 2), (-4, 2), (-2, 2), (0, 2), (2, 2), (4, 2),
(-6, 3), (-4, 3), (-2, 3), (0, 3), (2, 3), (4, 3)
],
[100, 200, 300, 400]
);
WITH
arrayMap(x -> x + 1, range(50)) as data
SELECT
arrayReduceInRanges('groupArray', [a, b], [c, d], data) =
arrayReduceInRanges('groupArray', [(a, c), (b, d)], data) =
[arraySlice(data, a, c), arraySlice(data, b, d)]
FROM (
SELECT