Make sumMap accept String keys

Adapt sumMap to accept an array of strings as the key column. This is useful when we do not want or cannot make the keys be numbers.

Signed-off-by: Baudouin Giard <bgiard@bloomberg.net>
This commit is contained in:
Ubuntu 2020-01-29 15:36:32 +00:00 committed by Baudouin Giard
parent 9247bbac45
commit bc0fbd688a
5 changed files with 24 additions and 8 deletions

View File

@ -83,6 +83,8 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con
AggregateFunctionPtr res(createWithNumericBasedType<Function>(*keys_type, keys_type, values_types, arguments));
if (!res)
res.reset(createWithDecimalType<Function>(*keys_type, keys_type, values_types, arguments));
if (!res)
res.reset(createWithStringType<Function>(*keys_type, keys_type, values_types, arguments));
if (!res)
throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -106,6 +108,8 @@ AggregateFunctionPtr createAggregateFunctionSumMapFiltered(const std::string & n
AggregateFunctionPtr res(createWithNumericBasedType<Function>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
if (!res)
res.reset(createWithDecimalType<Function>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
if (!res)
res.reset(createWithStringType<Function>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
if (!res)
throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -10,6 +10,7 @@
#include <Columns/ColumnTuple.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnString.h>
#include <Common/FieldVisitors.h>
#include <Common/assert_cast.h>
@ -56,8 +57,6 @@ class AggregateFunctionSumMapBase : public IAggregateFunctionDataHelper<
AggregateFunctionSumMapData<NearestFieldType<T>>, Derived>
{
private:
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
DataTypePtr keys_type;
DataTypes values_types;
@ -84,9 +83,10 @@ public:
void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override
{
// Column 0 contains array of keys of known type
Field key_field;
const ColumnArray & array_column0 = assert_cast<const ColumnArray &>(*columns[0]);
const IColumn::Offsets & offsets0 = array_column0.getOffsets();
const auto & keys_vec = static_cast<const ColVecType &>(array_column0.getData());
const IColumn & key_column = array_column0.getData();
const size_t keys_vec_offset = offsets0[row_num - 1];
const size_t keys_vec_size = (offsets0[row_num] - keys_vec_offset);
@ -111,7 +111,8 @@ public:
using IteratorType = typename MapType::iterator;
array_column.getData().get(values_vec_offset + i, value);
const auto & key = keys_vec.getElement(keys_vec_offset + i);
key_column.get(keys_vec_offset + i, key_field);
auto && key = key_field.get<T>();
if (!keepKey(key))
{
@ -121,7 +122,7 @@ public:
IteratorType it;
if constexpr (IsDecimalNumber<T>)
{
UInt32 scale = keys_vec.getData().getScale();
UInt32 scale = static_cast<const ColumnDecimal<T> &>(key_column).getData().getScale();
it = merged_maps.find(DecimalField<T>(key, scale));
}
else
@ -139,7 +140,7 @@ public:
if constexpr (IsDecimalNumber<T>)
{
UInt32 scale = keys_vec.getData().getScale();
UInt32 scale = static_cast<const ColumnDecimal<T> &>(key_column).getData().getScale();
merged_maps.emplace(DecimalField<T>(key, scale), std::move(new_values));
}
else

View File

@ -149,4 +149,13 @@ static IAggregateFunction * createWithTwoNumericTypes(const IDataType & first_ty
return nullptr;
}
template <template <typename> class AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction * createWithStringType(const IDataType & argument_type, TArgs && ... args)
{
WhichDataType which(argument_type);
if (which.idx == TypeIndex::String) return new AggregateFunctionTemplate<String>(std::forward<TArgs>(args)...);
if (which.idx == TypeIndex::FixedString) return new AggregateFunctionTemplate<String>(std::forward<TArgs>(args)...);
return nullptr;
}
}

View File

@ -19,3 +19,5 @@
(['1970-01-02'],[1])
(['01234567-89ab-cdef-0123-456789abcdef'],[1])
([1.01],[1])
(['a','b'],[1,2])
(['a','ab','abc'],[3,2,1])

View File

@ -35,5 +35,5 @@ select sumMap(val, cnt) from ( SELECT [ CAST(1, 'Date') ] as val, [1] as cnt );
select sumMap(val, cnt) from ( SELECT [ CAST('01234567-89ab-cdef-0123-456789abcdef', 'UUID') ] as val, [1] as cnt );
select sumMap(val, cnt) from ( SELECT [ CAST(1.01, 'Decimal(10,2)') ] as val, [1] as cnt );
select sumMap(val, cnt) from ( SELECT [ CAST('a', 'FixedString(1)') ] as val, [1] as cnt ); -- { serverError 43 }
select sumMap(val, cnt) from ( SELECT [ CAST('a', 'String') ] as val, [1] as cnt ); -- { serverError 43 }
select sumMap(val, cnt) from ( SELECT [ CAST('a', 'FixedString(1)'), CAST('b', 'FixedString(1)' ) ] as val, [1, 2] as cnt );
select sumMap(val, cnt) from ( SELECT [ CAST('abc', 'String'), CAST('ab', 'String'), CAST('a', 'String') ] as val, [1, 2, 3] as cnt );