Merge pull request #3281 from 4ertus2/summap

sumMap for number-based types
This commit is contained in:
alexey-milovidov 2018-10-05 00:12:57 +03:00 committed by GitHub
commit 47dd173007
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 74 additions and 13 deletions

View File

@ -27,7 +27,7 @@ AggregateFunctionPtr createAggregateFunctionAvg(const std::string & name, const
AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
if (isDecimal(data_type))
res.reset(createWithDecimalType<AggregateFuncAvg>(*data_type));
res.reset(createWithDecimalType<AggregateFuncAvg>(*data_type, *data_type));
else
res.reset(createWithNumericType<AggregateFuncAvg>(*data_type));

View File

@ -50,7 +50,7 @@ AggregateFunctionPtr createAggregateFunctionSum(const std::string & name, const
AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
if (isDecimal(data_type))
res.reset(createWithDecimalType<Function>(*data_type));
res.reset(createWithDecimalType<Function>(*data_type, *data_type));
else
res.reset(createWithNumericType<Function>(*data_type));

View File

@ -37,7 +37,9 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con
values_types.push_back(array_type->getNestedType());
}
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionSumMap>(*keys_type, keys_type, std::move(values_types)));
AggregateFunctionPtr res(createWithNumericBasedType<AggregateFunctionSumMap>(*keys_type, keys_type, values_types));
if (!res)
res.reset(createWithDecimalType<AggregateFunctionSumMap>(*keys_type, keys_type, values_types));
if (!res)
throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -8,6 +8,8 @@
#include <Columns/ColumnArray.h>
#include <Columns/ColumnTuple.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnDecimal.h>
#include <Common/FieldVisitors.h>
#include <AggregateFunctions/IAggregateFunction.h>
@ -53,6 +55,8 @@ class AggregateFunctionSumMap final : public IAggregateFunctionDataHelper<
AggregateFunctionSumMapData<typename NearestFieldType<T>::Type>, AggregateFunctionSumMap<T>>
{
private:
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
DataTypePtr keys_type;
DataTypes values_types;
@ -78,7 +82,7 @@ public:
// Column 0 contains array of keys of known type
const ColumnArray & array_column = static_cast<const ColumnArray &>(*columns[0]);
const IColumn::Offsets & offsets = array_column.getOffsets();
const auto & keys_vec = static_cast<const ColumnVector<T> &>(array_column.getData());
const auto & keys_vec = static_cast<const ColVecType &>(array_column.getData());
const size_t keys_vec_offset = row_num == 0 ? 0 : offsets[row_num - 1];
const size_t keys_vec_size = (offsets[row_num] - keys_vec_offset);
@ -99,9 +103,20 @@ public:
// Insert column values for all keys
for (size_t i = 0; i < keys_vec_size; ++i)
{
using MapType = std::decay_t<decltype(merged_maps)>;
using IteratorType = typename MapType::iterator;
array_column.getData().get(values_vec_offset + i, value);
const auto & key = keys_vec.getData()[keys_vec_offset + i];
const auto & it = merged_maps.find(key);
IteratorType it;
if constexpr (IsDecimalNumber<T>)
{
UInt32 scale = keys_vec.getData().getScale();
it = merged_maps.find(DecimalField<T>(key, scale));
}
else
it = merged_maps.find(key);
if (it != merged_maps.end())
applyVisitor(FieldVisitorSum(value), it->second[col]);
@ -113,7 +128,13 @@ public:
for (size_t k = 0; k < new_values.size(); ++k)
new_values[k] = (k == col) ? value : values_types[k]->getDefault();
merged_maps[key] = std::move(new_values);
if constexpr (IsDecimalNumber<T>)
{
UInt32 scale = keys_vec.getData().getScale();
merged_maps.emplace(DecimalField<T>(key, scale), std::move(new_values));
}
else
merged_maps.emplace(key, std::move(new_values));
}
}
}
@ -167,7 +188,10 @@ public:
for (size_t col = 0; col < values_types.size(); ++col)
values_types[col]->deserializeBinary(values[col], buf);
merged_maps[key.get<T>()] = values;
if constexpr (IsDecimalNumber<T>)
merged_maps[key.get<DecimalField<T>>()] = values;
else
merged_maps[key.get<T>()] = values;
}
}

View File

