Refactor GROUPING function

This commit is contained in:
Dmitry Novik 2022-05-18 15:23:31 +00:00
parent e5b395e054
commit 6356112a76
9 changed files with 297 additions and 332 deletions

View File

@ -1,5 +1,6 @@
#pragma once
#include <unordered_set>
#include <vector>
#include <string>
@ -8,6 +9,8 @@ namespace DB
{
using ColumnNumbers = std::vector<size_t>;
using ColumnNumbersSet = std::unordered_set<size_t>;
using ColumnNumbersList = std::vector<ColumnNumbers>;
using ColumnNumbersSetList = std::vector<ColumnNumbersSet>;
}

View File

@ -1,159 +0,0 @@
#include <base/types.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnFixedString.h>
#include <Common/FieldVisitors.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/FieldToDataType.h>
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
namespace DB
{
class FunctionGrouping : public IFunction
{
public:
static constexpr auto name = "grouping";
static FunctionPtr create(ContextPtr)
{
return std::make_shared<FunctionGrouping>();
}
bool isVariadic() const override
{
return true;
}
size_t getNumberOfArguments() const override
{
return 0;
}
bool useDefaultImplementationForNulls() const override { return false; }
bool isSuitableForConstantFolding() const override { return false; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }
String getName() const override
{
return name;
}
DataTypePtr getReturnTypeImpl(const DataTypes & /*arguments*/) const override
{
//TODO: add assert for argument types
return std::make_shared<DataTypeUInt64>();
}
ColumnPtr executeOrdinaryGroupBy(const ColumnsWithTypeAndName & arguments, size_t input_rows_count) const
{
auto grouping_set_map_column = checkAndGetColumnConst<ColumnUInt64>(arguments[0].column.get());
auto argument_keys_column = checkAndGetColumnConst<ColumnArray>(arguments[1].column.get());
auto aggregation_keys_number = (*grouping_set_map_column)[0].get<UInt64>();
auto result = std::make_shared<DataTypeUInt64>()->createColumn();
for (size_t i = 0; i < input_rows_count; ++i)
{
auto indexes = (*argument_keys_column)[i].get<Array>();
UInt64 value = 0;
for (auto index : indexes)
value = (value << 1) + (index.get<UInt64>() < aggregation_keys_number ? 1 : 0);
result->insert(Field(value));
}
return result;
}
ColumnPtr executeRollup(
const ColumnUInt64 * grouping_set_column,
const ColumnConst & argument_keys_column,
UInt64 keys,
size_t input_rows_count) const
{
auto result = std::make_shared<DataTypeUInt64>()->createColumn();
for (size_t i = 0; i < input_rows_count; ++i)
{
UInt64 set_index = grouping_set_column->get64(i);
auto indexes = argument_keys_column[i].get<Array>();
UInt64 value = 0;
for (auto index : indexes)
value = (value << 1) + (index.get<UInt64>() < keys - set_index ? 1 : 0);
result->insert(Field(value));
}
return result;
}
ColumnPtr executeCube(
const ColumnUInt64 * grouping_set_column,
const ColumnConst & argument_keys_column,
UInt64 keys,
size_t input_rows_count) const
{
static constexpr auto ONE = static_cast<UInt64>(1);
auto result = std::make_shared<DataTypeUInt64>()->createColumn();
auto mask_base = (ONE << keys) - 1;
for (size_t i = 0; i < input_rows_count; ++i)
{
UInt64 set_index = grouping_set_column->get64(i);
auto mask = mask_base - set_index;
auto indexes = argument_keys_column[i].get<Array>();
UInt64 value = 0;
for (auto index : indexes)
value = (value << 1) + (mask & (ONE << (keys - index.get<UInt64>() - 1)) ? 1 : 0);
result->insert(Field(value));
}
return result;
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & , size_t input_rows_count) const override
{
if (arguments.size() == 2)
return executeOrdinaryGroupBy(arguments, input_rows_count);
auto grouping_set_column = checkAndGetColumn<ColumnUInt64>(arguments[0].column.get());
auto grouping_set_map_column = checkAndGetColumnConst<ColumnArray>(arguments[1].column.get());
auto argument_keys_column = checkAndGetColumnConst<ColumnArray>(arguments[2].column.get());
auto masks = (*grouping_set_map_column)[0].get<Array>();
auto grouping_set_map_elem_type = applyVisitor(FieldToDataType(), masks[0]);
if (!isString(grouping_set_map_elem_type))
{
bool is_rollup = masks[0].get<UInt64>() == 0;
auto keys = masks[1].get<UInt64>();
if (is_rollup)
return executeRollup(grouping_set_column, *argument_keys_column, keys, input_rows_count);
else
return executeCube(grouping_set_column, *argument_keys_column, keys, input_rows_count);
}
auto result = std::make_shared<DataTypeUInt64>()->createColumn();
for (size_t i = 0; i < input_rows_count; ++i)
{
UInt64 set_index = grouping_set_column->get64(i);
auto mask = masks[set_index].get<const String &>();
auto indexes = (*argument_keys_column)[i].get<Array>();
UInt64 value = 0;
for (auto index : indexes)
value = (value << 1) + (mask[index.get<UInt64>()] == '1' ? 1 : 0);
result->insert(Field(value));
}
return result;
}
};
void registerFunctionGrouping(FunctionFactory & factory)
{
factory.registerFunction<FunctionGrouping>();
}
}

