Merge pull request #39420 from amosbird/better-projection1-fix1

Normalize AggregateFunction types and state representations
This commit is contained in:
Alexey Milovidov 2022-08-04 03:06:55 +03:00 committed by GitHub
commit 9e46abc560
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 282 additions and 41 deletions

View File

@ -49,6 +49,16 @@ public:
return nested_func->getReturnType();
}
const IAggregateFunction & getBaseAggregateFunctionWithSameStateRepresentation() const override
{
return nested_func->getBaseAggregateFunctionWithSameStateRepresentation();
}
DataTypePtr getNormalizedStateType() const override
{
return nested_func->getNormalizedStateType();
}
bool isVersioned() const override
{
return nested_func->isVersioned();

View File

@ -5,9 +5,11 @@
#include <array>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnsCommon.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <Common/assert_cast.h>
#include <Common/config.h>
@ -102,6 +104,19 @@ public:
}
}
bool haveSameStateRepresentationImpl(const IAggregateFunction & rhs) const override
{
return this->getName() == rhs.getName();
}
DataTypePtr getNormalizedStateType() const override
{
/// Return normalized state type: count()
AggregateFunctionProperties properties;
return std::make_shared<DataTypeAggregateFunction>(
AggregateFunctionFactory::instance().get(getName(), {}, {}, properties), DataTypes{}, Array{});
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
{
data(place).count += data(rhs).count;
@ -240,6 +255,19 @@ public:
}
}
bool haveSameStateRepresentationImpl(const IAggregateFunction & rhs) const override
{
return this->getName() == rhs.getName();
}
DataTypePtr getNormalizedStateType() const override
{
/// Return normalized state type: count()
AggregateFunctionProperties properties;
return std::make_shared<DataTypeAggregateFunction>(
AggregateFunctionFactory::instance().get(getName(), {}, {}, properties), DataTypes{}, Array{});
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
{
data(place).count += data(rhs).count;

View File

@ -56,6 +56,16 @@ public:
return nested_func->getReturnType();
}
const IAggregateFunction & getBaseAggregateFunctionWithSameStateRepresentation() const override
{
return nested_func->getBaseAggregateFunctionWithSameStateRepresentation();
}
DataTypePtr getNormalizedStateType() const override
{
return nested_func->getNormalizedStateType();
}
bool isVersioned() const override
{
return nested_func->isVersioned();

View File

@ -23,14 +23,20 @@ public:
DataTypes transformArguments(const DataTypes & arguments) const override
{
if (arguments.size() != 1)
throw Exception("Incorrect number of arguments for aggregate function with " + getName() + " suffix", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Incorrect number of arguments for aggregate function with {} suffix",
getName());
const DataTypePtr & argument = arguments[0];
const DataTypeAggregateFunction * function = typeid_cast<const DataTypeAggregateFunction *>(argument.get());
if (!function)
throw Exception("Illegal type " + argument->getName() + " of argument for aggregate function with " + getName() + " suffix"
+ " must be AggregateFunction(...)", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument for aggregate function with {} suffix. It must be AggregateFunction(...)",
argument->getName(),
getName());
return function->getArgumentsDataTypes();
}
@ -45,13 +51,21 @@ public:
const DataTypeAggregateFunction * function = typeid_cast<const DataTypeAggregateFunction *>(argument.get());
if (!function)
throw Exception("Illegal type " + argument->getName() + " of argument for aggregate function with " + getName() + " suffix"
+ " must be AggregateFunction(...)", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument for aggregate function with {} suffix. It must be AggregateFunction(...)",
argument->getName(),
getName());
if (nested_function->getName() != function->getFunctionName())
throw Exception("Illegal type " + argument->getName() + " of argument for aggregate function with " + getName() + " suffix"
+ ", because it corresponds to different aggregate function: " + function->getFunctionName() + " instead of " + nested_function->getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!nested_function->haveSameStateRepresentation(*function->getFunction()))
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument for aggregate function with {} suffix. because it corresponds to different aggregate "
"function: {} instead of {}",
argument->getName(),
getName(),
function->getFunctionName(),
nested_function->getName());
return std::make_shared<AggregateFunctionMerge>(nested_function, argument, params);
}

