mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
wip dealing with template magic
This commit is contained in:
parent
53471315ff
commit
8e7e232387
@ -42,3 +42,20 @@ Result:
|
||||
│ 8 │
|
||||
└────────────────────────┘
|
||||
```
|
||||
|
||||
**Example**
|
||||
|
||||
Query:
|
||||
|
||||
``` sql
|
||||
SELECT avgWeighted(x, w)
|
||||
FROM values('x Int8, w Float64', (4, 1), (1, 0), (10, 2))
|
||||
```
|
||||
|
||||
Result:
|
||||
|
||||
``` text
|
||||
┌─avgWeighted(x, weight)─┐
|
||||
│ 8 │
|
||||
└────────────────────────┘
|
||||
```
|
||||
|
@ -13,26 +13,37 @@ namespace ErrorCodes
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
struct AvgWeighted
|
||||
constexpr bool allowTypes(const DataTypePtr& left, const DataTypePtr& right)
|
||||
{
|
||||
using FieldType = std::conditional_t<
|
||||
IsDecimalNumber<T>,
|
||||
std::conditional_t<std::is_same_v<T, Decimal256>,
|
||||
Decimal256,
|
||||
Decimal128>,
|
||||
NearestFieldType<T>>;
|
||||
const WhichDataType l_dt(left), r_dt(right);
|
||||
|
||||
using Function = AggregateFunctionAvgWeighted<T, AggregateFunctionAvgData<FieldType, FieldType>>;
|
||||
constexpr auto allow = [](WhichDataType t)
|
||||
{
|
||||
return t.isInt() || t.isUInt() || t.isFloat() || t.isDecimal();
|
||||
};
|
||||
|
||||
return allow(l_dt) && allow(r_dt);
|
||||
}
|
||||
|
||||
template <class U> struct BiggerType
|
||||
|
||||
template <class U, class V> struct LargestType
|
||||
{
|
||||
using Type = bool;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using AggregateFuncAvgWeighted = typename AvgWeighted<T>::Function;
|
||||
|
||||
bool allowTypes(const DataTypePtr& left, const DataTypePtr& right)
|
||||
template <class U, class V> using AvgData = AggregateFunctionAvgData<
|
||||
typename LargestType<U, V>::Type,
|
||||
typename LargestType<U, V>::Type>;
|
||||
|
||||
template <class U, class V> using Function = AggregateFunctionAvgWeighted<
|
||||
U, V, typename LargestType<U, V>::Type, AvgData<U, V>>;
|
||||
|
||||
template <typename... TArgs>
|
||||
static IAggregateFunction * create(const IDataType & first_type, const IDataType & second_type, TArgs && ... args)
|
||||
{
|
||||
return (isInteger(left) || isFloat(left)) && (isInteger(right) || isFloat(right));
|
||||
|
||||
}
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionAvgWeighted(const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
||||
@ -40,8 +51,6 @@ AggregateFunctionPtr createAggregateFunctionAvgWeighted(const std::string & name
|
||||
assertNoParameters(name, parameters);
|
||||
assertBinary(name, argument_types);
|
||||
|
||||
AggregateFunctionPtr res;
|
||||
|
||||
const auto data_type = static_cast<const DataTypePtr>(argument_types[0]);
|
||||
const auto data_type_weight = static_cast<const DataTypePtr>(argument_types[1]);
|
||||
|
||||
@ -52,10 +61,8 @@ AggregateFunctionPtr createAggregateFunctionAvgWeighted(const std::string & name
|
||||
" are non-conforming as arguments for aggregate function " + name,
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
if (isDecimal(data_type))
|
||||
res.reset(createWithDecimalType<AggregateFuncAvgWeighted>(*data_type, *data_type, argument_types));
|
||||
else
|
||||
res.reset(createWithNumericType<AggregateFuncAvgWeighted>(*data_type, argument_types));
|
||||
AggregateFunctionPtr res;
|
||||
res.reset(create(*data_type, *data_type_weight, argument_types));
|
||||
|
||||
if (!res)
|
||||
throw Exception("Illegal type " + data_type->getName() + " of argument for aggregate function " + name,
|
||||
@ -70,5 +77,4 @@ void registerAggregateFunctionAvgWeighted(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("avgWeighted", createAggregateFunctionAvgWeighted, AggregateFunctionFactory::CaseSensitive);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -4,21 +4,28 @@
|
||||
|
||||
namespace DB
|
||||
{
|
||||
template <typename T, typename Data>
|
||||
class AggregateFunctionAvgWeighted final : public AggregateFunctionAvgBase<T, Data, AggregateFunctionAvgWeighted<T, Data>>
|
||||
|
||||
template <typename Value, typename Weight, typename Largest, typename Data>
|
||||
class AggregateFunctionAvgWeighted final :
|
||||
public AggregateFunctionAvgBase<Largest, Data, AggregateFunctionAvgWeighted<Value, Weight, Largest, Data>>
|
||||
{
|
||||
public:
|
||||
using AggregateFunctionAvgBase<T, Data, AggregateFunctionAvgWeighted<T, Data>>::AggregateFunctionAvgBase;
|
||||
using AggregateFunctionAvgBase<Largest, Data,
|
||||
AggregateFunctionAvgWeighted<Value, Weight, Largest, Data>>::AggregateFunctionAvgBase;
|
||||
|
||||
template <class T>
|
||||
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
const auto & values = static_cast<const ColVecType &>(*columns[0]);
|
||||
const auto & weights = static_cast<const ColVecType &>(*columns[1]);
|
||||
const auto & values = static_cast<const ColVecType<Value> &>(*columns[0]);
|
||||
const auto & weights = static_cast<const ColVecType<Weight> &>(*columns[1]);
|
||||
|
||||
this->data(place).numerator += static_cast<typename Data::NumeratorType>(values.getData()[row_num]) * weights.getData()[row_num];
|
||||
this->data(place).denominator += weights.getData()[row_num];
|
||||
const auto value = values.getData()[row_num];
|
||||
const auto weight = weights.getData()[row_num];
|
||||
|
||||
this->data(place).numerator += static_cast<typename Data::NumeratorType>(value) * weight;
|
||||
this->data(place).denominator += weight;
|
||||
}
|
||||
|
||||
String getName() const override { return "avgWeighted"; }
|
||||
|
@ -466,75 +466,64 @@ struct WhichDataType
|
||||
{
|
||||
TypeIndex idx;
|
||||
|
||||
WhichDataType(TypeIndex idx_ = TypeIndex::Nothing)
|
||||
: idx(idx_)
|
||||
{}
|
||||
constexpr WhichDataType(TypeIndex idx_ = TypeIndex::Nothing) : idx(idx_) {}
|
||||
constexpr WhichDataType(const IDataType & data_type) : idx(data_type.getTypeId()) {}
|
||||
constexpr WhichDataType(const IDataType * data_type) : idx(data_type->getTypeId()) {}
|
||||
constexpr WhichDataType(const DataTypePtr & data_type) : idx(data_type->getTypeId()) {}
|
||||
|
||||
WhichDataType(const IDataType & data_type)
|
||||
: idx(data_type.getTypeId())
|
||||
{}
|
||||
constexpr bool isUInt8() const { return idx == TypeIndex::UInt8; }
|
||||
constexpr bool isUInt16() const { return idx == TypeIndex::UInt16; }
|
||||
constexpr bool isUInt32() const { return idx == TypeIndex::UInt32; }
|
||||
constexpr bool isUInt64() const { return idx == TypeIndex::UInt64; }
|
||||
constexpr bool isUInt128() const { return idx == TypeIndex::UInt128; }
|
||||
constexpr bool isUInt256() const { return idx == TypeIndex::UInt256; }
|
||||
constexpr bool isUInt() const { return isUInt8() || isUInt16() || isUInt32() || isUInt64() || isUInt128() || isUInt256(); }
|
||||
constexpr bool isNativeUInt() const { return isUInt8() || isUInt16() || isUInt32() || isUInt64(); }
|
||||
|
||||
WhichDataType(const IDataType * data_type)
|
||||
: idx(data_type->getTypeId())
|
||||
{}
|
||||
constexpr bool isInt8() const { return idx == TypeIndex::Int8; }
|
||||
constexpr bool isInt16() const { return idx == TypeIndex::Int16; }
|
||||
constexpr bool isInt32() const { return idx == TypeIndex::Int32; }
|
||||
constexpr bool isInt64() const { return idx == TypeIndex::Int64; }
|
||||
constexpr bool isInt128() const { return idx == TypeIndex::Int128; }
|
||||
constexpr bool isInt256() const { return idx == TypeIndex::Int256; }
|
||||
constexpr bool isInt() const { return isInt8() || isInt16() || isInt32() || isInt64() || isInt128() || isInt256(); }
|
||||
constexpr bool isNativeInt() const { return isInt8() || isInt16() || isInt32() || isInt64(); }
|
||||
|
||||
WhichDataType(const DataTypePtr & data_type)
|
||||
: idx(data_type->getTypeId())
|
||||
{}
|
||||
constexpr bool isDecimal32() const { return idx == TypeIndex::Decimal32; }
|
||||
constexpr bool isDecimal64() const { return idx == TypeIndex::Decimal64; }
|
||||
constexpr bool isDecimal128() const { return idx == TypeIndex::Decimal128; }
|
||||
constexpr bool isDecimal256() const { return idx == TypeIndex::Decimal256; }
|
||||
constexpr bool isDecimal() const { return isDecimal32() || isDecimal64() || isDecimal128() || isDecimal256(); }
|
||||
|
||||
bool isUInt8() const { return idx == TypeIndex::UInt8; }
|
||||
bool isUInt16() const { return idx == TypeIndex::UInt16; }
|
||||
bool isUInt32() const { return idx == TypeIndex::UInt32; }
|
||||
bool isUInt64() const { return idx == TypeIndex::UInt64; }
|
||||
bool isUInt128() const { return idx == TypeIndex::UInt128; }
|
||||
bool isUInt256() const { return idx == TypeIndex::UInt256; }
|
||||
bool isUInt() const { return isUInt8() || isUInt16() || isUInt32() || isUInt64() || isUInt128() || isUInt256(); }
|
||||
bool isNativeUInt() const { return isUInt8() || isUInt16() || isUInt32() || isUInt64(); }
|
||||
constexpr bool isFloat32() const { return idx == TypeIndex::Float32; }
|
||||
constexpr bool isFloat64() const { return idx == TypeIndex::Float64; }
|
||||
constexpr bool isFloat() const { return isFloat32() || isFloat64(); }
|
||||
|
||||
bool isInt8() const { return idx == TypeIndex::Int8; }
|
||||
bool isInt16() const { return idx == TypeIndex::Int16; }
|
||||
bool isInt32() const { return idx == TypeIndex::Int32; }
|
||||
bool isInt64() const { return idx == TypeIndex::Int64; }
|
||||
bool isInt128() const { return idx == TypeIndex::Int128; }
|
||||
bool isInt256() const { return idx == TypeIndex::Int256; }
|
||||
bool isInt() const { return isInt8() || isInt16() || isInt32() || isInt64() || isInt128() || isInt256(); }
|
||||
bool isNativeInt() const { return isInt8() || isInt16() || isInt32() || isInt64(); }
|
||||
constexpr bool isEnum8() const { return idx == TypeIndex::Enum8; }
|
||||
constexpr bool isEnum16() const { return idx == TypeIndex::Enum16; }
|
||||
constexpr bool isEnum() const { return isEnum8() || isEnum16(); }
|
||||
|
||||
bool isDecimal32() const { return idx == TypeIndex::Decimal32; }
|
||||
bool isDecimal64() const { return idx == TypeIndex::Decimal64; }
|
||||
bool isDecimal128() const { return idx == TypeIndex::Decimal128; }
|
||||
bool isDecimal256() const { return idx == TypeIndex::Decimal256; }
|
||||
bool isDecimal() const { return isDecimal32() || isDecimal64() || isDecimal128() || isDecimal256(); }
|
||||
constexpr bool isDate() const { return idx == TypeIndex::Date; }
|
||||
constexpr bool isDateTime() const { return idx == TypeIndex::DateTime; }
|
||||
constexpr bool isDateTime64() const { return idx == TypeIndex::DateTime64; }
|
||||
constexpr bool isDateOrDateTime() const { return isDate() || isDateTime() || isDateTime64(); }
|
||||
|
||||
bool isFloat32() const { return idx == TypeIndex::Float32; }
|
||||
bool isFloat64() const { return idx == TypeIndex::Float64; }
|
||||
bool isFloat() const { return isFloat32() || isFloat64(); }
|
||||
constexpr bool isString() const { return idx == TypeIndex::String; }
|
||||
constexpr bool isFixedString() const { return idx == TypeIndex::FixedString; }
|
||||
constexpr bool isStringOrFixedString() const { return isString() || isFixedString(); }
|
||||
|
||||
bool isEnum8() const { return idx == TypeIndex::Enum8; }
|
||||
bool isEnum16() const { return idx == TypeIndex::Enum16; }
|
||||
bool isEnum() const { return isEnum8() || isEnum16(); }
|
||||
constexpr bool isUUID() const { return idx == TypeIndex::UUID; }
|
||||
constexpr bool isArray() const { return idx == TypeIndex::Array; }
|
||||
constexpr bool isTuple() const { return idx == TypeIndex::Tuple; }
|
||||
constexpr bool isSet() const { return idx == TypeIndex::Set; }
|
||||
constexpr bool isInterval() const { return idx == TypeIndex::Interval; }
|
||||
|
||||
bool isDate() const { return idx == TypeIndex::Date; }
|
||||
bool isDateTime() const { return idx == TypeIndex::DateTime; }
|
||||
bool isDateTime64() const { return idx == TypeIndex::DateTime64; }
|
||||
bool isDateOrDateTime() const { return isDate() || isDateTime() || isDateTime64(); }
|
||||
constexpr bool isNothing() const { return idx == TypeIndex::Nothing; }
|
||||
constexpr bool isNullable() const { return idx == TypeIndex::Nullable; }
|
||||
constexpr bool isFunction() const { return idx == TypeIndex::Function; }
|
||||
constexpr bool isAggregateFunction() const { return idx == TypeIndex::AggregateFunction; }
|
||||
|
||||
bool isString() const { return idx == TypeIndex::String; }
|
||||
bool isFixedString() const { return idx == TypeIndex::FixedString; }
|
||||
bool isStringOrFixedString() const { return isString() || isFixedString(); }
|
||||
|
||||
bool isUUID() const { return idx == TypeIndex::UUID; }
|
||||
bool isArray() const { return idx == TypeIndex::Array; }
|
||||
bool isTuple() const { return idx == TypeIndex::Tuple; }
|
||||
bool isSet() const { return idx == TypeIndex::Set; }
|
||||
bool isInterval() const { return idx == TypeIndex::Interval; }
|
||||
|
||||
bool isNothing() const { return idx == TypeIndex::Nothing; }
|
||||
bool isNullable() const { return idx == TypeIndex::Nullable; }
|
||||
bool isFunction() const { return idx == TypeIndex::Function; }
|
||||
bool isAggregateFunction() const { return idx == TypeIndex::AggregateFunction; }
|
||||
|
||||
bool IsBigIntOrDeimal() const { return isInt128() || isInt256() || isUInt256() || isDecimal256(); }
|
||||
constexpr bool IsBigIntOrDeimal() const { return isInt128() || isInt256() || isUInt256() || isDecimal256(); }
|
||||
};
|
||||
|
||||
/// IDataType helpers (alternative for IDataType virtual methods with single point of truth)
|
||||
|
@ -149,7 +149,7 @@ def run_single_test(args, ext, server_logs_level, client_options, case_file, std
|
||||
|
||||
command = pattern.format(**params)
|
||||
|
||||
print(command)
|
||||
# print(command)
|
||||
|
||||
proc = Popen(command, shell=True, env=os.environ)
|
||||
start_time = datetime.now()
|
||||
|
@ -3,9 +3,20 @@
|
||||
CUR_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
||||
. "$CUR_DIR"/../shell_config.sh
|
||||
|
||||
|
||||
${CLICKHOUSE_CLIENT} --query="SELECT avgWeighted(x, weight) FROM (SELECT t.1 AS x, t.2 AS weight FROM (SELECT arrayJoin([(1, 5), (2, 4), (3, 3), (4, 2), (5, 1)]) AS t));"
|
||||
${CLICKHOUSE_CLIENT} --query="SELECT avgWeighted(x, weight) FROM (SELECT t.1 AS x, t.2 AS weight FROM (SELECT arrayJoin([(1, 0), (2, 0), (3, 0), (4, 0), (5, 0)]) AS t));"
|
||||
|
||||
echo "$(${CLICKHOUSE_CLIENT} --server_logs_file=/dev/null --query="SELECT avgWeighted(toDecimal64(0, 0), toFloat64(0))" 2>&1)" \
|
||||
| grep -c 'Code: 43. DB::Exception: .* DB::Exception:.* Different types .* of arguments for aggregate function avgWeighted'
|
||||
types=("Int8" "Int16" "Int32" "Int64" "UInt8" "UInt16" "UInt32" "UInt64" "Float32" "Float64")
|
||||
|
||||
for left in "${types[@]}"
|
||||
do
|
||||
for right in "${types[@]}"
|
||||
do
|
||||
${CLICKHOUSE_CLIENT} --query="SELECT avgWeighted(x, w) FROM values('x ${left}, w ${right}', (4, 1), (1, 0), (10, 2))"
|
||||
${CLICKHOUSE_CLIENT} --query="SELECT avgWeighted(x, w) FROM values('x ${left}, w ${right}', (4, 1), (1, 0))"
|
||||
${CLICKHOUSE_CLIENT} --query="SELECT avgWeighted(x, w) FROM values('x ${left}, w ${right}', (4, 0), (1, 0))"
|
||||
done
|
||||
done
|
||||
|
||||
echo "$(${CLICKHOUSE_CLIENT} --server_logs_file=/dev/null --query="SELECT avgWeighted(['string'], toFloat64(0))" 2>&1)" \
|
||||
| grep -c 'Code: 43. DB::Exception: .* DB::Exception:.* Types .* of arguments are non-conforming as arguments for aggregate function avgWeighted'
|
||||
|
Loading…
Reference in New Issue
Block a user