#include #include #include #include #include #include #define TOP_K_MAX_SIZE 0xFFFFFF namespace DB { namespace ErrorCodes { extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int ARGUMENT_OUT_OF_BOUND; } namespace { /// Substitute return type for Date and DateTime class AggregateFunctionTopKDate : public AggregateFunctionTopK { using AggregateFunctionTopK::AggregateFunctionTopK; DataTypePtr getReturnType() const override { return std::make_shared(std::make_shared()); } }; class AggregateFunctionTopKDateTime : public AggregateFunctionTopK { using AggregateFunctionTopK::AggregateFunctionTopK; DataTypePtr getReturnType() const override { return std::make_shared(std::make_shared()); } }; static IAggregateFunction * createWithExtraTypes(const DataTypePtr & argument_type, UInt64 threshold) { if (typeid_cast(argument_type.get())) return new AggregateFunctionTopKDate(threshold); if (typeid_cast(argument_type.get())) return new AggregateFunctionTopKDateTime(threshold); /// Check that we can use plain version of AggregateFunctionTopKGeneric if (argument_type->isValueUnambiguouslyRepresentedInContiguousMemoryRegion()) return new AggregateFunctionTopKGeneric(threshold, argument_type); else return new AggregateFunctionTopKGeneric(threshold, argument_type); } AggregateFunctionPtr createAggregateFunctionTopK(const std::string & name, const DataTypes & argument_types, const Array & params) { assertUnary(name, argument_types); UInt64 threshold = 10; /// default value if (!params.empty()) { if (params.size() != 1) throw Exception("Aggregate function " + name + " requires one parameter or less.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); UInt64 k = applyVisitor(FieldVisitorConvertToNumber(), params[0]); if (k > TOP_K_MAX_SIZE) throw Exception("Too large parameter for aggregate function " + name + ". Maximum: " + toString(TOP_K_MAX_SIZE), ErrorCodes::ARGUMENT_OUT_OF_BOUND); if (k == 0) throw Exception("Parameter 0 is illegal for aggregate function " + name, ErrorCodes::ARGUMENT_OUT_OF_BOUND); threshold = k; } AggregateFunctionPtr res(createWithNumericType(*argument_types[0], threshold)); if (!res) res = AggregateFunctionPtr(createWithExtraTypes(argument_types[0], threshold)); if (!res) throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); return res; } } void registerAggregateFunctionTopK(AggregateFunctionFactory & factory) { factory.registerFunction("topK", createAggregateFunctionTopK); } }