fixup! Add groupSortedArray() function

- Fix memory access
- Support any type as sorting parameter
- Fix tests
- Rewrite/simplify function addBatchSinglePlace
This commit is contained in:
Pablo Alegre 2022-02-04 15:53:27 +01:00
parent e813f6413f
commit 7f553d55ae
6 changed files with 162 additions and 131 deletions

View File

@ -18,8 +18,8 @@ If the parameter is omitted, default value 10 is used.
**Arguments**
- `x` The value.
- `expr` — Optional. The field or expresion to sort by. If not set values are sorted by themselves. [Integer](../../../sql-reference/data-types/int-uint.md).
- `column` The value.
- `expr` — Optional. The field or expresion to sort by. If not set values are sorted by themselves.
**Example**

View File

@ -27,21 +27,22 @@ namespace ErrorCodes
namespace
{
template <typename T, bool expr_sorted, typename TIndex>
class AggregateFunctionGroupSortedArrayNumeric : public AggregateFunctionGroupSortedArray<false, T, expr_sorted, TIndex>
template <typename T, bool expr_sorted, typename TColumnB, bool is_plain_b>
class AggregateFunctionGroupSortedArrayNumeric : public AggregateFunctionGroupSortedArray<T, false, expr_sorted, TColumnB, is_plain_b>
{
using AggregateFunctionGroupSortedArray<false, T, expr_sorted, TIndex>::AggregateFunctionGroupSortedArray;
using AggregateFunctionGroupSortedArray<T, false, expr_sorted, TColumnB, is_plain_b>::AggregateFunctionGroupSortedArray;
};
template <typename T, bool expr_sorted, typename TIndex>
template <typename T, bool expr_sorted, typename TColumnB, bool is_plain_b>
class AggregateFunctionGroupSortedArrayFieldType
: public AggregateFunctionGroupSortedArray<false, typename T::FieldType, expr_sorted, TIndex>
: public AggregateFunctionGroupSortedArray<typename T::FieldType, false, expr_sorted, TColumnB, is_plain_b>
{
using AggregateFunctionGroupSortedArray<false, typename T::FieldType, expr_sorted, TIndex>::AggregateFunctionGroupSortedArray;
using AggregateFunctionGroupSortedArray<typename T::FieldType, false, expr_sorted, TColumnB, is_plain_b>::
AggregateFunctionGroupSortedArray;
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<T>()); }
};
template <bool expr_sorted, typename TIndex>
template <bool expr_sorted, typename TColumnB, bool is_plain_b>
static IAggregateFunction * createWithExtraTypes(const DataTypes & argument_types, UInt64 threshold, const Array & params)
{
if (argument_types.empty())
@ -49,45 +50,56 @@ namespace
WhichDataType which(argument_types[0]);
if (which.idx == TypeIndex::Date)
return new AggregateFunctionGroupSortedArrayFieldType<DataTypeDate, expr_sorted, TIndex>(threshold, argument_types, params);
return new AggregateFunctionGroupSortedArrayFieldType<DataTypeDate, expr_sorted, TColumnB, is_plain_b>(
threshold, argument_types, params);
if (which.idx == TypeIndex::DateTime)
return new AggregateFunctionGroupSortedArrayFieldType<DataTypeDateTime, expr_sorted, TIndex>(threshold, argument_types, params);
return new AggregateFunctionGroupSortedArrayFieldType<DataTypeDateTime, expr_sorted, TColumnB, is_plain_b>(
threshold, argument_types, params);
if (argument_types[0]->isValueUnambiguouslyRepresentedInContiguousMemoryRegion())
{
return new AggregateFunctionGroupSortedArray<true, StringRef, expr_sorted, TIndex>(threshold, argument_types, params);
return new AggregateFunctionGroupSortedArray<StringRef, true, expr_sorted, TColumnB, is_plain_b>(
threshold, argument_types, params);
}
else
{
return new AggregateFunctionGroupSortedArray<false, StringRef, expr_sorted, TIndex>(threshold, argument_types, params);
return new AggregateFunctionGroupSortedArray<StringRef, false, expr_sorted, TColumnB, is_plain_b>(
threshold, argument_types, params);
}
}
template <template <typename, bool, typename> class AggregateFunctionTemplate, bool bool_param, typename TIndex, typename... TArgs>
template <
template <typename, bool, typename, bool>
class AggregateFunctionTemplate,
bool bool_param,
typename TColumnB,
bool is_plain_b,
typename... TArgs>
static IAggregateFunction * createWithNumericType2(const IDataType & argument_type, TArgs &&... args)
{
WhichDataType which(argument_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionTemplate<TYPE, bool_param, TIndex>(std::forward<TArgs>(args)...);
return new AggregateFunctionTemplate<TYPE, bool_param, TColumnB, is_plain_b>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Enum8)
return new AggregateFunctionTemplate<Int8, bool_param, TIndex>(std::forward<TArgs>(args)...);
return new AggregateFunctionTemplate<Int8, bool_param, TColumnB, is_plain_b>(std::forward<TArgs>(args)...);
if (which.idx == TypeIndex::Enum16)
return new AggregateFunctionTemplate<Int16, bool_param, TIndex>(std::forward<TArgs>(args)...);
return new AggregateFunctionTemplate<Int16, bool_param, TColumnB, is_plain_b>(std::forward<TArgs>(args)...);
return nullptr;
}
template <bool expr_sorted, typename TIndex>
template <bool expr_sorted = false, typename TColumnB = UInt64, bool is_plain_b = false>
AggregateFunctionPtr createAggregateFunctionGroupSortedArrayTyped(
const std::string & name, const DataTypes & argument_types, const Array & params, UInt64 threshold)
{
AggregateFunctionPtr res(createWithNumericType2<AggregateFunctionGroupSortedArrayNumeric, expr_sorted, TIndex>(
AggregateFunctionPtr res(createWithNumericType2<AggregateFunctionGroupSortedArrayNumeric, expr_sorted, TColumnB, is_plain_b>(
*argument_types[0], threshold, argument_types, params));
if (!res)
res = AggregateFunctionPtr(createWithExtraTypes<expr_sorted, TIndex>(argument_types, threshold, params));
res = AggregateFunctionPtr(createWithExtraTypes<expr_sorted, TColumnB, is_plain_b>(argument_types, threshold, params));
if (!res)
throw Exception(
@ -98,25 +110,11 @@ namespace
}
template <bool expr_sorted>
AggregateFunctionPtr createAggregateFunctionGroupSortedArray_custom(
AggregateFunctionPtr createAggregateFunctionGroupSortedArray(
const std::string & name, const DataTypes & argument_types, const Array & params, const Settings *)
{
UInt64 threshold = GROUP_SORTED_ARRAY_DEFAULT_THRESHOLD;
if constexpr (!expr_sorted)
{
assertUnary(name, argument_types);
}
else
{
assertBinary(name, argument_types);
if (!isInteger(argument_types[1]))
throw Exception(
"The second argument for aggregate function 'groupSortedArray' must have integer type",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
if (params.size() == 1)
{
UInt64 k = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[0]);
@ -131,26 +129,39 @@ namespace
threshold = k;
}
else if (params.size() != 0)
{
throw Exception("Aggregate function " + name + " only supports 1 parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
}
if (expr_sorted && isUnsignedInteger(argument_types[1]))
return createAggregateFunctionGroupSortedArrayTyped<expr_sorted, UInt64>(name, argument_types, params, threshold);
if (argument_types.size() == 2)
{
if (isUnsignedInteger(argument_types[1]))
{
return createAggregateFunctionGroupSortedArrayTyped<true, UInt64>(name, argument_types, params, threshold);
}
else if (isInteger(argument_types[1]))
{
return createAggregateFunctionGroupSortedArrayTyped<true, Int64>(name, argument_types, params, threshold);
}
else if (argument_types[1]->isValueUnambiguouslyRepresentedInContiguousMemoryRegion())
{
return createAggregateFunctionGroupSortedArrayTyped<true, StringRef, true>(name, argument_types, params, threshold);
}
else
{
return createAggregateFunctionGroupSortedArrayTyped<true, StringRef, false>(name, argument_types, params, threshold);
}
}
else if (argument_types.size() == 1)
{
return createAggregateFunctionGroupSortedArrayTyped<>(name, argument_types, params, threshold);
}
else
return createAggregateFunctionGroupSortedArrayTyped<expr_sorted, Int64>(name, argument_types, params, threshold);
}
AggregateFunctionPtr createAggregateFunctionGroupSortedArray(
const std::string & name, const DataTypes & argument_types, const Array & params, const Settings * settings)
{
if (argument_types.size() > 2 || argument_types.size() < 1)
{
throw Exception(
"Aggregate function " + name + " requires one or two parameters.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
}
if (argument_types.size() > 1)
return createAggregateFunctionGroupSortedArray_custom<true>(name, argument_types, params, settings);
else
return createAggregateFunctionGroupSortedArray_custom<false>(name, argument_types, params, settings);
}
}
@ -159,5 +170,4 @@ void registerAggregateFunctionGroupSortedArray(AggregateFunctionFactory & factor
AggregateFunctionProperties properties = {.returns_default_when_only_null = false, .is_order_dependent = true};
factory.registerFunction("groupSortedArray", {createAggregateFunctionGroupSortedArray, properties});
}
}

View File

@ -8,12 +8,12 @@
namespace DB
{
template <typename T, bool is_plain_column>
inline T readItem(const IColumn * column, Arena * arena, size_t row)
template <typename TColumn, bool is_plain>
inline TColumn readItem(const IColumn * column, Arena * arena, size_t row)
{
if constexpr (std::is_same_v<T, StringRef>)
if constexpr (std::is_same_v<TColumn, StringRef>)
{
if constexpr (is_plain_column)
if constexpr (is_plain)
{
StringRef str = column->getDataAt(row);
auto ptr = arena->alloc(str.size);
@ -28,12 +28,15 @@ inline T readItem(const IColumn * column, Arena * arena, size_t row)
}
else
{
return column->getUInt(row);
if constexpr (std::is_same_v<TColumn, UInt64>)
return column->getUInt(row);
else
return column->getInt(row);
}
}
template <typename T>
void getFirstNElements(const T * data, int num_elements, int threshold, size_t * results)
template <typename TColumn>
size_t getFirstNElements(const TColumn * data, int num_elements, int threshold, size_t * results, const UInt8 * filter = nullptr)
{
for (int i = 0; i < threshold; i++)
{
@ -46,14 +49,20 @@ void getFirstNElements(const T * data, int num_elements, int threshold, size_t *
int z;
for (int i = 0; i < num_elements; i++)
{
if (filter && (filter[i] == 0))
continue;
//Starting from the highest values and we look for the immediately lower than the given one
for (cur = current_max; cur > 0 && (data[i] < data[results[cur - 1]]); cur--)
;
for (cur = current_max; cur > 0; cur--)
{
if (!(data[i] < data[results[cur - 1]]))
break;
}
if (cur < threshold)
{
//Move all the higher values 1 position to the right
for (z = current_max - 1; z >= cur; z--)
for (z = current_max - 1; z > cur; z--)
results[z] = results[z - 1];
if (current_max < threshold)
@ -63,17 +72,20 @@ void getFirstNElements(const T * data, int num_elements, int threshold, size_t *
results[cur] = i;
}
}
return current_max;
}
template <bool is_plain_column, typename T, bool expr_sorted, typename TIndex>
template <typename TColumnA, bool is_plain_a, bool use_column_b, typename TColumnB, bool is_plain_b>
class AggregateFunctionGroupSortedArray : public IAggregateFunctionDataHelper<
AggregateFunctionGroupSortedArrayData<T, expr_sorted, TIndex>,
AggregateFunctionGroupSortedArray<is_plain_column, T, expr_sorted, TIndex>>
AggregateFunctionGroupSortedArrayData<TColumnA, use_column_b, TColumnB>,
AggregateFunctionGroupSortedArray<TColumnA, is_plain_a, use_column_b, TColumnB, is_plain_b>>
{
protected:
using State = AggregateFunctionGroupSortedArrayData<T, expr_sorted, TIndex>;
using Base
= IAggregateFunctionDataHelper<AggregateFunctionGroupSortedArrayData<T, expr_sorted, TIndex>, AggregateFunctionGroupSortedArray>;
using State = AggregateFunctionGroupSortedArrayData<TColumnA, use_column_b, TColumnB>;
using Base = IAggregateFunctionDataHelper<
AggregateFunctionGroupSortedArrayData<TColumnA, use_column_b, TColumnB>,
AggregateFunctionGroupSortedArray>;
UInt64 threshold;
DataTypePtr & input_data_type;
@ -83,8 +95,9 @@ protected:
public:
AggregateFunctionGroupSortedArray(UInt64 threshold_, const DataTypes & argument_types_, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionGroupSortedArrayData<T, expr_sorted, TIndex>, AggregateFunctionGroupSortedArray>(
argument_types_, params)
: IAggregateFunctionDataHelper<
AggregateFunctionGroupSortedArrayData<TColumnA, use_column_b, TColumnB>,
AggregateFunctionGroupSortedArray>(argument_types_, params)
, threshold(threshold_)
, input_data_type(this->argument_types[0])
{
@ -102,7 +115,7 @@ public:
bool allocatesMemoryInArena() const override
{
if constexpr (std::is_same_v<T, StringRef>)
if constexpr (std::is_same_v<TColumnA, StringRef>)
return true;
else
return false;
@ -111,13 +124,51 @@ public:
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
State & data = this->data(place);
if constexpr (expr_sorted)
if constexpr (use_column_b)
{
data.add(readItem<T, is_plain_column>(columns[0], arena, row_num), readItem<TIndex, false>(columns[1], arena, row_num));
data.add(
readItem<TColumnA, is_plain_a>(columns[0], arena, row_num), readItem<TColumnB, is_plain_b>(columns[1], arena, row_num));
}
else
{
data.add(readItem<T, is_plain_column>(columns[0], arena, row_num));
data.add(readItem<TColumnA, is_plain_a>(columns[0], arena, row_num));
}
}
template <typename TColumn, bool is_plain, typename TFunc>
void
forFirstRows(size_t batch_size, const IColumn ** columns, size_t data_column, Arena * arena, ssize_t if_argument_pos, TFunc func) const
{
const TColumn * values = nullptr;
std::unique_ptr<std::vector<TColumn>> values_vector;
std::vector<size_t> best_rows(threshold);
if constexpr (std::is_same_v<TColumn, StringRef>)
{
values_vector.reset(new std::vector<TColumn>(batch_size));
for (size_t i = 0; i < batch_size; i++)
(*values_vector)[i] = readItem<TColumn, is_plain>(columns[data_column], arena, i);
values = (*values_vector).data();
}
else
{
StringRef ref = columns[data_column]->getRawData();
values = reinterpret_cast<const TColumn *>(ref.data);
}
const UInt8 * filter = nullptr;
StringRef refFilter;
if (if_argument_pos >= 0)
{
refFilter = columns[if_argument_pos]->getRawData();
filter = reinterpret_cast<const UInt8 *>(refFilter.data);
}
size_t num_elements = getFirstNElements(values, batch_size, threshold, best_rows.data(), filter);
for (size_t i = 0; i < num_elements; i++)
{
func(best_rows[i], values);
}
}
@ -125,58 +176,22 @@ public:
size_t batch_size, AggregateDataPtr place, const IColumn ** columns, Arena * arena, ssize_t if_argument_pos) const override
{
State & data = this->data(place);
if constexpr (expr_sorted)
if constexpr (use_column_b)
{
StringRef ref = columns[1]->getRawData();
TIndex values[batch_size];
memcpy(values, ref.data, batch_size * sizeof(TIndex));
size_t num_results = std::min(this->threshold, batch_size);
size_t * bestRows = new size_t[batch_size];
//First store the first n elements with the column number
if (if_argument_pos >= 0)
{
TIndex * value_w = values;
const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
for (size_t i = 0; i < batch_size; ++i)
forFirstRows<TColumnB, is_plain_b>(
batch_size, columns, 1, arena, if_argument_pos, [columns, &arena, &data](size_t row, const TColumnB * values)
{
if (flags[i])
*(value_w++) = values[i];
}
batch_size = value_w - values;
}
num_results = std::min(this->threshold, batch_size);
getFirstNElements(values, batch_size, num_results, bestRows);
for (size_t i = 0; i < num_results; i++)
{
auto row = bestRows[i];
data.add(readItem<T, is_plain_column>(columns[0], arena, row), values[row]);
}
delete[] bestRows;
data.add(readItem<TColumnA, is_plain_a>(columns[0], arena, row), values[row]);
});
}
else
{
if (if_argument_pos >= 0)
{
const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
for (size_t i = 0; i < batch_size; ++i)
{
if (flags[i])
{
data.add(readItem<T, is_plain_column>(columns[0], arena, i));
}
}
}
else
{
for (size_t i = 0; i < batch_size; ++i)
{
data.add(readItem<T, is_plain_column>(columns[0], arena, i));
}
}
forFirstRows<TColumnA, is_plain_a>(
batch_size, columns, 0, arena, if_argument_pos, [&data](size_t row, const TColumnA * values)
{
data.add(values[row]);
});
}
}
@ -207,10 +222,10 @@ public:
IColumn & data_to = arr_to.getData();
for (auto value : values)
{
if constexpr (std::is_same_v<T, StringRef>)
if constexpr (std::is_same_v<TColumnA, StringRef>)
{
auto str = State::itemValue(value);
if constexpr (is_plain_column)
if constexpr (is_plain_a)
{
data_to.insertData(str.data, str.size);
}

View File

@ -125,7 +125,7 @@ struct AggregateFunctionGroupSortedArrayData<T, true, TIndex> : public Aggregate
virtual typename Base::ValueType deserializeItem(ReadBuffer & buf, Arena * arena) const override
{
Int64 first;
TIndex first;
T second;
readOneItem(buf, arena, first);
readOneItem(buf, arena, second);

View File

@ -5,10 +5,13 @@
['0','1','2','3','4']
['0','1','2','3','4']
['9','8','7','6','5']
[0,1,2,3,4,5,6,7,8,9]
[0,1,2,3,4,5,6,7,8,9]
[(0,'0'),(1,'1'),(2,'2'),(3,'3'),(4,'4')]
['0','1','10','11','12']
['0','1','2','3','4','5','6','7','8','9','10','11','12','13','14','15','16','17','18','19','20','21','22','23','24','25','26','27','28','29','30','31','32','33','34','35','36','37','38','39','40','41','42','43','44','45','46','47','48','49']
pablo [1,2]
luis [1,3]
pablo [1,2]
luis [1,3]
[4,5,6,7,8]
[10,11,12,13,14]
['10','11','12','13','14']

View File

@ -13,14 +13,13 @@ SELECT groupSortedArray(5)(text) FROM (select toString(number) as text from numb
SELECT groupSortedArray(5)(text, -number) FROM (select toString(number) as text, number from numbers(10));
SELECT groupSortedArray(10)(number, number) from numbers(100); -- { serverError 42 }
SELECT groupSortedArray(10)(number) from numbers(100); -- { serverError 42 }
SELECT groupSortedArray(5)(number, text) FROM (select toString(number) as text, number from numbers(10)); -- { serverError 43 }
SELECT groupSortedArray(5)((number,text)) from (SELECT toString(number) as text, number FROM numbers(100));
SELECT groupSortedArray(5)(text,text) from (SELECT toString(number) as text FROM numbers(100));
SELECT groupSortedArray(50)(text,(number,text)) from (SELECT toString(number) as text, number FROM numbers(100));
DROP TABLE IF EXISTS test;
DROP VIEW IF EXISTS mv_test;
CREATE TABLE test (`n` String, `h` Int64) ENGINE = MergeTree ORDER BY n;
@ -35,4 +34,8 @@ CREATE MATERIALIZED VIEW mv_test (`n` String, `h` AggregateFunction(groupSortedA
INSERT INTO test VALUES ('pablo',1)('pablo', 2)('luis', 1)('luis', 3)('pablo', 5)('pablo',4)('pablo', 5)('luis', 6)('luis', 7)('pablo', 8)('pablo',9)('pablo',10)('luis',11)('luis',12)('pablo',13);
SELECT n, groupSortedArrayMerge(2)(h) from mv_test GROUP BY n;
DROP TABLE test;
DROP VIEW mv_test;
DROP VIEW mv_test;
SELECT groupSortedArrayIf(5)(number, number, number>3) from numbers(100);
SELECT groupSortedArrayIf(5)(number, toString(number), number>3) from numbers(100);
SELECT groupSortedArrayIf(5)(toString(number), number>3) from numbers(100);