diff --git a/src/Processors/Sources/MongoDBSource.cpp b/src/Processors/Sources/MongoDBSource.cpp index a8bfefdf8a6..8ebedc3e877 100644 --- a/src/Processors/Sources/MongoDBSource.cpp +++ b/src/Processors/Sources/MongoDBSource.cpp @@ -6,7 +6,9 @@ #include #include #include +#include +#include #include #include #include @@ -17,6 +19,9 @@ #include #include +#include +#include + // only after poco // naming conflict: // Poco/MongoDB/BSONWriter.h:54: void writeCString(const std::string & value); @@ -33,6 +38,11 @@ namespace ErrorCodes extern const int MONGODB_ERROR; } +namespace +{ + void prepareMongoDBArrayInfo( + std::unordered_map & array_info, size_t column_idx, const DataTypePtr data_type); +} std::unique_ptr createCursor(const std::string & database, const std::string & collection, const Block & sample_block_to_select) { @@ -58,6 +68,10 @@ MongoDBSource::MongoDBSource( , max_block_size{max_block_size_} { description.init(sample_block); + + for (const auto idx : collections::range(0, description.sample_block.columns())) + if (description.types[idx].first == ExternalResultDescription::ValueType::vtArray) + prepareMongoDBArrayInfo(array_info, idx, description.sample_block.getByPosition(idx).type); } @@ -68,6 +82,7 @@ namespace { using ValueType = ExternalResultDescription::ValueType; using ObjectId = Poco::MongoDB::ObjectId; + using MongoArray = Poco::MongoDB::Array; template void insertNumber(IColumn & column, const Poco::MongoDB::Element & value, const std::string & name) @@ -103,7 +118,129 @@ namespace } } - void insertValue(IColumn & column, const ValueType type, const Poco::MongoDB::Element & value, const std::string & name) + template + Field getNumber(const Poco::MongoDB::Element & value, const std::string & name) + { + switch (value.type()) + { + case Poco::MongoDB::ElementTraits::TypeId: + return static_cast(static_cast &>(value).value()); + case Poco::MongoDB::ElementTraits::TypeId: + return static_cast(static_cast &>(value).value()); + case Poco::MongoDB::ElementTraits::TypeId: + return static_cast(static_cast &>(value).value()); + case Poco::MongoDB::ElementTraits::TypeId: + return static_cast(static_cast &>(value).value()); + case Poco::MongoDB::ElementTraits::TypeId: + return Field(); + case Poco::MongoDB::ElementTraits::TypeId: + return parse(static_cast &>(value).value()); + default: + throw Exception(ErrorCodes::TYPE_MISMATCH, "Type mismatch, expected a number, got type id = {} for column {}", + toString(value.type()), name); + } + } + + void prepareMongoDBArrayInfo( + std::unordered_map & array_info, size_t column_idx, const DataTypePtr data_type) + { + const auto * array_type = typeid_cast(data_type.get()); + auto nested = array_type->getNestedType(); + + size_t count_dimensions = 1; + while (isArray(nested)) + { + ++count_dimensions; + nested = typeid_cast(nested.get())->getNestedType(); + } + + Field default_value = nested->getDefault(); + if (nested->isNullable()) + nested = static_cast(nested.get())->getNestedType(); + + WhichDataType which(nested); + std::function parser; + + if (which.isUInt8()) + parser = [](const Poco::MongoDB::Element & value, const std::string & name) -> Field { return getNumber(value, name); }; + else if (which.isUInt16()) + parser = [](const Poco::MongoDB::Element & value, const std::string & name) -> Field { return getNumber(value, name); }; + else if (which.isUInt32()) + parser = [](const Poco::MongoDB::Element & value, const std::string & name) -> Field { return getNumber(value, name); }; + else if (which.isUInt64()) + parser = [](const Poco::MongoDB::Element & value, const std::string & name) -> Field { return getNumber(value, name); }; + else if (which.isInt8()) + parser = [](const Poco::MongoDB::Element & value, const std::string & name) -> Field { return getNumber(value, name); }; + else if (which.isInt16()) + parser = [](const Poco::MongoDB::Element & value, const std::string & name) -> Field { return getNumber(value, name); }; + else if (which.isInt32()) + parser = [](const Poco::MongoDB::Element & value, const std::string & name) -> Field { return getNumber(value, name); }; + else if (which.isInt64()) + parser = [](const Poco::MongoDB::Element & value, const std::string & name) -> Field { return getNumber(value, name); }; + else if (which.isFloat32()) + parser = [](const Poco::MongoDB::Element & value, const std::string & name) -> Field { return getNumber(value, name); }; + else if (which.isFloat64()) + parser = [](const Poco::MongoDB::Element & value, const std::string & name) -> Field { return getNumber(value, name); }; + else if (which.isString() || which.isFixedString()) + parser = [](const Poco::MongoDB::Element & value, const std::string & name) -> Field + { + if (value.type() == Poco::MongoDB::ElementTraits::TypeId) + { + String string_id = value.toString(); + return Field(string_id.data(), string_id.size()); + } + else if (value.type() == Poco::MongoDB::ElementTraits::TypeId) + { + String string = static_cast &>(value).value(); + return Field(string.data(), string.size()); + } + + throw Exception(ErrorCodes::TYPE_MISMATCH, "Type mismatch, expected String, got type id = {} for column {}", + toString(value.type()), name); + }; + else if (which.isDate()) + parser = [](const Poco::MongoDB::Element & value, const std::string & name) -> Field + { + if (value.type() != Poco::MongoDB::ElementTraits::TypeId) + throw Exception(ErrorCodes::TYPE_MISMATCH, "Type mismatch, expected Timestamp, got type id = {} for column {}", + toString(value.type()), name); + + return static_cast(DateLUT::instance().toDayNum( + static_cast &>(value).value().epochTime())); + }; + else if (which.isDateTime()) + parser = [](const Poco::MongoDB::Element & value, const std::string & name) -> Field + { + if (value.type() != Poco::MongoDB::ElementTraits::TypeId) + throw Exception(ErrorCodes::TYPE_MISMATCH, "Type mismatch, expected Timestamp, got type id = {} for column {}", + toString(value.type()), name); + + return static_cast(static_cast &>(value).value().epochTime()); + }; + else if (which.isUUID()) + parser = [](const Poco::MongoDB::Element & value, const std::string & name) -> Field + { + if (value.type() != Poco::MongoDB::ElementTraits::TypeId) + throw Exception(ErrorCodes::TYPE_MISMATCH, "Type mismatch, expected String (UUID), got type id = {} for column {}", + toString(value.type()), name); + + String string = static_cast &>(value).value(); + return parse(string); + }; + else + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Type conversion to {} is not supported", nested->getName()); + + array_info[column_idx] = {count_dimensions, default_value, parser}; + + } + + void insertValue( + IColumn & column, + const ValueType type, + const Poco::MongoDB::Element & value, + const std::string & name, + std::unordered_map & array_info, + size_t idx) { switch (type) { @@ -192,8 +329,67 @@ namespace toString(value.type()), name); break; } + case ValueType::vtArray: + { + if (value.type() != Poco::MongoDB::ElementTraits::TypeId) + throw Exception(ErrorCodes::TYPE_MISMATCH, "Type mismatch, expected Array, got type id = {} for column {}", + toString(value.type()), name); + + size_t max_dimension = 0, expected_dimensions = array_info[idx].num_dimensions; + const auto parse_value = array_info[idx].parser; + std::vector dimensions(expected_dimensions + 1); + + auto array = static_cast &>(value).value(); + + std::vector> arrays; + arrays.emplace_back(&value, 0); + + while (!arrays.empty()) + { + size_t dimension = arrays.size(); + max_dimension = std::max(max_dimension, dimension); + + auto [element, i] = arrays.back(); + + auto parent = static_cast &>(*element).value(); + + if (i >= parent->size()) + { + dimensions[dimension].emplace_back(Array(dimensions[dimension + 1].begin(), dimensions[dimension + 1].end())); + dimensions[dimension + 1].clear(); + + arrays.pop_back(); + continue; + } + + Poco::MongoDB::Element::Ptr child = parent->get(static_cast(i)); + + if (child->type() == Poco::MongoDB::ElementTraits::TypeId) + { + if (dimension + 1 > expected_dimensions) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Got more dimensions than expected"); + + arrays.back().second += 1; + arrays.emplace_back(child.get(), 0); + } + else + { + dimensions[dimension].emplace_back(parse_value(*child, name)); + } + } + + if (max_dimension < expected_dimensions) + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Got less dimensions than expected. ({} instead of {})", max_dimension, expected_dimensions); + + // TODO: default value + + assert_cast(column).insert(Array(dimensions[1].begin(), dimensions[1].end())); + break; + + } default: - throw Exception(ErrorCodes::UNKNOWN_TYPE, "Value of unsupported type:{}", column.getName()); + throw Exception(ErrorCodes::UNKNOWN_TYPE, "Value of unsupported type: {}", column.getName()); } } @@ -252,11 +448,11 @@ Chunk MongoDBSource::generate() if (is_nullable) { ColumnNullable & column_nullable = assert_cast(*columns[idx]); - insertValue(column_nullable.getNestedColumn(), description.types[idx].first, *value, name); + insertValue(column_nullable.getNestedColumn(), description.types[idx].first, *value, name, array_info, idx); column_nullable.getNullMapData().emplace_back(0); } else - insertValue(*columns[idx], description.types[idx].first, *value, name); + insertValue(*columns[idx], description.types[idx].first, *value, name, array_info, idx); } } } diff --git a/src/Processors/Sources/MongoDBSource.h b/src/Processors/Sources/MongoDBSource.h index d03a7a45477..ec73f00f378 100644 --- a/src/Processors/Sources/MongoDBSource.h +++ b/src/Processors/Sources/MongoDBSource.h @@ -19,6 +19,13 @@ namespace MongoDB namespace DB { +struct MongoDBArrayInfo +{ + size_t num_dimensions; + Field default_value; + std::function parser; +}; + void authenticate(Poco::MongoDB::Connection & connection, const std::string & database, const std::string & user, const std::string & password); std::unique_ptr createCursor(const std::string & database, const std::string & collection, const Block & sample_block_to_select); @@ -45,6 +52,8 @@ private: const UInt64 max_block_size; ExternalResultDescription description; bool all_read = false; + + std::unordered_map array_info; }; } diff --git a/tests/integration/test_storage_mongodb/test.py b/tests/integration/test_storage_mongodb/test.py index 74b2b15fda0..cf843ddd489 100644 --- a/tests/integration/test_storage_mongodb/test.py +++ b/tests/integration/test_storage_mongodb/test.py @@ -70,6 +70,81 @@ def test_simple_select(started_cluster): simple_mongo_table.drop() +@pytest.mark.parametrize("started_cluster", [False], indirect=["started_cluster"]) +def test_arrays(started_cluster): + mongo_connection = get_mongo_connection(started_cluster) + db = mongo_connection["test"] + db.add_user("root", "clickhouse") + simple_mongo_table = db["simple_table"] + data = [] + for i in range(0, 100): + data.append({ + "key": i, + "arr_int64": [- (i + 1), - (i + 2), - (i + 3)], + "arr_int32": [- (i + 1), - (i + 2), - (i + 3)], + "arr_int16": [- (i + 1), - (i + 2), - (i + 3)], + "arr_int8": [- (i + 1), - (i + 2), - (i + 3)], + "arr_uint64": [i + 1, i + 2, i + 3], + "arr_uint32": [i + 1, i + 2, i + 3], + "arr_uint16": [i + 1, i + 2, i + 3], + "arr_uint8": [i + 1, i + 2, i + 3], + "arr_float32": [i + 1.125, i + 2.5, i + 3.750], + "arr_float64": [i + 1.125, i + 2.5, i + 3.750], + "arr_date": ['2023-11-01', '2023-06-19'], + "arr_datetime": ['2023-03-31 06:03:12', '2023-02-01 12:46:34'], + "arr_string": [str(i + 1), str(i + 2), str(i + 3)], + "arr_uuid": ['f0e77736-91d1-48ce-8f01-15123ca1c7ed', '93376a07-c044-4281-a76e-ad27cf6973c5'], + "arr_arr_bool": [[True, False, True]] + }) + + simple_mongo_table.insert_many(data) + + node = started_cluster.instances["node"] + node.query( + "CREATE TABLE simple_mongo_table(" + "key UInt64," + "arr_int64 Array(Int64)," + "arr_int32 Array(Int32)," + "arr_int16 Array(Int16)," + "arr_int8 Array(Int8)," + "arr_uint64 Array(UInt64)," + "arr_uint32 Array(UInt32)," + "arr_uint16 Array(UInt16)," + "arr_uint8 Array(UInt8)," + "arr_float32 Array(Float32)," + "arr_float64 Array(Float64)," + "arr_date Array(Date)," + "arr_datetime Array(DateTime)," + "arr_string Array(String)," + "arr_uuid Array(UUID)," + "arr_arr_bool Array(Array(Bool))" + ") ENGINE = MongoDB('mongo1:27017', 'test', 'simple_table', 'root', 'clickhouse')" + ) + + assert node.query("SELECT COUNT() FROM simple_mongo_table") == "100\n" + + for column_name in ["arr_int64", "arr_int32", "arr_int16", "arr_int8"]: + assert ( + node.query(f"SELECT {column_name} from simple_mongo_table where key = 42") + == "[-43,-44,-45]\n" + ) + + for column_name in ["arr_uint64", "arr_uint32", "arr_uint16", "arr_uint8"]: + assert ( + node.query(f"SELECT {column_name} from simple_mongo_table where key = 42") + == "[43,44,45]\n" + ) + + for column_name in ["arr_float32", "arr_float64"]: + assert ( + node.query(f"SELECT {column_name} from simple_mongo_table where key = 42") + == "[43,44,45]\n" + ) + + node.query("DROP TABLE simple_mongo_table") + simple_mongo_table.drop() + + @pytest.mark.parametrize("started_cluster", [False], indirect=["started_cluster"]) def test_complex_data_type(started_cluster): mongo_connection = get_mongo_connection(started_cluster)