Add map function

This commit is contained in:
hexiaoting 2020-11-02 14:05:53 +08:00
parent 483be134b2
commit 2ba68d7494
17 changed files with 493 additions and 521 deletions

View File

@ -208,6 +208,78 @@ String FieldVisitorToString::operator() (const Map & x) const
return wb.str();
}
void FieldVisitorWriteBinary::operator() (const Null &, WriteBuffer &) const { return ; }
void FieldVisitorWriteBinary::operator() (const UInt64 & x, WriteBuffer & buf) const { DB::writeVarUInt(x, buf); }
void FieldVisitorWriteBinary::operator() (const Int64 & x, WriteBuffer & buf) const { DB::writeVarInt(x, buf); }
void FieldVisitorWriteBinary::operator() (const Float64 & x, WriteBuffer & buf) const { DB::writeFloatBinary(x, buf); }
void FieldVisitorWriteBinary::operator() (const String & x, WriteBuffer & buf) const { DB::writeStringBinary(x, buf); }
void FieldVisitorWriteBinary::operator() (const UInt128 & x, WriteBuffer & buf) const { DB::writeBinary(x, buf); }
void FieldVisitorWriteBinary::operator() (const Int128 & x, WriteBuffer & buf) const { DB::writeVarInt(x, buf); }
void FieldVisitorWriteBinary::operator() (const UInt256 & x, WriteBuffer & buf) const { DB::writeBinary(x, buf); }
void FieldVisitorWriteBinary::operator() (const Int256 & x, WriteBuffer & buf) const { DB::writeBinary(x, buf); }
void FieldVisitorWriteBinary::operator() (const DecimalField<Decimal32> & x, WriteBuffer & buf) const { DB::writeBinary(x.getValue(), buf); }
void FieldVisitorWriteBinary::operator() (const DecimalField<Decimal64> & x, WriteBuffer & buf) const { DB::writeBinary(x.getValue(), buf); }
void FieldVisitorWriteBinary::operator() (const DecimalField<Decimal128> & x, WriteBuffer & buf) const { DB::writeBinary(x.getValue(), buf); }
void FieldVisitorWriteBinary::operator() (const DecimalField<Decimal256> & x, WriteBuffer & buf) const { DB::writeBinary(x.getValue(), buf); }
void FieldVisitorWriteBinary::operator() (const AggregateFunctionStateData & x, WriteBuffer & buf) const
{
DB::writeStringBinary(x.name, buf);
DB::writeStringBinary(x.data, buf);
}
void FieldVisitorWriteBinary::operator() (const Array & x, WriteBuffer & buf) const
{
const size_t size = x.size();
DB::writeBinary(size, buf);
for (auto it = x.begin(); it != x.end(); ++it)
{
const UInt8 type = it->getType();
DB::writeBinary(type, buf);
Field::dispatch(
[&buf](const auto & value) {
DB::FieldVisitorWriteBinary()(value, buf);
},
*it);
}
}
void FieldVisitorWriteBinary::operator() (const Tuple & x, WriteBuffer & buf) const
{
const size_t size = x.size();
DB::writeBinary(size, buf);
for (auto it = x.begin(); it != x.end(); ++it)
{
const UInt8 type = it->getType();
DB::writeBinary(type, buf);
Field::dispatch(
[&buf](const auto & value) {
DB::FieldVisitorWriteBinary()(value, buf);
},
*it);
}
}
void FieldVisitorWriteBinary::operator() (const Map & x, WriteBuffer & buf) const
{
const size_t size = x.size();
DB::writeBinary(size, buf);
for (auto it = x.begin(); it != x.end(); ++it)
{
const UInt8 type = it->getType();
writeBinary(type, buf);
Field::dispatch(
[&buf](const auto & value) {
DB::FieldVisitorWriteBinary()(value, buf);
},
*it);
}
}
FieldVisitorHash::FieldVisitorHash(SipHash & hash_) : hash(hash_) {}
void FieldVisitorHash::operator() (const Null &) const

View File