View File

@ -50,6 +50,11 @@ public:
return nested_func->getReturnType();
}
const IAggregateFunction & getBaseAggregateFunctionWithSameStateRepresentation() const override
{
return nested_func->getBaseAggregateFunctionWithSameStateRepresentation();
}
bool isVersioned() const override
{
return nested_func->isVersioned();

View File

@ -1,6 +1,7 @@
#pragma once
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
/// These must be exposed in header for the purpose of dynamic compilation.
#include <AggregateFunctions/QuantileReservoirSampler.h>
@ -20,9 +21,11 @@
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Common/assert_cast.h>
#include <Interpreters/GatherFunctionQuantileVisitor.h>
#include <type_traits>
@ -61,10 +64,9 @@ template <
typename FloatReturnType,
/// If true, the function will accept multiple parameters with quantile levels
/// and return an Array filled with many values of that quantiles.
bool returns_many
>
class AggregateFunctionQuantile final : public IAggregateFunctionDataHelper<Data,
AggregateFunctionQuantile<Value, Data, Name, has_second_arg, FloatReturnType, returns_many>>
bool returns_many>
class AggregateFunctionQuantile final
: public IAggregateFunctionDataHelper<Data, AggregateFunctionQuantile<Value, Data, Name, has_second_arg, FloatReturnType, returns_many>>
{
private:
using ColVecType = ColumnVectorOrDecimal<Value>;
@ -81,11 +83,14 @@ private:
public:
AggregateFunctionQuantile(const DataTypes & argument_types_, const Array & params)
: IAggregateFunctionDataHelper<Data, AggregateFunctionQuantile<Value, Data, Name, has_second_arg, FloatReturnType, returns_many>>(argument_types_, params)
, levels(params, returns_many), level(levels.levels[0]), argument_type(this->argument_types[0])
: IAggregateFunctionDataHelper<Data, AggregateFunctionQuantile<Value, Data, Name, has_second_arg, FloatReturnType, returns_many>>(
argument_types_, params)
, levels(params, returns_many)
, level(levels.levels[0])
, argument_type(this->argument_types[0])
{
if (!returns_many && levels.size() > 1)
throw Exception("Aggregate function " + getName() + " require one parameter or less", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} require one parameter or less", getName());
}
String getName() const override { return Name::name; }
@ -105,9 +110,22 @@ public:
return res;
}
bool haveSameStateRepresentation(const IAggregateFunction & rhs) const override
bool haveSameStateRepresentationImpl(const IAggregateFunction & rhs) const override
{
return getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs);
return GatherFunctionQuantileData::toFusedNameOrSelf(getName()) == GatherFunctionQuantileData::toFusedNameOrSelf(rhs.getName())
&& this->haveEqualArgumentTypes(rhs);
}
DataTypePtr getNormalizedStateType() const override
{
/// Return normalized state type: quantiles*(1)(...)
Array params{1};
AggregateFunctionProperties properties;
return std::make_shared<DataTypeAggregateFunction>(
AggregateFunctionFactory::instance().get(
GatherFunctionQuantileData::toFusedNameOrSelf(getName()), this->argument_types, params, properties),
this->argument_types,
params);
}
bool allocatesMemoryInArena() const override { return false; }
@ -124,9 +142,7 @@ public:
}
if constexpr (has_second_arg)
this->data(place).add(
value,
columns[1]->getUInt(row_num));
this->data(place).add(value, columns[1]->getUInt(row_num));
else
this->data(place).add(value);
}
@ -149,7 +165,6 @@ public:
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
/// const_cast is required because some data structures apply finalizaton (like sorting) for obtain a result.
auto & data = this->data(place);
if constexpr (returns_many)
@ -195,7 +210,11 @@ public:
{
assertBinary(Name::name, types);
if (!isUnsignedInteger(types[1]))
throw Exception("Second argument (weight) for function " + std::string(Name::name) + " must be unsigned integer, but it has type " + types[1]->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Second argument (weight) for function {} must be unsigned integer, but it has type {}",
Name::name,
types[1]->getName());
}
else
assertUnary(Name::name, types);

