wip dealing with template magic

This commit is contained in:
myrrc 2020-10-15 13:36:00 +03:00
parent 53471315ff
commit 8e7e232387
6 changed files with 121 additions and 91 deletions

View File

@ -42,3 +42,20 @@ Result:
│ 8 │
└────────────────────────┘
```
**Example**
Query:
``` sql
SELECT avgWeighted(x, w)
FROM values('x Int8, w Float64', (4, 1), (1, 0), (10, 2))
```
Result:
``` text
┌─avgWeighted(x, weight)─┐
│ 8 │
└────────────────────────┘
```

View File

@ -13,26 +13,37 @@ namespace ErrorCodes
namespace
{
template <typename T>
struct AvgWeighted
constexpr bool allowTypes(const DataTypePtr& left, const DataTypePtr& right)
{
using FieldType = std::conditional_t<
IsDecimalNumber<T>,
std::conditional_t<std::is_same_v<T, Decimal256>,
Decimal256,
Decimal128>,
NearestFieldType<T>>;
const WhichDataType l_dt(left), r_dt(right);
using Function = AggregateFunctionAvgWeighted<T, AggregateFunctionAvgData<FieldType, FieldType>>;
constexpr auto allow = [](WhichDataType t)
{
return t.isInt() || t.isUInt() || t.isFloat() || t.isDecimal();
};
return allow(l_dt) && allow(r_dt);
}
template <class U> struct BiggerType
template <class U, class V> struct LargestType
{
using Type = bool;
};
template <typename T>
using AggregateFuncAvgWeighted = typename AvgWeighted<T>::Function;
bool allowTypes(const DataTypePtr& left, const DataTypePtr& right)
template <class U, class V> using AvgData = AggregateFunctionAvgData<
typename LargestType<U, V>::Type,
typename LargestType<U, V>::Type>;
template <class U, class V> using Function = AggregateFunctionAvgWeighted<
U, V, typename LargestType<U, V>::Type, AvgData<U, V>>;
template <typename... TArgs>
static IAggregateFunction * create(const IDataType & first_type, const IDataType & second_type, TArgs && ... args)
{
return (isInteger(left) || isFloat(left)) && (isInteger(right) || isFloat(right));
}
AggregateFunctionPtr createAggregateFunctionAvgWeighted(const std::string & name, const DataTypes & argument_types, const Array & parameters)
@ -40,8 +51,6 @@ AggregateFunctionPtr createAggregateFunctionAvgWeighted(const std::string & name
assertNoParameters(name, parameters);
assertBinary(name, argument_types);
AggregateFunctionPtr res;
const auto data_type = static_cast<const DataTypePtr>(argument_types[0]);
const auto data_type_weight = static_cast<const DataTypePtr>(argument_types[1]);
@ -52,10 +61,8 @@ AggregateFunctionPtr createAggregateFunctionAvgWeighted(const std::string & name
" are non-conforming as arguments for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (isDecimal(data_type))
res.reset(createWithDecimalType<AggregateFuncAvgWeighted>(*data_type, *data_type, argument_types));
else
res.reset(createWithNumericType<AggregateFuncAvgWeighted>(*data_type, argument_types));
AggregateFunctionPtr res;
res.reset(create(*data_type, *data_type_weight, argument_types));
if (!res)
throw Exception("Illegal type " + data_type->getName() + " of argument for aggregate function " + name,
@ -70,5 +77,4 @@ void registerAggregateFunctionAvgWeighted(AggregateFunctionFactory & factory)
{
factory.registerFunction("avgWeighted", createAggregateFunctionAvgWeighted, AggregateFunctionFactory::CaseSensitive);
}
}

View File

@ -4,21 +4,28 @@
namespace DB
{
template <typename T, typename Data>
class AggregateFunctionAvgWeighted final : public AggregateFunctionAvgBase<T, Data, AggregateFunctionAvgWeighted<T, Data>>
template <typename Value, typename Weight, typename Largest, typename Data>
class AggregateFunctionAvgWeighted final :
public AggregateFunctionAvgBase<Largest, Data, AggregateFunctionAvgWeighted<Value, Weight, Largest, Data>>
{
public:
using AggregateFunctionAvgBase<T, Data, AggregateFunctionAvgWeighted<T, Data>>::AggregateFunctionAvgBase;
using AggregateFunctionAvgBase<Largest, Data,
AggregateFunctionAvgWeighted<Value, Weight, Largest, Data>>::AggregateFunctionAvgBase;
template <class T>
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
const auto & values = static_cast<const ColVecType &>(*columns[0]);
const auto & weights = static_cast<const ColVecType &>(*columns[1]);
const auto & values = static_cast<const ColVecType<Value> &>(*columns[0]);
const auto & weights = static_cast<const ColVecType<Weight> &>(*columns[1]);
this->data(place).numerator += static_cast<typename Data::NumeratorType>(values.getData()[row_num]) * weights.getData()[row_num];
this->data(place).denominator += weights.getData()[row_num];
const auto value = values.getData()[row_num];
const auto weight = weights.getData()[row_num];
this->data(place).numerator += static_cast<typename Data::NumeratorType>(value) * weight;
this->data(place).denominator += weight;
}
String getName() const override { return "avgWeighted"; }

View File

@ -466,75 +466,64 @@ struct WhichDataType
{
TypeIndex idx;
WhichDataType(TypeIndex idx_ = TypeIndex::Nothing)
: idx(idx_)
{}
constexpr WhichDataType(TypeIndex idx_ = TypeIndex::Nothing) : idx(idx_) {}
constexpr WhichDataType(const IDataType & data_type) : idx(data_type.getTypeId()) {}
constexpr WhichDataType(const IDataType * data_type) : idx(data_type->getTypeId()) {}
constexpr WhichDataType(const DataTypePtr & data_type) : idx(data_type->getTypeId()) {}
WhichDataType(const IDataType & data_type)
: idx(data_type.getTypeId())
{}
constexpr bool isUInt8() const { return idx == TypeIndex::UInt8; }
constexpr bool isUInt16() const { return idx == TypeIndex::UInt16; }
constexpr bool isUInt32() const { return idx == TypeIndex::UInt32; }
constexpr bool isUInt64() const { return idx == TypeIndex::UInt64; }
constexpr bool isUInt128() const { return idx == TypeIndex::UInt128; }
constexpr bool isUInt256() const { return idx == TypeIndex::UInt256; }
constexpr bool isUInt() const { return isUInt8() || isUInt16() || isUInt32() || isUInt64() || isUInt128() || isUInt256(); }
constexpr bool isNativeUInt() const { return isUInt8() || isUInt16() || isUInt32() || isUInt64(); }
WhichDataType(const IDataType * data_type)
: idx(data_type->getTypeId())
{}
constexpr bool isInt8() const { return idx == TypeIndex::Int8; }
constexpr bool isInt16() const { return idx == TypeIndex::Int16; }
constexpr bool isInt32() const { return idx == TypeIndex::Int32; }
constexpr bool isInt64() const { return idx == TypeIndex::Int64; }
constexpr bool isInt128() const { return idx == TypeIndex::Int128; }
constexpr bool isInt256() const { return idx == TypeIndex::Int256; }
constexpr bool isInt() const { return isInt8() || isInt16() || isInt32() || isInt64() || isInt128() || isInt256(); }
constexpr bool isNativeInt() const { return isInt8() || isInt16() || isInt32() || isInt64(); }
WhichDataType(const DataTypePtr & data_type)
: idx(data_type->getTypeId())
{}
constexpr bool isDecimal32() const { return idx == TypeIndex::Decimal32; }
constexpr bool isDecimal64() const { return idx == TypeIndex::Decimal64; }
constexpr bool isDecimal128() const { return idx == TypeIndex::Decimal128; }
constexpr bool isDecimal256() const { return idx == TypeIndex::Decimal256; }
constexpr bool isDecimal() const { return isDecimal32() || isDecimal64() || isDecimal128() || isDecimal256(); }
bool isUInt8() const { return idx == TypeIndex::UInt8; }
bool isUInt16() const { return idx == TypeIndex::UInt16; }
bool isUInt32() const { return idx == TypeIndex::UInt32; }
bool isUInt64() const { return idx == TypeIndex::UInt64; }
bool isUInt128() const { return idx == TypeIndex::UInt128; }
bool isUInt256() const { return idx == TypeIndex::UInt256; }
bool isUInt() const { return isUInt8() || isUInt16() || isUInt32() || isUInt64() || isUInt128() || isUInt256(); }
bool isNativeUInt() const { return isUInt8() || isUInt16() || isUInt32() || isUInt64(); }
constexpr bool isFloat32() const { return idx == TypeIndex::Float32; }
constexpr bool isFloat64() const { return idx == TypeIndex::Float64; }
constexpr bool isFloat() const { return isFloat32() || isFloat64(); }
bool isInt8() const { return idx == TypeIndex::Int8; }
bool isInt16() const { return idx == TypeIndex::Int16; }
bool isInt32() const { return idx == TypeIndex::Int32; }
bool isInt64() const { return idx == TypeIndex::Int64; }
bool isInt128() const { return idx == TypeIndex::Int128; }
bool isInt256() const { return idx == TypeIndex::Int256; }
bool isInt() const { return isInt8() || isInt16() || isInt32() || isInt64() || isInt128() || isInt256(); }
bool isNativeInt() const { return isInt8() || isInt16() || isInt32() || isInt64(); }
constexpr bool isEnum8() const { return idx == TypeIndex::Enum8; }
constexpr bool isEnum16() const { return idx == TypeIndex::Enum16; }
constexpr bool isEnum() const { return isEnum8() || isEnum16(); }
bool isDecimal32() const { return idx == TypeIndex::Decimal32; }
bool isDecimal64() const { return idx == TypeIndex::Decimal64; }
bool isDecimal128() const { return idx == TypeIndex::Decimal128; }
bool isDecimal256() const { return idx == TypeIndex::Decimal256; }
bool isDecimal() const { return isDecimal32() || isDecimal64() || isDecimal128() || isDecimal256(); }
constexpr bool isDate() const { return idx == TypeIndex::Date; }
constexpr bool isDateTime() const { return idx == TypeIndex::DateTime; }
constexpr bool isDateTime64() const { return idx == TypeIndex::DateTime64; }
constexpr bool isDateOrDateTime() const { return isDate() || isDateTime() || isDateTime64(); }
bool isFloat32() const { return idx == TypeIndex::Float32; }
bool isFloat64() const { return idx == TypeIndex::Float64; }
bool isFloat() const { return isFloat32() || isFloat64(); }
constexpr bool isString() const { return idx == TypeIndex::String; }
constexpr bool isFixedString() const { return idx == TypeIndex::FixedString; }
constexpr bool isStringOrFixedString() const { return isString() || isFixedString(); }
bool isEnum8() const { return idx == TypeIndex::Enum8; }
bool isEnum16() const { return idx == TypeIndex::Enum16; }
bool isEnum() const { return isEnum8() || isEnum16(); }
constexpr bool isUUID() const { return idx == TypeIndex::UUID; }
constexpr bool isArray() const { return idx == TypeIndex::Array; }
constexpr bool isTuple() const { return idx == TypeIndex::Tuple; }
constexpr bool isSet() const { return idx == TypeIndex::Set; }
constexpr bool isInterval() const { return idx == TypeIndex::Interval; }
bool isDate() const { return idx == TypeIndex::Date; }
bool isDateTime() const { return idx == TypeIndex::DateTime; }
bool isDateTime64() const { return idx == TypeIndex::DateTime64; }
bool isDateOrDateTime() const { return isDate() || isDateTime() || isDateTime64(); }
constexpr bool isNothing() const { return idx == TypeIndex::Nothing; }
constexpr bool isNullable() const { return idx == TypeIndex::Nullable; }
constexpr bool isFunction() const { return idx == TypeIndex::Function; }
constexpr bool isAggregateFunction() const { return idx == TypeIndex::AggregateFunction; }
bool isString() const { return idx == TypeIndex::String; }
bool isFixedString() const { return idx == TypeIndex::FixedString; }
bool isStringOrFixedString() const { return isString() || isFixedString(); }
bool isUUID() const { return idx == TypeIndex::UUID; }
bool isArray() const { return idx == TypeIndex::Array; }
bool isTuple() const { return idx == TypeIndex::Tuple; }
bool isSet() const { return idx == TypeIndex::Set; }
bool isInterval() const { return idx == TypeIndex::Interval; }
bool isNothing() const { return idx == TypeIndex::Nothing; }
bool isNullable() const { return idx == TypeIndex::Nullable; }
bool isFunction() const { return idx == TypeIndex::Function; }
bool isAggregateFunction() const { return idx == TypeIndex::AggregateFunction; }
bool IsBigIntOrDeimal() const { return isInt128() || isInt256() || isUInt256() || isDecimal256(); }
constexpr bool IsBigIntOrDeimal() const { return isInt128() || isInt256() || isUInt256() || isDecimal256(); }
};
/// IDataType helpers (alternative for IDataType virtual methods with single point of truth)

View File

@ -149,7 +149,7 @@ def run_single_test(args, ext, server_logs_level, client_options, case_file, std
command = pattern.format(**params)
print(command)
# print(command)
proc = Popen(command, shell=True, env=os.environ)
start_time = datetime.now()

View File

@ -3,9 +3,20 @@
CUR_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
. "$CUR_DIR"/../shell_config.sh
${CLICKHOUSE_CLIENT} --query="SELECT avgWeighted(x, weight) FROM (SELECT t.1 AS x, t.2 AS weight FROM (SELECT arrayJoin([(1, 5), (2, 4), (3, 3), (4, 2), (5, 1)]) AS t));"
${CLICKHOUSE_CLIENT} --query="SELECT avgWeighted(x, weight) FROM (SELECT t.1 AS x, t.2 AS weight FROM (SELECT arrayJoin([(1, 0), (2, 0), (3, 0), (4, 0), (5, 0)]) AS t));"
echo "$(${CLICKHOUSE_CLIENT} --server_logs_file=/dev/null --query="SELECT avgWeighted(toDecimal64(0, 0), toFloat64(0))" 2>&1)" \
| grep -c 'Code: 43. DB::Exception: .* DB::Exception:.* Different types .* of arguments for aggregate function avgWeighted'
types=("Int8" "Int16" "Int32" "Int64" "UInt8" "UInt16" "UInt32" "UInt64" "Float32" "Float64")
for left in "${types[@]}"
do
for right in "${types[@]}"
do
${CLICKHOUSE_CLIENT} --query="SELECT avgWeighted(x, w) FROM values('x ${left}, w ${right}', (4, 1), (1, 0), (10, 2))"
${CLICKHOUSE_CLIENT} --query="SELECT avgWeighted(x, w) FROM values('x ${left}, w ${right}', (4, 1), (1, 0))"
${CLICKHOUSE_CLIENT} --query="SELECT avgWeighted(x, w) FROM values('x ${left}, w ${right}', (4, 0), (1, 0))"
done
done
echo "$(${CLICKHOUSE_CLIENT} --server_logs_file=/dev/null --query="SELECT avgWeighted(['string'], toFloat64(0))" 2>&1)" \
| grep -c 'Code: 43. DB::Exception: .* DB::Exception:.* Types .* of arguments are non-conforming as arguments for aggregate function avgWeighted'