#include #include #include #include #include #include #include #include #include #include #include #include namespace DB { AggregateFunctionFactory::AggregateFunctionFactory() { } AggregateFunctionPtr AggregateFunctionFactory::get(const String & name, const DataTypes & argument_types) const { if (name == "count") return new AggregateFunctionCount; else if (name == "any") return new AggregateFunctionAny; else if (name == "anyLast") return new AggregateFunctionAnyLast; else if (name == "min") return new AggregateFunctionMin; else if (name == "max") return new AggregateFunctionMax; else if (name == "groupArray") return new AggregateFunctionGroupArray; else if (name == "sum") { if (argument_types.size() != 1) throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); String argument_type_name = argument_types[0]->getName(); if (argument_type_name == "UInt8" || argument_type_name == "UInt16" || argument_type_name == "UInt32" || argument_type_name == "UInt64" || argument_type_name == "VarUInt") return new AggregateFunctionSum; else if (argument_type_name == "Int8" || argument_type_name == "Int16" || argument_type_name == "Int32" || argument_type_name == "Int64" || argument_type_name == "VarInt") return new AggregateFunctionSum; else if (argument_type_name == "Float32" || argument_type_name == "Float64") return new AggregateFunctionSum; else throw Exception("Illegal type " + argument_type_name + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } else if (name == "avg") { if (argument_types.size() != 1) throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); String argument_type_name = argument_types[0]->getName(); if (argument_type_name == "UInt8" || argument_type_name == "UInt16" || argument_type_name == "UInt32" || argument_type_name == "UInt64" || argument_type_name == "VarUInt") return new AggregateFunctionAvg; else if (argument_type_name == "Int8" || argument_type_name == "Int16" || argument_type_name == "Int32" || argument_type_name == "Int64" || argument_type_name == "VarInt") return new AggregateFunctionAvg; else if (argument_type_name == "Float32" || argument_type_name == "Float64") return new AggregateFunctionAvg; else throw Exception("Illegal type " + argument_type_name + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } else if (name == "uniq") { if (argument_types.size() != 1) throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); String argument_type_name = argument_types[0]->getName(); if (argument_type_name == "UInt8" || argument_type_name == "UInt16" || argument_type_name == "UInt32" || argument_type_name == "UInt64" || argument_type_name == "VarUInt" || argument_type_name == "Date" || argument_type_name == "DateTime") return new AggregateFunctionUniq; else if (argument_type_name == "Int8" || argument_type_name == "Int16" || argument_type_name == "Int32" || argument_type_name == "Int64" || argument_type_name == "VarInt") return new AggregateFunctionUniq; else if (argument_type_name == "Float32" || argument_type_name == "Float64") return new AggregateFunctionUniq; else if (argument_type_name == "String" || 0 == argument_type_name.compare(0, strlen("FixedString"), "FixedString")) return new AggregateFunctionUniq; else throw Exception("Illegal type " + argument_type_name + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } else if (name == "median") { if (argument_types.size() != 1) throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); String argument_type_name = argument_types[0]->getName(); if (argument_type_name == "UInt8" || argument_type_name == "UInt16" || argument_type_name == "UInt32" || argument_type_name == "UInt64" || argument_type_name == "VarUInt") return new AggregateFunctionMedian(new DataTypeUInt64); else if (argument_type_name == "Int8" || argument_type_name == "Int16" || argument_type_name == "Int32" || argument_type_name == "Int64" || argument_type_name == "VarInt") return new AggregateFunctionMedian(new DataTypeInt64); else if (argument_type_name == "Float32" || argument_type_name == "Float64") return new AggregateFunctionMedian(new DataTypeFloat64); else if (argument_type_name == "Date") return new AggregateFunctionMedian(new DataTypeDate); else if (argument_type_name == "DateTime") return new AggregateFunctionMedian(new DataTypeDateTime); else throw Exception("Illegal type " + argument_type_name + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } else throw Exception("Unknown aggregate function " + name, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION); } AggregateFunctionPtr AggregateFunctionFactory::getByTypeID(const String & type_id) const { if (type_id == "count") return new AggregateFunctionCount; else if (type_id == "any") return new AggregateFunctionAny; else if (type_id == "anyLast") return new AggregateFunctionAnyLast; else if (type_id == "min") return new AggregateFunctionMin; else if (type_id == "max") return new AggregateFunctionMax; else if (type_id == "groupArray") return new AggregateFunctionGroupArray; else if (0 == type_id.compare(0, strlen("sum_"), "sum_")) { if (0 == type_id.compare(strlen("sum_"), strlen("UInt64"), "UInt64")) return new AggregateFunctionSum; else if (0 == type_id.compare(strlen("sum_"), strlen("Int64"), "Int64")) return new AggregateFunctionSum; else if (0 == type_id.compare(strlen("sum_"), strlen("Float64"), "Float64")) return new AggregateFunctionSum; else throw Exception("Unknown type id of aggregate function " + type_id, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION); } else if (0 == type_id.compare(0, strlen("avg_"), "avg_")) { if (0 == type_id.compare(strlen("avg_"), strlen("UInt64"), "UInt64")) return new AggregateFunctionAvg; else if (0 == type_id.compare(strlen("avg_"), strlen("Int64"), "Int64")) return new AggregateFunctionAvg; else if (0 == type_id.compare(strlen("avg_"), strlen("Float64"), "Float64")) return new AggregateFunctionAvg; else throw Exception("Unknown type id of aggregate function " + type_id, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION); } else if (0 == type_id.compare(0, strlen("uniq_"), "uniq_")) { if (0 == type_id.compare(strlen("uniq_"), strlen("UInt64"), "UInt64")) return new AggregateFunctionUniq; else if (0 == type_id.compare(strlen("uniq_"), strlen("Int64"), "Int64")) return new AggregateFunctionUniq; else if (0 == type_id.compare(strlen("uniq_"), strlen("Float64"), "Float64")) return new AggregateFunctionUniq; else if (0 == type_id.compare(strlen("uniq_"), strlen("String"), "String")) return new AggregateFunctionUniq; else throw Exception("Unknown type id of aggregate function " + type_id, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION); } else if (0 == type_id.compare(0, strlen("median_"), "median_")) { if (0 == type_id.compare(strlen("median_"), strlen("UInt64"), "UInt64")) return new AggregateFunctionMedian(new DataTypeUInt64); else if (0 == type_id.compare(strlen("median_"), strlen("Int64"), "Int64")) return new AggregateFunctionMedian(new DataTypeInt64); else if (0 == type_id.compare(strlen("median_"), strlen("Float64"), "Float64")) return new AggregateFunctionMedian(new DataTypeFloat64); else if (0 == type_id.compare(strlen("median_"), strlen("Date"), "Date")) return new AggregateFunctionMedian(new DataTypeDate); else if (0 == type_id.compare(strlen("median_"), strlen("DateTime"), "DateTime")) return new AggregateFunctionMedian(new DataTypeDateTime); else throw Exception("Unknown type id of aggregate function " + type_id, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION); } else throw Exception("Unknown type id of aggregate function " + type_id, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION); } AggregateFunctionPtr AggregateFunctionFactory::tryGet(const String & name, const DataTypes & argument_types) const { std::set names; boost::assign::insert(names) ("count") ("any") ("anyLast") ("min") ("max") ("sum") ("avg") ("uniq") ("groupArray") ("median"); return names.end() != names.find(name) ? get(name, argument_types) : NULL; } }