mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-24 16:42:05 +00:00
sumMap for number-based types [issue-3277]
This commit is contained in:
parent
4fa81f18d2
commit
5b987f02d7
@ -27,7 +27,7 @@ AggregateFunctionPtr createAggregateFunctionAvg(const std::string & name, const
|
||||
AggregateFunctionPtr res;
|
||||
DataTypePtr data_type = argument_types[0];
|
||||
if (isDecimal(data_type))
|
||||
res.reset(createWithDecimalType<AggregateFuncAvg>(*data_type));
|
||||
res.reset(createWithDecimalType<AggregateFuncAvg>(*data_type, *data_type));
|
||||
else
|
||||
res.reset(createWithNumericType<AggregateFuncAvg>(*data_type));
|
||||
|
||||
|
@ -50,7 +50,7 @@ AggregateFunctionPtr createAggregateFunctionSum(const std::string & name, const
|
||||
AggregateFunctionPtr res;
|
||||
DataTypePtr data_type = argument_types[0];
|
||||
if (isDecimal(data_type))
|
||||
res.reset(createWithDecimalType<Function>(*data_type));
|
||||
res.reset(createWithDecimalType<Function>(*data_type, *data_type));
|
||||
else
|
||||
res.reset(createWithNumericType<Function>(*data_type));
|
||||
|
||||
|
@ -37,7 +37,9 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con
|
||||
values_types.push_back(array_type->getNestedType());
|
||||
}
|
||||
|
||||
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionSumMap>(*keys_type, keys_type, std::move(values_types)));
|
||||
AggregateFunctionPtr res(createWithNumericBasedType<AggregateFunctionSumMap>(*keys_type, keys_type, values_types));
|
||||
if (!res)
|
||||
res.reset(createWithDecimalType<AggregateFunctionSumMap>(*keys_type, keys_type, values_types));
|
||||
if (!res)
|
||||
throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
|
@ -8,6 +8,8 @@
|
||||
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnDecimal.h>
|
||||
|
||||
#include <Common/FieldVisitors.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
@ -53,6 +55,8 @@ class AggregateFunctionSumMap final : public IAggregateFunctionDataHelper<
|
||||
AggregateFunctionSumMapData<typename NearestFieldType<T>::Type>, AggregateFunctionSumMap<T>>
|
||||
{
|
||||
private:
|
||||
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
|
||||
|
||||
DataTypePtr keys_type;
|
||||
DataTypes values_types;
|
||||
|
||||
@ -78,7 +82,7 @@ public:
|
||||
// Column 0 contains array of keys of known type
|
||||
const ColumnArray & array_column = static_cast<const ColumnArray &>(*columns[0]);
|
||||
const IColumn::Offsets & offsets = array_column.getOffsets();
|
||||
const auto & keys_vec = static_cast<const ColumnVector<T> &>(array_column.getData());
|
||||
const auto & keys_vec = static_cast<const ColVecType &>(array_column.getData());
|
||||
const size_t keys_vec_offset = row_num == 0 ? 0 : offsets[row_num - 1];
|
||||
const size_t keys_vec_size = (offsets[row_num] - keys_vec_offset);
|
||||
|
||||
@ -99,9 +103,20 @@ public:
|
||||
// Insert column values for all keys
|
||||
for (size_t i = 0; i < keys_vec_size; ++i)
|
||||
{
|
||||
using MapType = std::decay_t<decltype(merged_maps)>;
|
||||
using IteratorType = typename MapType::iterator;
|
||||
|
||||
array_column.getData().get(values_vec_offset + i, value);
|
||||
const auto & key = keys_vec.getData()[keys_vec_offset + i];
|
||||
const auto & it = merged_maps.find(key);
|
||||
|
||||
IteratorType it;
|
||||
if constexpr (IsDecimalNumber<T>)
|
||||
{
|
||||
UInt32 scale = keys_vec.getData().getScale();
|
||||
it = merged_maps.find(DecimalField<T>(key, scale));
|
||||
}
|
||||
else
|
||||
it = merged_maps.find(key);
|
||||
|
||||
if (it != merged_maps.end())
|
||||
applyVisitor(FieldVisitorSum(value), it->second[col]);
|
||||
@ -113,7 +128,13 @@ public:
|
||||
for (size_t k = 0; k < new_values.size(); ++k)
|
||||
new_values[k] = (k == col) ? value : values_types[k]->getDefault();
|
||||
|
||||
merged_maps[key] = std::move(new_values);
|
||||
if constexpr (IsDecimalNumber<T>)
|
||||
{
|
||||
UInt32 scale = keys_vec.getData().getScale();
|
||||
merged_maps.emplace(DecimalField<T>(key, scale), std::move(new_values));
|
||||
}
|
||||
else
|
||||
merged_maps.emplace(key, std::move(new_values));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -167,7 +188,10 @@ public:
|
||||
for (size_t col = 0; col < values_types.size(); ++col)
|
||||
values_types[col]->deserializeBinary(values[col], buf);
|
||||
|
||||
merged_maps[key.get<T>()] = values;
|
||||
if constexpr (IsDecimalNumber<T>)
|
||||
merged_maps[key.get<DecimalField<T>>()] = values;
|
||||
else
|
||||
merged_maps[key.get<T>()] = values;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -24,9 +24,7 @@ AggregateFunctionPtr createAggregateFunctionStatisticsUnary(const std::string &
|
||||
AggregateFunctionPtr res;
|
||||
DataTypePtr data_type = argument_types[0];
|
||||
if (isDecimal(data_type))
|
||||
{
|
||||
res.reset(createWithDecimalType<FunctionTemplate>(*data_type));
|
||||
}
|
||||
res.reset(createWithDecimalType<FunctionTemplate>(*data_type, *data_type));
|
||||
else
|
||||
res.reset(createWithNumericType<FunctionTemplate>(*data_type));
|
||||
|
||||
|
@ -70,13 +70,28 @@ static IAggregateFunction * createWithUnsignedIntegerType(const IDataType & argu
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <template <typename> class AggregateFunctionTemplate, typename ... TArgs>
|
||||
static IAggregateFunction * createWithNumericBasedType(const IDataType & argument_type, TArgs && ... args)
|
||||
{
|
||||
IAggregateFunction * f = createWithNumericType<AggregateFunctionTemplate>(argument_type, std::forward<TArgs>(args)...);
|
||||
if (f)
|
||||
return f;
|
||||
|
||||
/// expects that DataTypeDate based on UInt16, DataTypeDateTime based on UInt32 and UUID based on UInt128
|
||||
WhichDataType which(argument_type);
|
||||
if (which.idx == TypeIndex::Date) return new AggregateFunctionTemplate<UInt16>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::DateTime) return new AggregateFunctionTemplate<UInt32>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::UUID) return new AggregateFunctionTemplate<UInt128>(std::forward<TArgs>(args)...);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <template <typename> class AggregateFunctionTemplate, typename ... TArgs>
|
||||
static IAggregateFunction * createWithDecimalType(const IDataType & argument_type, TArgs && ... args)
|
||||
{
|
||||
WhichDataType which(argument_type);
|
||||
if (which.idx == TypeIndex::Decimal32) return new AggregateFunctionTemplate<Decimal32>(argument_type, std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::Decimal64) return new AggregateFunctionTemplate<Decimal64>(argument_type, std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::Decimal128) return new AggregateFunctionTemplate<Decimal128>(argument_type, std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::Decimal32) return new AggregateFunctionTemplate<Decimal32>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::Decimal64) return new AggregateFunctionTemplate<Decimal64>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::Decimal128) return new AggregateFunctionTemplate<Decimal128>(std::forward<TArgs>(args)...);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -8,3 +8,10 @@
|
||||
2000-01-01 00:01:00 ([4,5,6,7,8],[10,10,20,10,10])
|
||||
2000-01-01 00:00:00 [1,2,3,4,5] [10,10,20,10,10]
|
||||
2000-01-01 00:01:00 [4,5,6,7,8] [10,10,20,10,10]
|
||||
([1],[1])
|
||||
([1],[1])
|
||||
(['a'],[1])
|
||||
(['1970-01-01 03:00:01'],[1])
|
||||
(['1970-01-02'],[1])
|
||||
(['01234567-89ab-cdef-0123-456789abcdef'],[1])
|
||||
([1.01],[1])
|
||||
|
@ -1,3 +1,4 @@
|
||||
CREATE DATABASE IF NOT EXISTS test;
|
||||
DROP TABLE IF EXISTS test.sum_map;
|
||||
CREATE TABLE test.sum_map(date Date, timeslot DateTime, statusMap Nested(status UInt16, requests UInt64)) ENGINE = Log;
|
||||
|
||||
@ -10,3 +11,15 @@ SELECT timeslot, sumMap(statusMap.status, statusMap.requests) FROM test.sum_map
|
||||
SELECT timeslot, sumMap(statusMap.status, statusMap.requests).1, sumMap(statusMap.status, statusMap.requests).2 FROM test.sum_map GROUP BY timeslot ORDER BY timeslot;
|
||||
|
||||
DROP TABLE test.sum_map;
|
||||
|
||||
select sumMap(val, cnt) from ( SELECT [ CAST(1, 'UInt64') ] as val, [1] as cnt );
|
||||
select sumMap(val, cnt) from ( SELECT [ CAST(1, 'Float64') ] as val, [1] as cnt );
|
||||
select sumMap(val, cnt) from ( SELECT [ CAST('a', 'Enum16(\'a\'=1)') ] as val, [1] as cnt );
|
||||
|
||||
select sumMap(val, cnt) from ( SELECT [ CAST(1, 'DateTime') ] as val, [1] as cnt );
|
||||
select sumMap(val, cnt) from ( SELECT [ CAST(1, 'Date') ] as val, [1] as cnt );
|
||||
select sumMap(val, cnt) from ( SELECT [ CAST('01234567-89ab-cdef-0123-456789abcdef', 'UUID') ] as val, [1] as cnt );
|
||||
select sumMap(val, cnt) from ( SELECT [ CAST(1.01, 'Decimal(10,2)') ] as val, [1] as cnt );
|
||||
|
||||
select sumMap(val, cnt) from ( SELECT [ CAST('a', 'FixedString(1)') ] as val, [1] as cnt ); -- { serverError 43 }
|
||||
select sumMap(val, cnt) from ( SELECT [ CAST('a', 'String') ] as val, [1] as cnt ); -- { serverError 43 }
|
||||
|
Loading…
Reference in New Issue
Block a user