151
src/Functions/grouping.h Normal file
View File

@ -0,0 +1,151 @@
#include <Functions/IFunction.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnFixedString.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <DataTypes/DataTypesNumber.h>
#include "Core/ColumnNumbers.h"
#include "DataTypes/Serializations/ISerialization.h"
#include "base/types.h"
namespace DB
{
class FunctionGroupingBase : public IFunction
{
protected:
static constexpr UInt64 ONE = 1;
const ColumnNumbers arguments_indexes;
public:
FunctionGroupingBase(ColumnNumbers arguments_indexes_)
: arguments_indexes(std::move(arguments_indexes_))
{}
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
bool useDefaultImplementationForNulls() const override { return false; }
bool isSuitableForConstantFolding() const override { return false; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }
DataTypePtr getReturnTypeImpl(const DataTypes & /*arguments*/) const override
{
return std::make_shared<DataTypeUInt64>();
}
template <typename AggregationKeyChecker>
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, size_t input_rows_count, AggregationKeyChecker checker) const
{
auto grouping_set_column = checkAndGetColumn<ColumnUInt64>(arguments[0].column.get());
auto result = std::make_shared<DataTypeUInt64>()->createColumn();
for (size_t i = 0; i < input_rows_count; ++i)
{
UInt64 set_index = grouping_set_column->get64(i);
UInt64 value = 0;
for (auto index : arguments_indexes)
value = (value << 1) + (checker(set_index, index) ? 1 : 0);
result->insert(Field(value));
}
return result;
}
};
class FunctionGroupingOrdinary : public FunctionGroupingBase
{
public:
explicit FunctionGroupingOrdinary(ColumnNumbers arguments_indexes_)
: FunctionGroupingBase(std::move(arguments_indexes_))
{}
String getName() const override { return "groupingOrdinary"; }
ColumnPtr executeImpl(const ColumnsWithTypeAndName &, const DataTypePtr &, size_t input_rows_count) const override
{
UInt64 value = (ONE << arguments_indexes.size()) - 1;
return ColumnUInt64::create(input_rows_count, value);
}
};
class FunctionGroupingForRollup : public FunctionGroupingBase
{
const UInt64 aggregation_keys_number;
public:
FunctionGroupingForRollup(ColumnNumbers arguments_indexes_, UInt64 aggregation_keys_number_)
: FunctionGroupingBase(std::move(arguments_indexes_))
, aggregation_keys_number(aggregation_keys_number_)
{}
String getName() const override { return "groupingForRollup"; }
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
return FunctionGroupingBase::executeImpl(arguments, input_rows_count,
[this](UInt64 set_index, UInt64 arg_index)
{
return arg_index < aggregation_keys_number - set_index;
}
);
}
};
class FunctionGroupingForCube : public FunctionGroupingBase
{
const UInt64 aggregation_keys_number;
public:
FunctionGroupingForCube(ColumnNumbers arguments_indexes_, UInt64 aggregation_keys_number_)
: FunctionGroupingBase(arguments_indexes_)
, aggregation_keys_number(aggregation_keys_number_)
{}
String getName() const override { return "groupingForCube"; }
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
return FunctionGroupingBase::executeImpl(arguments, input_rows_count,
[this](UInt64 set_index, UInt64 arg_index)
{
auto set_mask = (ONE << aggregation_keys_number) - 1 - set_index;
return set_mask & (ONE << (aggregation_keys_number - arg_index - 1));
}
);
}
};
class FunctionGroupingForGroupingSets : public FunctionGroupingBase
{
ColumnNumbersSetList grouping_sets;
public:
FunctionGroupingForGroupingSets(ColumnNumbers arguments_indexes_, ColumnNumbersList const & grouping_sets_)
: FunctionGroupingBase(std::move(arguments_indexes_))
{
for (auto const & set : grouping_sets_)
grouping_sets.emplace_back(set.begin(), set.end());
}
String getName() const override { return "groupingForGroupingSets"; }
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
return FunctionGroupingBase::executeImpl(arguments, input_rows_count,
[this](UInt64 set_index, UInt64 arg_index)
{
return grouping_sets[set_index].contains(arg_index);
}
);
}
};
}