@ -24,9 +24,7 @@ AggregateFunctionPtr createAggregateFunctionStatisticsUnary(const std::string &
AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
if (isDecimal(data_type))
{
res.reset(createWithDecimalType<FunctionTemplate>(*data_type));
}
res.reset(createWithDecimalType<FunctionTemplate>(*data_type, *data_type));
else
res.reset(createWithNumericType<FunctionTemplate>(*data_type));

View File

@ -70,13 +70,28 @@ static IAggregateFunction * createWithUnsignedIntegerType(const IDataType & argu
return nullptr;
}
template <template <typename> class AggregateFunctionTemplate, typename ... TArgs>
static IAggregateFunction * createWithNumericBasedType(const IDataType & argument_type, TArgs && ... args)
{
IAggregateFunction * f = createWithNumericType<AggregateFunctionTemplate>(argument_type, std::forward<TArgs>(args)...);
if (f)
return f;
/// expects that DataTypeDate based on UInt16, DataTypeDateTime based on UInt32 and UUID based on UInt128
WhichDataType which(argument_type);
if (which.idx == TypeIndex::Date) return new AggregateFunctionTemplate<UInt16>(std::forward<TArgs>(args)...);
if (which.idx == TypeIndex::DateTime) return new AggregateFunctionTemplate<UInt32>(std::forward<TArgs>(args)...);
if (which.idx == TypeIndex::UUID) return new AggregateFunctionTemplate<UInt128>(std::forward<TArgs>(args)...);
return nullptr;
}
template <template <typename> class AggregateFunctionTemplate, typename ... TArgs>
static IAggregateFunction * createWithDecimalType(const IDataType & argument_type, TArgs && ... args)
{
WhichDataType which(argument_type);
if (which.idx == TypeIndex::Decimal32) return new AggregateFunctionTemplate<Decimal32>(argument_type, std::forward<TArgs>(args)...);
if (which.idx == TypeIndex::Decimal64) return new AggregateFunctionTemplate<Decimal64>(argument_type, std::forward<TArgs>(args)...);
if (which.idx == TypeIndex::Decimal128) return new AggregateFunctionTemplate<Decimal128>(argument_type, std::forward<TArgs>(args)...);
if (which.idx == TypeIndex::Decimal32) return new AggregateFunctionTemplate<Decimal32>(std::forward<TArgs>(args)...);
if (which.idx == TypeIndex::Decimal64) return new AggregateFunctionTemplate<Decimal64>(std::forward<TArgs>(args)...);
if (which.idx == TypeIndex::Decimal128) return new AggregateFunctionTemplate<Decimal128>(std::forward<TArgs>(args)...);
return nullptr;
}

View File

@ -8,3 +8,10 @@
2000-01-01 00:01:00 ([4,5,6,7,8],[10,10,20,10,10])
2000-01-01 00:00:00 [1,2,3,4,5] [10,10,20,10,10]
2000-01-01 00:01:00 [4,5,6,7,8] [10,10,20,10,10]
([1],[1])
([1],[1])
(['a'],[1])
(['1970-01-01 03:00:01'],[1])
(['1970-01-02'],[1])
(['01234567-89ab-cdef-0123-456789abcdef'],[1])
([1.01],[1])

View File

@ -1,3 +1,6 @@
SET send_logs_level = 'none';
CREATE DATABASE IF NOT EXISTS test;
DROP TABLE IF EXISTS test.sum_map;
CREATE TABLE test.sum_map(date Date, timeslot DateTime, statusMap Nested(status UInt16, requests UInt64)) ENGINE = Log;
@ -10,3 +13,15 @@ SELECT timeslot, sumMap(statusMap.status, statusMap.requests) FROM test.sum_map
SELECT timeslot, sumMap(statusMap.status, statusMap.requests).1, sumMap(statusMap.status, statusMap.requests).2 FROM test.sum_map GROUP BY timeslot ORDER BY timeslot;
DROP TABLE test.sum_map;
select sumMap(val, cnt) from ( SELECT [ CAST(1, 'UInt64') ] as val, [1] as cnt );
select sumMap(val, cnt) from ( SELECT [ CAST(1, 'Float64') ] as val, [1] as cnt );
select sumMap(val, cnt) from ( SELECT [ CAST('a', 'Enum16(\'a\'=1)') ] as val, [1] as cnt );
select sumMap(val, cnt) from ( SELECT [ CAST(1, 'DateTime') ] as val, [1] as cnt );
select sumMap(val, cnt) from ( SELECT [ CAST(1, 'Date') ] as val, [1] as cnt );
select sumMap(val, cnt) from ( SELECT [ CAST('01234567-89ab-cdef-0123-456789abcdef', 'UUID') ] as val, [1] as cnt );
select sumMap(val, cnt) from ( SELECT [ CAST(1.01, 'Decimal(10,2)') ] as val, [1] as cnt );
select sumMap(val, cnt) from ( SELECT [ CAST('a', 'FixedString(1)') ] as val, [1] as cnt ); -- { serverError 43 }
select sumMap(val, cnt) from ( SELECT [ CAST('a', 'String') ] as val, [1] as cnt ); -- { serverError 43 }