diff --git a/dbms/src/AggregateFunctions/AggregateFunctionAvg.cpp b/dbms/src/AggregateFunctions/AggregateFunctionAvg.cpp index e075cf9329a..36b29796b97 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionAvg.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionAvg.cpp @@ -27,7 +27,7 @@ AggregateFunctionPtr createAggregateFunctionAvg(const std::string & name, const AggregateFunctionPtr res; DataTypePtr data_type = argument_types[0]; if (isDecimal(data_type)) - res.reset(createWithDecimalType(*data_type)); + res.reset(createWithDecimalType(*data_type, *data_type)); else res.reset(createWithNumericType(*data_type)); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSum.cpp b/dbms/src/AggregateFunctions/AggregateFunctionSum.cpp index 2a7f1cfcc90..45a97d2bc83 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSum.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionSum.cpp @@ -50,7 +50,7 @@ AggregateFunctionPtr createAggregateFunctionSum(const std::string & name, const AggregateFunctionPtr res; DataTypePtr data_type = argument_types[0]; if (isDecimal(data_type)) - res.reset(createWithDecimalType(*data_type)); + res.reset(createWithDecimalType(*data_type, *data_type)); else res.reset(createWithNumericType(*data_type)); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.cpp b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.cpp index 7cad7c35092..571d6f5c0a1 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.cpp @@ -37,7 +37,9 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con values_types.push_back(array_type->getNestedType()); } - AggregateFunctionPtr res(createWithNumericType(*keys_type, keys_type, std::move(values_types))); + AggregateFunctionPtr res(createWithNumericBasedType(*keys_type, keys_type, values_types)); + if (!res) + res.reset(createWithDecimalType(*keys_type, keys_type, values_types)); if (!res) throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h index dc27660278a..e89c0bf8411 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h @@ -8,6 +8,8 @@ #include #include +#include +#include #include #include @@ -53,6 +55,8 @@ class AggregateFunctionSumMap final : public IAggregateFunctionDataHelper< AggregateFunctionSumMapData::Type>, AggregateFunctionSumMap> { private: + using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; + DataTypePtr keys_type; DataTypes values_types; @@ -78,7 +82,7 @@ public: // Column 0 contains array of keys of known type const ColumnArray & array_column = static_cast(*columns[0]); const IColumn::Offsets & offsets = array_column.getOffsets(); - const auto & keys_vec = static_cast &>(array_column.getData()); + const auto & keys_vec = static_cast(array_column.getData()); const size_t keys_vec_offset = row_num == 0 ? 0 : offsets[row_num - 1]; const size_t keys_vec_size = (offsets[row_num] - keys_vec_offset); @@ -99,9 +103,20 @@ public: // Insert column values for all keys for (size_t i = 0; i < keys_vec_size; ++i) { + using MapType = std::decay_t; + using IteratorType = typename MapType::iterator; + array_column.getData().get(values_vec_offset + i, value); const auto & key = keys_vec.getData()[keys_vec_offset + i]; - const auto & it = merged_maps.find(key); + + IteratorType it; + if constexpr (IsDecimalNumber) + { + UInt32 scale = keys_vec.getData().getScale(); + it = merged_maps.find(DecimalField(key, scale)); + } + else + it = merged_maps.find(key); if (it != merged_maps.end()) applyVisitor(FieldVisitorSum(value), it->second[col]); @@ -113,7 +128,13 @@ public: for (size_t k = 0; k < new_values.size(); ++k) new_values[k] = (k == col) ? value : values_types[k]->getDefault(); - merged_maps[key] = std::move(new_values); + if constexpr (IsDecimalNumber) + { + UInt32 scale = keys_vec.getData().getScale(); + merged_maps.emplace(DecimalField(key, scale), std::move(new_values)); + } + else + merged_maps.emplace(key, std::move(new_values)); } } } @@ -167,7 +188,10 @@ public: for (size_t col = 0; col < values_types.size(); ++col) values_types[col]->deserializeBinary(values[col], buf); - merged_maps[key.get()] = values; + if constexpr (IsDecimalNumber) + merged_maps[key.get>()] = values; + else + merged_maps[key.get()] = values; } } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionsStatisticsSimple.cpp b/dbms/src/AggregateFunctions/AggregateFunctionsStatisticsSimple.cpp index 35bee73c9d8..4159403afc7 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionsStatisticsSimple.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionsStatisticsSimple.cpp @@ -24,9 +24,7 @@ AggregateFunctionPtr createAggregateFunctionStatisticsUnary(const std::string & AggregateFunctionPtr res; DataTypePtr data_type = argument_types[0]; if (isDecimal(data_type)) - { - res.reset(createWithDecimalType(*data_type)); - } + res.reset(createWithDecimalType(*data_type, *data_type)); else res.reset(createWithNumericType(*data_type)); diff --git a/dbms/src/AggregateFunctions/Helpers.h b/dbms/src/AggregateFunctions/Helpers.h index d8ef3240d64..e98da75cd53 100644 --- a/dbms/src/AggregateFunctions/Helpers.h +++ b/dbms/src/AggregateFunctions/Helpers.h @@ -70,13 +70,28 @@ static IAggregateFunction * createWithUnsignedIntegerType(const IDataType & argu return nullptr; } +template