#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace DB { struct Settings; namespace ErrorCodes { extern const int BAD_ARGUMENTS; extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int LOGICAL_ERROR; } template struct AggregateFunctionMapData { // Map needs to be ordered to maintain function properties std::map merged_maps; }; /** Aggregate function, that takes at least two arguments: keys and values, and as a result, builds a tuple of at least 2 arrays - * ordered keys and variable number of argument values aggregated by corresponding keys. * * sumMap function is the most useful when using SummingMergeTree to sum Nested columns, which name ends in "Map". * * Example: sumMap(k, v...) of: * k v * [1,2,3] [10,10,10] * [3,4,5] [10,10,10] * [4,5,6] [10,10,10] * [6,7,8] [10,10,10] * [7,5,3] [5,15,25] * [8,9,10] [20,20,20] * will return: * ([1,2,3,4,5,6,7,8,9,10],[10,10,45,20,35,20,15,30,20,20]) * * minMap and maxMap share the same idea, but calculate min and max correspondingly. * * NOTE: The implementation of these functions are "amateur grade" - not efficient and low quality. */ template class AggregateFunctionMapBase : public IAggregateFunctionDataHelper< AggregateFunctionMapData>, Derived> { private: static constexpr auto STATE_VERSION_1_MIN_REVISION = 54451; DataTypePtr keys_type; SerializationPtr keys_serialization; DataTypes values_types; Serializations values_serializations; Serializations promoted_values_serializations; public: using Base = IAggregateFunctionDataHelper< AggregateFunctionMapData>, Derived>; AggregateFunctionMapBase(const DataTypePtr & keys_type_, const DataTypes & values_types_, const DataTypes & argument_types_) : Base(argument_types_, {} /* parameters */) , keys_type(keys_type_) , keys_serialization(keys_type->getDefaultSerialization()) , values_types(values_types_) { values_serializations.reserve(values_types.size()); promoted_values_serializations.reserve(values_types.size()); for (const auto & type : values_types) { values_serializations.emplace_back(type->getDefaultSerialization()); if (type->canBePromoted()) { if (type->isNullable()) promoted_values_serializations.emplace_back( makeNullable(removeNullable(type)->promoteNumericType())->getDefaultSerialization()); else promoted_values_serializations.emplace_back(type->promoteNumericType()->getDefaultSerialization()); } else { promoted_values_serializations.emplace_back(type->getDefaultSerialization()); } } } bool isVersioned() const override { return true; } size_t getDefaultVersion() const override { return 1; } size_t getVersionFromRevision(size_t revision) const override { if (revision >= STATE_VERSION_1_MIN_REVISION) return 1; else return 0; } DataTypePtr getReturnType() const override { DataTypes types; types.emplace_back(std::make_shared(keys_type)); for (const auto & value_type : values_types) { if constexpr (std::is_same_v) { if (!value_type->isSummable()) throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Values for {} cannot be summed, passed type {}", getName(), value_type->getName()}; } DataTypePtr result_type; if constexpr (overflow) { if (value_type->onlyNull()) throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Cannot calculate {} of type {}", getName(), value_type->getName()}; // Overflow, meaning that the returned type is the same as // the input type. Nulls are skipped. result_type = removeNullable(value_type); } else { auto value_type_without_nullable = removeNullable(value_type); // No overflow, meaning we promote the types if necessary. if (!value_type_without_nullable->canBePromoted()) throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Values for {} are expected to be Numeric, Float or Decimal, passed type {}", getName(), value_type->getName()}; WhichDataType value_type_to_check(value_type); /// Do not promote decimal because of implementation issues of this function design /// Currently we cannot get result column type in case of decimal we cannot get decimal scale /// in method void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override /// If we decide to make this function more efficient we should promote decimal type during summ if (value_type_to_check.isDecimal()) result_type = value_type_without_nullable; else result_type = value_type_without_nullable->promoteNumericType(); } types.emplace_back(std::make_shared(result_type)); } return std::make_shared(types); } bool allocatesMemoryInArena() const override { return false; } static const auto & getArgumentColumns(const IColumn**& columns) { if constexpr (tuple_argument) { return assert_cast(columns[0])->getColumns(); } else { return columns; } } void add(AggregateDataPtr __restrict place, const IColumn ** columns_, const size_t row_num, Arena *) const override { const auto & columns = getArgumentColumns(columns_); // Column 0 contains array of keys of known type const ColumnArray & array_column0 = assert_cast(*columns[0]); const IColumn::Offsets & offsets0 = array_column0.getOffsets(); const IColumn & key_column = array_column0.getData(); const size_t keys_vec_offset = offsets0[row_num - 1]; const size_t keys_vec_size = (offsets0[row_num] - keys_vec_offset); // Columns 1..n contain arrays of numeric values to sum auto & merged_maps = this->data(place).merged_maps; for (size_t col = 0, size = values_types.size(); col < size; ++col) { const auto & array_column = assert_cast(*columns[col + 1]); const IColumn & value_column = array_column.getData(); const IColumn::Offsets & offsets = array_column.getOffsets(); const size_t values_vec_offset = offsets[row_num - 1]; const size_t values_vec_size = (offsets[row_num] - values_vec_offset); // Expect key and value arrays to be of same length if (keys_vec_size != values_vec_size) throw Exception("Sizes of keys and values arrays do not match", ErrorCodes::BAD_ARGUMENTS); // Insert column values for all keys for (size_t i = 0; i < keys_vec_size; ++i) { auto value = value_column[values_vec_offset + i]; auto key = key_column[keys_vec_offset + i].get(); if (!keepKey(key)) continue; decltype(merged_maps.begin()) it; if constexpr (is_decimal) { // FIXME why is storing NearestFieldType not enough, and we // have to check for decimals again here? UInt32 scale = static_cast &>(key_column).getData().getScale(); it = merged_maps.find(DecimalField(key, scale)); } else it = merged_maps.find(key); if (it != merged_maps.end()) { if (!value.isNull()) { if (it->second[col].isNull()) it->second[col] = value; else applyVisitor(Visitor(value), it->second[col]); } } else { // Create a value array for this key Array new_values; new_values.resize(size); new_values[col] = value; if constexpr (is_decimal) { UInt32 scale = static_cast &>(key_column).getData().getScale(); merged_maps.emplace(DecimalField(key, scale), std::move(new_values)); } else { merged_maps.emplace(key, std::move(new_values)); } } } } } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { auto & merged_maps = this->data(place).merged_maps; const auto & rhs_maps = this->data(rhs).merged_maps; for (const auto & elem : rhs_maps) { const auto & it = merged_maps.find(elem.first); if (it != merged_maps.end()) { for (size_t col = 0; col < values_types.size(); ++col) if (!elem.second[col].isNull()) applyVisitor(Visitor(elem.second[col]), it->second[col]); } else merged_maps[elem.first] = elem.second; } } void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional version) const override { if (!version) version = getDefaultVersion(); const auto & merged_maps = this->data(place).merged_maps; size_t size = merged_maps.size(); writeVarUInt(size, buf); std::function serialize; switch (*version) { case 0: { serialize = [&](size_t col_idx, const Array & values){ values_serializations[col_idx]->serializeBinary(values[col_idx], buf); }; break; } case 1: { serialize = [&](size_t col_idx, const Array & values){ promoted_values_serializations[col_idx]->serializeBinary(values[col_idx], buf); }; break; } } for (const auto & elem : merged_maps) { keys_serialization->serializeBinary(elem.first, buf); for (size_t col = 0; col < values_types.size(); ++col) serialize(col, elem.second); } } void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional version, Arena *) const override { if (!version) version = getDefaultVersion(); auto & merged_maps = this->data(place).merged_maps; size_t size = 0; readVarUInt(size, buf); std::function deserialize; switch (*version) { case 0: { deserialize = [&](size_t col_idx, Array & values){ values_serializations[col_idx]->deserializeBinary(values[col_idx], buf); }; break; } case 1: { deserialize = [&](size_t col_idx, Array & values){ promoted_values_serializations[col_idx]->deserializeBinary(values[col_idx], buf); }; break; } } for (size_t i = 0; i < size; ++i) { Field key; keys_serialization->deserializeBinary(key, buf); Array values; values.resize(values_types.size()); for (size_t col = 0; col < values_types.size(); ++col) deserialize(col, values); if constexpr (is_decimal) merged_maps[key.get>()] = values; else merged_maps[key.get()] = values; } } void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override { size_t num_columns = values_types.size(); // Final step does compaction of keys that have zero values, this mutates the state auto & merged_maps = this->data(place).merged_maps; // Remove keys which are zeros or empty. This should be enabled only for sumMap. if constexpr (compact) { for (auto it = merged_maps.cbegin(); it != merged_maps.cend();) { // Key is not compacted if it has at least one non-zero value bool erase = true; for (size_t col = 0; col < num_columns; ++col) { if (!it->second[col].isNull() && it->second[col] != values_types[col]->getDefault()) { erase = false; break; } } if (erase) it = merged_maps.erase(it); else ++it; } } size_t size = merged_maps.size(); auto & to_tuple = assert_cast(to); auto & to_keys_arr = assert_cast(to_tuple.getColumn(0)); auto & to_keys_col = to_keys_arr.getData(); // Advance column offsets auto & to_keys_offsets = to_keys_arr.getOffsets(); to_keys_offsets.push_back(to_keys_offsets.back() + size); to_keys_col.reserve(size); for (size_t col = 0; col < num_columns; ++col) { auto & to_values_arr = assert_cast(to_tuple.getColumn(col + 1)); auto & to_values_offsets = to_values_arr.getOffsets(); to_values_offsets.push_back(to_values_offsets.back() + size); to_values_arr.getData().reserve(size); } // Write arrays of keys and values for (const auto & elem : merged_maps) { // Write array of keys into column to_keys_col.insert(elem.first); // Write 0..n arrays of values for (size_t col = 0; col < num_columns; ++col) { auto & to_values_col = assert_cast(to_tuple.getColumn(col + 1)).getData(); if (elem.second[col].isNull()) to_values_col.insertDefault(); else to_values_col.insert(elem.second[col]); } } } bool keepKey(const T & key) const { return static_cast(*this).keepKey(key); } String getName() const override { return static_cast(*this).getName(); } }; template class AggregateFunctionSumMap final : public AggregateFunctionMapBase, FieldVisitorSum, overflow, tuple_argument, true> { private: using Self = AggregateFunctionSumMap; using Base = AggregateFunctionMapBase; public: AggregateFunctionSumMap(const DataTypePtr & keys_type_, DataTypes & values_types_, const DataTypes & argument_types_, const Array & params_) : Base{keys_type_, values_types_, argument_types_} { // The constructor accepts parameters to have a uniform interface with // sumMapFiltered, but this function doesn't have any parameters. assertNoParameters(getName(), params_); } String getName() const override { if constexpr (overflow) { return "sumMapWithOverflow"; } else { return "sumMap"; } } bool keepKey(const T &) const { return true; } }; template class AggregateFunctionSumMapFiltered final : public AggregateFunctionMapBase, FieldVisitorSum, overflow, tuple_argument, true> { private: using Self = AggregateFunctionSumMapFiltered; using Base = AggregateFunctionMapBase; using ContainerT = std::unordered_set; ContainerT keys_to_keep; public: AggregateFunctionSumMapFiltered(const DataTypePtr & keys_type_, const DataTypes & values_types_, const DataTypes & argument_types_, const Array & params_) : Base{keys_type_, values_types_, argument_types_} { if (params_.size() != 1) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function '{}' requires exactly one parameter " "of Array type", getName()); Array keys_to_keep_; if (!params_.front().tryGet(keys_to_keep_)) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} requires an Array as a parameter", getName()); keys_to_keep.reserve(keys_to_keep_.size()); for (const Field & f : keys_to_keep_) keys_to_keep.emplace(f.safeGet()); } String getName() const override { return overflow ? "sumMapFilteredWithOverflow" : "sumMapFiltered"; } bool keepKey(const T & key) const { return keys_to_keep.count(key); } }; /** Implements `Max` operation. * Returns true if changed */ class FieldVisitorMax : public StaticVisitor { private: const Field & rhs; template bool compareImpl(FieldType & x) const { auto val = get(rhs); if (val > x) { x = val; return true; } return false; } public: explicit FieldVisitorMax(const Field & rhs_) : rhs(rhs_) {} bool operator() (Null &) const { throw Exception("Cannot compare Nulls", ErrorCodes::LOGICAL_ERROR); } bool operator() (AggregateFunctionStateData &) const { throw Exception("Cannot compare AggregateFunctionStates", ErrorCodes::LOGICAL_ERROR); } bool operator() (Array & x) const { return compareImpl(x); } bool operator() (Tuple & x) const { return compareImpl(x); } template bool operator() (DecimalField & x) const { return compareImpl>(x); } template bool operator() (T & x) const { return compareImpl(x); } }; /** Implements `Min` operation. * Returns true if changed */ class FieldVisitorMin : public StaticVisitor { private: const Field & rhs; template bool compareImpl(FieldType & x) const { auto val = get(rhs); if (val < x) { x = val; return true; } return false; } public: explicit FieldVisitorMin(const Field & rhs_) : rhs(rhs_) {} bool operator() (Null &) const { throw Exception("Cannot compare Nulls", ErrorCodes::LOGICAL_ERROR); } bool operator() (AggregateFunctionStateData &) const { throw Exception("Cannot sum AggregateFunctionStates", ErrorCodes::LOGICAL_ERROR); } bool operator() (Array & x) const { return compareImpl(x); } bool operator() (Tuple & x) const { return compareImpl(x); } template bool operator() (DecimalField & x) const { return compareImpl>(x); } template bool operator() (T & x) const { return compareImpl(x); } }; template class AggregateFunctionMinMap final : public AggregateFunctionMapBase, FieldVisitorMin, true, tuple_argument, false> { private: using Self = AggregateFunctionMinMap; using Base = AggregateFunctionMapBase; public: AggregateFunctionMinMap(const DataTypePtr & keys_type_, DataTypes & values_types_, const DataTypes & argument_types_, const Array & params_) : Base{keys_type_, values_types_, argument_types_} { // The constructor accepts parameters to have a uniform interface with // sumMapFiltered, but this function doesn't have any parameters. assertNoParameters(getName(), params_); } String getName() const override { return "minMap"; } bool keepKey(const T &) const { return true; } }; template class AggregateFunctionMaxMap final : public AggregateFunctionMapBase, FieldVisitorMax, true, tuple_argument, false> { private: using Self = AggregateFunctionMaxMap; using Base = AggregateFunctionMapBase; public: AggregateFunctionMaxMap(const DataTypePtr & keys_type_, DataTypes & values_types_, const DataTypes & argument_types_, const Array & params_) : Base{keys_type_, values_types_, argument_types_} { // The constructor accepts parameters to have a uniform interface with // sumMapFiltered, but this function doesn't have any parameters. assertNoParameters(getName(), params_); } String getName() const override { return "maxMap"; } bool keepKey(const T &) const { return true; } }; }