diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.cpp b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.cpp index cf5c8254887..41150799815 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.cpp @@ -83,6 +83,8 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con AggregateFunctionPtr res(createWithNumericBasedType(*keys_type, keys_type, values_types, arguments)); if (!res) res.reset(createWithDecimalType(*keys_type, keys_type, values_types, arguments)); + if (!res) + res.reset(createWithStringType(*keys_type, keys_type, values_types, arguments)); if (!res) throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); @@ -106,6 +108,8 @@ AggregateFunctionPtr createAggregateFunctionSumMapFiltered(const std::string & n AggregateFunctionPtr res(createWithNumericBasedType(*keys_type, keys_type, values_types, keys_to_keep, arguments, params)); if (!res) res.reset(createWithDecimalType(*keys_type, keys_type, values_types, keys_to_keep, arguments, params)); + if (!res) + res.reset(createWithStringType(*keys_type, keys_type, values_types, keys_to_keep, arguments, params)); if (!res) throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h index c201e8e3370..9b1f164bdd2 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -56,8 +57,6 @@ class AggregateFunctionSumMapBase : public IAggregateFunctionDataHelper< AggregateFunctionSumMapData>, Derived> { private: - using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; - DataTypePtr keys_type; DataTypes values_types; @@ -84,9 +83,10 @@ public: void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override { // Column 0 contains array of keys of known type + Field key_field; const ColumnArray & array_column0 = assert_cast(*columns[0]); const IColumn::Offsets & offsets0 = array_column0.getOffsets(); - const auto & keys_vec = static_cast(array_column0.getData()); + const IColumn & key_column = array_column0.getData(); const size_t keys_vec_offset = offsets0[row_num - 1]; const size_t keys_vec_size = (offsets0[row_num] - keys_vec_offset); @@ -111,7 +111,8 @@ public: using IteratorType = typename MapType::iterator; array_column.getData().get(values_vec_offset + i, value); - const auto & key = keys_vec.getElement(keys_vec_offset + i); + key_column.get(keys_vec_offset + i, key_field); + auto && key = key_field.get(); if (!keepKey(key)) { @@ -121,7 +122,7 @@ public: IteratorType it; if constexpr (IsDecimalNumber) { - UInt32 scale = keys_vec.getData().getScale(); + UInt32 scale = static_cast &>(key_column).getData().getScale(); it = merged_maps.find(DecimalField(key, scale)); } else @@ -139,7 +140,7 @@ public: if constexpr (IsDecimalNumber) { - UInt32 scale = keys_vec.getData().getScale(); + UInt32 scale = static_cast &>(key_column).getData().getScale(); merged_maps.emplace(DecimalField(key, scale), std::move(new_values)); } else diff --git a/dbms/src/AggregateFunctions/Helpers.h b/dbms/src/AggregateFunctions/Helpers.h index 8d42654811a..6c03d25e0b1 100644 --- a/dbms/src/AggregateFunctions/Helpers.h +++ b/dbms/src/AggregateFunctions/Helpers.h @@ -149,4 +149,13 @@ static IAggregateFunction * createWithTwoNumericTypes(const IDataType & first_ty return nullptr; } +template