View File

@ -83,7 +83,6 @@ void registerFunctionZooKeeperSessionUptime(FunctionFactory &);
void registerFunctionGetOSKernelVersion(FunctionFactory &);
void registerFunctionGetTypeSerializationStreams(FunctionFactory &);
void registerFunctionFlattenTuple(FunctionFactory &);
void registerFunctionGrouping(FunctionFactory &);
#if USE_ICU
void registerFunctionConvertCharset(FunctionFactory &);
@ -173,7 +172,6 @@ void registerFunctionsMiscellaneous(FunctionFactory & factory)
registerFunctionGetOSKernelVersion(factory);
registerFunctionGetTypeSerializationStreams(factory);
registerFunctionFlattenTuple(factory);
registerFunctionGrouping(factory);
#if USE_ICU
registerFunctionConvertCharset(factory);

View File

@ -1,3 +1,4 @@
#include <memory>
#include <Common/quoteString.h>
#include <Common/typeid_cast.h>
#include <Columns/ColumnArray.h>
@ -5,6 +6,7 @@
#include <Core/ColumnNumbers.h>
#include <Core/ColumnWithTypeAndName.h>
#include <Functions/grouping.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionsMiscellaneous.h>
@ -839,89 +841,39 @@ void ActionsMatcher::visit(const ASTFunction & node, const ASTPtr & ast, Data &
if (node.name == "grouping")
{
auto arguments_column_name = data.getUniqueName("__grouping_args");
ColumnNumbers arguments_indexes;
auto aggregation_keys_number = data.aggregation_keys.size();
for (auto const & arg : node.arguments->children)
{
if (!data.hasColumn("__grouping_set_map"))
{
ColumnWithTypeAndName column;
column.name = "__grouping_set_map";
switch (data.group_by_kind)
{
case GroupByKind::GROUPING_SETS:
{
size_t map_size = data.aggregation_keys.size() + 1;
column.type = std::make_shared<DataTypeArray>(std::make_shared<DataTypeFixedString>(map_size));
Array maps_per_set;
for (auto & grouping_set : data.grouping_set_keys)
{
std::string key_map(map_size, '0');
for (auto index : grouping_set)
key_map[index] = '1';
maps_per_set.push_back(key_map);
}
auto grouping_set_map_column = ColumnArray::create(ColumnFixedString::create(map_size));
grouping_set_map_column->insert(maps_per_set);
column.column = ColumnConst::create(std::move(grouping_set_map_column), 1);
break;
}
case GroupByKind::ROLLUP:
case GroupByKind::CUBE:
{
column.type = std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt64>());
auto grouping_set_map_column = ColumnArray::create(ColumnUInt64::create());
Array kind_and_keys_size;
kind_and_keys_size.push_back(data.group_by_kind == GroupByKind::ROLLUP ? 0 : 1);
kind_and_keys_size.push_back(data.aggregation_keys.size());
grouping_set_map_column->insert(kind_and_keys_size);
column.column = ColumnConst::create(std::move(grouping_set_map_column), 1);
break;
}
case GroupByKind::ORDINARY:
{
column.type = std::make_shared<DataTypeUInt64>();
auto grouping_set_map_column = ColumnUInt64::create(1, data.aggregation_keys.size());
column.column = ColumnConst::create(std::move(grouping_set_map_column), 1);
break;
}
default:
throw Exception(ErrorCodes::LOGICAL_ERROR,
"Unexpected kind of GROUP BY clause for GROUPING function: {}", data.group_by_kind);
}
data.addColumn(column);
}
ColumnWithTypeAndName column;
column.name = arguments_column_name;
column.type = std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt64>());
Array arguments_to_keys_map;
for (auto const & arg : node.arguments->children)
{
size_t pos = data.aggregation_keys.getPosByName(arg->getColumnName());
arguments_to_keys_map.push_back(pos);
}
auto arguments_column = ColumnArray::create(ColumnUInt64::create());
arguments_column->insert(Field{arguments_to_keys_map});
column.column = ColumnConst::create(ColumnPtr(std::move(arguments_column)), 1);
data.addColumn(column);
size_t pos = data.aggregation_keys.getPosByName(arg->getColumnName());
if (pos == aggregation_keys_number)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument of GROUPING function {} is not a part of GROUP BY clause", arg->getColumnName());
arguments_indexes.push_back(pos);
}
if (data.group_by_kind != GroupByKind::ORDINARY)
switch (data.group_by_kind)
{
data.addFunction(
FunctionFactory::instance().get("grouping", data.getContext()),
{ "__grouping_set", "__grouping_set_map", arguments_column_name },
column_name
);
}
else
{
data.addFunction(
FunctionFactory::instance().get("grouping", data.getContext()),
{ "__grouping_set_map", arguments_column_name },
column_name
);
case GroupByKind::GROUPING_SETS:
{
data.addFunction(std::make_shared<FunctionToOverloadResolverAdaptor>(std::make_shared<FunctionGroupingForGroupingSets>(std::move(arguments_indexes), data.grouping_set_keys)), { "__grouping_set" }, column_name);
break;
}
case GroupByKind::ROLLUP:
data.addFunction(std::make_shared<FunctionToOverloadResolverAdaptor>(std::make_shared<FunctionGroupingForRollup>(std::move(arguments_indexes), data.aggregation_keys.size())), { "__grouping_set" }, column_name);
break;
case GroupByKind::CUBE:
{
data.addFunction(std::make_shared<FunctionToOverloadResolverAdaptor>(std::make_shared<FunctionGroupingForCube>(std::move(arguments_indexes), data.aggregation_keys.size())), { "__grouping_set" }, column_name);
break;
}
case GroupByKind::ORDINARY:
{
data.addFunction(std::make_shared<FunctionToOverloadResolverAdaptor>(std::make_shared<FunctionGroupingOrdinary>(std::move(arguments_indexes))), {}, column_name);
break;
}
default:
throw Exception(ErrorCodes::LOGICAL_ERROR,
"Unexpected kind of GROUP BY clause for GROUPING function: {}", data.group_by_kind);
}
return;
}

