diff --git a/contrib/libpqxx-cmake/CMakeLists.txt b/contrib/libpqxx-cmake/CMakeLists.txt index ae35538ccf4..65fa94cb3fd 100644 --- a/contrib/libpqxx-cmake/CMakeLists.txt +++ b/contrib/libpqxx-cmake/CMakeLists.txt @@ -22,6 +22,7 @@ set (SRCS "${LIBRARY_DIR}/src/transaction.cxx" "${LIBRARY_DIR}/src/transaction_base.cxx" "${LIBRARY_DIR}/src/row.cxx" + "${LIBRARY_DIR}/src/params.cxx" "${LIBRARY_DIR}/src/util.cxx" "${LIBRARY_DIR}/src/version.cxx" ) @@ -31,6 +32,7 @@ set (SRCS # conflicts with all includes of . set (HDRS "${LIBRARY_DIR}/include/pqxx/array.hxx" + "${LIBRARY_DIR}/include/pqxx/params.hxx" "${LIBRARY_DIR}/include/pqxx/binarystring.hxx" "${LIBRARY_DIR}/include/pqxx/composite.hxx" "${LIBRARY_DIR}/include/pqxx/connection.hxx" @@ -75,4 +77,3 @@ set(CM_CONFIG_PQ "${LIBRARY_DIR}/include/pqxx/config-internal-libpq.h") configure_file("${CM_CONFIG_H_IN}" "${CM_CONFIG_INT}" @ONLY) configure_file("${CM_CONFIG_H_IN}" "${CM_CONFIG_PUB}" @ONLY) configure_file("${CM_CONFIG_H_IN}" "${CM_CONFIG_PQ}" @ONLY) - diff --git a/docs/en/engines/table-engines/integrations/postgresql.md b/docs/en/engines/table-engines/integrations/postgresql.md index 4c763153a36..53ab3f5088c 100644 --- a/docs/en/engines/table-engines/integrations/postgresql.md +++ b/docs/en/engines/table-engines/integrations/postgresql.md @@ -34,6 +34,7 @@ The table structure can differ from the original PostgreSQL table structure: - `user` — PostgreSQL user. - `password` — User password. - `schema` — Non-default table schema. Optional. +- `on conflict ...` — example: `ON CONFLICT DO NOTHING`. Optional. Note: adding this option will make insertion less efficient. ## Implementation Details {#implementation-details} diff --git a/programs/odbc-bridge/ODBCBlockOutputStream.cpp b/programs/odbc-bridge/ODBCBlockOutputStream.cpp index b4b514d1473..8a4387c2389 100644 --- a/programs/odbc-bridge/ODBCBlockOutputStream.cpp +++ b/programs/odbc-bridge/ODBCBlockOutputStream.cpp @@ -5,40 +5,16 @@ #include #include #include -#include -#include -#include #include "getIdentifierQuote.h" #include #include #include +#include namespace DB { -namespace -{ - using ValueType = ExternalResultDescription::ValueType; - - std::string getInsertQuery(const std::string & db_name, const std::string & table_name, const ColumnsWithTypeAndName & columns, IdentifierQuotingStyle quoting) - { - ASTInsertQuery query; - query.table_id.database_name = db_name; - query.table_id.table_name = table_name; - query.columns = std::make_shared(','); - query.children.push_back(query.columns); - for (const auto & column : columns) - query.columns->children.emplace_back(std::make_shared(column.name)); - - WriteBufferFromOwnString buf; - IAST::FormatSettings settings(buf, true); - settings.always_quote_identifiers = true; - settings.identifier_quoting_style = quoting; - query.IAST::format(settings); - return buf.str(); - } -} ODBCBlockOutputStream::ODBCBlockOutputStream(nanodbc::ConnectionHolderPtr connection_holder_, const std::string & remote_database_name_, diff --git a/programs/odbc-bridge/ODBCBlockOutputStream.h b/programs/odbc-bridge/ODBCBlockOutputStream.h index 1b42119e490..16a1602d3cd 100644 --- a/programs/odbc-bridge/ODBCBlockOutputStream.h +++ b/programs/odbc-bridge/ODBCBlockOutputStream.h @@ -13,6 +13,7 @@ namespace DB class ODBCBlockOutputStream : public IBlockOutputStream { +using ValueType = ExternalResultDescription::ValueType; public: ODBCBlockOutputStream( diff --git a/src/Parsers/getInsertQuery.cpp b/src/Parsers/getInsertQuery.cpp new file mode 100644 index 00000000000..6f52056dfe2 --- /dev/null +++ b/src/Parsers/getInsertQuery.cpp @@ -0,0 +1,28 @@ +#include + +#include +#include +#include +#include + + +namespace DB +{ +std::string getInsertQuery(const std::string & db_name, const std::string & table_name, const ColumnsWithTypeAndName & columns, IdentifierQuotingStyle quoting) +{ + ASTInsertQuery query; + query.table_id.database_name = db_name; + query.table_id.table_name = table_name; + query.columns = std::make_shared(','); + query.children.push_back(query.columns); + for (const auto & column : columns) + query.columns->children.emplace_back(std::make_shared(column.name)); + + WriteBufferFromOwnString buf; + IAST::FormatSettings settings(buf, true); + settings.always_quote_identifiers = true; + settings.identifier_quoting_style = quoting; + query.IAST::format(settings); + return buf.str(); +} +} diff --git a/src/Parsers/getInsertQuery.h b/src/Parsers/getInsertQuery.h new file mode 100644 index 00000000000..0bcb5e3660b --- /dev/null +++ b/src/Parsers/getInsertQuery.h @@ -0,0 +1,8 @@ +#pragma once +#include +#include + +namespace DB +{ +std::string getInsertQuery(const std::string & db_name, const std::string & table_name, const ColumnsWithTypeAndName & columns, IdentifierQuotingStyle quoting); +} diff --git a/src/Storages/StoragePostgreSQL.cpp b/src/Storages/StoragePostgreSQL.cpp index 603a52b2801..3617e964734 100644 --- a/src/Storages/StoragePostgreSQL.cpp +++ b/src/Storages/StoragePostgreSQL.cpp @@ -29,6 +29,8 @@ #include #include #include +#include +#include namespace DB @@ -47,10 +49,12 @@ StoragePostgreSQL::StoragePostgreSQL( const ColumnsDescription & columns_, const ConstraintsDescription & constraints_, const String & comment, - const String & remote_table_schema_) + const String & remote_table_schema_, + const String & on_conflict_) : IStorage(table_id_) , remote_table_name(remote_table_name_) , remote_table_schema(remote_table_schema_) + , on_conflict(on_conflict_) , pool(std::move(pool_)) { StorageInMemoryMetadata storage_metadata; @@ -94,17 +98,22 @@ Pipe StoragePostgreSQL::read( class PostgreSQLSink : public SinkToStorage { + +using Row = std::vector>; + public: explicit PostgreSQLSink( const StorageMetadataPtr & metadata_snapshot_, postgres::ConnectionHolderPtr connection_holder_, const String & remote_table_name_, - const String & remote_table_schema_) + const String & remote_table_schema_, + const String & on_conflict_) : SinkToStorage(metadata_snapshot_->getSampleBlock()) , metadata_snapshot(metadata_snapshot_) , connection_holder(std::move(connection_holder_)) , remote_table_name(remote_table_name_) , remote_table_schema(remote_table_schema_) + , on_conflict(on_conflict_) { } @@ -113,11 +122,21 @@ public: void consume(Chunk chunk) override { auto block = getPort().getHeader().cloneWithColumns(chunk.detachColumns()); + if (!inserter) - inserter = std::make_unique(connection_holder->get(), - remote_table_schema.empty() ? pqxx::table_path({remote_table_name}) - : pqxx::table_path({remote_table_schema, remote_table_name}), - block.getNames()); + { + if (on_conflict.empty()) + { + inserter = std::make_unique(connection_holder->get(), + remote_table_schema.empty() ? pqxx::table_path({remote_table_name}) + : pqxx::table_path({remote_table_schema, remote_table_name}), block.getNames()); + } + else + { + inserter = std::make_unique(connection_holder->get(), remote_table_name, + remote_table_schema, block.getColumnsWithTypeAndName(), on_conflict); + } + } const auto columns = block.getColumns(); const size_t num_rows = block.rows(), num_cols = block.columns(); @@ -151,7 +170,7 @@ public: } } - inserter->stream.write_values(row); + inserter->insert(row); } } @@ -268,37 +287,92 @@ public: } private: - struct StreamTo + struct Inserter { + pqxx::connection & connection; pqxx::work tx; + + explicit Inserter(pqxx::connection & connection_) + : connection(connection_) + , tx(connection) {} + + virtual ~Inserter() = default; + + virtual void insert(const Row & row) = 0; + virtual void complete() = 0; + }; + + struct StreamTo : Inserter + { Names columns; pqxx::stream_to stream; - StreamTo(pqxx::connection & connection, pqxx::table_path table_, Names columns_) - : tx(connection) + StreamTo(pqxx::connection & connection_, pqxx::table_path table_, Names columns_) + : Inserter(connection_) , columns(std::move(columns_)) , stream(pqxx::stream_to::raw_table(tx, connection.quote_table(table_), connection.quote_columns(columns))) { } - void complete() + void complete() override { stream.complete(); tx.commit(); } + + void insert(const Row & row) override + { + stream.write_values(row); + } + }; + + struct PreparedInsert : Inserter + { + PreparedInsert(pqxx::connection & connection_, const String & table, const String & schema, + const ColumnsWithTypeAndName & columns, const String & on_conflict_) + : Inserter(connection_) + { + WriteBufferFromOwnString buf; + buf << getInsertQuery(schema, table, columns, IdentifierQuotingStyle::DoubleQuotes); + buf << " ("; + for (size_t i = 1; i <= columns.size(); ++i) + { + if (i > 1) + buf << ", "; + buf << "$" << i; + } + buf << ") "; + buf << on_conflict_; + connection.prepare("insert", buf.str()); + } + + void complete() override + { + connection.unprepare("insert"); + tx.commit(); + } + + void insert(const Row & row) override + { + pqxx::params params; + params.reserve(row.size()); + params.append_multi(row); + tx.exec_prepared("insert", params); + } }; StorageMetadataPtr metadata_snapshot; postgres::ConnectionHolderPtr connection_holder; - const String remote_table_name, remote_table_schema; - std::unique_ptr inserter; + const String remote_db_name, remote_table_name, remote_table_schema, on_conflict; + + std::unique_ptr inserter; }; SinkToStoragePtr StoragePostgreSQL::write( const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr /* context */) { - return std::make_shared(metadata_snapshot, pool->get(), remote_table_name, remote_table_schema); + return std::make_shared(metadata_snapshot, pool->get(), remote_table_name, remote_table_schema, on_conflict); } @@ -308,9 +382,9 @@ void registerStoragePostgreSQL(StorageFactory & factory) { ASTs & engine_args = args.engine_args; - if (engine_args.size() < 5 || engine_args.size() > 6) - throw Exception("Storage PostgreSQL requires from 5 to 6 parameters: " - "PostgreSQL('host:port', 'database', 'table', 'username', 'password' [, 'schema']", + if (engine_args.size() < 5 || engine_args.size() > 7) + throw Exception("Storage PostgreSQL requires from 5 to 7 parameters: " + "PostgreSQL('host:port', 'database', 'table', 'username', 'password' [, 'schema', 'ON CONFLICT ...']", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); for (auto & engine_arg : engine_args) @@ -326,9 +400,11 @@ void registerStoragePostgreSQL(StorageFactory & factory) const String & username = engine_args[3]->as().value.safeGet(); const String & password = engine_args[4]->as().value.safeGet(); - String remote_table_schema; - if (engine_args.size() == 6) + String remote_table_schema, on_conflict; + if (engine_args.size() >= 6) remote_table_schema = engine_args[5]->as().value.safeGet(); + if (engine_args.size() >= 7) + on_conflict = engine_args[6]->as().value.safeGet(); auto pool = std::make_shared( remote_database, @@ -345,7 +421,8 @@ void registerStoragePostgreSQL(StorageFactory & factory) args.columns, args.constraints, args.comment, - remote_table_schema); + remote_table_schema, + on_conflict); }, { .source_access_type = AccessType::POSTGRES, diff --git a/src/Storages/StoragePostgreSQL.h b/src/Storages/StoragePostgreSQL.h index bd5cd317c3d..a12b52e6e48 100644 --- a/src/Storages/StoragePostgreSQL.h +++ b/src/Storages/StoragePostgreSQL.h @@ -27,7 +27,8 @@ public: const ColumnsDescription & columns_, const ConstraintsDescription & constraints_, const String & comment, - const std::string & remote_table_schema_ = ""); + const String & remote_table_schema_ = "", + const String & on_conflict = ""); String getName() const override { return "PostgreSQL"; } @@ -47,6 +48,7 @@ private: String remote_table_name; String remote_table_schema; + String on_conflict; postgres::PoolWithFailoverPtr pool; }; diff --git a/src/TableFunctions/TableFunctionPostgreSQL.cpp b/src/TableFunctions/TableFunctionPostgreSQL.cpp index d701728479b..568cc6171fd 100644 --- a/src/TableFunctions/TableFunctionPostgreSQL.cpp +++ b/src/TableFunctions/TableFunctionPostgreSQL.cpp @@ -37,7 +37,8 @@ StoragePtr TableFunctionPostgreSQL::executeImpl(const ASTPtr & /*ast_function*/, columns, ConstraintsDescription{}, String{}, - remote_table_schema); + remote_table_schema, + on_conflict); result->startup(); return result; @@ -67,9 +68,9 @@ void TableFunctionPostgreSQL::parseArguments(const ASTPtr & ast_function, Contex ASTs & args = func_args.arguments->children; - if (args.size() < 5 || args.size() > 6) - throw Exception("Table function 'PostgreSQL' requires from 5 to 6 parameters: " - "PostgreSQL('host:port', 'database', 'table', 'user', 'password', [, 'schema']).", + if (args.size() < 5 || args.size() > 7) + throw Exception("Table function 'PostgreSQL' requires from 5 to 7 parameters: " + "PostgreSQL('host:port', 'database', 'table', 'user', 'password', [, 'schema', 'ON CONFLICT ...']).", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); for (auto & arg : args) @@ -82,8 +83,10 @@ void TableFunctionPostgreSQL::parseArguments(const ASTPtr & ast_function, Contex remote_table_name = args[2]->as().value.safeGet(); - if (args.size() == 6) + if (args.size() >= 6) remote_table_schema = args[5]->as().value.safeGet(); + if (args.size() >= 7) + on_conflict = args[6]->as().value.safeGet(); connection_pool = std::make_shared( args[1]->as().value.safeGet(), diff --git a/src/TableFunctions/TableFunctionPostgreSQL.h b/src/TableFunctions/TableFunctionPostgreSQL.h index c31d02fa955..e3810a0e391 100644 --- a/src/TableFunctions/TableFunctionPostgreSQL.h +++ b/src/TableFunctions/TableFunctionPostgreSQL.h @@ -28,7 +28,7 @@ private: void parseArguments(const ASTPtr & ast_function, ContextPtr context) override; String connection_str; - String remote_table_name, remote_table_schema; + String remote_table_name, remote_table_schema, on_conflict; postgres::PoolWithFailoverPtr connection_pool; }; diff --git a/tests/integration/test_storage_postgresql/test.py b/tests/integration/test_storage_postgresql/test.py index 28a76631c0f..bb0e284eac9 100644 --- a/tests/integration/test_storage_postgresql/test.py +++ b/tests/integration/test_storage_postgresql/test.py @@ -291,7 +291,7 @@ def test_postgres_distributed(started_cluster): node2.query('DROP TABLE test_shards') node2.query('DROP TABLE test_replicas') - + def test_datetime_with_timezone(started_cluster): cursor = started_cluster.postgres_conn.cursor() cursor.execute("DROP TABLE IF EXISTS test_timezone") @@ -328,6 +328,32 @@ def test_postgres_ndim(started_cluster): cursor.execute("DROP TABLE arr1, arr2") +def test_postgres_on_conflict(started_cluster): + cursor = started_cluster.postgres_conn.cursor() + table = 'test_conflict' + cursor.execute(f'DROP TABLE IF EXISTS {table}') + cursor.execute(f'CREATE TABLE {table} (a integer PRIMARY KEY, b text, c integer)') + + node1.query(''' + CREATE TABLE test_conflict (a UInt32, b String, c Int32) + ENGINE PostgreSQL('postgres1:5432', 'postgres', 'test_conflict', 'postgres', 'mysecretpassword', '', 'ON CONFLICT DO NOTHING'); + ''') + node1.query(f''' INSERT INTO {table} SELECT number, concat('name_', toString(number)), 3 from numbers(100)''') + node1.query(f''' INSERT INTO {table} SELECT number, concat('name_', toString(number)), 4 from numbers(100)''') + + check1 = f"SELECT count() FROM {table}" + assert (node1.query(check1)).rstrip() == '100' + + table_func = f'''postgresql('{started_cluster.postgres_ip}:{started_cluster.postgres_port}', 'postgres', '{table}', 'postgres', 'mysecretpassword', '', 'ON CONFLICT DO NOTHING')''' + node1.query(f'''INSERT INTO TABLE FUNCTION {table_func} SELECT number, concat('name_', toString(number)), 3 from numbers(100)''') + node1.query(f'''INSERT INTO TABLE FUNCTION {table_func} SELECT number, concat('name_', toString(number)), 3 from numbers(100)''') + + check1 = f"SELECT count() FROM {table}" + assert (node1.query(check1)).rstrip() == '100' + + cursor.execute(f'DROP TABLE {table} ') + + if __name__ == '__main__': cluster.start() input("Cluster created, press any key to destroy...")