View File

@ -163,7 +163,7 @@ public:
this->data(place).deserialize(buf);
}
bool haveSameStateRepresentation(const IAggregateFunction & rhs) const override
bool haveSameStateRepresentationImpl(const IAggregateFunction & rhs) const override
{
return this->getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs);
}

View File

@ -194,7 +194,7 @@ public:
DataTypePtr getReturnType() const override { return data_type; }
bool haveSameStateRepresentation(const IAggregateFunction & rhs) const override
bool haveSameStateRepresentationImpl(const IAggregateFunction & rhs) const override
{
return this->getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs);
}

View File

@ -37,6 +37,11 @@ public:
return getStateType();
}
const IAggregateFunction & getBaseAggregateFunctionWithSameStateRepresentation() const override
{
return nested_func->getBaseAggregateFunctionWithSameStateRepresentation();
}
DataTypePtr getStateType() const override
{
return nested_func->getStateType();

View File

@ -59,6 +59,13 @@ bool IAggregateFunction::haveEqualArgumentTypes(const IAggregateFunction & rhs)
}
bool IAggregateFunction::haveSameStateRepresentation(const IAggregateFunction & rhs) const
{
const auto & lhs_base = getBaseAggregateFunctionWithSameStateRepresentation();
const auto & rhs_base = rhs.getBaseAggregateFunctionWithSameStateRepresentation();
return lhs_base.haveSameStateRepresentationImpl(rhs_base);
}
bool IAggregateFunction::haveSameStateRepresentationImpl(const IAggregateFunction & rhs) const
{
bool res = getName() == rhs.getName()
&& parameters == rhs.parameters

View File

@ -73,13 +73,19 @@ public:
/// Get the data type of internal state. By default it is AggregateFunction(name(params), argument_types...).
virtual DataTypePtr getStateType() const;
/// Same as the above but normalize state types so that variants with the same binary representation will use the same type.
virtual DataTypePtr getNormalizedStateType() const { return getStateType(); }
/// Returns true if two aggregate functions have the same state representation in memory and the same serialization,
/// so state of one aggregate function can be safely used with another.
/// Examples:
/// - quantile(x), quantile(a)(x), quantile(b)(x) - parameter doesn't affect state and used for finalization only
/// - foo(x) and fooIf(x) - If combinator doesn't affect state
/// By default returns true only if functions have exactly the same names, combinators and parameters.
virtual bool haveSameStateRepresentation(const IAggregateFunction & rhs) const;
bool haveSameStateRepresentation(const IAggregateFunction & rhs) const;
virtual bool haveSameStateRepresentationImpl(const IAggregateFunction & rhs) const;
virtual const IAggregateFunction & getBaseAggregateFunctionWithSameStateRepresentation() const { return *this; }
bool haveEqualArgumentTypes(const IAggregateFunction & rhs) const;

View File

@ -8,6 +8,7 @@
#include <Common/assert_cast.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnSparse.h>
@ -68,9 +69,27 @@ static ReturnType checkColumnStructure(const ColumnWithTypeAndName & actual, con
actual_column = &column_sparse->getValuesColumn();
}
if (actual_column->getName() != expected.column->getName())
return onError<ReturnType>("Block structure mismatch in " + std::string(context_description) + " stream: different columns:\n"
+ actual.dumpStructure() + "\n" + expected.dumpStructure(), code);
const auto * actual_column_maybe_agg = typeid_cast<const ColumnAggregateFunction *>(actual_column);
const auto * expected_column_maybe_agg = typeid_cast<const ColumnAggregateFunction *>(expected.column.get());
if (actual_column_maybe_agg && expected_column_maybe_agg)
{
if (!actual_column_maybe_agg->getAggregateFunction()->haveSameStateRepresentation(*expected_column_maybe_agg->getAggregateFunction()))
return onError<ReturnType>(
fmt::format(
"Block structure mismatch in {} stream: different columns:\n{}\n{}",
context_description,
actual.dumpStructure(),
expected.dumpStructure()),
code);
}
else if (actual_column->getName() != expected.column->getName())
return onError<ReturnType>(
fmt::format(
"Block structure mismatch in {} stream: different columns:\n{}\n{}",
context_description,
actual.dumpStructure(),
expected.dumpStructure()),
code);
if (isColumnConst(*actual.column) && isColumnConst(*expected.column))
{

View File

@ -119,7 +119,40 @@ Field DataTypeAggregateFunction::getDefault() const
bool DataTypeAggregateFunction::equals(const IDataType & rhs) const
{
return typeid(rhs) == typeid(*this) && getNameWithoutVersion() == typeid_cast<const DataTypeAggregateFunction &>(rhs).getNameWithoutVersion();
if (typeid(rhs) != typeid(*this))
return false;
auto lhs_state_type = function->getNormalizedStateType();
auto rhs_state_type = typeid_cast<const DataTypeAggregateFunction &>(rhs).function->getNormalizedStateType();
if (typeid(lhs_state_type.get()) != typeid(rhs_state_type.get()))
return false;
if (const auto * lhs_state = typeid_cast<const DataTypeAggregateFunction *>(lhs_state_type.get()))
{
const auto & rhs_state = typeid_cast<const DataTypeAggregateFunction &>(*rhs_state_type);
if (lhs_state->function->getName() != rhs_state.function->getName())
return false;
if (lhs_state->parameters.size() != lhs_state->parameters.size())
return false;
for (size_t i = 0; i < lhs_state->parameters.size(); ++i)
if (lhs_state->parameters[i] != rhs_state.parameters[i])
return false;
if (lhs_state->argument_types.size() != lhs_state->argument_types.size())
return false;
for (size_t i = 0; i < lhs_state->argument_types.size(); ++i)
if (!lhs_state->argument_types[i]->equals(*rhs_state.argument_types[i]))
return false;
return true;
}
return lhs_state_type->equals(*rhs_state_type);
}

View File

@ -599,3 +599,26 @@ template <typename T> inline constexpr bool IsDataTypeEnum<DataTypeEnum<T>> = tr
M(Float32) \
M(Float64)
}
/// See https://fmt.dev/latest/api.html#formatting-user-defined-types
template <>
struct fmt::formatter<DB::DataTypePtr>
{
constexpr static auto parse(format_parse_context & ctx)
{
const auto * it = ctx.begin();
const auto * end = ctx.end();
/// Only support {}.
if (it != end && *it != '}')
throw format_error("invalid format");
return it;
}
template <typename FormatContext>
auto format(const DB::DataTypePtr & type, FormatContext & ctx)
{
return format_to(ctx.out(), "{}", type->getName());
}
};

View File

@ -1,8 +1,9 @@
#include <string>
#include <Interpreters/GatherFunctionQuantileVisitor.h>
#include <AggregateFunctions/AggregateFunctionQuantile.h>
#include <Parsers/ASTFunction.h>
#include <Common/Exception.h>
#include <base/types.h>
#include <Common/Exception.h>
namespace DB
{
@ -30,6 +31,13 @@ static const std::unordered_map<String, String> quantile_fuse_name_mapping = {
{NameQuantileTimingWeighted::name, NameQuantilesTimingWeighted::name},
};
String GatherFunctionQuantileData::toFusedNameOrSelf(const String & func_name)
{
if (auto it = quantile_fuse_name_mapping.find(func_name); it != quantile_fuse_name_mapping.end())
return it->second;
return func_name;
}
String GatherFunctionQuantileData::getFusedName(const String & func_name)
{
if (auto it = quantile_fuse_name_mapping.find(func_name); it != quantile_fuse_name_mapping.end())
@ -53,11 +61,9 @@ void GatherFunctionQuantileData::FuseQuantileAggregatesData::addFuncNode(ASTPtr
const auto & arguments = func->arguments->children;
bool need_two_args = func->name == NameQuantileDeterministic::name
|| func->name == NameQuantileExactWeighted::name
|| func->name == NameQuantileTimingWeighted::name
|| func->name == NameQuantileTDigestWeighted::name
|| func->name == NameQuantileBFloat16Weighted::name;
bool need_two_args = func->name == NameQuantileDeterministic::name || func->name == NameQuantileExactWeighted::name
|| func->name == NameQuantileTimingWeighted::name || func->name == NameQuantileTDigestWeighted::name
|| func->name == NameQuantileBFloat16Weighted::name;
if (arguments.size() != (need_two_args ? 2 : 1))
return;
@ -83,4 +89,3 @@ bool GatherFunctionQuantileData::needChild(const ASTPtr & node, const ASTPtr &)
}
}

View File

@ -1,6 +1,5 @@
#pragma once
#include <AggregateFunctions/AggregateFunctionQuantile.h>
#include <Interpreters/InDepthNodeVisitor.h>
#include <Parsers/IAST_fwd.h>
@ -26,6 +25,8 @@ public:
void visit(ASTFunction & function, ASTPtr & ast);
static String toFusedNameOrSelf(const String & func_name);
static String getFusedName(const String & func_name);
static bool needChild(const ASTPtr & node, const ASTPtr &);

View File

@ -0,0 +1,45 @@
import pytest
from helpers.cluster import ClickHouseCluster
cluster = ClickHouseCluster(__file__)
node1 = cluster.add_instance("node1", with_zookeeper=False)
node2 = cluster.add_instance(
"node2",
with_zookeeper=False,
image="yandex/clickhouse-server",
tag="21.7.2.7",
stay_alive=True,
with_installed_binary=True,
)
@pytest.fixture(scope="module")
def start_cluster():
try:
cluster.start()
yield cluster
finally:
cluster.shutdown()
def test_select_aggregate_alias_column(start_cluster):
node1.query(
"create table tab (x UInt64, y String, z Nullable(Int64)) engine = Memory"
)
node2.query(
"create table tab (x UInt64, y String, z Nullable(Int64)) engine = Memory"
)
node1.query("insert into tab values (1, 'a', null)")
node2.query("insert into tab values (1, 'a', null)")
node1.query(
"select count(), count(1), count(x), count(y), count(z) from remote('node{1,2}', default, tab)"
)
node2.query(
"select count(), count(1), count(x), count(y), count(z) from remote('node{1,2}', default, tab)"
)
node1.query("drop table tab")
node2.query("drop table tab")

View File

@ -1,5 +1,3 @@
-- Tags: no-s3-storage
drop table if exists t;
create table t (n int) engine MergeTree order by n;
insert into t values (1);

View File

@ -0,0 +1,5 @@
drop table if exists t;
create table t (n int, s String) engine MergeTree order by n;
insert into t values (1, 'a');
select count(), count(n), count(s) from cluster('test_cluster_two_shards', currentDatabase(), t);
drop table t;

View File

@ -0,0 +1,3 @@
SELECT countMerge(*) FROM (SELECT countState(0.5) AS a UNION ALL SELECT countState() UNION ALL SELECT countIfState(2, 1) UNION ALL SELECT countArrayState([1, 2]) UNION ALL SELECT countArrayIfState([1, 2], 1));
SELECT quantileMerge(*) FROM (SELECT quantilesState(0.5)(1) AS a UNION ALL SELECT quantileStateIf(2, identity(1)));

View File

@ -0,0 +1 @@
2 2

View File

@ -0,0 +1 @@
select sum(a), sum(b) from cluster(test_cluster_two_shards, view(select cast(number as Decimal(7, 2)) a, 0 as b from numbers(2) union all select 0, cast(number as Decimal(7, 2)) as b from numbers(2)));