diff --git a/src/AggregateFunctions/AggregateFunctionsArgMinArgMax.cpp b/src/AggregateFunctions/AggregateFunctionsArgMinArgMax.cpp index cc72a26af16..d44a5d13b98 100644 --- a/src/AggregateFunctions/AggregateFunctionsArgMinArgMax.cpp +++ b/src/AggregateFunctions/AggregateFunctionsArgMinArgMax.cpp @@ -13,18 +13,36 @@ struct Settings; namespace ErrorCodes { - -extern const int INCORRECT_DATA; -extern const int ILLEGAL_TYPE_OF_ARGUMENT; -extern const int LOGICAL_ERROR; + extern const int INCORRECT_DATA; + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int LOGICAL_ERROR; } namespace { -template +template struct AggregateFunctionArgMinMaxData { +private: + ResultType result_data; + ValueType value_data; + +public: + ResultType & result() { return result_data; } + const ResultType & result() const { return result_data; } + ValueType & value() { return value_data; } + const ValueType & value() const { return value_data; } + + static bool allocatesMemoryInArena(TypeIndex) + { + return ResultType::allocatesMemoryInArena() || ValueType::allocatesMemoryInArena(); + } +}; + +template +struct AggregateFunctionArgMinMaxDataGeneric +{ private: SingleValueDataBaseMemoryBlock result_data; ValueType value_data; @@ -35,27 +53,22 @@ public: ValueType & value() { return value_data; } const ValueType & value() const { return value_data; } - [[noreturn]] explicit AggregateFunctionArgMinMaxData() + static bool allocatesMemoryInArena(TypeIndex result_type_index) { - throw Exception(ErrorCodes::LOGICAL_ERROR, "AggregateFunctionArgMinMaxData initialized empty"); + return singleValueTypeAllocatesMemoryInArena(result_type_index) || ValueType::allocatesMemoryInArena(); } - explicit AggregateFunctionArgMinMaxData(TypeIndex result_type) : value_data() - { - generateSingleValueFromTypeIndex(result_type, result_data); - } - - ~AggregateFunctionArgMinMaxData() { result().~SingleValueDataBase(); } + ~AggregateFunctionArgMinMaxDataGeneric() { result().~SingleValueDataBase(); } }; static_assert( - sizeof(AggregateFunctionArgMinMaxData) <= 2 * SingleValueDataBase::MAX_STORAGE_SIZE, + sizeof(AggregateFunctionArgMinMaxDataGeneric) <= 2 * SingleValueDataBase::MAX_STORAGE_SIZE, "Incorrect size of AggregateFunctionArgMinMaxData struct"); /// Returns the first arg value found for the minimum/maximum value. Example: argMin(arg, value). -template +template class AggregateFunctionArgMinMax final - : public IAggregateFunctionDataHelper, AggregateFunctionArgMinMax> + : public IAggregateFunctionDataHelper> { private: const DataTypePtr & type_val; @@ -63,7 +76,8 @@ private: const SerializationPtr serialization_val; const TypeIndex result_type_index; - using Base = IAggregateFunctionDataHelper, AggregateFunctionArgMinMax>; + + using Base = IAggregateFunctionDataHelper>; public: explicit AggregateFunctionArgMinMax(const DataTypes & argument_types_) @@ -91,7 +105,7 @@ public: void create(AggregateDataPtr __restrict place) const override /// NOLINT { - new (place) AggregateFunctionArgMinMaxData(result_type_index); + new (place) Data(); } String getName() const override @@ -215,7 +229,7 @@ public: bool allocatesMemoryInArena() const override { - return singleValueTypeAllocatesMemoryInArena(result_type_index) || ValueData::allocatesMemoryInArena(); + return Data::allocatesMemoryInArena(result_type_index); } void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override @@ -224,12 +238,77 @@ public: } }; -template -AggregateFunctionPtr createAggregateFunctionArgMinMax( - const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings) +using AllTypes = std::tuple; + + +template +IAggregateFunction * createForPair(const TypeIndex & result_type, const TypeIndex & value_type, const DataTypes & argument_types) { - return AggregateFunctionPtr(createAggregateFunctionSingleValue( - name, argument_types, parameters, settings)); + if (TypeToTypeIndex == result_type && value_type == TypeToTypeIndex) + { + using Data = AggregateFunctionArgMinMaxData, SingleValueDataFixed>; + return new AggregateFunctionArgMinMax(argument_types); + } + return nullptr; +} + +template +IAggregateFunction * tryValueTypes(const DataTypes & argument_types, const TypeIndex & result_type, const TypeIndex & value_type, std::tuple) +{ + IAggregateFunction * result = nullptr; + ((result = result ? result : createForPair(result_type, value_type, argument_types)), ...); // Fold expression + return result; +} + +template +IAggregateFunction * tryResultTypes(const DataTypes & argument_types, const TypeIndex result_idx, const TypeIndex value_idx, std::tuple, std::tuple value_tuple) +{ + IAggregateFunction * result = nullptr; + ((result = result ? result : tryValueTypes(argument_types, result_idx, value_idx, value_tuple)), ...); // Fold expression + return result; +} + +template +AggregateFunctionPtr createAggregateFunctionArgMinMax(const std::string &, const DataTypes & argument_types, const Array &, const Settings *) { + using AllTypesTuple = AllTypes; + + const DataTypePtr & result_type = argument_types[0]; + const DataTypePtr & value_type = argument_types[1]; + + WhichDataType which_result(result_type); + WhichDataType which_value(value_type); + + auto convert_date_type = [] (TypeIndex type_index) + { + if (type_index == TypeIndex::Date) + return TypeToTypeIndex; + else if (type_index == TypeIndex::DateTime) + return TypeToTypeIndex; + else + return type_index; + }; + + AggregateFunctionPtr result = AggregateFunctionPtr(tryResultTypes(argument_types, convert_date_type(which_result.idx), convert_date_type(which_value.idx), AllTypesTuple{}, AllTypesTuple{})); + if (!result) + { + WhichDataType which(value_type); +#define DISPATCH(TYPE) \ + if (which_value.idx == TypeIndex::TYPE) \ + return AggregateFunctionPtr(new AggregateFunctionArgMinMax>, isMin>(argument_types)); /// NOLINT + FOR_SINGLE_VALUE_NUMERIC_TYPES(DISPATCH) +#undef DISPATCH + + if (which.idx == TypeIndex::Date) + return AggregateFunctionPtr(new AggregateFunctionArgMinMax>, isMin>(argument_types)); + if (which.idx == TypeIndex::DateTime) + return AggregateFunctionPtr(new AggregateFunctionArgMinMax>, isMin>(argument_types)); + if (which.idx == TypeIndex::String) + return AggregateFunctionPtr(new AggregateFunctionArgMinMax, isMin>(argument_types)); + + return AggregateFunctionPtr(new AggregateFunctionArgMinMax, isMin>(argument_types)); + } + return result; } }