@ -88,6 +88,30 @@ public:
};
class FieldVisitorWriteBinary
{
public:
void operator() (const Null & x, WriteBuffer & buf) const;
void operator() (const UInt64 & x, WriteBuffer & buf) const;
void operator() (const UInt128 & x, WriteBuffer & buf) const;
void operator() (const Int64 & x, WriteBuffer & buf) const;
void operator() (const Int128 & x, WriteBuffer & buf) const;
void operator() (const Float64 & x, WriteBuffer & buf) const;
void operator() (const String & x, WriteBuffer & buf) const;
void operator() (const Array & x, WriteBuffer & buf) const;
void operator() (const Tuple & x, WriteBuffer & buf) const;
void operator() (const Map & x, WriteBuffer & buf) const;
void operator() (const DecimalField<Decimal32> & x, WriteBuffer & buf) const;
void operator() (const DecimalField<Decimal64> & x, WriteBuffer & buf) const;
void operator() (const DecimalField<Decimal128> & x, WriteBuffer & buf) const;
void operator() (const DecimalField<Decimal256> & x, WriteBuffer & buf) const;
void operator() (const AggregateFunctionStateData & x, WriteBuffer & buf) const;
void operator() (const UInt256 & x, WriteBuffer & buf) const;
void operator() (const Int256 & x, WriteBuffer & buf) const;
};
/** Print readable and unique text dump of field type and value. */
class FieldVisitorDump : public StaticVisitor<String>
{

View File

@ -17,6 +17,63 @@ namespace ErrorCodes
extern const int DECIMAL_OVERFLOW;
}
inline Field getBinaryValue(UInt8 type, ReadBuffer & buf)
{
switch (type)
{
case Field::Types::Null: {
return DB::Field();
}
case Field::Types::UInt64: {
UInt64 value;
DB::readVarUInt(value, buf);
return value;
}
case Field::Types::UInt128: {
UInt128 value;
DB::readBinary(value, buf);
return value;
}
case Field::Types::Int64: {
Int64 value;
DB::readVarInt(value, buf);
return value;
}
case Field::Types::Float64: {
Float64 value;
DB::readFloatBinary(value, buf);
return value;
}
case Field::Types::String: {
std::string value;
DB::readStringBinary(value, buf);
return value;
}
case Field::Types::Array: {
Array value;
DB::readBinary(value, buf);
return value;
}
case Field::Types::Tuple: {
Tuple value;
DB::readBinary(value, buf);
return value;
}
case Field::Types::Map: {
Map value;
DB::readBinary(value, buf);
return value;
}
case Field::Types::AggregateFunctionState: {
AggregateFunctionStateData value;
DB::readStringBinary(value.name, buf);
DB::readStringBinary(value.data, buf);
return value;
}
}
return DB::Field();
}
void readBinary(Array & x, ReadBuffer & buf)
{
size_t size;
@ -25,80 +82,7 @@ void readBinary(Array & x, ReadBuffer & buf)
DB::readBinary(size, buf);
for (size_t index = 0; index < size; ++index)
{
switch (type)
{
case Field::Types::Null:
{
x.push_back(DB::Field());
break;
}
case Field::Types::UInt64:
{
UInt64 value;
DB::readVarUInt(value, buf);
x.push_back(value);
break;
}
case Field::Types::UInt128:
{
UInt128 value;
DB::readBinary(value, buf);
x.push_back(value);
break;
}
case Field::Types::Int64:
{
Int64 value;
DB::readVarInt(value, buf);
x.push_back(value);
break;
}
case Field::Types::Float64:
{
Float64 value;
DB::readFloatBinary(value, buf);
x.push_back(value);
break;
}
case Field::Types::String:
{
std::string value;
DB::readStringBinary(value, buf);
x.push_back(value);
break;
}
case Field::Types::Array:
{
Array value;
DB::readBinary(value, buf);
x.push_back(value);
break;
}
case Field::Types::Tuple:
{
Tuple value;
DB::readBinary(value, buf);
x.push_back(value);
break;
}
case Field::Types::Map:
{
Map value;
DB::readBinary(value, buf);
x.push_back(value);
break;
}
case Field::Types::AggregateFunctionState:
{
AggregateFunctionStateData value;
DB::readStringBinary(value.name, buf);
DB::readStringBinary(value.data, buf);
x.push_back(value);
break;
}
}
}
x.push_back(getBinaryValue(type, buf));
}
void writeBinary(const Array & x, WriteBuffer & buf)
@ -111,58 +95,7 @@ void writeBinary(const Array & x, WriteBuffer & buf)
DB::writeBinary(size, buf);
for (const auto & elem : x)
{
switch (type)
{
case Field::Types::Null: break;
case Field::Types::UInt64:
{
DB::writeVarUInt(get<UInt64>(elem), buf);
break;
}
case Field::Types::UInt128:
{
DB::writeBinary(get<UInt128>(elem), buf);
break;
}
case Field::Types::Int64:
{
DB::writeVarInt(get<Int64>(elem), buf);
break;
}
case Field::Types::Float64:
{
DB::writeFloatBinary(get<Float64>(elem), buf);
break;
}
case Field::Types::String:
{
DB::writeStringBinary(get<std::string>(elem), buf);
break;
}
case Field::Types::Array:
{
DB::writeBinary(get<Array>(elem), buf);
break;
}
case Field::Types::Tuple:
{
DB::writeBinary(get<Tuple>(elem), buf);
break;
}
case Field::Types::Map:
{
DB::writeBinary(get<Map>(elem), buf);
break;
}
case Field::Types::AggregateFunctionState:
{
DB::writeStringBinary(elem.get<AggregateFunctionStateData>().name, buf);
DB::writeStringBinary(elem.get<AggregateFunctionStateData>().data, buf);
break;
}
}
}
Field::dispatch([&buf](const auto & value) { DB::FieldVisitorWriteBinary()(value, buf); }, elem);
}
void writeText(const Array & x, WriteBuffer & buf)
@ -180,100 +113,7 @@ void readBinary(Tuple & x, ReadBuffer & buf)
{
UInt8 type;
DB::readBinary(type, buf);
switch (type)
{
case Field::Types::Null:
{
x.push_back(DB::Field());
break;
}
case Field::Types::UInt64:
{
UInt64 value;
DB::readVarUInt(value, buf);
x.push_back(value);
break;
}
case Field::Types::UInt128:
{
UInt128 value;
DB::readBinary(value, buf);
x.push_back(value);
break;
}
case Field::Types::Int64:
{
Int64 value;
DB::readVarInt(value, buf);
x.push_back(value);
break;
}
case Field::Types::Int128:
{
Int64 value;
DB::readVarInt(value, buf);
x.push_back(value);
break;
}
case Field::Types::Float64:
{
Float64 value;
DB::readFloatBinary(value, buf);
x.push_back(value);
break;
}
case Field::Types::String:
{
std::string value;
DB::readStringBinary(value, buf);
x.push_back(value);
break;
}
case Field::Types::UInt256:
{
UInt256 value;
DB::readBinary(value, buf);
x.push_back(value);
break;
}
case Field::Types::Int256:
{
Int256 value;
DB::readBinary(value, buf);
x.push_back(value);
break;
}
case Field::Types::Array:
{
Array value;
DB::readBinary(value, buf);
x.push_back(value);
break;
}
case Field::Types::Tuple:
{
Tuple value;
DB::readBinary(value, buf);
x.push_back(value);
break;
}
case Field::Types::Map:
{
Map value;
DB::readBinary(value, buf);
x.push_back(value);
break;
}
case Field::Types::AggregateFunctionState:
{
AggregateFunctionStateData value;
DB::readStringBinary(value.name, buf);
DB::readStringBinary(value.data, buf);
x.push_back(value);
break;
}
}
x.push_back(getBinaryValue(type, buf));
}
}
@ -286,67 +126,11 @@ void writeBinary(const Tuple & x, WriteBuffer & buf)
{
const UInt8 type = elem.getType();
DB::writeBinary(type, buf);
switch (type)
{
case Field::Types::Null: break;
case Field::Types::UInt64:
{
DB::writeVarUInt(get<UInt64>(elem), buf);
break;
}
case Field::Types::UInt128:
{
DB::writeBinary(get<UInt128>(elem), buf);
break;
}
case Field::Types::Int64:
{
DB::writeVarInt(get<Int64>(elem), buf);
break;
}
case Field::Types::Int128:
{
DB::writeVarInt(get<Int64>(elem), buf);
break;
}
case Field::Types::Float64:
{
DB::writeFloatBinary(get<Float64>(elem), buf);
break;
}
case Field::Types::String:
{
DB::writeStringBinary(get<std::string>(elem), buf);
break;
}
case Field::Types::UInt256:
{
DB::writeBinary(get<UInt256>(elem), buf);
break;
}
case Field::Types::Int256:
{
DB::writeBinary(get<Int256>(elem), buf);
break;
}
case Field::Types::Array:
{
DB::writeBinary(get<Array>(elem), buf);
break;
}
case Field::Types::Tuple:
{
DB::writeBinary(get<Tuple>(elem), buf);
break;
}
case Field::Types::AggregateFunctionState:
{
DB::writeStringBinary(elem.get<AggregateFunctionStateData>().name, buf);
DB::writeStringBinary(elem.get<AggregateFunctionStateData>().data, buf);
break;
}
}
Field::dispatch(
[&buf](const auto & value) {
DB::FieldVisitorWriteBinary()(value, buf);
},
elem);
}
}
@ -364,93 +148,7 @@ void readBinary(Map & x, ReadBuffer & buf)
{
UInt8 type;
DB::readBinary(type, buf);
switch (type)
{
case Field::Types::Null:
{
x.push_back(DB::Field());
break;
}
case Field::Types::UInt64:
{
UInt64 value;
DB::readVarUInt(value, buf);
x.push_back(value);
break;
}
case Field::Types::UInt128:
{
UInt128 value;
DB::readBinary(value, buf);
x.push_back(value);
break;
}
case Field::Types::Int64:
{
Int64 value;
DB::readVarInt(value, buf);
x.push_back(value);
break;
}
case Field::Types::Int128:
{
Int64 value;
DB::readVarInt(value, buf);
x.push_back(value);
break;
}
case Field::Types::Float64:
{
Float64 value;
DB::readFloatBinary(value, buf);
x.push_back(value);
break;
}
case Field::Types::String:
{
std::string value;
DB::readStringBinary(value, buf);
x.push_back(value);
break;
}
case Field::Types::UInt256:
{
UInt256 value;
DB::readBinary(value, buf);
x.push_back(value);
break;
}
case Field::Types::Int256:
{
Int256 value;
DB::readBinary(value, buf);
x.push_back(value);
break;
}
case Field::Types::Array:
{
Array value;
DB::readBinary(value, buf);
x.push_back(value);
break;
}
case Field::Types::Tuple:
{
Tuple value;
DB::readBinary(value, buf);
x.push_back(value);
break;
}
case Field::Types::AggregateFunctionState:
{
AggregateFunctionStateData value;
DB::readStringBinary(value.name, buf);
DB::readStringBinary(value.data, buf);
x.push_back(value);
break;
}
}
x.push_back(getBinaryValue(type, buf));
}
}
@ -463,67 +161,11 @@ void writeBinary(const Map & x, WriteBuffer & buf)
{
const UInt8 type = elem.getType();
DB::writeBinary(type, buf);
switch (type)
{
case Field::Types::Null: break;
case Field::Types::UInt64:
{
DB::writeVarUInt(get<UInt64>(elem), buf);
break;
}
case Field::Types::UInt128:
{
DB::writeBinary(get<UInt128>(elem), buf);
break;
}
case Field::Types::Int64:
{
DB::writeVarInt(get<Int64>(elem), buf);
break;
}
case Field::Types::Int128:
{
DB::writeVarInt(get<Int64>(elem), buf);
break;
}
case Field::Types::Float64:
{
DB::writeFloatBinary(get<Float64>(elem), buf);
break;
}
case Field::Types::String:
{
DB::writeStringBinary(get<std::string>(elem), buf);
break;
}
case Field::Types::UInt256:
{
DB::writeBinary(get<UInt256>(elem), buf);
break;
}
case Field::Types::Int256:
{
DB::writeBinary(get<Int256>(elem), buf);
break;
}
case Field::Types::Array:
{
DB::writeBinary(get<Array>(elem), buf);
break;
}
case Field::Types::Tuple:
{
DB::writeBinary(get<Tuple>(elem), buf);
break;
}
case Field::Types::AggregateFunctionState:
{
DB::writeStringBinary(elem.get<AggregateFunctionStateData>().name, buf);
DB::writeStringBinary(elem.get<AggregateFunctionStateData>().data, buf);
break;
}
}
Field::dispatch(
[&buf](const auto & value) {
DB::FieldVisitorWriteBinary()(value, buf);
},
elem);
}
}

