From 2ba68d74947a84d2f6008fa6a3bceab3bad0ed5b Mon Sep 17 00:00:00 2001 From: hexiaoting Date: Mon, 2 Nov 2020 14:05:53 +0800 Subject: [PATCH] Add map function --- src/Common/FieldVisitors.cpp | 72 +++ src/Common/FieldVisitors.h | 24 + src/Core/Field.cpp | 500 +++--------------- src/DataTypes/DataTypeMap.cpp | 20 +- src/DataTypes/DataTypeMap.h | 5 +- src/Functions/FunctionsConversion.h | 28 +- src/Functions/array/arrayElement.cpp | 187 +++++-- src/Functions/registerFunctions.cpp | 2 + src/Functions/ya.make | 1 + src/Interpreters/convertFieldToType.cpp | 7 +- src/Parsers/ASTFunction.cpp | 18 +- src/Parsers/ExpressionElementParsers.cpp | 31 ++ src/Parsers/ExpressionElementParsers.h | 7 + src/Parsers/ExpressionListParsers.cpp | 45 ++ src/Parsers/ExpressionListParsers.h | 7 + .../01550_create_map_type.reference | 24 +- .../0_stateless/01550_create_map_type.sql | 36 +- 17 files changed, 493 insertions(+), 521 deletions(-) diff --git a/src/Common/FieldVisitors.cpp b/src/Common/FieldVisitors.cpp index e35e719c7c9..928755be3eb 100644 --- a/src/Common/FieldVisitors.cpp +++ b/src/Common/FieldVisitors.cpp @@ -208,6 +208,78 @@ String FieldVisitorToString::operator() (const Map & x) const return wb.str(); } + +void FieldVisitorWriteBinary::operator() (const Null &, WriteBuffer &) const { return ; } +void FieldVisitorWriteBinary::operator() (const UInt64 & x, WriteBuffer & buf) const { DB::writeVarUInt(x, buf); } +void FieldVisitorWriteBinary::operator() (const Int64 & x, WriteBuffer & buf) const { DB::writeVarInt(x, buf); } +void FieldVisitorWriteBinary::operator() (const Float64 & x, WriteBuffer & buf) const { DB::writeFloatBinary(x, buf); } +void FieldVisitorWriteBinary::operator() (const String & x, WriteBuffer & buf) const { DB::writeStringBinary(x, buf); } +void FieldVisitorWriteBinary::operator() (const UInt128 & x, WriteBuffer & buf) const { DB::writeBinary(x, buf); } +void FieldVisitorWriteBinary::operator() (const Int128 & x, WriteBuffer & buf) const { DB::writeVarInt(x, buf); } +void FieldVisitorWriteBinary::operator() (const UInt256 & x, WriteBuffer & buf) const { DB::writeBinary(x, buf); } +void FieldVisitorWriteBinary::operator() (const Int256 & x, WriteBuffer & buf) const { DB::writeBinary(x, buf); } +void FieldVisitorWriteBinary::operator() (const DecimalField & x, WriteBuffer & buf) const { DB::writeBinary(x.getValue(), buf); } +void FieldVisitorWriteBinary::operator() (const DecimalField & x, WriteBuffer & buf) const { DB::writeBinary(x.getValue(), buf); } +void FieldVisitorWriteBinary::operator() (const DecimalField & x, WriteBuffer & buf) const { DB::writeBinary(x.getValue(), buf); } +void FieldVisitorWriteBinary::operator() (const DecimalField & x, WriteBuffer & buf) const { DB::writeBinary(x.getValue(), buf); } +void FieldVisitorWriteBinary::operator() (const AggregateFunctionStateData & x, WriteBuffer & buf) const +{ + DB::writeStringBinary(x.name, buf); + DB::writeStringBinary(x.data, buf); +} + +void FieldVisitorWriteBinary::operator() (const Array & x, WriteBuffer & buf) const +{ + const size_t size = x.size(); + DB::writeBinary(size, buf); + + for (auto it = x.begin(); it != x.end(); ++it) + { + const UInt8 type = it->getType(); + DB::writeBinary(type, buf); + Field::dispatch( + [&buf](const auto & value) { + DB::FieldVisitorWriteBinary()(value, buf); + }, + *it); + } +} + +void FieldVisitorWriteBinary::operator() (const Tuple & x, WriteBuffer & buf) const +{ + const size_t size = x.size(); + DB::writeBinary(size, buf); + + for (auto it = x.begin(); it != x.end(); ++it) + { + const UInt8 type = it->getType(); + DB::writeBinary(type, buf); + Field::dispatch( + [&buf](const auto & value) { + DB::FieldVisitorWriteBinary()(value, buf); + }, + *it); + } +} + + +void FieldVisitorWriteBinary::operator() (const Map & x, WriteBuffer & buf) const +{ + const size_t size = x.size(); + DB::writeBinary(size, buf); + for (auto it = x.begin(); it != x.end(); ++it) + { + const UInt8 type = it->getType(); + writeBinary(type, buf); + Field::dispatch( + [&buf](const auto & value) { + DB::FieldVisitorWriteBinary()(value, buf); + }, + *it); + } +} + + FieldVisitorHash::FieldVisitorHash(SipHash & hash_) : hash(hash_) {} void FieldVisitorHash::operator() (const Null &) const diff --git a/src/Common/FieldVisitors.h b/src/Common/FieldVisitors.h index 5d86750dd49..55602acb2b1 100644 --- a/src/Common/FieldVisitors.h +++ b/src/Common/FieldVisitors.h @@ -88,6 +88,30 @@ public: }; +class FieldVisitorWriteBinary +{ +public: + void operator() (const Null & x, WriteBuffer & buf) const; + void operator() (const UInt64 & x, WriteBuffer & buf) const; + void operator() (const UInt128 & x, WriteBuffer & buf) const; + void operator() (const Int64 & x, WriteBuffer & buf) const; + void operator() (const Int128 & x, WriteBuffer & buf) const; + void operator() (const Float64 & x, WriteBuffer & buf) const; + void operator() (const String & x, WriteBuffer & buf) const; + void operator() (const Array & x, WriteBuffer & buf) const; + void operator() (const Tuple & x, WriteBuffer & buf) const; + void operator() (const Map & x, WriteBuffer & buf) const; + void operator() (const DecimalField & x, WriteBuffer & buf) const; + void operator() (const DecimalField & x, WriteBuffer & buf) const; + void operator() (const DecimalField & x, WriteBuffer & buf) const; + void operator() (const DecimalField & x, WriteBuffer & buf) const; + void operator() (const AggregateFunctionStateData & x, WriteBuffer & buf) const; + + void operator() (const UInt256 & x, WriteBuffer & buf) const; + void operator() (const Int256 & x, WriteBuffer & buf) const; +}; + + /** Print readable and unique text dump of field type and value. */ class FieldVisitorDump : public StaticVisitor { diff --git a/src/Core/Field.cpp b/src/Core/Field.cpp index e4082e6df8c..b91cbe84f9d 100644 --- a/src/Core/Field.cpp +++ b/src/Core/Field.cpp @@ -17,6 +17,63 @@ namespace ErrorCodes extern const int DECIMAL_OVERFLOW; } +inline Field getBinaryValue(UInt8 type, ReadBuffer & buf) +{ + switch (type) + { + case Field::Types::Null: { + return DB::Field(); + } + case Field::Types::UInt64: { + UInt64 value; + DB::readVarUInt(value, buf); + return value; + } + case Field::Types::UInt128: { + UInt128 value; + DB::readBinary(value, buf); + return value; + } + case Field::Types::Int64: { + Int64 value; + DB::readVarInt(value, buf); + return value; + } + case Field::Types::Float64: { + Float64 value; + DB::readFloatBinary(value, buf); + return value; + } + case Field::Types::String: { + std::string value; + DB::readStringBinary(value, buf); + return value; + } + case Field::Types::Array: { + Array value; + DB::readBinary(value, buf); + return value; + } + case Field::Types::Tuple: { + Tuple value; + DB::readBinary(value, buf); + return value; + } + case Field::Types::Map: { + Map value; + DB::readBinary(value, buf); + return value; + } + case Field::Types::AggregateFunctionState: { + AggregateFunctionStateData value; + DB::readStringBinary(value.name, buf); + DB::readStringBinary(value.data, buf); + return value; + } + } + return DB::Field(); +} + void readBinary(Array & x, ReadBuffer & buf) { size_t size; @@ -25,80 +82,7 @@ void readBinary(Array & x, ReadBuffer & buf) DB::readBinary(size, buf); for (size_t index = 0; index < size; ++index) - { - switch (type) - { - case Field::Types::Null: - { - x.push_back(DB::Field()); - break; - } - case Field::Types::UInt64: - { - UInt64 value; - DB::readVarUInt(value, buf); - x.push_back(value); - break; - } - case Field::Types::UInt128: - { - UInt128 value; - DB::readBinary(value, buf); - x.push_back(value); - break; - } - case Field::Types::Int64: - { - Int64 value; - DB::readVarInt(value, buf); - x.push_back(value); - break; - } - case Field::Types::Float64: - { - Float64 value; - DB::readFloatBinary(value, buf); - x.push_back(value); - break; - } - case Field::Types::String: - { - std::string value; - DB::readStringBinary(value, buf); - x.push_back(value); - break; - } - case Field::Types::Array: - { - Array value; - DB::readBinary(value, buf); - x.push_back(value); - break; - } - case Field::Types::Tuple: - { - Tuple value; - DB::readBinary(value, buf); - x.push_back(value); - break; - } - case Field::Types::Map: - { - Map value; - DB::readBinary(value, buf); - x.push_back(value); - break; - } - case Field::Types::AggregateFunctionState: - { - AggregateFunctionStateData value; - DB::readStringBinary(value.name, buf); - DB::readStringBinary(value.data, buf); - x.push_back(value); - break; - } - } - } + x.push_back(getBinaryValue(type, buf)); } void writeBinary(const Array & x, WriteBuffer & buf) @@ -111,58 +95,7 @@ void writeBinary(const Array & x, WriteBuffer & buf) DB::writeBinary(size, buf); for (const auto & elem : x) - { - switch (type) - { - case Field::Types::Null: break; - case Field::Types::UInt64: - { - DB::writeVarUInt(get(elem), buf); - break; - } - case Field::Types::UInt128: - { - DB::writeBinary(get(elem), buf); - break; - } - case Field::Types::Int64: - { - DB::writeVarInt(get(elem), buf); - break; - } - case Field::Types::Float64: - { - DB::writeFloatBinary(get(elem), buf); - break; - } - case Field::Types::String: - { - DB::writeStringBinary(get(elem), buf); - break; - } - case Field::Types::Array: - { - DB::writeBinary(get(elem), buf); - break; - } - case Field::Types::Tuple: - { - DB::writeBinary(get(elem), buf); - break; - } - case Field::Types::Map: - { - DB::writeBinary(get(elem), buf); - break; - } - case Field::Types::AggregateFunctionState: - { - DB::writeStringBinary(elem.get().name, buf); - DB::writeStringBinary(elem.get().data, buf); - break; - } - } - } + Field::dispatch([&buf](const auto & value) { DB::FieldVisitorWriteBinary()(value, buf); }, elem); } void writeText(const Array & x, WriteBuffer & buf) @@ -180,100 +113,7 @@ void readBinary(Tuple & x, ReadBuffer & buf) { UInt8 type; DB::readBinary(type, buf); - - switch (type) - { - case Field::Types::Null: - { - x.push_back(DB::Field()); - break; - } - case Field::Types::UInt64: - { - UInt64 value; - DB::readVarUInt(value, buf); - x.push_back(value); - break; - } - case Field::Types::UInt128: - { - UInt128 value; - DB::readBinary(value, buf); - x.push_back(value); - break; - } - case Field::Types::Int64: - { - Int64 value; - DB::readVarInt(value, buf); - x.push_back(value); - break; - } - case Field::Types::Int128: - { - Int64 value; - DB::readVarInt(value, buf); - x.push_back(value); - break; - } - case Field::Types::Float64: - { - Float64 value; - DB::readFloatBinary(value, buf); - x.push_back(value); - break; - } - case Field::Types::String: - { - std::string value; - DB::readStringBinary(value, buf); - x.push_back(value); - break; - } - case Field::Types::UInt256: - { - UInt256 value; - DB::readBinary(value, buf); - x.push_back(value); - break; - } - case Field::Types::Int256: - { - Int256 value; - DB::readBinary(value, buf); - x.push_back(value); - break; - } - case Field::Types::Array: - { - Array value; - DB::readBinary(value, buf); - x.push_back(value); - break; - } - case Field::Types::Tuple: - { - Tuple value; - DB::readBinary(value, buf); - x.push_back(value); - break; - } - case Field::Types::Map: - { - Map value; - DB::readBinary(value, buf); - x.push_back(value); - break; - } - case Field::Types::AggregateFunctionState: - { - AggregateFunctionStateData value; - DB::readStringBinary(value.name, buf); - DB::readStringBinary(value.data, buf); - x.push_back(value); - break; - } - } + x.push_back(getBinaryValue(type, buf)); } } @@ -286,67 +126,11 @@ void writeBinary(const Tuple & x, WriteBuffer & buf) { const UInt8 type = elem.getType(); DB::writeBinary(type, buf); - - switch (type) - { - case Field::Types::Null: break; - case Field::Types::UInt64: - { - DB::writeVarUInt(get(elem), buf); - break; - } - case Field::Types::UInt128: - { - DB::writeBinary(get(elem), buf); - break; - } - case Field::Types::Int64: - { - DB::writeVarInt(get(elem), buf); - break; - } - case Field::Types::Int128: - { - DB::writeVarInt(get(elem), buf); - break; - } - case Field::Types::Float64: - { - DB::writeFloatBinary(get(elem), buf); - break; - } - case Field::Types::String: - { - DB::writeStringBinary(get(elem), buf); - break; - } - case Field::Types::UInt256: - { - DB::writeBinary(get(elem), buf); - break; - } - case Field::Types::Int256: - { - DB::writeBinary(get(elem), buf); - break; - } - case Field::Types::Array: - { - DB::writeBinary(get(elem), buf); - break; - } - case Field::Types::Tuple: - { - DB::writeBinary(get(elem), buf); - break; - } - case Field::Types::AggregateFunctionState: - { - DB::writeStringBinary(elem.get().name, buf); - DB::writeStringBinary(elem.get().data, buf); - break; - } - } + Field::dispatch( + [&buf](const auto & value) { + DB::FieldVisitorWriteBinary()(value, buf); + }, + elem); } } @@ -364,93 +148,7 @@ void readBinary(Map & x, ReadBuffer & buf) { UInt8 type; DB::readBinary(type, buf); - - switch (type) - { - case Field::Types::Null: - { - x.push_back(DB::Field()); - break; - } - case Field::Types::UInt64: - { - UInt64 value; - DB::readVarUInt(value, buf); - x.push_back(value); - break; - } - case Field::Types::UInt128: - { - UInt128 value; - DB::readBinary(value, buf); - x.push_back(value); - break; - } - case Field::Types::Int64: - { - Int64 value; - DB::readVarInt(value, buf); - x.push_back(value); - break; - } - case Field::Types::Int128: - { - Int64 value; - DB::readVarInt(value, buf); - x.push_back(value); - break; - } - case Field::Types::Float64: - { - Float64 value; - DB::readFloatBinary(value, buf); - x.push_back(value); - break; - } - case Field::Types::String: - { - std::string value; - DB::readStringBinary(value, buf); - x.push_back(value); - break; - } - case Field::Types::UInt256: - { - UInt256 value; - DB::readBinary(value, buf); - x.push_back(value); - break; - } - case Field::Types::Int256: - { - Int256 value; - DB::readBinary(value, buf); - x.push_back(value); - break; - } - case Field::Types::Array: - { - Array value; - DB::readBinary(value, buf); - x.push_back(value); - break; - } - case Field::Types::Tuple: - { - Tuple value; - DB::readBinary(value, buf); - x.push_back(value); - break; - } - case Field::Types::AggregateFunctionState: - { - AggregateFunctionStateData value; - DB::readStringBinary(value.name, buf); - DB::readStringBinary(value.data, buf); - x.push_back(value); - break; - } - } + x.push_back(getBinaryValue(type, buf)); } } @@ -463,67 +161,11 @@ void writeBinary(const Map & x, WriteBuffer & buf) { const UInt8 type = elem.getType(); DB::writeBinary(type, buf); - - switch (type) - { - case Field::Types::Null: break; - case Field::Types::UInt64: - { - DB::writeVarUInt(get(elem), buf); - break; - } - case Field::Types::UInt128: - { - DB::writeBinary(get(elem), buf); - break; - } - case Field::Types::Int64: - { - DB::writeVarInt(get(elem), buf); - break; - } - case Field::Types::Int128: - { - DB::writeVarInt(get(elem), buf); - break; - } - case Field::Types::Float64: - { - DB::writeFloatBinary(get(elem), buf); - break; - } - case Field::Types::String: - { - DB::writeStringBinary(get(elem), buf); - break; - } - case Field::Types::UInt256: - { - DB::writeBinary(get(elem), buf); - break; - } - case Field::Types::Int256: - { - DB::writeBinary(get(elem), buf); - break; - } - case Field::Types::Array: - { - DB::writeBinary(get(elem), buf); - break; - } - case Field::Types::Tuple: - { - DB::writeBinary(get(elem), buf); - break; - } - case Field::Types::AggregateFunctionState: - { - DB::writeStringBinary(elem.get().name, buf); - DB::writeStringBinary(elem.get().data, buf); - break; - } - } + Field::dispatch( + [&buf](const auto & value) { + DB::FieldVisitorWriteBinary()(value, buf); + }, + elem); } } diff --git a/src/DataTypes/DataTypeMap.cpp b/src/DataTypes/DataTypeMap.cpp index 270c92991f3..4a82f33d9ed 100644 --- a/src/DataTypes/DataTypeMap.cpp +++ b/src/DataTypes/DataTypeMap.cpp @@ -35,6 +35,7 @@ namespace ErrorCodes DataTypeMap::DataTypeMap(const DataTypes & elems_) { + assert(elems_.size() < 3); key_type = elems_.size() == 1 ? DataTypeFactory::instance().get("String") : elems_[0]; value_type = elems_.size() == 1 ? elems_[0] : elems_[1]; @@ -47,8 +48,7 @@ DataTypeMap::DataTypeMap(const DataTypes & elems_) std::string DataTypeMap::doGetName() const { WriteBufferFromOwnString s; - s << "Map(" << (typeid_cast(keys.get()))->getNestedType()->getName() - << "," << (typeid_cast(values.get()))->getNestedType()->getName() << ")"; + s << "Map(" << key_type->getName() << "," << value_type->getName() << ")"; return s.str(); } @@ -217,20 +217,12 @@ void DataTypeMap::deserializeText(IColumn & column, ReadBuffer & istr, const For void DataTypeMap::serializeTextJSON(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const { - writeChar('[', ostr); - keys->serializeAsTextJSON(extractElementColumn(column, 0), row_num, ostr, settings); - writeChar(',', ostr); - values->serializeAsTextJSON(extractElementColumn(column, 1), row_num, ostr, settings); - writeChar(']', ostr); + serializeText(column, row_num, ostr, settings); } void DataTypeMap::deserializeTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const { - assertChar('[', istr); - keys->deserializeAsTextJSON(extractElementColumn(column, 0), istr, settings); - assertChar(',', istr); - values->deserializeAsTextJSON(extractElementColumn(column, 1), istr, settings); - assertChar(']', istr); + deserializeText(column, istr, settings); } void DataTypeMap::serializeTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const @@ -305,11 +297,12 @@ static DeserializeBinaryBulkStateMap * checkAndGetMapDeserializeState(IDataType: void DataTypeMap::enumerateStreams(const StreamCallback & callback, SubstreamPath & path) const { + // path.push_back(Substream::ArraySizes); path.push_back(Substream::MapElement); path.back().map_element_name = "keys"; keys->enumerateStreams(callback, path); path.back().map_element_name = "values"; - keys->enumerateStreams(callback, path); + values->enumerateStreams(callback, path); path.pop_back(); } @@ -374,6 +367,7 @@ void DataTypeMap::serializeBinaryBulkWithMultipleStreams( const auto & keys_col = extractElementColumn(column, 0); settings.path.back().map_element_name = "keys"; + keys->serializeBinaryBulkWithMultipleStreams(keys_col, offset, limit, settings, map_state->states[0]); const auto & values_col = extractElementColumn(column, 1); settings.path.back().map_element_name = "values"; diff --git a/src/DataTypes/DataTypeMap.h b/src/DataTypes/DataTypeMap.h index 4f9de4c8b9e..846531fad41 100644 --- a/src/DataTypes/DataTypeMap.h +++ b/src/DataTypes/DataTypeMap.h @@ -86,9 +86,8 @@ public: bool isParametric() const override { return true; } bool haveSubtypes() const override { return true; } - const DataTypePtr & getKeyType() const { return keys; } - const DataTypePtr & getValueType() const { return values; } - const DataTypePtr & getVType() const { return value_type; } + const DataTypePtr & getKeyType() const { return key_type; } + const DataTypePtr & getValueType() const { return value_type; } const DataTypes & getElements() const {return kv; } }; diff --git a/src/Functions/FunctionsConversion.h b/src/Functions/FunctionsConversion.h index 157d7de3842..bbae8e805ab 100644 --- a/src/Functions/FunctionsConversion.h +++ b/src/Functions/FunctionsConversion.h @@ -2173,15 +2173,31 @@ private: throw Exception{"CAST AS Map can only be performed between map types with the same number of elements.\n" "Left type: " + from_type->getName() + ", right type: " + to_type->getName(), ErrorCodes::TYPE_MISMATCH}; - return [] - (ColumnsWithTypeAndName & arguments, const DataTypePtr &, const ColumnNullable * /*nullable_source*/, size_t /*input_rows_count*/) + const auto & from_kv_types = from_type->getElements(); + const auto & to_kv_types = to_type->getElements(); + std::vector element_wrappers; + element_wrappers.reserve(2); + + /// Create conversion wrapper for each element in tuple + for (const auto idx_type : ext::enumerate(from_kv_types)) + element_wrappers.push_back(prepareUnpackDictionaries(idx_type.second, to_kv_types[idx_type.first])); + + return [element_wrappers, from_kv_types, to_kv_types] + (ColumnsWithTypeAndName & arguments, const DataTypePtr &, const ColumnNullable * nullable_source, size_t input_rows_count) -> ColumnPtr { - const auto col = arguments.front().column.get(); - const ColumnMap & column_map = typeid_cast(*col); + const auto * col = arguments.front().column.get(); + + // size_t tuple_size = from_kv_types.size(); + const ColumnMap & column_tuple = typeid_cast(*col); Columns converted_columns(2); - converted_columns[0] = column_map.getColumns()[0]; - converted_columns[1] = column_map.getColumns()[1]; + + /// invoke conversion for each element + for (size_t i = 0; i < 2; ++i) + { + ColumnsWithTypeAndName element = {{column_tuple.getColumns()[i], from_kv_types[i], "" }}; + converted_columns[i] = element_wrappers[i](element, to_kv_types[i], nullable_source, input_rows_count); + } return ColumnMap::create(converted_columns); }; diff --git a/src/Functions/array/arrayElement.cpp b/src/Functions/array/arrayElement.cpp index f495d9af339..a5e880f0a04 100644 --- a/src/Functions/array/arrayElement.cpp +++ b/src/Functions/array/arrayElement.cpp @@ -85,15 +85,22 @@ private: /** For a Map, the function is to find the matched key's value */ - static bool executeMappedKeyString(const ColumnArray * column, Field & index, std::vector &matched_idxs); + static bool executeMappedKeyStringArgument(const ColumnArray * column, ColumnsWithTypeAndName & arguments, std::vector &matched_idxs); + + static bool executeMappedKeyStringConst(const ColumnArray * column, Field & index, std::vector &matched_idxs); template - static bool executeMappedKeyNumber(const ColumnArray * column, Field & index, std::vector &matched_idxs); - - static bool getMappedKey(const ColumnArray * column, Field & index, std::vector &matched_idxs); + static bool executeMappedKeyNumberArgument(const ColumnArray * column, ColumnsWithTypeAndName & arguments, std::vector &matched_idxs); template - static bool executeMappedValueNumber(const ColumnArray * column, std::vector matched_idxs, + static bool executeMappedKeyNumberConst(const ColumnArray * column, Field & index, std::vector &matched_idxs); + + static bool getMappedKey(const ColumnArray * column, ColumnsWithTypeAndName & arguments, std::vector &matched_idxs); + + static bool getMappedKeyConst(const ColumnArray * column, Field & index, std::vector &matched_idxs); + + template + static bool executeMappedValueNumber(const ColumnArray * column, const std::vector & matched_idxs, IColumn * col_res_untyped); static bool executeMappedValueString(const ColumnArray * column, std::vector matched_idxs, @@ -104,7 +111,7 @@ private: static bool getMappedValue(const ColumnArray * column, std::vector matched_idxs, IColumn * col_res_untyped); - static ColumnPtr executeMap(ColumnsWithTypeAndName & columns, size_t input_rows_count); + static ColumnPtr executeMap(ColumnsWithTypeAndName & arguments, size_t input_rows_count); }; @@ -647,10 +654,8 @@ ColumnPtr FunctionArrayElement::executeArgument( ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, ArrayImpl::NullMapBuilder & builder, size_t input_rows_count) const { auto index = checkAndGetColumn>(arguments[1].column.get()); - if (!index) return nullptr; - const auto & index_data = index->getData(); if (builder) @@ -725,7 +730,7 @@ ColumnPtr FunctionArrayElement::executeTuple(ColumnsWithTypeAndName & arguments, return ColumnTuple::create(result_tuple_columns); } -bool FunctionArrayElement::executeMappedKeyString(const ColumnArray * column, Field & index, +bool FunctionArrayElement::executeMappedKeyStringConst(const ColumnArray * column, Field & index, std::vector &matched_idxs) { const ColumnString * keys = checkAndGetColumn(&column->getData()); @@ -757,8 +762,45 @@ bool FunctionArrayElement::executeMappedKeyString(const ColumnArray * column, Fi return true; } +bool FunctionArrayElement::executeMappedKeyStringArgument(const ColumnArray * column, ColumnsWithTypeAndName & arguments, std::vector &matched_idxs) +{ + auto index = checkAndGetColumn(arguments[1].column.get()); + if (!index) + return false; + + const ColumnString * keys = checkAndGetColumn(&column->getData()); + const ColumnArray::Offsets & offsets = column->getOffsets(); + size_t rows = offsets.size(); + + if (!keys) + return false; + + // String str = index.get(); + for (size_t i = 0; i < rows; i++) + { + bool matched = false; + size_t begin = offsets[i - 1]; + size_t end = offsets[i]; + for (size_t j = begin; j < end; j++) + { + if (strcmp(keys->getDataAt(j).data, index->getDataAt(i).data) == 0) + { + matched_idxs.push_back(j); + matched = true; + break; + } + } + if (!matched) + matched_idxs.push_back(-1); + } + + return true; +} + + + template -bool FunctionArrayElement::executeMappedKeyNumber(const ColumnArray * column, Field & index, +bool FunctionArrayElement::executeMappedKeyNumberConst(const ColumnArray * column, Field & index, std::vector &matched_idxs) { const ColumnVector * col_nested = checkAndGetColumn>(&column->getData()); @@ -795,17 +837,72 @@ bool FunctionArrayElement::executeMappedKeyNumber(const ColumnArray * column, Fi return true; } -bool FunctionArrayElement::getMappedKey(const ColumnArray * column, Field & index, std::vector &matched_idxs) +template +bool FunctionArrayElement::executeMappedKeyNumberArgument(const ColumnArray * column, ColumnsWithTypeAndName & arguments, std::vector &matched_idxs) { - if (!(executeMappedKeyNumber(column, index, matched_idxs) - || executeMappedKeyNumber(column, index, matched_idxs) - || executeMappedKeyNumber(column, index, matched_idxs) - || executeMappedKeyNumber(column, index, matched_idxs) - || executeMappedKeyNumber(column, index, matched_idxs) - || executeMappedKeyNumber(column, index, matched_idxs) - || executeMappedKeyNumber(column, index, matched_idxs) - || executeMappedKeyNumber(column, index, matched_idxs) - || executeMappedKeyString(column, index, matched_idxs))) + auto index = checkAndGetColumn>(arguments[1].column.get()); + if (!index) + return false; + + const PaddedPODArray & index_data = index->getData(); + + const ColumnVector * col_nested = checkAndGetColumn>(&column->getData()); + if (!col_nested) + return false; + + const ColumnArray::Offsets & offsets = column->getOffsets(); + + for (size_t i = 0; i < offsets.size(); i++) + { + bool matched = false; + size_t begin = offsets[i - 1]; + size_t end = offsets[i]; + + for (size_t j = begin; j < end; j++) + { + DataType ele = col_nested->getElement(j); + + if (!CompareHelper::compare(ele, index_data[i], 0)) + { + matched_idxs.push_back(j); + matched = true; + break; + } + } + if (!matched) + matched_idxs.push_back(-1); + } + + return true; +} + +bool FunctionArrayElement::getMappedKey(const ColumnArray * column, ColumnsWithTypeAndName & arguments, std::vector &matched_idxs) +{ + if (!(executeMappedKeyNumberArgument(column, arguments, matched_idxs) + || executeMappedKeyNumberArgument(column, arguments, matched_idxs) + || executeMappedKeyNumberArgument(column, arguments, matched_idxs) + || executeMappedKeyNumberArgument(column, arguments, matched_idxs) + || executeMappedKeyNumberArgument(column, arguments, matched_idxs) + || executeMappedKeyNumberArgument(column, arguments, matched_idxs) + || executeMappedKeyNumberArgument(column, arguments, matched_idxs) + || executeMappedKeyNumberArgument(column, arguments, matched_idxs) + || executeMappedKeyStringArgument(column, arguments, matched_idxs))) + throw Exception("Second argument for function " + column->getName() + " for Map must must have UInt or Int or String type.", + ErrorCodes::ILLEGAL_COLUMN); + return true; +} + +bool FunctionArrayElement::getMappedKeyConst(const ColumnArray * column, Field & index, std::vector &matched_idxs) +{ + if (!(executeMappedKeyNumberConst(column, index, matched_idxs) + || executeMappedKeyNumberConst(column, index, matched_idxs) + || executeMappedKeyNumberConst(column, index, matched_idxs) + || executeMappedKeyNumberConst(column, index, matched_idxs) + || executeMappedKeyNumberConst(column, index, matched_idxs) + || executeMappedKeyNumberConst(column, index, matched_idxs) + || executeMappedKeyNumberConst(column, index, matched_idxs) + || executeMappedKeyNumberConst(column, index, matched_idxs) + || executeMappedKeyStringConst(column, index, matched_idxs))) throw Exception("Illegal column" + column->getName() + "of first argument , type not match", ErrorCodes::ILLEGAL_COLUMN); @@ -813,7 +910,7 @@ bool FunctionArrayElement::getMappedKey(const ColumnArray * column, Field & inde } template -bool FunctionArrayElement::executeMappedValueNumber(const ColumnArray * column, std::vector matched_idxs, +bool FunctionArrayElement::executeMappedValueNumber(const ColumnArray * column, const std::vector & matched_idxs, IColumn * col_res_untyped) { const ColumnVector * col_nested = checkAndGetColumn>(&column->getData()); @@ -830,15 +927,9 @@ bool FunctionArrayElement::executeMappedValueNumber(const ColumnArray * column, for (size_t i = 0; i < rows; i++) { if (matched_idxs[i] != -1) - { col_res->insertFrom(*col_nested, matched_idxs[i]); - } else - { - // Default value for unmatched keys - DataType default_value = -1; - col_res->insertValue(default_value); - } + col_res->insertDefault(); } return true; @@ -867,8 +958,7 @@ bool FunctionArrayElement::executeMappedValueString(const ColumnArray * column, } else { - // Default value for unmatched keys - col_res->insertData("null", 4); + col_res->insertDefault(); } } return true; @@ -897,7 +987,7 @@ bool FunctionArrayElement::executeMappedValueArray(const ColumnArray * column, s } else { - col_res->insertData("", 0); + col_res->insertDefault(); } } return true; @@ -921,36 +1011,39 @@ bool FunctionArrayElement::getMappedValue(const ColumnArray * column, std::vecto return true; } -ColumnPtr FunctionArrayElement::executeMap(ColumnsWithTypeAndName & arguments, size_t input_rows_count) +ColumnPtr FunctionArrayElement::executeMap(ColumnsWithTypeAndName & arguments, size_t /*input_rows_count*/) { const ColumnMap * col_map = typeid_cast(arguments[0].column.get()); if (!col_map) return nullptr; - const DataTypes & kv_types = (typeid_cast(*arguments[0].type)).getElements(); - const DataTypePtr & key_type = (typeid_cast(kv_types[0].get()))->getNestedType(); - const DataTypePtr & value_type = (typeid_cast(kv_types[1].get()))->getNestedType(); + const DataTypePtr & key_type = (typeid_cast(*arguments[0].type)).getKeyType(); + const DataTypePtr & value_type = (typeid_cast(*arguments[0].type)).getValueType(); - Field index = (*arguments[1].column)[0]; - - // Get Matched key's value const ColumnArray * col_keys_untyped = typeid_cast(&col_map->getColumn(0)); const ColumnArray * col_values_untyped = typeid_cast(&col_map->getColumn(1)); size_t rows = col_keys_untyped->getOffsets().size(); auto col_res_untyped = value_type->createColumn(); - if (rows > 0) + std::vector matched_idxs; + matched_idxs.reserve(rows); + + if (!isColumnConst(*arguments[1].column)) { - if (input_rows_count) - assert(input_rows_count == rows); - - std::vector matched_idxs; - if (!getMappedKey(col_keys_untyped, index, matched_idxs)) + if (rows > 0 && !getMappedKey(col_keys_untyped, arguments, matched_idxs)) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "key type unmatched, we need type '{}' failed", key_type->getName()); - - if (!getMappedValue(col_values_untyped, matched_idxs, col_res_untyped.get())) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "value type unmatched, we need type '{}' failed", value_type->getName()); } + else + { + Field index = (*arguments[1].column)[0]; + + // Get Matched key's value + if (rows > 0 && !getMappedKeyConst(col_keys_untyped, index, matched_idxs)) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "key type unmatched, we need type '{}' failed", key_type->getName()); + } + + if (rows > 0 && !getMappedValue(col_values_untyped, matched_idxs, col_res_untyped.get())) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "value type unmatched, we need type '{}' failed", value_type->getName()); return col_res_untyped; } diff --git a/src/Functions/registerFunctions.cpp b/src/Functions/registerFunctions.cpp index 6dfebc13665..3f75746f861 100644 --- a/src/Functions/registerFunctions.cpp +++ b/src/Functions/registerFunctions.cpp @@ -10,6 +10,7 @@ namespace DB void registerFunctionsArithmetic(FunctionFactory &); void registerFunctionsArray(FunctionFactory &); void registerFunctionsTuple(FunctionFactory &); +void registerFunctionsMap(FunctionFactory &); void registerFunctionsBitmap(FunctionFactory &); void registerFunctionsCoding(FunctionFactory &); void registerFunctionsComparison(FunctionFactory &); @@ -64,6 +65,7 @@ void registerFunctions() registerFunctionsArithmetic(factory); registerFunctionsArray(factory); registerFunctionsTuple(factory); + registerFunctionsMap(factory); #if !defined(ARCADIA_BUILD) registerFunctionsBitmap(factory); #endif diff --git a/src/Functions/ya.make b/src/Functions/ya.make index ed03f5175ab..5534e8582fc 100644 --- a/src/Functions/ya.make +++ b/src/Functions/ya.make @@ -280,6 +280,7 @@ SRCS( lowCardinalityKeys.cpp lower.cpp lowerUTF8.cpp + map.cpp match.cpp materialize.cpp minus.cpp diff --git a/src/Interpreters/convertFieldToType.cpp b/src/Interpreters/convertFieldToType.cpp index a04b7892375..13259164e63 100644 --- a/src/Interpreters/convertFieldToType.cpp +++ b/src/Interpreters/convertFieldToType.cpp @@ -259,8 +259,7 @@ Field convertFieldToTypeImpl(const Field & src, const IDataType & type, const ID if (src_map_size % 2) throw Exception("Bad size of map in In or VALUES section, Expected size must %2==0", ErrorCodes::BAD_ARGUMENTS); Map res(2); - const auto & key_type = *(type_map->getKeyType()); - const auto & value_type = *(type_map->getValueType()); + const auto & kv_type = type_map->getElements(); Map keys(count); Map values(count); @@ -271,8 +270,8 @@ Field convertFieldToTypeImpl(const Field & src, const IDataType & type, const ID values[i] = src_map[i * 2 + 1]; } - res[0] = convertFieldToType(keys, key_type); - res[1] = convertFieldToType(values, value_type); + res[0] = convertFieldToType(keys, *kv_type[0].get()); + res[1] = convertFieldToType(values, *kv_type[1].get()); if (res[0].isNull()) throw Exception("Bad type of key", ErrorCodes::BAD_TYPE_OF_FIELD); diff --git a/src/Parsers/ASTFunction.cpp b/src/Parsers/ASTFunction.cpp index 66565eeaf8f..5b983cb69a6 100644 --- a/src/Parsers/ASTFunction.cpp +++ b/src/Parsers/ASTFunction.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -27,14 +28,14 @@ void ASTFunction::appendColumnNameImpl(WriteBuffer & ostr) const writeChar(')', ostr); } - writeChar('(', ostr); + writeChar(name == "map" ? '{' : '(', ostr); for (auto it = arguments->children.begin(); it != arguments->children.end(); ++it) { if (it != arguments->children.begin()) writeCString(", ", ostr); (*it)->appendColumnName(ostr); } - writeChar(')', ostr); + writeChar(name == "map" ? '}' : ')', ostr); } /** Get the text that identifies this element. */ @@ -364,6 +365,19 @@ void ASTFunction::formatImplWithoutAlias(const FormatSettings & settings, Format settings.ostr << (settings.hilite ? hilite_operator : "") << ')' << (settings.hilite ? hilite_none : ""); written = true; } + + if (!written && 0 == strcmp(name.c_str(), "map")) + { + settings.ostr << (settings.hilite ? hilite_operator : "") << '{' << (settings.hilite ? hilite_none : ""); + for (size_t i = 0; i < arguments->children.size(); ++i) + { + if (i != 0) + settings.ostr << ", "; + arguments->children[i]->formatImpl(settings, state, nested_dont_need_parens); + } + settings.ostr << (settings.hilite ? hilite_operator : "") << '}' << (settings.hilite ? hilite_none : ""); + written = true; + } } if (!written) diff --git a/src/Parsers/ExpressionElementParsers.cpp b/src/Parsers/ExpressionElementParsers.cpp index dcdd1a3c1c1..a7b54fffb15 100644 --- a/src/Parsers/ExpressionElementParsers.cpp +++ b/src/Parsers/ExpressionElementParsers.cpp @@ -123,6 +123,36 @@ bool ParserParenthesisExpression::parseImpl(Pos & pos, ASTPtr & node, Expected & return true; } +bool ParserMap::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) +{ + ASTPtr contents_node; + ParserMapExpressionList contents; + if (pos->type != TokenType::OpeningCurlyBrace) + return false; + ++pos; + + if (!contents.parse(pos, contents_node, expected)) + return false; + if (pos->type != TokenType::ClosingCurlyBrace) + return false; + ++pos; + + const auto & expr_list = contents_node->as(); + + /// empty expression in parentheses is not allowed + if (expr_list.children.empty()) + { + expected.add(pos, "non-empty curlyBraced list of expressions"); + return false; + } + + auto function_node = std::make_shared(); + function_node->name = "map"; + function_node->arguments = contents_node; + function_node->children.push_back(contents_node); + node = function_node; + return true; +} bool ParserSubquery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) { @@ -1478,6 +1508,7 @@ bool ParserExpressionElement::parseImpl(Pos & pos, ASTPtr & node, Expected & exp return ParserSubquery().parse(pos, node, expected) || ParserTupleOfLiterals().parse(pos, node, expected) || ParserMapOfLiterals().parse(pos, node, expected) + || ParserMap().parse(pos, node, expected) || ParserParenthesisExpression().parse(pos, node, expected) || ParserArrayOfLiterals().parse(pos, node, expected) || ParserArray().parse(pos, node, expected) diff --git a/src/Parsers/ExpressionElementParsers.h b/src/Parsers/ExpressionElementParsers.h index ff9a1420588..0a7cc01735e 100644 --- a/src/Parsers/ExpressionElementParsers.h +++ b/src/Parsers/ExpressionElementParsers.h @@ -27,6 +27,13 @@ protected: bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; }; +class ParserMap : public IParserBase +{ +protected: + const char * getName() const override { return "map curlybraced expression"; } + bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; +}; + /** The SELECT subquery is in parenthesis. */ diff --git a/src/Parsers/ExpressionListParsers.cpp b/src/Parsers/ExpressionListParsers.cpp index 26affe020b1..612ff3ebd32 100644 --- a/src/Parsers/ExpressionListParsers.cpp +++ b/src/Parsers/ExpressionListParsers.cpp @@ -96,6 +96,7 @@ bool ParserList::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) auto list = std::make_shared(result_separator); list->children = std::move(elements); node = list; + return true; } @@ -516,6 +517,50 @@ ParserExpressionWithOptionalAlias::ParserExpressionWithOptionalAlias(bool allow_ { } +bool ParserMapExpressionList::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) +{ + ParserPtr && separator_parser_ = std::make_unique(TokenType::Comma); + ASTs elements; + + auto parse_key_value = [&] + { + ASTPtr key, value; + ParserExpression key_parser, value_parser; + + if (!key_parser.parse(pos, key, expected)) + return false; + elements.push_back(key); + + if (pos->type != TokenType::Colon) + return false; + ++pos; + + if (!value_parser.parse(pos, value, expected)) + return false; + elements.push_back(value); + + return true; + }; + + Pos begin = pos; + if (!parse_key_value()) + return false; + + while (true) + { + begin = pos; + if (!separator_parser_->ignore(pos, expected) || !parse_key_value()) + { + pos = begin; + break; + } + } + + auto list = std::make_shared(); + list->children = std::move(elements); + node = list; + return true; +} bool ParserExpressionList::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) { diff --git a/src/Parsers/ExpressionListParsers.h b/src/Parsers/ExpressionListParsers.h index 93a47648a0b..3dd22e7f01c 100644 --- a/src/Parsers/ExpressionListParsers.h +++ b/src/Parsers/ExpressionListParsers.h @@ -396,6 +396,13 @@ protected: } }; +/** A comma-separated list of map expressions, probably empty. */ +class ParserMapExpressionList : public IParserBase +{ +protected: + const char * getName() const override { return "list of map expression"; } + bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; +}; /** A comma-separated list of expressions, probably empty. */ class ParserExpressionList : public IParserBase diff --git a/tests/queries/0_stateless/01550_create_map_type.reference b/tests/queries/0_stateless/01550_create_map_type.reference index 4bd91d44a53..239b93a23d6 100644 --- a/tests/queries/0_stateless/01550_create_map_type.reference +++ b/tests/queries/0_stateless/01550_create_map_type.reference @@ -1,9 +1,25 @@ zhangsan lisi -lisi +1111 2222 +1112 2224 +1113 2226 +female zhangsan -100 -60 -100 +1116 +1117 +1118 +1119 +[] +[] +[] +[] +[] +[] +[0,2,0] [1,2,3] +[1,3,2] +[2,4,4] +[3,5,6] +[4,6,8] +[5,7,10] [100,20,90] diff --git a/tests/queries/0_stateless/01550_create_map_type.sql b/tests/queries/0_stateless/01550_create_map_type.sql index 5b5e618838b..dc0975b5945 100644 --- a/tests/queries/0_stateless/01550_create_map_type.sql +++ b/tests/queries/0_stateless/01550_create_map_type.sql @@ -5,23 +5,33 @@ insert into table_map values ({'name':'zhangsan', 'gender':'male'}), ({'name':'l select a['name'] from table_map; drop table if exists table_map; + +drop table if exists table_map; +create table table_map (a Map(String, UInt64)) engine = MergeTree() order by a; +insert into table_map select map('key1', number, 'key2', number * 2) from numbers(1111, 3); +select a['key1'], a['key2'] from table_map; +drop table if exists table_map; + -- MergeTree Engine -create table table_map (a Map(String, String)) engine = MergeTree() order by a; -insert into table_map values ({'name':'zhangsan', 'gender':'male'}), ({'name':'lisi', 'gender':'female'}); -select a['name'] from table_map; +drop table if exists table_map; +create table table_map (a Map(String, String), b String) engine = MergeTree() order by a; +insert into table_map values ({'name':'zhangsan', 'gender':'male'}, 'name'), ({'name':'lisi', 'gender':'female'}, 'gender'); +select a[b] from table_map; drop table if exists table_map; -- Int type -drop table if exists table_map2; -create table table_map2(a Map(UInt8, UInt64), b UInt8) Engine = MergeTree() order by b partition by a; -insert into table_map2 values({1: 100, 1000: 300}, 1), ({1: 60}, 2), ({2: 40, 7:90, 1:100}, 3); -select a[1] from table_map2; -drop table if exists table_map2; +drop table if exists table_map; +create table table_map(a Map(UInt8, UInt64), b UInt8) Engine = MergeTree() order by b; +insert into table_map select {number:number+5}, number from numbers(1111,4); +select a[b] from table_map; +drop table if exists table_map; -- Array Type -drop table if exists table_map3; -create table table_map3(a Map(String, Array(UInt8))) Engine = Memory; -insert into table_map3 values({'k1':[1,2,3], 'k2':[4,5,6]}), ({'k0':[], 'k1':[100,20,90]}); -select a['k1'] from table_map3; -drop table if exists table_map3; +drop table if exists table_map; +create table table_map(a Map(String, Array(UInt8))) Engine = MergeTree() order by a; +insert into table_map values({'k1':[1,2,3], 'k2':[4,5,6]}), ({'k0':[], 'k1':[100,20,90]}); +insert into table_map select {'k1' : [number, number + 2, number * 2]} from numbers(6); +insert into table_map select map('k2' , [number, number + 2, number * 2]) from numbers(6); +select a['k1'] as col1 from table_map order by col1; +drop table if exists table_map;