More correct

This commit is contained in:
kssenii 2021-05-31 14:44:57 +00:00
parent c11ad44aad
commit e510c3839e
34 changed files with 245 additions and 167 deletions

View File

@ -85,7 +85,7 @@ public:
this->data(place).value.write(buf, *serialization_val);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
this->data(place).result.read(buf, *serialization_res, arena);
this->data(place).value.read(buf, *serialization_val, arena);

View File

@ -125,7 +125,7 @@ public:
nested_func->serialize(place, buf, version);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> version, Arena * arena) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> version, Arena * arena) const override
{
nested_func->deserialize(place, buf, version, arena);
}

View File

@ -117,7 +117,7 @@ public:
writeBinary(this->data(place).denominator, buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
{
readBinary(this->data(place).numerator, buf);

View File

@ -72,7 +72,7 @@ public:
writeBinary(this->data(place).value, buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
{
readBinary(this->data(place).value, buf);
}

View File

@ -147,7 +147,7 @@ public:
data(place).serialize(buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
{
data(place).deserialize(buf);
}

View File

@ -146,7 +146,7 @@ public:
writeVarUInt(data(place).count, buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
{
readVarUInt(data(place).count, buf);
}

View File

@ -110,7 +110,7 @@ public:
writePODBinary<bool>(this->data(place).seen, buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
{
readIntBinary(this->data(place).sum, buf);
readIntBinary(this->data(place).first, buf);

View File

@ -153,7 +153,7 @@ public:
writePODBinary<bool>(this->data(place).seen, buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
{
readIntBinary(this->data(place).sum, buf);
readIntBinary(this->data(place).first, buf);

View File

@ -187,7 +187,7 @@ public:
this->data(place).serialize(buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
this->data(place).deserialize(buf, arena);
}

View File

@ -130,7 +130,7 @@ public:
this->data(const_cast<AggregateDataPtr>(place)).serialize(buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
{
this->data(place).deserialize(buf);
}

View File

@ -219,7 +219,7 @@ public:
}
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> version, Arena * arena) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> version, Arena * arena) const override
{
AggregateFunctionForEachData & state = data(place);

View File

@ -570,7 +570,7 @@ public:
// if constexpr (Trait::sampler == Sampler::DETERMINATOR)
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
UInt64 elems;
readVarUInt(elems, buf);

View File

@ -164,7 +164,7 @@ public:
}
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
{
size_t size = 0;
readVarUInt(size, buf);

View File

@ -147,7 +147,7 @@ public:
buf.write(reinterpret_cast<const char *>(value.data()), size * sizeof(value[0]));
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
size_t size = 0;
readVarUInt(size, buf);

View File

@ -96,7 +96,7 @@ public:
writeIntBinary(elem, buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
{
this->data(place).value.read(buf);
}

View File

@ -351,7 +351,7 @@ public:
this->data(place).write(buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
{
this->data(place).read(buf, max_bins);
}

View File

@ -139,7 +139,7 @@ public:
nested_func->serialize(place, buf, version);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> version, Arena * arena) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> version, Arena * arena) const override
{
nested_func->deserialize(place, buf, version, arena);
}

View File

@ -220,7 +220,7 @@ public:
this->data(place).write(buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
this->data(place).read(buf, arena);
}

View File

@ -118,7 +118,7 @@ public:
buf.write(reinterpret_cast<const char *>(value.data()), size * sizeof(value[0]));
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
size_t size = 0;
readVarUInt(size, buf);

View File

@ -100,7 +100,7 @@ public:
nested_func->serialize(place, buf, version);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> version, Arena * arena) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> version, Arena * arena) const override
{
nested_func->deserialize(place, buf, version, arena);
}

View File

@ -87,7 +87,7 @@ public:
this->data(place).write(buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
this->data(place).read(buf, arena);
}

View File

@ -120,7 +120,7 @@ public:
this->data(place).serialize(buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
{
this->data(place).deserialize(buf);
}

View File

@ -174,7 +174,7 @@ public:
this->data(place).serialize(buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
{
this->data(place).deserialize(buf);
}

View File

@ -269,30 +269,37 @@ public:
if (!version)
version = getDefaultVersion();
LOG_TRACE(&Poco::Logger::get("kssenii"), "version to serialize: {}, stack: {}", *version, StackTrace().toString());
const auto & merged_maps = this->data(place).merged_maps;
size_t size = merged_maps.size();
writeVarUInt(size, buf);
std::function<void(size_t, const Array &)> serialize;
switch (*version)
{
case 0:
{
serialize = [&](size_t col_idx, const Array & values){ values_serializations[col_idx]->serializeBinary(values[col_idx], buf); };
break;
}
case 1:
{
serialize = [&](size_t col_idx, const Array & values)
{
const auto & type = values_types[col_idx];
if (isInteger(type))
SerializationNumber<Int64>().serializeBinary(values[col_idx], buf);
else
values_serializations[col_idx]->serializeBinary(values[col_idx], buf);
};
break;
}
}
for (const auto & elem : merged_maps)
{
keys_serialization->serializeBinary(elem.first, buf);
for (size_t col = 0; col < values_types.size(); ++col)
{
switch (*version)
{
case 0:
{
values_serializations[col]->serializeBinary(elem.second[col], buf);
break;
}
case 1:
{
SerializationNumber<Int64>().serializeBinary(elem.second[col], buf);
break;
}
}
}
serialize(col, elem.second);
}
}
@ -301,11 +308,32 @@ public:
if (!version)
version = getDefaultVersion();
LOG_TRACE(&Poco::Logger::get("kssenii"), "version to deserialize: {}, stack: {}", *version, StackTrace().toString());
auto & merged_maps = this->data(place).merged_maps;
size_t size = 0;
readVarUInt(size, buf);
std::function<void(size_t, Array &)> deserialize;
switch (*version)
{
case 0:
{
deserialize = [&](size_t col_idx, Array & values){ values_serializations[col_idx]->deserializeBinary(values[col_idx], buf); };
break;
}
case 1:
{
deserialize = [&](size_t col_idx, Array & values)
{
const auto & type = values_types[col_idx];
if (isInteger(type))
SerializationNumber<Int64>().deserializeBinary(values[col_idx], buf);
else
values_serializations[col_idx]->deserializeBinary(values[col_idx], buf);
};
break;
}
}
for (size_t i = 0; i < size; ++i)
{
Field key;
@ -313,22 +341,9 @@ public:
Array values;
values.resize(values_types.size());
for (size_t col = 0; col < values_types.size(); ++col)
{
switch (*version)
{
case 0:
{
values_serializations[col]->deserializeBinary(values[col], buf);
break;
}
case 1:
{
SerializationNumber<Int64>().deserializeBinary(values[col], buf);
break;
}
}
}
deserialize(col, values);
if constexpr (IsDecimalNumber<T>)
merged_maps[key.get<DecimalField<T>>()] = values;

View File

@ -28,7 +28,7 @@ namespace ErrorCodes
}
static std::string getTypeString(const AggregateFunctionPtr & func)
static String getTypeString(const AggregateFunctionPtr & func)
{
WriteBufferFromOwnString stream;
stream << "AggregateFunction(" << func->getName();
@ -55,8 +55,8 @@ static std::string getTypeString(const AggregateFunctionPtr & func)
}
ColumnAggregateFunction::ColumnAggregateFunction(const AggregateFunctionPtr & func_)
: func(func_), type_string(getTypeString(func))
ColumnAggregateFunction::ColumnAggregateFunction(const AggregateFunctionPtr & func_, std::optional<size_t> version_)
: func(func_), type_string(getTypeString(func)), version(version_)
{
}
@ -354,7 +354,7 @@ INSTANTIATE_INDEX_IMPL(ColumnAggregateFunction)
void ColumnAggregateFunction::updateHashWithValue(size_t n, SipHash & hash) const
{
WriteBufferFromOwnString wbuf;
func->serialize(data[n], wbuf);
func->serialize(data[n], wbuf, version);
hash.update(wbuf.str().c_str(), wbuf.str().size());
}
@ -371,7 +371,7 @@ void ColumnAggregateFunction::updateWeakHash32(WeakHash32 & hash) const
for (size_t i = 0; i < s; ++i)
{
WriteBufferFromVector<std::vector<UInt8>> wbuf(v);
func->serialize(data[i], wbuf);
func->serialize(data[i], wbuf, version);
wbuf.finalize();
hash_data[i] = ::updateWeakHash32(v.data(), v.size(), hash_data[i]);
}
@ -423,7 +423,7 @@ Field ColumnAggregateFunction::operator[](size_t n) const
field.get<AggregateFunctionStateData &>().name = type_string;
{
WriteBufferFromString buffer(field.get<AggregateFunctionStateData &>().data);
func->serialize(data[n], buffer);
func->serialize(data[n], buffer, version);
}
return field;
}
@ -434,7 +434,7 @@ void ColumnAggregateFunction::get(size_t n, Field & res) const
res.get<AggregateFunctionStateData &>().name = type_string;
{
WriteBufferFromString buffer(res.get<AggregateFunctionStateData &>().data);
func->serialize(data[n], buffer);
func->serialize(data[n], buffer, version);
}
}
@ -514,7 +514,7 @@ void ColumnAggregateFunction::insert(const Field & x)
Arena & arena = createOrGetArena();
pushBackAndCreateState(data, arena, func.get());
ReadBufferFromString read_buffer(x.get<const AggregateFunctionStateData &>().data);
func->deserialize(data.back(), read_buffer, std::nullopt, &arena);
func->deserialize(data.back(), read_buffer, version, &arena);
}
void ColumnAggregateFunction::insertDefault()
@ -527,7 +527,7 @@ void ColumnAggregateFunction::insertDefault()
StringRef ColumnAggregateFunction::serializeValueIntoArena(size_t n, Arena & arena, const char *& begin) const
{
WriteBufferFromArena out(arena, begin);
func->serialize(data[n], out);
func->serialize(data[n], out, version);
return out.finish();
}
@ -549,7 +549,7 @@ const char * ColumnAggregateFunction::deserializeAndInsertFromArena(const char *
* Probably this will not work under UBSan.
*/
ReadBufferFromMemory read_buffer(src_arena, std::numeric_limits<char *>::max() - src_arena - 1);
func->deserialize(data.back(), read_buffer, std::nullopt, &dst_arena);
func->deserialize(data.back(), read_buffer, version, &dst_arena);
return read_buffer.position();
}
@ -649,7 +649,7 @@ void ColumnAggregateFunction::getExtremes(Field & min, Field & max) const
try
{
WriteBufferFromString buffer(serialized.data);
func->serialize(place, buffer);
func->serialize(place, buffer, version);
}
catch (...)
{

View File

@ -82,6 +82,8 @@ private:
/// Name of the type to distinguish different aggregation states.
String type_string;
std::optional<size_t> version;
ColumnAggregateFunction() = default;
/// Create a new column that has another column as a source.
@ -92,7 +94,7 @@ private:
/// but ownership of different elements cannot be mixed by different columns.
void ensureOwnership();
ColumnAggregateFunction(const AggregateFunctionPtr & func_);
ColumnAggregateFunction(const AggregateFunctionPtr & func_, std::optional<size_t> version_ = std::nullopt);
ColumnAggregateFunction(const AggregateFunctionPtr & func_, const ConstArenas & arenas_);

View File

@ -10,6 +10,7 @@
#include <DataStreams/NativeBlockInputStream.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <DataTypes/DataTypeAggregateFunction.h>
namespace DB
@ -71,7 +72,7 @@ void NativeBlockInputStream::resetParser()
is_killed.store(false);
}
void NativeBlockInputStream::readData(const IDataType & type, ColumnPtr & column, ReadBuffer & istr, size_t rows, double avg_value_size_hint)
void NativeBlockInputStream::readData(const IDataType & type, ColumnPtr & column, ReadBuffer & istr, size_t rows, double avg_value_size_hint, size_t revision)
{
ISerialization::DeserializeBinaryBulkSettings settings;
settings.getter = [&](ISerialization::SubstreamPath) -> ReadBuffer * { return &istr; };
@ -79,6 +80,14 @@ void NativeBlockInputStream::readData(const IDataType & type, ColumnPtr & column
settings.position_independent_encoding = false;
ISerialization::DeserializeBinaryBulkStatePtr state;
const auto * aggregate_function_data_type = typeid_cast<const DataTypeAggregateFunction *>(&type);
if (aggregate_function_data_type && aggregate_function_data_type->isVersioned())
{
auto version = aggregate_function_data_type->getVersionFromRevision(revision);
aggregate_function_data_type->setVersionIfEmpty(version);
}
auto serialization = type.getDefaultSerialization();
serialization->deserializeBinaryBulkStatePrefix(settings, state);
@ -164,7 +173,7 @@ Block NativeBlockInputStream::readImpl()
double avg_value_size_hint = avg_value_size_hints.empty() ? 0 : avg_value_size_hints[i];
if (rows) /// If no rows, nothing to read.
readData(*column.type, read_column, istr, rows, avg_value_size_hint);
readData(*column.type, read_column, istr, rows, avg_value_size_hint, server_revision);
column.column = std::move(read_column);

View File

@ -74,7 +74,7 @@ public:
String getName() const override { return "Native"; }
static void readData(const IDataType & type, ColumnPtr & column, ReadBuffer & istr, size_t rows, double avg_value_size_hint);
static void readData(const IDataType & type, ColumnPtr & column, ReadBuffer & istr, size_t rows, double avg_value_size_hint, size_t revision);
Block getHeader() const override;

View File

@ -10,6 +10,7 @@
#include <Common/typeid_cast.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <DataTypes/DataTypeAggregateFunction.h>
namespace DB
{
@ -41,7 +42,7 @@ void NativeBlockOutputStream::flush()
}
static void writeData(const IDataType & type, const ColumnPtr & column, WriteBuffer & ostr, UInt64 offset, UInt64 limit)
static void writeData(const IDataType & type, const ColumnPtr & column, WriteBuffer & ostr, UInt64 offset, UInt64 limit, size_t revision)
{
/** If there are columns-constants - then we materialize them.
* (Since the data type does not know how to serialize / deserialize constants.)
@ -53,6 +54,13 @@ static void writeData(const IDataType & type, const ColumnPtr & column, WriteBuf
settings.position_independent_encoding = false;
settings.low_cardinality_max_dictionary_size = 0; //-V1048
const auto * aggregate_function_data_type = typeid_cast<const DataTypeAggregateFunction *>(&type);
if (aggregate_function_data_type && aggregate_function_data_type->isVersioned())
{
auto version = aggregate_function_data_type->getVersionFromRevision(revision);
aggregate_function_data_type->setVersionIfEmpty(version);
}
auto serialization = type.getDefaultSerialization();
ISerialization::SerializeBinaryBulkStatePtr state;
@ -123,7 +131,7 @@ void NativeBlockOutputStream::write(const Block & block)
/// Data
if (rows) /// Zero items of data is always represented as zero number of bytes.
writeData(*column.type, column.column, ostr, 0, 0);
writeData(*column.type, column.column, ostr, 0, 0, client_revision);
if (index_ostr)
{

View File

@ -20,7 +20,6 @@
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTIdentifier.h>
#include <Common/ClickHouseRevision.h>
namespace DB
@ -89,10 +88,7 @@ String DataTypeAggregateFunction::getNameImpl(bool with_version) const
MutableColumnPtr DataTypeAggregateFunction::createColumn() const
{
/// FIXME: There are a lot of function->serialize inside ColumnAggregateFunction.
/// Looks like it also needs version.
LOG_TRACE(&Poco::Logger::get("kssenii"), "KSSENII COLUMN");
return ColumnAggregateFunction::create(function);
return ColumnAggregateFunction::create(function, version);
}
@ -153,7 +149,7 @@ static DataTypePtr create(const ASTPtr & arguments)
/* If aggregate function definition doesn't have version, it will have in AST children args [ASTFunction, types...] - in case
* it is parametric, or [ASTIdentifier, types...] - otherwise. If aggregate function has version in AST, then it will be:
* [ASTLitearl, ASTFunction (or ASTIdentifier), types].
* [ASTLitearl, ASTFunction (or ASTIdentifier), types...].
*/
if (auto version_ast = arguments->children[0]->as<ASTLiteral>())
{
@ -168,7 +164,6 @@ static DataTypePtr create(const ASTPtr & arguments)
throw Exception("Unexpected level of parameters to aggregate function", ErrorCodes::SYNTAX_ERROR);
function_name = parametric->name;
LOG_TRACE(&Poco::Logger::get("kssenii"), "Paramtric function name: {}", function_name);
if (parametric->arguments)
{

View File

@ -3,7 +3,6 @@
#include <AggregateFunctions/IAggregateFunction.h>
#include <DataTypes/IDataType.h>
#include <common/logger_useful.h>
namespace DB
@ -15,7 +14,7 @@ namespace DB
* Data type can support versioning for serialization of aggregate function state.
* Version 0 also means no versioning. When a table with versioned data type is attached, its version is parsed from AST. If
* there is no version in AST, then it is either attach with no version in metadata (then version is 0) or it
* is a new data type (then version is default). In distributed queries version of data type is known from data type name.
* is a new data type (then version is default - latest).
*/
class DataTypeAggregateFunction final : public IDataType
{
@ -26,6 +25,7 @@ private:
mutable std::optional<size_t> version;
String getNameImpl(bool with_version) const;
size_t getVersion() const;
public:
static constexpr bool is_parametric = true;
@ -65,13 +65,15 @@ public:
SerializationPtr doGetDefaultSerialization() const override;
/// Version of aggregate function state serialization.
size_t getVersion() const;
bool isVersioned() const { return function->isVersioned(); }
/// Version is not empty only if it was parsed from AST.
/// It is ok to have an empty version value here - then for serialization
/// a default (latest) version is used. This method is used to force some
/// zero version to be used instead of default - if there was no version in AST.
size_t getVersionFromRevision(size_t revision) const { return function->getVersionFromRevision(revision); }
/// Version is not empty only if it was parsed from AST or implicitly cast to 0 or version according
/// to server revision.
/// It is ok to have an empty version value here - then for serialization a default (latest)
/// version is used. This method is used to force some zero version to be used instead of
/// default, or to set version for serialization in distributed queries.
void setVersionIfEmpty(size_t version_) const
{
if (!version)

View File

@ -45,6 +45,7 @@
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <Databases/DatabaseFactory.h>
#include <Databases/DatabaseReplicated.h>
@ -394,6 +395,10 @@ ColumnsDescription InterpreterCreateQuery::getColumnsDescription(
{
column_type = DataTypeFactory::instance().get(col_decl.type);
auto aggregate_function_type = typeid_cast<const DataTypeAggregateFunction *>(column_type.get());
if (attach && aggregate_function_type && aggregate_function_type->isVersioned())
aggregate_function_type->setVersionIfEmpty(0);
if (col_decl.null_modifier)
{
if (column_type->isNullable())

View File

@ -399,9 +399,8 @@ void IMergeTreeDataPart::setColumns(const NamesAndTypesList & new_columns, bool
{
column_name_to_position.emplace(column.name, pos);
/// TODO: May be there is a better way or a better place for that.
const auto * aggregate_function_data_type = typeid_cast<const DataTypeAggregateFunction *>(column.type.get());
if (loaded_from_disk && aggregate_function_data_type)
if (loaded_from_disk && aggregate_function_data_type && aggregate_function_data_type->isVersioned())
aggregate_function_data_type->setVersionIfEmpty(0);
for (const auto & subcolumn : column.type->getSubcolumnNames())
@ -1040,16 +1039,6 @@ void IMergeTreeDataPart::loadColumns(bool require)
{
loaded_columns.readText(*volume->getDisk()->readFile(path));
loaded_from_disk = true;
for (auto & col : loaded_columns)
{
LOG_TRACE(&Poco::Logger::get("kssenii"), "Setting version for columns: {}, {}", col.name, col.type->getName());
}
}
for (auto & col : loaded_columns)
{
LOG_TRACE(&Poco::Logger::get("kssenii"), "Loaded columns: {}, {}", col.name, col.type->getName());
}
setColumns(loaded_columns, loaded_from_disk);

View File

@ -2,35 +2,53 @@ import pytest
from helpers.cluster import ClickHouseCluster
from helpers.test_tools import assert_eq_with_retry, exec_query_with_retry
cluster = ClickHouseCluster(__file__)
node1 = cluster.add_instance('node1', main_configs=["configs/log_conf.xml"], stay_alive=True)
node2 = cluster.add_instance('node2', main_configs=["configs/log_conf.xml"],
image='yandex/clickhouse-server',
tag='21.5', with_installed_binary=True, stay_alive=True)
node3 = cluster.add_instance('node3', with_zookeeper=True, image='yandex/clickhouse-server', tag='21.2', with_installed_binary=True, stay_alive=True)
node2 = cluster.add_instance('node2', with_zookeeper=True, image='yandex/clickhouse-server', tag='21.2', with_installed_binary=True, stay_alive=True)
# Use differents nodes because if there is node.restart_from_latest_version(), then in later tests
# it will be with latest version, but shouldn't, order of tests in CI is shuffled.
node3 = cluster.add_instance('node3', main_configs=["configs/log_conf.xml"],
image='yandex/clickhouse-server', tag='21.5', with_installed_binary=True, stay_alive=True)
node4 = cluster.add_instance('node4', main_configs=["configs/log_conf.xml"],
image='yandex/clickhouse-server', tag='21.5', with_installed_binary=True, stay_alive=True)
node5 = cluster.add_instance('node5', main_configs=["configs/log_conf.xml"],
image='yandex/clickhouse-server', tag='21.5', with_installed_binary=True, stay_alive=True)
node6 = cluster.add_instance('node6', main_configs=["configs/log_conf.xml"],
image='yandex/clickhouse-server', tag='21.5', with_installed_binary=True, stay_alive=True)
def insert_data(node):
node.query(""" INSERT INTO test_table
SELECT toDateTime('2020-10-01 19:20:30'), 1,
sumMapState(arrayMap(i -> 1, range(300)), arrayMap(i -> 1, range(300)));""")
def insert_data(node, table_name='test_table', n=1, col2=1):
node.query(""" INSERT INTO {}
SELECT toDateTime(NOW()), {},
sumMapState(arrayMap(i -> 1, range(300)), arrayMap(i -> 1, range(300)))
FROM numbers({});""".format(table_name, col2, n))
def create_and_fill_table(node):
node.query("DROP TABLE IF EXISTS test_table;")
node.query("""
CREATE TABLE test_table
(
`col1` DateTime,
`col2` Int64,
`col3` AggregateFunction(sumMap, Array(UInt8), Array(UInt8))
)
ENGINE = AggregatingMergeTree() ORDER BY (col1, col2) """)
insert_data(node)
def create_table(node, name='test_table', version=None):
node.query("DROP TABLE IF EXISTS {};".format(name))
if version is None:
node.query("""
CREATE TABLE {}
(
`col1` DateTime,
`col2` Int64,
`col3` AggregateFunction(sumMap, Array(UInt8), Array(UInt8))
)
ENGINE = AggregatingMergeTree() ORDER BY (col1, col2) """.format(name))
else:
node.query("""
CREATE TABLE {}
(
`col1` DateTime,
`col2` Int64,
`col3` AggregateFunction({}, sumMap, Array(UInt8), Array(UInt8))
)
ENGINE = AggregatingMergeTree() ORDER BY (col1, col2) """.format(name, version))
@pytest.fixture(scope="module")
@ -43,104 +61,139 @@ def start_cluster():
def test_modulo_partition_key_issue_23508(start_cluster):
node3.query("CREATE TABLE test (id Int64, v UInt64, value String) ENGINE = ReplicatedReplacingMergeTree('/clickhouse/tables/table1', '1', v) PARTITION BY id % 20 ORDER BY (id, v)")
node3.query("INSERT INTO test SELECT number, number, toString(number) FROM numbers(10)")
node2.query("CREATE TABLE test (id Int64, v UInt64, value String) ENGINE = ReplicatedReplacingMergeTree('/clickhouse/tables/table1', '1', v) PARTITION BY id % 20 ORDER BY (id, v)")
node2.query("INSERT INTO test SELECT number, number, toString(number) FROM numbers(10)")
expected = node3.query("SELECT number, number, toString(number) FROM numbers(10)")
partition_data = node3.query("SELECT partition, name FROM system.parts WHERE table='test' ORDER BY partition")
assert(expected == node3.query("SELECT * FROM test ORDER BY id"))
expected = node2.query("SELECT number, number, toString(number) FROM numbers(10)")
partition_data = node2.query("SELECT partition, name FROM system.parts WHERE table='test' ORDER BY partition")
assert(expected == node2.query("SELECT * FROM test ORDER BY id"))
node3.restart_with_latest_version()
node2.restart_with_latest_version()
assert(expected == node3.query("SELECT * FROM test ORDER BY id"))
assert(partition_data == node3.query("SELECT partition, name FROM system.parts WHERE table='test' ORDER BY partition"))
assert(expected == node2.query("SELECT * FROM test ORDER BY id"))
assert(partition_data == node2.query("SELECT partition, name FROM system.parts WHERE table='test' ORDER BY partition"))
# Test from issue 16587
def test_aggregate_function_versioning_issue_16587(start_cluster):
for node in [node1, node2]:
for node in [node1, node3]:
node.query("DROP TABLE IF EXISTS test_table;")
node.query("""
CREATE TABLE test_table (`col1` DateTime, `col2` Int64)
ENGINE = MergeTree() ORDER BY col1""")
node.query("insert into test_table select '2020-10-26 00:00:00', 70724110 from numbers(300)")
node.query("insert into test_table select '2020-10-26 00:00:00', 1929292 from numbers(300)")
expected = "([1],[600])"
# Incorrect result on old server
result_on_old_version = node2.query("select sumMap(sm) from (select sumMap([1],[1]) as sm from remote('127.0.0.{1,2}', default.test_table) group by col1, col2);")
assert(result_on_old_version.strip() != expected)
result_on_old_version = node3.query("select sumMap(sm) from (select sumMap([1],[1]) as sm from remote('127.0.0.{1,2}', default.test_table) group by col1, col2);").strip()
assert(result_on_old_version != expected)
# Correct result on new server
result_on_new_version = node1.query("select sumMap(sm) from (select sumMap([1],[1]) as sm from remote('127.0.0.{1,2}', default.test_table) group by col1, col2);")
assert(result_on_new_version.strip() == expected)
result_on_new_version = node1.query("select sumMap(sm) from (select sumMap([1],[1]) as sm from remote('127.0.0.{1,2}', default.test_table) group by col1, col2);").strip()
assert(result_on_new_version == expected)
def test_aggregate_function_versioning_fetch_data_from_new_to_old_server(start_cluster):
for node in [node1, node2]:
create_and_fill_table(node)
def test_aggregate_function_versioning_fetch_data_from_old_to_new_server(start_cluster):
for node in [node1, node4]:
create_table(node)
insert_data(node)
expected = "([1],[300])"
new_server_data = node1.query("select finalizeAggregation(col3) from default.test_table;").strip()
assert(new_server_data == expected)
old_server_data = node2.query("select finalizeAggregation(col3) from default.test_table;").strip()
old_server_data = node4.query("select finalizeAggregation(col3) from default.test_table;").strip()
assert(old_server_data != expected)
data_from_old_to_new_server = node1.query("select finalizeAggregation(col3) from remote('node2', default.test_table);").strip()
data_from_old_to_new_server = node1.query("select finalizeAggregation(col3) from remote('node4', default.test_table);").strip()
assert(data_from_old_to_new_server == old_server_data)
def test_aggregate_function_versioning_server_upgrade(start_cluster):
for node in [node1, node2]:
create_and_fill_table(node)
for node in [node1, node5]:
create_table(node)
insert_data(node1, col2=5)
insert_data(node5, col2=1)
# Serialization with version 0.
old_server_data = node5.query("select finalizeAggregation(col3) from default.test_table;").strip()
assert(old_server_data == "([1],[44])")
# Upgrade server.
node5.restart_with_latest_version()
# Deserialized with version 0.
upgraded_server_data = node5.query("select finalizeAggregation(col3) from default.test_table;").strip()
assert(upgraded_server_data == "([1],[44])")
# Data from upgraded server to new server. Deserialize with version 0.
data_from_upgraded_to_new_server = node1.query("select finalizeAggregation(col3) from remote('node5', default.test_table);").strip()
assert(data_from_upgraded_to_new_server == upgraded_server_data == "([1],[44])")
upgraded_server_data = node5.query("select finalizeAggregation(col3) from remote('127.0.0.{1,2}', default.test_table);").strip()
assert(upgraded_server_data == "([1],[44])\n([1],[44])")
# Check insertion after server upgarde.
insert_data(node5, col2=2)
upgraded_server_data = node5.query("select finalizeAggregation(col3) from default.test_table order by col2;").strip()
assert(upgraded_server_data == "([1],[44])\n([1],[44])")
new_server_data = node1.query("select finalizeAggregation(col3) from default.test_table;").strip()
assert(new_server_data == "([1],[300])")
old_server_data = node2.query("select finalizeAggregation(col3) from default.test_table;").strip()
assert(old_server_data == "([1],[44])")
node2.restart_with_latest_version()
# Insert from new server to upgraded server, data version 1.
node1.query("insert into table function remote('node5', default.test_table) select * from default.test_table;").strip()
upgraded_server_data = node5.query("select finalizeAggregation(col3) from default.test_table order by col2;").strip()
assert(upgraded_server_data == "([1],[44])\n([1],[44])\n([1],[44])")
# Check that after server upgrade aggregate function is serialized according to older version.
upgraded_server_data = node2.query("select finalizeAggregation(col3) from default.test_table;").strip()
assert(upgraded_server_data == "([1],[44])")
insert_data(node1)
new_server_data = node1.query("select finalizeAggregation(col3) from default.test_table;").strip()
assert(new_server_data == "([1],[300])\n([1],[300])")
# Remote fetches are still with older version.
data_from_upgraded_to_new_server = node1.query("select finalizeAggregation(col3) from remote('node2', default.test_table);").strip()
assert(data_from_upgraded_to_new_server == upgraded_server_data == "([1],[44])")
# Create table with column with version 0 serialiazation to be used for futher check.
create_table(node1, name='test_table_0', version=0)
insert_data(node1, table_name='test_table_0', col2=3)
data = node1.query("select finalizeAggregation(col3) from default.test_table_0;").strip()
assert(data == "([1],[44])")
# Check it is ok to write into table with older version of aggregate function.
insert_data(node2)
# Hm, should newly inserted data be serialized as old version?
upgraded_server_data = node2.query("select finalizeAggregation(col3) from default.test_table;").strip()
assert(upgraded_server_data == "([1],[300])\n([1],[44])")
# Insert from new server to upgraded server, data version 0.
node1.query("insert into table function remote('node5', default.test_table) select * from default.test_table_0;").strip()
upgraded_server_data = node5.query("select finalizeAggregation(col3) from default.test_table order by col2;").strip()
assert(upgraded_server_data == "([1],[44])\n([1],[44])\n([1],[44])\n([1],[44])")
def test_aggregate_function_versioning_persisting_metadata(start_cluster):
for node in [node1, node2]:
create_and_fill_table(node)
node2.restart_with_latest_version()
for node in [node1, node6]:
create_table(node)
insert_data(node)
data = node1.query("select finalizeAggregation(col3) from default.test_table;").strip()
assert(data == "([1],[300])")
data = node6.query("select finalizeAggregation(col3) from default.test_table;").strip()
assert(data == "([1],[44])")
for node in [node1, node2]:
node6.restart_with_latest_version()
for node in [node1, node6]:
node.query("DETACH TABLE test_table")
node.query("ATTACH TABLE test_table")
for node in [node1, node2]:
for node in [node1, node6]:
insert_data(node)
new_server_data = node1.query("select finalizeAggregation(col3) from default.test_table;").strip()
assert(new_server_data == "([1],[300])\n([1],[300])")
upgraded_server_data = node2.query("select finalizeAggregation(col3) from default.test_table;").strip()
upgraded_server_data = node6.query("select finalizeAggregation(col3) from default.test_table;").strip()
assert(upgraded_server_data == "([1],[44])\n([1],[44])")
for node in [node1, node2]:
for node in [node1, node6]:
node.restart_clickhouse()
insert_data(node)
result = node1.query("select finalizeAggregation(col3) from remote('127.0.0.{1,2}', default.test_table);").strip()
assert(result == "([1],[300])\n([1],[300])\n([1],[300])\n([1],[300])")
result = node2.query("select finalizeAggregation(col3) from remote('127.0.0.{1,2}', default.test_table);").strip()
assert(result == "([1],[44])\n([1],[44])\n([1],[44])\n([1],[44])")
assert(result == "([1],[300])\n([1],[300])\n([1],[300])\n([1],[300])\n([1],[300])\n([1],[300])")
result = node6.query("select finalizeAggregation(col3) from remote('127.0.0.{1,2}', default.test_table);").strip()
assert(result == "([1],[44])\n([1],[44])\n([1],[44])\n([1],[44])\n([1],[44])\n([1],[44])")