Merge pull request #28081 from kssenii/pg-conflict

Support `on conflict` for postgres engine
This commit is contained in:
Kseniia Sumarokova 2021-08-26 16:30:30 +03:00 committed by GitHub
commit 31afd7d09c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 177 additions and 54 deletions

View File

@ -22,6 +22,7 @@ set (SRCS
"${LIBRARY_DIR}/src/transaction.cxx" "${LIBRARY_DIR}/src/transaction.cxx"
"${LIBRARY_DIR}/src/transaction_base.cxx" "${LIBRARY_DIR}/src/transaction_base.cxx"
"${LIBRARY_DIR}/src/row.cxx" "${LIBRARY_DIR}/src/row.cxx"
"${LIBRARY_DIR}/src/params.cxx"
"${LIBRARY_DIR}/src/util.cxx" "${LIBRARY_DIR}/src/util.cxx"
"${LIBRARY_DIR}/src/version.cxx" "${LIBRARY_DIR}/src/version.cxx"
) )
@ -31,6 +32,7 @@ set (SRCS
# conflicts with all includes of <array>. # conflicts with all includes of <array>.
set (HDRS set (HDRS
"${LIBRARY_DIR}/include/pqxx/array.hxx" "${LIBRARY_DIR}/include/pqxx/array.hxx"
"${LIBRARY_DIR}/include/pqxx/params.hxx"
"${LIBRARY_DIR}/include/pqxx/binarystring.hxx" "${LIBRARY_DIR}/include/pqxx/binarystring.hxx"
"${LIBRARY_DIR}/include/pqxx/composite.hxx" "${LIBRARY_DIR}/include/pqxx/composite.hxx"
"${LIBRARY_DIR}/include/pqxx/connection.hxx" "${LIBRARY_DIR}/include/pqxx/connection.hxx"
@ -75,4 +77,3 @@ set(CM_CONFIG_PQ "${LIBRARY_DIR}/include/pqxx/config-internal-libpq.h")
configure_file("${CM_CONFIG_H_IN}" "${CM_CONFIG_INT}" @ONLY) configure_file("${CM_CONFIG_H_IN}" "${CM_CONFIG_INT}" @ONLY)
configure_file("${CM_CONFIG_H_IN}" "${CM_CONFIG_PUB}" @ONLY) configure_file("${CM_CONFIG_H_IN}" "${CM_CONFIG_PUB}" @ONLY)
configure_file("${CM_CONFIG_H_IN}" "${CM_CONFIG_PQ}" @ONLY) configure_file("${CM_CONFIG_H_IN}" "${CM_CONFIG_PQ}" @ONLY)

View File

@ -34,6 +34,7 @@ The table structure can differ from the original PostgreSQL table structure:
- `user` — PostgreSQL user. - `user` — PostgreSQL user.
- `password` — User password. - `password` — User password.
- `schema` — Non-default table schema. Optional. - `schema` — Non-default table schema. Optional.
- `on conflict ...` — example: `ON CONFLICT DO NOTHING`. Optional. Note: adding this option will make insertion less efficient.
## Implementation Details {#implementation-details} ## Implementation Details {#implementation-details}

View File

@ -5,40 +5,16 @@
#include <Core/Field.h> #include <Core/Field.h>
#include <common/LocalDate.h> #include <common/LocalDate.h>
#include <common/LocalDateTime.h> #include <common/LocalDateTime.h>
#include <Parsers/ASTInsertQuery.h>
#include <Parsers/ASTExpressionList.h>
#include <Parsers/ASTIdentifier.h>
#include "getIdentifierQuote.h" #include "getIdentifierQuote.h"
#include <IO/WriteHelpers.h> #include <IO/WriteHelpers.h>
#include <IO/Operators.h> #include <IO/Operators.h>
#include <Formats/FormatFactory.h> #include <Formats/FormatFactory.h>
#include <Parsers/getInsertQuery.h>
namespace DB namespace DB
{ {
namespace
{
using ValueType = ExternalResultDescription::ValueType;
std::string getInsertQuery(const std::string & db_name, const std::string & table_name, const ColumnsWithTypeAndName & columns, IdentifierQuotingStyle quoting)
{
ASTInsertQuery query;
query.table_id.database_name = db_name;
query.table_id.table_name = table_name;
query.columns = std::make_shared<ASTExpressionList>(',');
query.children.push_back(query.columns);
for (const auto & column : columns)
query.columns->children.emplace_back(std::make_shared<ASTIdentifier>(column.name));
WriteBufferFromOwnString buf;
IAST::FormatSettings settings(buf, true);
settings.always_quote_identifiers = true;
settings.identifier_quoting_style = quoting;
query.IAST::format(settings);
return buf.str();
}
}
ODBCBlockOutputStream::ODBCBlockOutputStream(nanodbc::ConnectionHolderPtr connection_holder_, ODBCBlockOutputStream::ODBCBlockOutputStream(nanodbc::ConnectionHolderPtr connection_holder_,
const std::string & remote_database_name_, const std::string & remote_database_name_,

View File

@ -13,6 +13,7 @@ namespace DB
class ODBCBlockOutputStream : public IBlockOutputStream class ODBCBlockOutputStream : public IBlockOutputStream
{ {
using ValueType = ExternalResultDescription::ValueType;
public: public:
ODBCBlockOutputStream( ODBCBlockOutputStream(

View File

@ -0,0 +1,28 @@
#include <Parsers/getInsertQuery.h>
#include <Parsers/ASTInsertQuery.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTExpressionList.h>
#include <IO/WriteBufferFromString.h>
namespace DB
{
std::string getInsertQuery(const std::string & db_name, const std::string & table_name, const ColumnsWithTypeAndName & columns, IdentifierQuotingStyle quoting)
{
ASTInsertQuery query;
query.table_id.database_name = db_name;
query.table_id.table_name = table_name;
query.columns = std::make_shared<ASTExpressionList>(',');
query.children.push_back(query.columns);
for (const auto & column : columns)
query.columns->children.emplace_back(std::make_shared<ASTIdentifier>(column.name));
WriteBufferFromOwnString buf;
IAST::FormatSettings settings(buf, true);
settings.always_quote_identifiers = true;
settings.identifier_quoting_style = quoting;
query.IAST::format(settings);
return buf.str();
}
}

View File

@ -0,0 +1,8 @@
#pragma once
#include <Core/ColumnsWithTypeAndName.h>
#include <Parsers/IdentifierQuotingStyle.h>
namespace DB
{
std::string getInsertQuery(const std::string & db_name, const std::string & table_name, const ColumnsWithTypeAndName & columns, IdentifierQuotingStyle quoting);
}

View File

@ -29,6 +29,8 @@
#include <Processors/Pipe.h> #include <Processors/Pipe.h>
#include <Processors/Sinks/SinkToStorage.h> #include <Processors/Sinks/SinkToStorage.h>
#include <IO/WriteHelpers.h> #include <IO/WriteHelpers.h>
#include <Parsers/getInsertQuery.h>
#include <IO/Operators.h>
namespace DB namespace DB
@ -47,10 +49,12 @@ StoragePostgreSQL::StoragePostgreSQL(
const ColumnsDescription & columns_, const ColumnsDescription & columns_,
const ConstraintsDescription & constraints_, const ConstraintsDescription & constraints_,
const String & comment, const String & comment,
const String & remote_table_schema_) const String & remote_table_schema_,
const String & on_conflict_)
: IStorage(table_id_) : IStorage(table_id_)
, remote_table_name(remote_table_name_) , remote_table_name(remote_table_name_)
, remote_table_schema(remote_table_schema_) , remote_table_schema(remote_table_schema_)
, on_conflict(on_conflict_)
, pool(std::move(pool_)) , pool(std::move(pool_))
{ {
StorageInMemoryMetadata storage_metadata; StorageInMemoryMetadata storage_metadata;
@ -94,17 +98,22 @@ Pipe StoragePostgreSQL::read(
class PostgreSQLSink : public SinkToStorage class PostgreSQLSink : public SinkToStorage
{ {
using Row = std::vector<std::optional<std::string>>;
public: public:
explicit PostgreSQLSink( explicit PostgreSQLSink(
const StorageMetadataPtr & metadata_snapshot_, const StorageMetadataPtr & metadata_snapshot_,
postgres::ConnectionHolderPtr connection_holder_, postgres::ConnectionHolderPtr connection_holder_,
const String & remote_table_name_, const String & remote_table_name_,
const String & remote_table_schema_) const String & remote_table_schema_,
const String & on_conflict_)
: SinkToStorage(metadata_snapshot_->getSampleBlock()) : SinkToStorage(metadata_snapshot_->getSampleBlock())
, metadata_snapshot(metadata_snapshot_) , metadata_snapshot(metadata_snapshot_)
, connection_holder(std::move(connection_holder_)) , connection_holder(std::move(connection_holder_))
, remote_table_name(remote_table_name_) , remote_table_name(remote_table_name_)
, remote_table_schema(remote_table_schema_) , remote_table_schema(remote_table_schema_)
, on_conflict(on_conflict_)
{ {
} }
@ -113,11 +122,21 @@ public:
void consume(Chunk chunk) override void consume(Chunk chunk) override
{ {
auto block = getPort().getHeader().cloneWithColumns(chunk.detachColumns()); auto block = getPort().getHeader().cloneWithColumns(chunk.detachColumns());
if (!inserter) if (!inserter)
inserter = std::make_unique<StreamTo>(connection_holder->get(), {
remote_table_schema.empty() ? pqxx::table_path({remote_table_name}) if (on_conflict.empty())
: pqxx::table_path({remote_table_schema, remote_table_name}), {
block.getNames()); inserter = std::make_unique<StreamTo>(connection_holder->get(),
remote_table_schema.empty() ? pqxx::table_path({remote_table_name})
: pqxx::table_path({remote_table_schema, remote_table_name}), block.getNames());
}
else
{
inserter = std::make_unique<PreparedInsert>(connection_holder->get(), remote_table_name,
remote_table_schema, block.getColumnsWithTypeAndName(), on_conflict);
}
}
const auto columns = block.getColumns(); const auto columns = block.getColumns();
const size_t num_rows = block.rows(), num_cols = block.columns(); const size_t num_rows = block.rows(), num_cols = block.columns();
@ -151,7 +170,7 @@ public:
} }
} }
inserter->stream.write_values(row); inserter->insert(row);
} }
} }
@ -268,37 +287,92 @@ public:
} }
private: private:
struct StreamTo struct Inserter
{ {
pqxx::connection & connection;
pqxx::work tx; pqxx::work tx;
explicit Inserter(pqxx::connection & connection_)
: connection(connection_)
, tx(connection) {}
virtual ~Inserter() = default;
virtual void insert(const Row & row) = 0;
virtual void complete() = 0;
};
struct StreamTo : Inserter
{
Names columns; Names columns;
pqxx::stream_to stream; pqxx::stream_to stream;
StreamTo(pqxx::connection & connection, pqxx::table_path table_, Names columns_) StreamTo(pqxx::connection & connection_, pqxx::table_path table_, Names columns_)
: tx(connection) : Inserter(connection_)
, columns(std::move(columns_)) , columns(std::move(columns_))
, stream(pqxx::stream_to::raw_table(tx, connection.quote_table(table_), connection.quote_columns(columns))) , stream(pqxx::stream_to::raw_table(tx, connection.quote_table(table_), connection.quote_columns(columns)))
{ {
} }
void complete() void complete() override
{ {
stream.complete(); stream.complete();
tx.commit(); tx.commit();
} }
void insert(const Row & row) override
{
stream.write_values(row);
}
};
struct PreparedInsert : Inserter
{
PreparedInsert(pqxx::connection & connection_, const String & table, const String & schema,
const ColumnsWithTypeAndName & columns, const String & on_conflict_)
: Inserter(connection_)
{
WriteBufferFromOwnString buf;
buf << getInsertQuery(schema, table, columns, IdentifierQuotingStyle::DoubleQuotes);
buf << " (";
for (size_t i = 1; i <= columns.size(); ++i)
{
if (i > 1)
buf << ", ";
buf << "$" << i;
}
buf << ") ";
buf << on_conflict_;
connection.prepare("insert", buf.str());
}
void complete() override
{
connection.unprepare("insert");
tx.commit();
}
void insert(const Row & row) override
{
pqxx::params params;
params.reserve(row.size());
params.append_multi(row);
tx.exec_prepared("insert", params);
}
}; };
StorageMetadataPtr metadata_snapshot; StorageMetadataPtr metadata_snapshot;
postgres::ConnectionHolderPtr connection_holder; postgres::ConnectionHolderPtr connection_holder;
const String remote_table_name, remote_table_schema; const String remote_db_name, remote_table_name, remote_table_schema, on_conflict;
std::unique_ptr<StreamTo> inserter;
std::unique_ptr<Inserter> inserter;
}; };
SinkToStoragePtr StoragePostgreSQL::write( SinkToStoragePtr StoragePostgreSQL::write(
const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr /* context */) const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr /* context */)
{ {
return std::make_shared<PostgreSQLSink>(metadata_snapshot, pool->get(), remote_table_name, remote_table_schema); return std::make_shared<PostgreSQLSink>(metadata_snapshot, pool->get(), remote_table_name, remote_table_schema, on_conflict);
} }
@ -308,9 +382,9 @@ void registerStoragePostgreSQL(StorageFactory & factory)
{ {
ASTs & engine_args = args.engine_args; ASTs & engine_args = args.engine_args;
if (engine_args.size() < 5 || engine_args.size() > 6) if (engine_args.size() < 5 || engine_args.size() > 7)
throw Exception("Storage PostgreSQL requires from 5 to 6 parameters: " throw Exception("Storage PostgreSQL requires from 5 to 7 parameters: "
"PostgreSQL('host:port', 'database', 'table', 'username', 'password' [, 'schema']", "PostgreSQL('host:port', 'database', 'table', 'username', 'password' [, 'schema', 'ON CONFLICT ...']",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
for (auto & engine_arg : engine_args) for (auto & engine_arg : engine_args)
@ -326,9 +400,11 @@ void registerStoragePostgreSQL(StorageFactory & factory)
const String & username = engine_args[3]->as<ASTLiteral &>().value.safeGet<String>(); const String & username = engine_args[3]->as<ASTLiteral &>().value.safeGet<String>();
const String & password = engine_args[4]->as<ASTLiteral &>().value.safeGet<String>(); const String & password = engine_args[4]->as<ASTLiteral &>().value.safeGet<String>();
String remote_table_schema; String remote_table_schema, on_conflict;
if (engine_args.size() == 6) if (engine_args.size() >= 6)
remote_table_schema = engine_args[5]->as<ASTLiteral &>().value.safeGet<String>(); remote_table_schema = engine_args[5]->as<ASTLiteral &>().value.safeGet<String>();
if (engine_args.size() >= 7)
on_conflict = engine_args[6]->as<ASTLiteral &>().value.safeGet<String>();
auto pool = std::make_shared<postgres::PoolWithFailover>( auto pool = std::make_shared<postgres::PoolWithFailover>(
remote_database, remote_database,
@ -345,7 +421,8 @@ void registerStoragePostgreSQL(StorageFactory & factory)
args.columns, args.columns,
args.constraints, args.constraints,
args.comment, args.comment,
remote_table_schema); remote_table_schema,
on_conflict);
}, },
{ {
.source_access_type = AccessType::POSTGRES, .source_access_type = AccessType::POSTGRES,

View File

@ -27,7 +27,8 @@ public:
const ColumnsDescription & columns_, const ColumnsDescription & columns_,
const ConstraintsDescription & constraints_, const ConstraintsDescription & constraints_,
const String & comment, const String & comment,
const std::string & remote_table_schema_ = ""); const String & remote_table_schema_ = "",
const String & on_conflict = "");
String getName() const override { return "PostgreSQL"; } String getName() const override { return "PostgreSQL"; }
@ -47,6 +48,7 @@ private:
String remote_table_name; String remote_table_name;
String remote_table_schema; String remote_table_schema;
String on_conflict;
postgres::PoolWithFailoverPtr pool; postgres::PoolWithFailoverPtr pool;
}; };

View File

@ -37,7 +37,8 @@ StoragePtr TableFunctionPostgreSQL::executeImpl(const ASTPtr & /*ast_function*/,
columns, columns,
ConstraintsDescription{}, ConstraintsDescription{},
String{}, String{},
remote_table_schema); remote_table_schema,
on_conflict);
result->startup(); result->startup();
return result; return result;
@ -67,9 +68,9 @@ void TableFunctionPostgreSQL::parseArguments(const ASTPtr & ast_function, Contex
ASTs & args = func_args.arguments->children; ASTs & args = func_args.arguments->children;
if (args.size() < 5 || args.size() > 6) if (args.size() < 5 || args.size() > 7)
throw Exception("Table function 'PostgreSQL' requires from 5 to 6 parameters: " throw Exception("Table function 'PostgreSQL' requires from 5 to 7 parameters: "
"PostgreSQL('host:port', 'database', 'table', 'user', 'password', [, 'schema']).", "PostgreSQL('host:port', 'database', 'table', 'user', 'password', [, 'schema', 'ON CONFLICT ...']).",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
for (auto & arg : args) for (auto & arg : args)
@ -82,8 +83,10 @@ void TableFunctionPostgreSQL::parseArguments(const ASTPtr & ast_function, Contex
remote_table_name = args[2]->as<ASTLiteral &>().value.safeGet<String>(); remote_table_name = args[2]->as<ASTLiteral &>().value.safeGet<String>();
if (args.size() == 6) if (args.size() >= 6)
remote_table_schema = args[5]->as<ASTLiteral &>().value.safeGet<String>(); remote_table_schema = args[5]->as<ASTLiteral &>().value.safeGet<String>();
if (args.size() >= 7)
on_conflict = args[6]->as<ASTLiteral &>().value.safeGet<String>();
connection_pool = std::make_shared<postgres::PoolWithFailover>( connection_pool = std::make_shared<postgres::PoolWithFailover>(
args[1]->as<ASTLiteral &>().value.safeGet<String>(), args[1]->as<ASTLiteral &>().value.safeGet<String>(),

View File

@ -28,7 +28,7 @@ private:
void parseArguments(const ASTPtr & ast_function, ContextPtr context) override; void parseArguments(const ASTPtr & ast_function, ContextPtr context) override;
String connection_str; String connection_str;
String remote_table_name, remote_table_schema; String remote_table_name, remote_table_schema, on_conflict;
postgres::PoolWithFailoverPtr connection_pool; postgres::PoolWithFailoverPtr connection_pool;
}; };

View File

@ -291,7 +291,7 @@ def test_postgres_distributed(started_cluster):
node2.query('DROP TABLE test_shards') node2.query('DROP TABLE test_shards')
node2.query('DROP TABLE test_replicas') node2.query('DROP TABLE test_replicas')
def test_datetime_with_timezone(started_cluster): def test_datetime_with_timezone(started_cluster):
cursor = started_cluster.postgres_conn.cursor() cursor = started_cluster.postgres_conn.cursor()
cursor.execute("DROP TABLE IF EXISTS test_timezone") cursor.execute("DROP TABLE IF EXISTS test_timezone")
@ -328,6 +328,32 @@ def test_postgres_ndim(started_cluster):
cursor.execute("DROP TABLE arr1, arr2") cursor.execute("DROP TABLE arr1, arr2")
def test_postgres_on_conflict(started_cluster):
cursor = started_cluster.postgres_conn.cursor()
table = 'test_conflict'
cursor.execute(f'DROP TABLE IF EXISTS {table}')
cursor.execute(f'CREATE TABLE {table} (a integer PRIMARY KEY, b text, c integer)')
node1.query('''
CREATE TABLE test_conflict (a UInt32, b String, c Int32)
ENGINE PostgreSQL('postgres1:5432', 'postgres', 'test_conflict', 'postgres', 'mysecretpassword', '', 'ON CONFLICT DO NOTHING');
''')
node1.query(f''' INSERT INTO {table} SELECT number, concat('name_', toString(number)), 3 from numbers(100)''')
node1.query(f''' INSERT INTO {table} SELECT number, concat('name_', toString(number)), 4 from numbers(100)''')
check1 = f"SELECT count() FROM {table}"
assert (node1.query(check1)).rstrip() == '100'
table_func = f'''postgresql('{started_cluster.postgres_ip}:{started_cluster.postgres_port}', 'postgres', '{table}', 'postgres', 'mysecretpassword', '', 'ON CONFLICT DO NOTHING')'''
node1.query(f'''INSERT INTO TABLE FUNCTION {table_func} SELECT number, concat('name_', toString(number)), 3 from numbers(100)''')
node1.query(f'''INSERT INTO TABLE FUNCTION {table_func} SELECT number, concat('name_', toString(number)), 3 from numbers(100)''')
check1 = f"SELECT count() FROM {table}"
assert (node1.query(check1)).rstrip() == '100'
cursor.execute(f'DROP TABLE {table} ')
if __name__ == '__main__': if __name__ == '__main__':
cluster.start() cluster.start()
input("Cluster created, press any key to destroy...") input("Cluster created, press any key to destroy...")