View File

@ -1,27 +1,27 @@
0 2
0 2
0 4
1 4
2 4
3 4
4 4
5 4
6 4
7 4
8 4
9 4
0 1
0 1
0 4
1 4
2 4
3 4
4 4
5 4
6 4
7 4
8 4
9 4
0 2
1 2
2 2
3 2
4 2
5 2
6 2
7 2
8 2
9 2
0 1
0 2
0 2
1 1
2 1
3 1
4 1
5 1
6 1
7 1
8 1
9 1
0 0
0 1
0 1
@ -47,26 +47,26 @@
8
9
0 10 0
0 1 4
1 1 4
2 1 4
3 1 4
4 1 4
5 1 4
6 1 4
7 1 4
8 1 4
9 1 4
0 1 6
1 1 6
2 1 6
3 1 6
4 1 6
5 1 6
6 1 6
7 1 6
8 1 6
9 1 6
0 1 2
1 1 2
2 1 2
3 1 2
4 1 2
5 1 2
6 1 2
7 1 2
8 1 2
9 1 2
0 1 3
1 1 3
2 1 3
3 1 3
4 1 3
5 1 3
6 1 3
7 1 3
8 1 3
9 1 3
0
1
2

View File

@ -7,11 +7,11 @@ GROUP BY
(number),
(number % 2)
)
ORDER BY number, gr;
ORDER BY number, gr; -- { serverError BAD_ARGUMENTS }
SELECT
number,
grouping(number, number % 3, number % 2) AS gr
grouping(number, number % 2) AS gr
FROM numbers(10)
GROUP BY
GROUPING SETS (
@ -22,7 +22,18 @@ ORDER BY number, gr;
SELECT
number,
grouping(number, number % 2, number % 3) = 2 AS gr
grouping(number % 2, number) AS gr
FROM numbers(10)
GROUP BY
GROUPING SETS (
(number),
(number % 2)
)
ORDER BY number, gr;
SELECT
number,
grouping(number, number % 2) = 1 AS gr
FROM numbers(10)
GROUP BY
GROUPING SETS (
@ -39,12 +50,12 @@ GROUP BY
(number),
(number % 2)
)
ORDER BY number, grouping(number, number % 2, number % 3) = 2;
ORDER BY number, grouping(number, number % 2) = 1;
SELECT
number,
count(),
grouping(number, number % 2, number % 3) AS gr
grouping(number, number % 2) AS gr
FROM numbers(10)
GROUP BY
GROUPING SETS (
@ -62,7 +73,7 @@ GROUP BY
(number),
(number % 2)
)
HAVING grouping(number, number % 2, number % 3) = 4
HAVING grouping(number, number % 2) = 2
ORDER BY number
SETTINGS enable_optimize_predicate_expression = 0;
@ -74,13 +85,13 @@ GROUP BY
(number),
(number % 2)
)
HAVING grouping(number, number % 2, number % 3) = 2
HAVING grouping(number, number % 2) = 1
ORDER BY number
SETTINGS enable_optimize_predicate_expression = 0;
SELECT
number,
GROUPING(number, number % 2, number % 3) = 2 as gr
GROUPING(number, number % 2) = 1 as gr
FROM remote('127.0.0.{2,3}', numbers(10))
GROUP BY
GROUPING SETS (

View File

@ -19,47 +19,47 @@
8 1 1
9 1 1
0 0
0 4
0 6
1 4
1 6
2 4
2 6
3 4
3 6
4 4
4 6
5 4
5 6
6 4
6 6
7 4
7 6
8 4
8 6
9 4
9 6
0 2
0 3
1 2
1 3
2 2
2 3
3 2
3 3
4 2
4 3
5 2
5 3
6 2
6 3
7 2
7 3
8 2
8 3
9 2
9 3
0 0
0 4
0 6
1 4
1 6
2 4
2 6
3 4
3 6
4 4
4 6
5 4
5 6
6 4
6 6
7 4
7 6
8 4
8 6
9 4
9 6
0 2
0 3
1 2
1 3
2 2
2 3
3 2
3 3
4 2
4 3
5 2
5 3
6 2
6 3
7 2
7 3
8 2
8 3
9 2
9 3
0 0
0 1
0 1

View File

@ -2,6 +2,15 @@ SELECT
number,
grouping(number, number % 2, number % 3) = 6
FROM remote('127.0.0.{2,3}', numbers(10))
GROUP BY
number,
number % 2
ORDER BY number; -- { serverError BAD_ARGUMENTS }
SELECT
number,
grouping(number, number % 2) = 3
FROM remote('127.0.0.{2,3}', numbers(10))
GROUP BY
number,
number % 2
@ -19,7 +28,7 @@ ORDER BY number;
SELECT
number,
grouping(number, number % 2, number % 3) AS gr
grouping(number, number % 2) AS gr
FROM remote('127.0.0.{2,3}', numbers(10))
GROUP BY
number,
@ -30,7 +39,7 @@ ORDER BY
SELECT
number,
grouping(number, number % 2, number % 3) AS gr
grouping(number, number % 2) AS gr
FROM remote('127.0.0.{2,3}', numbers(10))
GROUP BY
ROLLUP(number, number % 2)