Fix tests

This commit is contained in:
Amos Bird 2022-07-22 14:27:45 +08:00
parent f84e5b6827
commit 09c99d8440
No known key found for this signature in database
GPG Key ID: 80D430DCBECFEDB4
7 changed files with 37 additions and 12 deletions

View File

@ -54,9 +54,9 @@ public:
return nested_func->getBaseAggregateFunctionWithSameStateRepresentation();
}
DataTypePtr getStateType() const override
DataTypePtr getNormalizedStateType() const override
{
return nested_func->getStateType();
return nested_func->getNormalizedStateType();
}
bool isVersioned() const override

View File

@ -109,8 +109,9 @@ public:
return this->getName() == rhs.getName();
}
DataTypePtr getStateType() const override
DataTypePtr getNormalizedStateType() const override
{
/// Return normalized state type: count()
AggregateFunctionProperties properties;
return std::make_shared<DataTypeAggregateFunction>(
AggregateFunctionFactory::instance().get(getName(), {}, {}, properties), DataTypes{}, Array{});
@ -259,8 +260,9 @@ public:
return this->getName() == rhs.getName();
}
DataTypePtr getStateType() const override
DataTypePtr getNormalizedStateType() const override
{
/// Return normalized state type: count()
AggregateFunctionProperties properties;
return std::make_shared<DataTypeAggregateFunction>(
AggregateFunctionFactory::instance().get(getName(), {}, {}, properties), DataTypes{}, Array{});

View File

@ -61,9 +61,9 @@ public:
return nested_func->getBaseAggregateFunctionWithSameStateRepresentation();
}
DataTypePtr getStateType() const override
DataTypePtr getNormalizedStateType() const override
{
return nested_func->getStateType();
return nested_func->getNormalizedStateType();
}
bool isVersioned() const override

View File

@ -116,8 +116,9 @@ public:
&& this->haveEqualArgumentTypes(rhs);
}
DataTypePtr getStateType() const override
DataTypePtr getNormalizedStateType() const override
{
/// Return normalized state type: quantiles*(1)(...)
Array params{1};
AggregateFunctionProperties properties;
return std::make_shared<DataTypeAggregateFunction>(

View File

@ -73,6 +73,9 @@ public:
/// Get the data type of internal state. By default it is AggregateFunction(name(params), argument_types...).
virtual DataTypePtr getStateType() const;
/// Same as the above but normalize state types so that variants with the same binary representation will use the same type.
virtual DataTypePtr getNormalizedStateType() const { return getStateType(); }
/// Returns true if two aggregate functions have the same state representation in memory and the same serialization,
/// so state of one aggregate function can be safely used with another.
/// Examples:

View File

@ -8,6 +8,7 @@
#include <Common/assert_cast.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnSparse.h>
@ -68,9 +69,27 @@ static ReturnType checkColumnStructure(const ColumnWithTypeAndName & actual, con
actual_column = &column_sparse->getValuesColumn();
}
if (actual_column->getName() != expected.column->getName())
return onError<ReturnType>("Block structure mismatch in " + std::string(context_description) + " stream: different columns:\n"
+ actual.dumpStructure() + "\n" + expected.dumpStructure(), code);
const auto * actual_column_maybe_agg = typeid_cast<const ColumnAggregateFunction *>(actual_column);
const auto * expected_column_maybe_agg = typeid_cast<const ColumnAggregateFunction *>(expected.column.get());
if (actual_column_maybe_agg && expected_column_maybe_agg)
{
if (!actual_column_maybe_agg->getAggregateFunction()->haveSameStateRepresentation(*expected_column_maybe_agg->getAggregateFunction()))
return onError<ReturnType>(
fmt::format(
"Block structure mismatch in {} stream: different columns:\n{}\n{}",
context_description,
actual.dumpStructure(),
expected.dumpStructure()),
code);
}
else if (actual_column->getName() != expected.column->getName())
return onError<ReturnType>(
fmt::format(
"Block structure mismatch in {} stream: different columns:\n{}\n{}",
context_description,
actual.dumpStructure(),
expected.dumpStructure()),
code);
if (isColumnConst(*actual.column) && isColumnConst(*expected.column))
{

View File

@ -122,8 +122,8 @@ bool DataTypeAggregateFunction::equals(const IDataType & rhs) const
if (typeid(rhs) != typeid(*this))
return false;
auto lhs_state_type = function->getStateType();
auto rhs_state_type = typeid_cast<const DataTypeAggregateFunction &>(rhs).function->getStateType();
auto lhs_state_type = function->getNormalizedStateType();
auto rhs_state_type = typeid_cast<const DataTypeAggregateFunction &>(rhs).function->getNormalizedStateType();
if (typeid(lhs_state_type.get()) != typeid(rhs_state_type.get()))
return false;