fixed the build, added some comments

This commit is contained in:
myrrc 2020-11-03 17:56:07 +03:00
parent fcbc0fb91e
commit ab1b7267b6
6 changed files with 91 additions and 12 deletions

View File

@ -1,3 +1,4 @@
# Needed when using Apache Avro serialization format
option (ENABLE_AVRO "Enable Avro" ${ENABLE_LIBRARIES}) option (ENABLE_AVRO "Enable Avro" ${ENABLE_LIBRARIES})
if (NOT ENABLE_AVRO) if (NOT ENABLE_AVRO)

View File

@ -1,3 +1,5 @@
# Needed when securely connecting to an external server, e.g.
# clickhouse-client --host ... --secure
option(ENABLE_SSL "Enable ssl" ${ENABLE_LIBRARIES}) option(ENABLE_SSL "Enable ssl" ${ENABLE_LIBRARIES})
if(NOT ENABLE_SSL) if(NOT ENABLE_SSL)

View File

@ -25,11 +25,20 @@ AggregateFunctionPtr createAggregateFunctionAvg(const std::string & name, const
assertNoParameters(name, parameters); assertNoParameters(name, parameters);
assertUnary(name, argument_types); assertUnary(name, argument_types);
if (!allowType(argument_types[0])) const DataTypePtr& data_type = argument_types[0];
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,
if (!allowType(data_type))
throw Exception("Illegal type " + data_type->getName() + " of argument for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<AggregateFunctionAvg>(argument_types); AggregateFunctionPtr res;
if (isDecimal(data_type))
res.reset(createWithDecimalType<AggregateFunctionAvg>(*data_type, argument_types));
else
res.reset(createWithNumericType<AggregateFunctionAvg>(*data_type, argument_types));
return res;
} }
} }

View File