View File

@ -35,6 +35,7 @@ namespace ErrorCodes
DataTypeMap::DataTypeMap(const DataTypes & elems_)
{
assert(elems_.size() < 3);
key_type = elems_.size() == 1 ? DataTypeFactory::instance().get("String") : elems_[0];
value_type = elems_.size() == 1 ? elems_[0] : elems_[1];
@ -47,8 +48,7 @@ DataTypeMap::DataTypeMap(const DataTypes & elems_)
std::string DataTypeMap::doGetName() const
{
WriteBufferFromOwnString s;
s << "Map(" << (typeid_cast<const DataTypeArray *>(keys.get()))->getNestedType()->getName()
<< "," << (typeid_cast<const DataTypeArray *>(values.get()))->getNestedType()->getName() << ")";
s << "Map(" << key_type->getName() << "," << value_type->getName() << ")";
return s.str();
}
@ -217,20 +217,12 @@ void DataTypeMap::deserializeText(IColumn & column, ReadBuffer & istr, const For
void DataTypeMap::serializeTextJSON(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
{
writeChar('[', ostr);
keys->serializeAsTextJSON(extractElementColumn(column, 0), row_num, ostr, settings);
writeChar(',', ostr);
values->serializeAsTextJSON(extractElementColumn(column, 1), row_num, ostr, settings);
writeChar(']', ostr);
serializeText(column, row_num, ostr, settings);
}
void DataTypeMap::deserializeTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
{
assertChar('[', istr);
keys->deserializeAsTextJSON(extractElementColumn(column, 0), istr, settings);
assertChar(',', istr);
values->deserializeAsTextJSON(extractElementColumn(column, 1), istr, settings);
assertChar(']', istr);
deserializeText(column, istr, settings);
}
void DataTypeMap::serializeTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
@ -305,11 +297,12 @@ static DeserializeBinaryBulkStateMap * checkAndGetMapDeserializeState(IDataType:
void DataTypeMap::enumerateStreams(const StreamCallback & callback, SubstreamPath & path) const
{
// path.push_back(Substream::ArraySizes);
path.push_back(Substream::MapElement);
path.back().map_element_name = "keys";
keys->enumerateStreams(callback, path);
path.back().map_element_name = "values";
keys->enumerateStreams(callback, path);
values->enumerateStreams(callback, path);
path.pop_back();
}
@ -374,6 +367,7 @@ void DataTypeMap::serializeBinaryBulkWithMultipleStreams(
const auto & keys_col = extractElementColumn(column, 0);
settings.path.back().map_element_name = "keys";
keys->serializeBinaryBulkWithMultipleStreams(keys_col, offset, limit, settings, map_state->states[0]);
const auto & values_col = extractElementColumn(column, 1);
settings.path.back().map_element_name = "values";

View File

@ -86,9 +86,8 @@ public:
bool isParametric() const override { return true; }
bool haveSubtypes() const override { return true; }
const DataTypePtr & getKeyType() const { return keys; }
const DataTypePtr & getValueType() const { return values; }
const DataTypePtr & getVType() const { return value_type; }
const DataTypePtr & getKeyType() const { return key_type; }
const DataTypePtr & getValueType() const { return value_type; }
const DataTypes & getElements() const {return kv; }
};

View File

@ -2173,15 +2173,31 @@ private:
throw Exception{"CAST AS Map can only be performed between map types with the same number of elements.\n"
"Left type: " + from_type->getName() + ", right type: " + to_type->getName(), ErrorCodes::TYPE_MISMATCH};
return []
(ColumnsWithTypeAndName & arguments, const DataTypePtr &, const ColumnNullable * /*nullable_source*/, size_t /*input_rows_count*/)
const auto & from_kv_types = from_type->getElements();
const auto & to_kv_types = to_type->getElements();
std::vector<WrapperType> element_wrappers;
element_wrappers.reserve(2);
/// Create conversion wrapper for each element in tuple
for (const auto idx_type : ext::enumerate(from_kv_types))
element_wrappers.push_back(prepareUnpackDictionaries(idx_type.second, to_kv_types[idx_type.first]));
return [element_wrappers, from_kv_types, to_kv_types]
(ColumnsWithTypeAndName & arguments, const DataTypePtr &, const ColumnNullable * nullable_source, size_t input_rows_count) -> ColumnPtr
{
const auto col = arguments.front().column.get();
const ColumnMap & column_map = typeid_cast<const ColumnMap &>(*col);
const auto * col = arguments.front().column.get();
// size_t tuple_size = from_kv_types.size();
const ColumnMap & column_tuple = typeid_cast<const ColumnMap &>(*col);
Columns converted_columns(2);
converted_columns[0] = column_map.getColumns()[0];
converted_columns[1] = column_map.getColumns()[1];
/// invoke conversion for each element
for (size_t i = 0; i < 2; ++i)
{
ColumnsWithTypeAndName element = {{column_tuple.getColumns()[i], from_kv_types[i], "" }};
converted_columns[i] = element_wrappers[i](element, to_kv_types[i], nullable_source, input_rows_count);
}
return ColumnMap::create(converted_columns);
};

View File

@ -85,15 +85,22 @@ private:
/** For a Map, the function is to find the matched key's value
*/
static bool executeMappedKeyString(const ColumnArray * column, Field & index, std::vector<int> &matched_idxs);
static bool executeMappedKeyStringArgument(const ColumnArray * column, ColumnsWithTypeAndName & arguments, std::vector<int> &matched_idxs);
static bool executeMappedKeyStringConst(const ColumnArray * column, Field & index, std::vector<int> &matched_idxs);
template <typename DataType>
static bool executeMappedKeyNumber(const ColumnArray * column, Field & index, std::vector<int> &matched_idxs);
static bool getMappedKey(const ColumnArray * column, Field & index, std::vector<int> &matched_idxs);
static bool executeMappedKeyNumberArgument(const ColumnArray * column, ColumnsWithTypeAndName & arguments, std::vector<int> &matched_idxs);
template <typename DataType>
static bool executeMappedValueNumber(const ColumnArray * column, std::vector<int> matched_idxs,
static bool executeMappedKeyNumberConst(const ColumnArray * column, Field & index, std::vector<int> &matched_idxs);
static bool getMappedKey(const ColumnArray * column, ColumnsWithTypeAndName & arguments, std::vector<int> &matched_idxs);
static bool getMappedKeyConst(const ColumnArray * column, Field & index, std::vector<int> &matched_idxs);
template <typename DataType>
static bool executeMappedValueNumber(const ColumnArray * column, const std::vector<int> & matched_idxs,
IColumn * col_res_untyped);
static bool executeMappedValueString(const ColumnArray * column, std::vector<int> matched_idxs,
@ -104,7 +111,7 @@ private:
static bool getMappedValue(const ColumnArray * column, std::vector<int> matched_idxs, IColumn * col_res_untyped);
static ColumnPtr executeMap(ColumnsWithTypeAndName & columns, size_t input_rows_count);
static ColumnPtr executeMap(ColumnsWithTypeAndName & arguments, size_t input_rows_count);
};
@ -647,10 +654,8 @@ ColumnPtr FunctionArrayElement::executeArgument(
ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, ArrayImpl::NullMapBuilder & builder, size_t input_rows_count) const
{
auto index = checkAndGetColumn<ColumnVector<IndexType>>(arguments[1].column.get());
if (!index)
return nullptr;
const auto & index_data = index->getData();
if (builder)
@ -725,7 +730,7 @@ ColumnPtr FunctionArrayElement::executeTuple(ColumnsWithTypeAndName & arguments,
return ColumnTuple::create(result_tuple_columns);
}
bool FunctionArrayElement::executeMappedKeyString(const ColumnArray * column, Field & index,
bool FunctionArrayElement::executeMappedKeyStringConst(const ColumnArray * column, Field & index,
std::vector<int> &matched_idxs)
{
const ColumnString * keys = checkAndGetColumn<ColumnString>(&column->getData());
@ -757,8 +762,45 @@ bool FunctionArrayElement::executeMappedKeyString(const ColumnArray * column, Fi
return true;
}
bool FunctionArrayElement::executeMappedKeyStringArgument(const ColumnArray * column, ColumnsWithTypeAndName & arguments, std::vector<int> &matched_idxs)
{
auto index = checkAndGetColumn<ColumnString>(arguments[1].column.get());
if (!index)
return false;
const ColumnString * keys = checkAndGetColumn<ColumnString>(&column->getData());
const ColumnArray::Offsets & offsets = column->getOffsets();
size_t rows = offsets.size();
if (!keys)
return false;
// String str = index.get<String>();
for (size_t i = 0; i < rows; i++)
{
bool matched = false;
size_t begin = offsets[i - 1];
size_t end = offsets[i];
for (size_t j = begin; j < end; j++)
{
if (strcmp(keys->getDataAt(j).data, index->getDataAt(i).data) == 0)
{
matched_idxs.push_back(j);
matched = true;
break;
}
}
if (!matched)
matched_idxs.push_back(-1);
}
return true;
}
template <typename DataType>
bool FunctionArrayElement::executeMappedKeyNumber(const ColumnArray * column, Field & index,
bool FunctionArrayElement::executeMappedKeyNumberConst(const ColumnArray * column, Field & index,
std::vector<int> &matched_idxs)
{
const ColumnVector<DataType> * col_nested = checkAndGetColumn<ColumnVector<DataType>>(&column->getData());
@ -795,17 +837,72 @@ bool FunctionArrayElement::executeMappedKeyNumber(const ColumnArray * column, Fi
return true;
}
bool FunctionArrayElement::getMappedKey(const ColumnArray * column, Field & index, std::vector<int> &matched_idxs)
template <typename DataType>
bool FunctionArrayElement::executeMappedKeyNumberArgument(const ColumnArray * column, ColumnsWithTypeAndName & arguments, std::vector<int> &matched_idxs)
{
if (!(executeMappedKeyNumber<UInt8>(column, index, matched_idxs)
|| executeMappedKeyNumber<UInt16>(column, index, matched_idxs)
|| executeMappedKeyNumber<UInt32>(column, index, matched_idxs)
|| executeMappedKeyNumber<UInt64>(column, index, matched_idxs)
|| executeMappedKeyNumber<Int8>(column, index, matched_idxs)
|| executeMappedKeyNumber<Int16>(column, index, matched_idxs)
|| executeMappedKeyNumber<Int32>(column, index, matched_idxs)
|| executeMappedKeyNumber<Int64>(column, index, matched_idxs)
|| executeMappedKeyString(column, index, matched_idxs)))
auto index = checkAndGetColumn<ColumnVector<DataType>>(arguments[1].column.get());
if (!index)
return false;
const PaddedPODArray<DataType> & index_data = index->getData();
const ColumnVector<DataType> * col_nested = checkAndGetColumn<ColumnVector<DataType>>(&column->getData());
if (!col_nested)
return false;
const ColumnArray::Offsets & offsets = column->getOffsets();
for (size_t i = 0; i < offsets.size(); i++)
{
bool matched = false;
size_t begin = offsets[i - 1];
size_t end = offsets[i];
for (size_t j = begin; j < end; j++)
{
DataType ele = col_nested->getElement(j);
if (!CompareHelper<DataType>::compare(ele, index_data[i], 0))
{
matched_idxs.push_back(j);
matched = true;
break;
}
}
if (!matched)
matched_idxs.push_back(-1);
}
return true;
}
bool FunctionArrayElement::getMappedKey(const ColumnArray * column, ColumnsWithTypeAndName & arguments, std::vector<int> &matched_idxs)
{
if (!(executeMappedKeyNumberArgument<UInt8>(column, arguments, matched_idxs)
|| executeMappedKeyNumberArgument<UInt16>(column, arguments, matched_idxs)
|| executeMappedKeyNumberArgument<UInt32>(column, arguments, matched_idxs)
|| executeMappedKeyNumberArgument<UInt64>(column, arguments, matched_idxs)
|| executeMappedKeyNumberArgument<Int8>(column, arguments, matched_idxs)
|| executeMappedKeyNumberArgument<Int16>(column, arguments, matched_idxs)
|| executeMappedKeyNumberArgument<Int32>(column, arguments, matched_idxs)
|| executeMappedKeyNumberArgument<Int64>(column, arguments, matched_idxs)
|| executeMappedKeyStringArgument(column, arguments, matched_idxs)))
throw Exception("Second argument for function " + column->getName() + " for Map must must have UInt or Int or String type.",
ErrorCodes::ILLEGAL_COLUMN);
return true;
}
bool FunctionArrayElement::getMappedKeyConst(const ColumnArray * column, Field & index, std::vector<int> &matched_idxs)
{
if (!(executeMappedKeyNumberConst<UInt8>(column, index, matched_idxs)
|| executeMappedKeyNumberConst<UInt16>(column, index, matched_idxs)
|| executeMappedKeyNumberConst<UInt32>(column, index, matched_idxs)
|| executeMappedKeyNumberConst<UInt64>(column, index, matched_idxs)
|| executeMappedKeyNumberConst<Int8>(column, index, matched_idxs)
|| executeMappedKeyNumberConst<Int16>(column, index, matched_idxs)
|| executeMappedKeyNumberConst<Int32>(column, index, matched_idxs)
|| executeMappedKeyNumberConst<Int64>(column, index, matched_idxs)
|| executeMappedKeyStringConst(column, index, matched_idxs)))
throw Exception("Illegal column" + column->getName()
+ "of first argument , type not match", ErrorCodes::ILLEGAL_COLUMN);
@ -813,7 +910,7 @@ bool FunctionArrayElement::getMappedKey(const ColumnArray * column, Field & inde
}
template <typename DataType>
bool FunctionArrayElement::executeMappedValueNumber(const ColumnArray * column, std::vector<int> matched_idxs,
bool FunctionArrayElement::executeMappedValueNumber(const ColumnArray * column, const std::vector<int> & matched_idxs,
IColumn * col_res_untyped)
{
const ColumnVector<DataType> * col_nested = checkAndGetColumn<ColumnVector<DataType>>(&column->getData());
@ -830,15 +927,9 @@ bool FunctionArrayElement::executeMappedValueNumber(const ColumnArray * column,
for (size_t i = 0; i < rows; i++)
{
if (matched_idxs[i] != -1)
{
col_res->insertFrom(*col_nested, matched_idxs[i]);
}
else
{
// Default value for unmatched keys
DataType default_value = -1;
col_res->insertValue(default_value);
}
col_res->insertDefault();
}
return true;
@ -867,8 +958,7 @@ bool FunctionArrayElement::executeMappedValueString(const ColumnArray * column,
}
else
{
// Default value for unmatched keys
col_res->insertData("null", 4);
col_res->insertDefault();
}
}
return true;
@ -897,7 +987,7 @@ bool FunctionArrayElement::executeMappedValueArray(const ColumnArray * column, s
}
else
{
col_res->insertData("", 0);
col_res->insertDefault();
}
}
return true;
@ -921,36 +1011,39 @@ bool FunctionArrayElement::getMappedValue(const ColumnArray * column, std::vecto
return true;
}
ColumnPtr FunctionArrayElement::executeMap(ColumnsWithTypeAndName & arguments, size_t input_rows_count)
ColumnPtr FunctionArrayElement::executeMap(ColumnsWithTypeAndName & arguments, size_t /*input_rows_count*/)
{
const ColumnMap * col_map = typeid_cast<const ColumnMap *>(arguments[0].column.get());
if (!col_map)
return nullptr;
const DataTypes & kv_types = (typeid_cast<const DataTypeMap &>(*arguments[0].type)).getElements();
const DataTypePtr & key_type = (typeid_cast<const DataTypeArray *>(kv_types[0].get()))->getNestedType();
const DataTypePtr & value_type = (typeid_cast<const DataTypeArray *>(kv_types[1].get()))->getNestedType();
const DataTypePtr & key_type = (typeid_cast<const DataTypeMap &>(*arguments[0].type)).getKeyType();
const DataTypePtr & value_type = (typeid_cast<const DataTypeMap &>(*arguments[0].type)).getValueType();
Field index = (*arguments[1].column)[0];
// Get Matched key's value
const ColumnArray * col_keys_untyped = typeid_cast<const ColumnArray *>(&col_map->getColumn(0));
const ColumnArray * col_values_untyped = typeid_cast<const ColumnArray *>(&col_map->getColumn(1));
size_t rows = col_keys_untyped->getOffsets().size();
auto col_res_untyped = value_type->createColumn();
if (rows > 0)
{
if (input_rows_count)
assert(input_rows_count == rows);
std::vector<int> matched_idxs;
if (!getMappedKey(col_keys_untyped, index, matched_idxs))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "key type unmatched, we need type '{}' failed", key_type->getName());
matched_idxs.reserve(rows);
if (!getMappedValue(col_values_untyped, matched_idxs, col_res_untyped.get()))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "value type unmatched, we need type '{}' failed", value_type->getName());
if (!isColumnConst(*arguments[1].column))
{
if (rows > 0 && !getMappedKey(col_keys_untyped, arguments, matched_idxs))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "key type unmatched, we need type '{}' failed", key_type->getName());
}
else
{
Field index = (*arguments[1].column)[0];
// Get Matched key's value
if (rows > 0 && !getMappedKeyConst(col_keys_untyped, index, matched_idxs))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "key type unmatched, we need type '{}' failed", key_type->getName());
}
if (rows > 0 && !getMappedValue(col_values_untyped, matched_idxs, col_res_untyped.get()))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "value type unmatched, we need type '{}' failed", value_type->getName());
return col_res_untyped;
}

View File

@ -10,6 +10,7 @@ namespace DB
void registerFunctionsArithmetic(FunctionFactory &);
void registerFunctionsArray(FunctionFactory &);
void registerFunctionsTuple(FunctionFactory &);
void registerFunctionsMap(FunctionFactory &);
void registerFunctionsBitmap(FunctionFactory &);
void registerFunctionsCoding(FunctionFactory &);
void registerFunctionsComparison(FunctionFactory &);
@ -64,6 +65,7 @@ void registerFunctions()
registerFunctionsArithmetic(factory);
registerFunctionsArray(factory);
registerFunctionsTuple(factory);
registerFunctionsMap(factory);
#if !defined(ARCADIA_BUILD)
registerFunctionsBitmap(factory);
#endif

View File

@ -280,6 +280,7 @@ SRCS(
lowCardinalityKeys.cpp
lower.cpp
lowerUTF8.cpp
map.cpp
match.cpp
materialize.cpp
minus.cpp

View File

@ -259,8 +259,7 @@ Field convertFieldToTypeImpl(const Field & src, const IDataType & type, const ID
if (src_map_size % 2)
throw Exception("Bad size of map in In or VALUES section, Expected size must %2==0", ErrorCodes::BAD_ARGUMENTS);
Map res(2);
const auto & key_type = *(type_map->getKeyType());
const auto & value_type = *(type_map->getValueType());
const auto & kv_type = type_map->getElements();
Map keys(count);
Map values(count);
@ -271,8 +270,8 @@ Field convertFieldToTypeImpl(const Field & src, const IDataType & type, const ID
values[i] = src_map[i * 2 + 1];
}
res[0] = convertFieldToType(keys, key_type);
res[1] = convertFieldToType(values, value_type);
res[0] = convertFieldToType(keys, *kv_type[0].get());
res[1] = convertFieldToType(values, *kv_type[1].get());
if (res[0].isNull())
throw Exception("Bad type of key", ErrorCodes::BAD_TYPE_OF_FIELD);

View File

@ -3,6 +3,7 @@
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTWithAlias.h>
#include <Parsers/ASTSubquery.h>
#include <Parsers/ASTExpressionList.h>
#include <IO/WriteHelpers.h>
#include <IO/WriteBufferFromString.h>
#include <Common/SipHash.h>
@ -27,14 +28,14 @@ void ASTFunction::appendColumnNameImpl(WriteBuffer & ostr) const
writeChar(')', ostr);
}
writeChar('(', ostr);
writeChar(name == "map" ? '{' : '(', ostr);
for (auto it = arguments->children.begin(); it != arguments->children.end(); ++it)
{
if (it != arguments->children.begin())
writeCString(", ", ostr);
(*it)->appendColumnName(ostr);
}
writeChar(')', ostr);
writeChar(name == "map" ? '}' : ')', ostr);
}
/** Get the text that identifies this element. */
@ -364,6 +365,19 @@ void ASTFunction::formatImplWithoutAlias(const FormatSettings & settings, Format
settings.ostr << (settings.hilite ? hilite_operator : "") << ')' << (settings.hilite ? hilite_none : "");
written = true;
}
if (!written && 0 == strcmp(name.c_str(), "map"))
{
settings.ostr << (settings.hilite ? hilite_operator : "") << '{' << (settings.hilite ? hilite_none : "");
for (size_t i = 0; i < arguments->children.size(); ++i)
{
if (i != 0)
settings.ostr << ", ";
arguments->children[i]->formatImpl(settings, state, nested_dont_need_parens);
}
settings.ostr << (settings.hilite ? hilite_operator : "") << '}' << (settings.hilite ? hilite_none : "");
written = true;
}
}
if (!written)

View File

@ -123,6 +123,36 @@ bool ParserParenthesisExpression::parseImpl(Pos & pos, ASTPtr & node, Expected &
return true;
}
bool ParserMap::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
ASTPtr contents_node;
ParserMapExpressionList contents;
if (pos->type != TokenType::OpeningCurlyBrace)
return false;
++pos;
if (!contents.parse(pos, contents_node, expected))
return false;
if (pos->type != TokenType::ClosingCurlyBrace)
return false;
++pos;
const auto & expr_list = contents_node->as<ASTExpressionList &>();
/// empty expression in parentheses is not allowed
if (expr_list.children.empty())
{
expected.add(pos, "non-empty curlyBraced list of expressions");
return false;
}
auto function_node = std::make_shared<ASTFunction>();
function_node->name = "map";
function_node->arguments = contents_node;
function_node->children.push_back(contents_node);
node = function_node;
return true;
}
bool ParserSubquery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
@ -1478,6 +1508,7 @@ bool ParserExpressionElement::parseImpl(Pos & pos, ASTPtr & node, Expected & exp
return ParserSubquery().parse(pos, node, expected)
|| ParserTupleOfLiterals().parse(pos, node, expected)
|| ParserMapOfLiterals().parse(pos, node, expected)
|| ParserMap().parse(pos, node, expected)
|| ParserParenthesisExpression().parse(pos, node, expected)
|| ParserArrayOfLiterals().parse(pos, node, expected)
|| ParserArray().parse(pos, node, expected)

View File

@ -27,6 +27,13 @@ protected:
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override;
};
class ParserMap : public IParserBase
{
protected:
const char * getName() const override { return "map curlybraced expression"; }
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override;
};
/** The SELECT subquery is in parenthesis.
*/

View File

@ -96,6 +96,7 @@ bool ParserList::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
auto list = std::make_shared<ASTExpressionList>(result_separator);
list->children = std::move(elements);
node = list;
return true;
}
@ -516,6 +517,50 @@ ParserExpressionWithOptionalAlias::ParserExpressionWithOptionalAlias(bool allow_
{
}
bool ParserMapExpressionList::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
ParserPtr && separator_parser_ = std::make_unique<ParserToken>(TokenType::Comma);
ASTs elements;
auto parse_key_value = [&]
{
ASTPtr key, value;
ParserExpression key_parser, value_parser;
if (!key_parser.parse(pos, key, expected))
return false;
elements.push_back(key);
if (pos->type != TokenType::Colon)
return false;
++pos;
if (!value_parser.parse(pos, value, expected))
return false;
elements.push_back(value);
return true;
};
Pos begin = pos;
if (!parse_key_value())
return false;
while (true)
{
begin = pos;
if (!separator_parser_->ignore(pos, expected) || !parse_key_value())
{
pos = begin;
break;
}
}
auto list = std::make_shared<ASTExpressionList>();
list->children = std::move(elements);
node = list;
return true;
}
bool ParserExpressionList::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{

View File

@ -396,6 +396,13 @@ protected:
}
};
/** A comma-separated list of map expressions, probably empty. */
class ParserMapExpressionList : public IParserBase
{
protected:
const char * getName() const override { return "list of map expression"; }
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override;
};
/** A comma-separated list of expressions, probably empty. */
class ParserExpressionList : public IParserBase

View File

@ -1,9 +1,25 @@
zhangsan
lisi
lisi
1111 2222
1112 2224
1113 2226
female
zhangsan
100
60
100
1116
1117
1118
1119
[]
[]
[]
[]
[]
[]
[0,2,0]
[1,2,3]
[1,3,2]
[2,4,4]
[3,5,6]
[4,6,8]
[5,7,10]
[100,20,90]

View File

@ -5,23 +5,33 @@ insert into table_map values ({'name':'zhangsan', 'gender':'male'}), ({'name':'l
select a['name'] from table_map;
drop table if exists table_map;
drop table if exists table_map;
create table table_map (a Map(String, UInt64)) engine = MergeTree() order by a;
insert into table_map select map('key1', number, 'key2', number * 2) from numbers(1111, 3);
select a['key1'], a['key2'] from table_map;
drop table if exists table_map;
-- MergeTree Engine
create table table_map (a Map(String, String)) engine = MergeTree() order by a;
insert into table_map values ({'name':'zhangsan', 'gender':'male'}), ({'name':'lisi', 'gender':'female'});
select a['name'] from table_map;
drop table if exists table_map;
create table table_map (a Map(String, String), b String) engine = MergeTree() order by a;
insert into table_map values ({'name':'zhangsan', 'gender':'male'}, 'name'), ({'name':'lisi', 'gender':'female'}, 'gender');
select a[b] from table_map;
drop table if exists table_map;
-- Int type
drop table if exists table_map2;
create table table_map2(a Map(UInt8, UInt64), b UInt8) Engine = MergeTree() order by b partition by a;
insert into table_map2 values({1: 100, 1000: 300}, 1), ({1: 60}, 2), ({2: 40, 7:90, 1:100}, 3);
select a[1] from table_map2;
drop table if exists table_map2;
drop table if exists table_map;
create table table_map(a Map(UInt8, UInt64), b UInt8) Engine = MergeTree() order by b;
insert into table_map select {number:number+5}, number from numbers(1111,4);
select a[b] from table_map;
drop table if exists table_map;
-- Array Type
drop table if exists table_map3;
create table table_map3(a Map(String, Array(UInt8))) Engine = Memory;
insert into table_map3 values({'k1':[1,2,3], 'k2':[4,5,6]}), ({'k0':[], 'k1':[100,20,90]});
select a['k1'] from table_map3;
drop table if exists table_map3;
drop table if exists table_map;
create table table_map(a Map(String, Array(UInt8))) Engine = MergeTree() order by a;
insert into table_map values({'k1':[1,2,3], 'k2':[4,5,6]}), ({'k0':[], 'k1':[100,20,90]});
insert into table_map select {'k1' : [number, number + 2, number * 2]} from numbers(6);
insert into table_map select map('k2' , [number, number + 2, number * 2]) from numbers(6);
select a['k1'] as col1 from table_map order by col1;
drop table if exists table_map;