@ -21,6 +21,13 @@ struct RationalFraction
Float64 NO_SANITIZE_UNDEFINED result() const { return numerator / denominator; } Float64 NO_SANITIZE_UNDEFINED result() const { return numerator / denominator; }
}; };
template <class T> constexpr bool DecimalOrExtendedInt =
IsDecimalNumber<T>
|| std::is_same_v<T, Int128>
|| std::is_same_v<T, Int256>
|| std::is_same_v<T, UInt128>
|| std::is_same_v<T, UInt256>;
/** /**
* The discussion showed that the easiest (and simplest) way is to cast both the columns of numerator and denominator * The discussion showed that the easiest (and simplest) way is to cast both the columns of numerator and denominator
* to Float64. Another way would be to write some template magic that figures out the appropriate numerator * to Float64. Another way would be to write some template magic that figures out the appropriate numerator
@ -78,14 +85,22 @@ public:
} }
}; };
class AggregateFunctionAvg final : public AggregateFunctionAvgBase<UInt64, AggregateFunctionAvg> template <class T>
class AggregateFunctionAvg final : public AggregateFunctionAvgBase<UInt64, AggregateFunctionAvg<T>>
{ {
public: public:
using AggregateFunctionAvgBase<UInt64, AggregateFunctionAvg>::AggregateFunctionAvgBase; using AggregateFunctionAvgBase<UInt64, AggregateFunctionAvg<T>>::AggregateFunctionAvgBase;
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const final void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const final
{ {
this->data(place).numerator += columns[0]->getFloat64(row_num); if constexpr(IsDecimalNumber<T>)
this->data(place).numerator += columns[0]->getFloat64(row_num);
else if constexpr(DecimalOrExtendedInt<T>)
this->data(place).numerator += static_cast<Float64>(
static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
else
this->data(place).numerator += static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];
++this->data(place).denominator; ++this->data(place).denominator;
} }

View File

@ -27,6 +27,39 @@ bool allowTypes(const DataTypePtr& left, const DataTypePtr& right) noexcept
return allow(l_dt) && allow(r_dt); return allow(l_dt) && allow(r_dt);
} }
#define AT_SWITCH(LINE) \
switch (which.idx) \
{ \
LINE(Int8); LINE(Int16); LINE(Int32); LINE(Int64); LINE(Int128); LINE(Int256); \
LINE(UInt8); LINE(UInt16); LINE(UInt32); LINE(UInt64); LINE(UInt128); LINE(UInt256); \
LINE(Decimal32); LINE(Decimal64); LINE(Decimal128); LINE(Decimal256); \
LINE(Float32); LINE(Float64); \
default: return nullptr; \
}
template <class First, class ... TArgs>
static IAggregateFunction * create(const IDataType & second_type, TArgs && ... args)
{
const WhichDataType which(second_type);
#define LINE(Type) \
case TypeIndex::Type: return new AggregateFunctionAvgWeighted<First, Type>(std::forward<TArgs>(args)...)
AT_SWITCH(LINE)
#undef LINE
}
// Not using helper functions because there are no templates for binary decimal/numeric function.
template <class... TArgs>
static IAggregateFunction * create(const IDataType & first_type, const IDataType & second_type, TArgs && ... args)
{
const WhichDataType which(first_type);
#define LINE(Type) \
case TypeIndex::Type: return create<Type, TArgs...>(second_type, std::forward<TArgs>(args)...)
AT_SWITCH(LINE)
#undef LINE
}
AggregateFunctionPtr createAggregateFunctionAvgWeighted(const std::string & name, const DataTypes & argument_types, const Array & parameters) AggregateFunctionPtr createAggregateFunctionAvgWeighted(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{ {
assertNoParameters(name, parameters); assertNoParameters(name, parameters);
@ -42,7 +75,9 @@ AggregateFunctionPtr createAggregateFunctionAvgWeighted(const std::string & name
" are non-conforming as arguments for aggregate function " + name, " are non-conforming as arguments for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<AggregateFunctionAvgWeighted>(argument_types); AggregateFunctionPtr ptr;
ptr.reset(create(*data_type, *data_type_weight, argument_types));
return ptr;
} }
} }

View File

@ -5,17 +5,34 @@
namespace DB namespace DB
{ {
class AggregateFunctionAvgWeighted final : public AggregateFunctionAvgBase<Float64, AggregateFunctionAvgWeighted> template <class Value, class Weight>
class AggregateFunctionAvgWeighted final :
public AggregateFunctionAvgBase<Float64, AggregateFunctionAvgWeighted<Value, Weight>>
{ {
public: public:
using AggregateFunctionAvgBase<Float64, AggregateFunctionAvgWeighted>::AggregateFunctionAvgBase; using Base = AggregateFunctionAvgBase<Float64, AggregateFunctionAvgWeighted<Value, Weight>>;
using Base::Base;
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{ {
const auto value = columns[0]->getFloat64(row_num); const Float64 value = [&columns, row_num] {
const auto weight = columns[1]->getFloat64(row_num); if constexpr(IsDecimalNumber<Value>)
return columns[0]->getFloat64(row_num);
else
return static_cast<Float64>(static_cast<const ColumnVector<Value>&>(*columns[0]).getData()[row_num]);
}();
this->data(place).numerator += value * weight; using WeightRet = std::conditional_t<DecimalOrExtendedInt<Weight>, Float64, Weight>;
const WeightRet weight = [&columns, row_num]() -> WeightRet {
if constexpr(IsDecimalNumber<Weight>)
return columns[1]->getFloat64(row_num);
else if constexpr(DecimalOrExtendedInt<Weight>)
return static_cast<Float64>(static_cast<const ColumnVector<Weight>&>(*columns[1]).getData()[row_num]);
else
return static_cast<const ColumnVector<Weight>&>(*columns[1]).getData()[row_num];
}();
this->data(place).numerator += weight * value;
this->data(place).denominator += weight; this->data(place).denominator += weight;
} }