Merge pull request #28081 from kssenii/pg-conflict

Support `on conflict` for postgres engine
This commit is contained in:
Kseniia Sumarokova 2021-08-26 16:30:30 +03:00 committed by GitHub
commit 31afd7d09c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 177 additions and 54 deletions

View File

@ -22,6 +22,7 @@ set (SRCS
"${LIBRARY_DIR}/src/transaction.cxx"
"${LIBRARY_DIR}/src/transaction_base.cxx"
"${LIBRARY_DIR}/src/row.cxx"
"${LIBRARY_DIR}/src/params.cxx"
"${LIBRARY_DIR}/src/util.cxx"
"${LIBRARY_DIR}/src/version.cxx"
)
@ -31,6 +32,7 @@ set (SRCS
# conflicts with all includes of <array>.
set (HDRS
"${LIBRARY_DIR}/include/pqxx/array.hxx"
"${LIBRARY_DIR}/include/pqxx/params.hxx"
"${LIBRARY_DIR}/include/pqxx/binarystring.hxx"
"${LIBRARY_DIR}/include/pqxx/composite.hxx"
"${LIBRARY_DIR}/include/pqxx/connection.hxx"
@ -75,4 +77,3 @@ set(CM_CONFIG_PQ "${LIBRARY_DIR}/include/pqxx/config-internal-libpq.h")
configure_file("${CM_CONFIG_H_IN}" "${CM_CONFIG_INT}" @ONLY)
configure_file("${CM_CONFIG_H_IN}" "${CM_CONFIG_PUB}" @ONLY)
configure_file("${CM_CONFIG_H_IN}" "${CM_CONFIG_PQ}" @ONLY)

View File

@ -34,6 +34,7 @@ The table structure can differ from the original PostgreSQL table structure:
- `user` — PostgreSQL user.
- `password` — User password.
- `schema` — Non-default table schema. Optional.
- `on conflict ...` — example: `ON CONFLICT DO NOTHING`. Optional. Note: adding this option will make insertion less efficient.
## Implementation Details {#implementation-details}

View File

@ -5,40 +5,16 @@
#include <Core/Field.h>
#include <common/LocalDate.h>
#include <common/LocalDateTime.h>
#include <Parsers/ASTInsertQuery.h>
#include <Parsers/ASTExpressionList.h>
#include <Parsers/ASTIdentifier.h>
#include "getIdentifierQuote.h"
#include <IO/WriteHelpers.h>
#include <IO/Operators.h>
#include <Formats/FormatFactory.h>
#include <Parsers/getInsertQuery.h>
namespace DB
{
namespace
{
using ValueType = ExternalResultDescription::ValueType;
std::string getInsertQuery(const std::string & db_name, const std::string & table_name, const ColumnsWithTypeAndName & columns, IdentifierQuotingStyle quoting)
{
ASTInsertQuery query;
query.table_id.database_name = db_name;
query.table_id.table_name = table_name;
query.columns = std::make_shared<ASTExpressionList>(',');
query.children.push_back(query.columns);
for (const auto & column : columns)
query.columns->children.emplace_back(std::make_shared<ASTIdentifier>(column.name));
WriteBufferFromOwnString buf;
IAST::FormatSettings settings(buf, true);
settings.always_quote_identifiers = true;
settings.identifier_quoting_style = quoting;
query.IAST::format(settings);
return buf.str();
}
}
ODBCBlockOutputStream::ODBCBlockOutputStream(nanodbc::ConnectionHolderPtr connection_holder_,
const std::string & remote_database_name_,

View File

@ -13,6 +13,7 @@ namespace DB
class ODBCBlockOutputStream : public IBlockOutputStream
{
using ValueType = ExternalResultDescription::ValueType;
public:
ODBCBlockOutputStream(

View File

@ -0,0 +1,28 @@
#include <Parsers/getInsertQuery.h>
#include <Parsers/ASTInsertQuery.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTExpressionList.h>
#include <IO/WriteBufferFromString.h>
namespace DB
{
std::string getInsertQuery(const std::string & db_name, const std::string & table_name, const ColumnsWithTypeAndName & columns, IdentifierQuotingStyle quoting)
{
ASTInsertQuery query;
query.table_id.database_name = db_name;
query.table_id.table_name = table_name;
query.columns = std::make_shared<ASTExpressionList>(',');
query.children.push_back(query.columns);
for (const auto & column : columns)
query.columns->children.emplace_back(std::make_shared<ASTIdentifier>(column.name));
WriteBufferFromOwnString buf;
IAST::FormatSettings settings(buf, true);
settings.always_quote_identifiers = true;
settings.identifier_quoting_style = quoting;
query.IAST::format(settings);
return buf.str();
}
}

View File

@ -0,0 +1,8 @@
#pragma once
#include <Core/ColumnsWithTypeAndName.h>
#include <Parsers/IdentifierQuotingStyle.h>
namespace DB
{
std::string getInsertQuery(const std::string & db_name, const std::string & table_name, const ColumnsWithTypeAndName & columns, IdentifierQuotingStyle quoting);
}

View File

@ -29,6 +29,8 @@
#include <Processors/Pipe.h>
#include <Processors/Sinks/SinkToStorage.h>
#include <IO/WriteHelpers.h>
#include <Parsers/getInsertQuery.h>
#include <IO/Operators.h>
namespace DB
@ -47,10 +49,12 @@ StoragePostgreSQL::StoragePostgreSQL(
const ColumnsDescription & columns_,
const ConstraintsDescription & constraints_,
const String & comment,
const String & remote_table_schema_)
const String & remote_table_schema_,
const String & on_conflict_)
: IStorage(table_id_)
, remote_table_name(remote_table_name_)
, remote_table_schema(remote_table_schema_)
, on_conflict(on_conflict_)
, pool(std::move(pool_))
{
StorageInMemoryMetadata storage_metadata;
@ -94,17 +98,22 @@ Pipe StoragePostgreSQL::read(
class PostgreSQLSink : public SinkToStorage
{
using Row = std::vector<std::optional<std::string>>;
public:
explicit PostgreSQLSink(
const StorageMetadataPtr & metadata_snapshot_,
postgres::ConnectionHolderPtr connection_holder_,
const String & remote_table_name_,
const String & remote_table_schema_)
const String & remote_table_schema_,
const String & on_conflict_)
: SinkToStorage(metadata_snapshot_->getSampleBlock())
, metadata_snapshot(metadata_snapshot_)
, connection_holder(std::move(connection_holder_))
, remote_table_name(remote_table_name_)
, remote_table_schema(remote_table_schema_)
, on_conflict(on_conflict_)
{
}
@ -113,11 +122,21 @@ public:
void consume(Chunk chunk) override
{
auto block = getPort().getHeader().cloneWithColumns(chunk.detachColumns());
if (!inserter)
inserter = std::make_unique<StreamTo>(connection_holder->get(),
remote_table_schema.empty() ? pqxx::table_path({remote_table_name})
: pqxx::table_path({remote_table_schema, remote_table_name}),
block.getNames());
{
if (on_conflict.empty())
{
inserter = std::make_unique<StreamTo>(connection_holder->get(),
remote_table_schema.empty() ? pqxx::table_path({remote_table_name})
: pqxx::table_path({remote_table_schema, remote_table_name}), block.getNames());
}
else
{
inserter = std::make_unique<PreparedInsert>(connection_holder->get(), remote_table_name,
remote_table_schema, block.getColumnsWithTypeAndName(), on_conflict);
}
}
const auto columns = block.getColumns();
const size_t num_rows = block.rows(), num_cols = block.columns();
@ -151,7 +170,7 @@ public:
}
}
inserter->stream.write_values(row);
inserter->insert(row);
}
}
@ -268,37 +287,92 @@ public:
}
private:
struct StreamTo
struct Inserter
{
pqxx::connection & connection;
pqxx::work tx;
explicit Inserter(pqxx::connection & connection_)
: connection(connection_)
, tx(connection) {}
virtual ~Inserter() = default;
virtual void insert(const Row & row) = 0;
virtual void complete() = 0;
};
struct StreamTo : Inserter
{
Names columns;
pqxx::stream_to stream;
StreamTo(pqxx::connection & connection, pqxx::table_path table_, Names columns_)
: tx(connection)
StreamTo(pqxx::connection & connection_, pqxx::table_path table_, Names columns_)
: Inserter(connection_)
, columns(std::move(columns_))
, stream(pqxx::stream_to::raw_table(tx, connection.quote_table(table_), connection.quote_columns(columns)))
{
}
void complete()
void complete() override
{
stream.complete();
tx.commit();
}
void insert(const Row & row) override
{
stream.write_values(row);
}
};
struct PreparedInsert : Inserter
{
PreparedInsert(pqxx::connection & connection_, const String & table, const String & schema,
const ColumnsWithTypeAndName & columns, const String & on_conflict_)
: Inserter(connection_)
{
WriteBufferFromOwnString buf;
buf << getInsertQuery(schema, table, columns, IdentifierQuotingStyle::DoubleQuotes);
buf << " (";
for (size_t i = 1; i <= columns.size(); ++i)
{
if (i > 1)
buf << ", ";
buf << "$" << i;
}
buf << ") ";
buf << on_conflict_;
connection.prepare("insert", buf.str());
}
void complete() override
{
connection.unprepare("insert");
tx.commit();
}
void insert(const Row & row) override
{
pqxx::params params;
params.reserve(row.size());
params.append_multi(row);
tx.exec_prepared("insert", params);
}
};
StorageMetadataPtr metadata_snapshot;
postgres::ConnectionHolderPtr connection_holder;
const String remote_table_name, remote_table_schema;
std::unique_ptr<StreamTo> inserter;
const String remote_db_name, remote_table_name, remote_table_schema, on_conflict;
std::unique_ptr<Inserter> inserter;
};
SinkToStoragePtr StoragePostgreSQL::write(
const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr /* context */)
{
return std::make_shared<PostgreSQLSink>(metadata_snapshot, pool->get(), remote_table_name, remote_table_schema);
return std::make_shared<PostgreSQLSink>(metadata_snapshot, pool->get(), remote_table_name, remote_table_schema, on_conflict);
}
@ -308,9 +382,9 @@ void registerStoragePostgreSQL(StorageFactory & factory)
{
ASTs & engine_args = args.engine_args;
if (engine_args.size() < 5 || engine_args.size() > 6)
throw Exception("Storage PostgreSQL requires from 5 to 6 parameters: "
"PostgreSQL('host:port', 'database', 'table', 'username', 'password' [, 'schema']",
if (engine_args.size() < 5 || engine_args.size() > 7)
throw Exception("Storage PostgreSQL requires from 5 to 7 parameters: "
"PostgreSQL('host:port', 'database', 'table', 'username', 'password' [, 'schema', 'ON CONFLICT ...']",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
for (auto & engine_arg : engine_args)
@ -326,9 +400,11 @@ void registerStoragePostgreSQL(StorageFactory & factory)
const String & username = engine_args[3]->as<ASTLiteral &>().value.safeGet<String>();
const String & password = engine_args[4]->as<ASTLiteral &>().value.safeGet<String>();
String remote_table_schema;
if (engine_args.size() == 6)
String remote_table_schema, on_conflict;
if (engine_args.size() >= 6)
remote_table_schema = engine_args[5]->as<ASTLiteral &>().value.safeGet<String>();
if (engine_args.size() >= 7)
on_conflict = engine_args[6]->as<ASTLiteral &>().value.safeGet<String>();
auto pool = std::make_shared<postgres::PoolWithFailover>(
remote_database,
@ -345,7 +421,8 @@ void registerStoragePostgreSQL(StorageFactory & factory)
args.columns,
args.constraints,
args.comment,
remote_table_schema);
remote_table_schema,
on_conflict);
},
{
.source_access_type = AccessType::POSTGRES,

View File

@ -27,7 +27,8 @@ public:
const ColumnsDescription & columns_,
const ConstraintsDescription & constraints_,
const String & comment,
const std::string & remote_table_schema_ = "");
const String & remote_table_schema_ = "",
const String & on_conflict = "");
String getName() const override { return "PostgreSQL"; }
@ -47,6 +48,7 @@ private:
String remote_table_name;
String remote_table_schema;
String on_conflict;
postgres::PoolWithFailoverPtr pool;
};

View File

@ -37,7 +37,8 @@ StoragePtr TableFunctionPostgreSQL::executeImpl(const ASTPtr & /*ast_function*/,
columns,
ConstraintsDescription{},
String{},
remote_table_schema);
remote_table_schema,
on_conflict);
result->startup();
return result;
@ -67,9 +68,9 @@ void TableFunctionPostgreSQL::parseArguments(const ASTPtr & ast_function, Contex
ASTs & args = func_args.arguments->children;
if (args.size() < 5 || args.size() > 6)
throw Exception("Table function 'PostgreSQL' requires from 5 to 6 parameters: "
"PostgreSQL('host:port', 'database', 'table', 'user', 'password', [, 'schema']).",
if (args.size() < 5 || args.size() > 7)
throw Exception("Table function 'PostgreSQL' requires from 5 to 7 parameters: "
"PostgreSQL('host:port', 'database', 'table', 'user', 'password', [, 'schema', 'ON CONFLICT ...']).",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
for (auto & arg : args)
@ -82,8 +83,10 @@ void TableFunctionPostgreSQL::parseArguments(const ASTPtr & ast_function, Contex
remote_table_name = args[2]->as<ASTLiteral &>().value.safeGet<String>();
if (args.size() == 6)
if (args.size() >= 6)
remote_table_schema = args[5]->as<ASTLiteral &>().value.safeGet<String>();
if (args.size() >= 7)
on_conflict = args[6]->as<ASTLiteral &>().value.safeGet<String>();
connection_pool = std::make_shared<postgres::PoolWithFailover>(
args[1]->as<ASTLiteral &>().value.safeGet<String>(),

View File

@ -28,7 +28,7 @@ private:
void parseArguments(const ASTPtr & ast_function, ContextPtr context) override;
String connection_str;
String remote_table_name, remote_table_schema;
String remote_table_name, remote_table_schema, on_conflict;
postgres::PoolWithFailoverPtr connection_pool;
};

View File

@ -328,6 +328,32 @@ def test_postgres_ndim(started_cluster):
cursor.execute("DROP TABLE arr1, arr2")
def test_postgres_on_conflict(started_cluster):
cursor = started_cluster.postgres_conn.cursor()
table = 'test_conflict'
cursor.execute(f'DROP TABLE IF EXISTS {table}')
cursor.execute(f'CREATE TABLE {table} (a integer PRIMARY KEY, b text, c integer)')
node1.query('''
CREATE TABLE test_conflict (a UInt32, b String, c Int32)
ENGINE PostgreSQL('postgres1:5432', 'postgres', 'test_conflict', 'postgres', 'mysecretpassword', '', 'ON CONFLICT DO NOTHING');
''')
node1.query(f''' INSERT INTO {table} SELECT number, concat('name_', toString(number)), 3 from numbers(100)''')
node1.query(f''' INSERT INTO {table} SELECT number, concat('name_', toString(number)), 4 from numbers(100)''')
check1 = f"SELECT count() FROM {table}"
assert (node1.query(check1)).rstrip() == '100'
table_func = f'''postgresql('{started_cluster.postgres_ip}:{started_cluster.postgres_port}', 'postgres', '{table}', 'postgres', 'mysecretpassword', '', 'ON CONFLICT DO NOTHING')'''
node1.query(f'''INSERT INTO TABLE FUNCTION {table_func} SELECT number, concat('name_', toString(number)), 3 from numbers(100)''')
node1.query(f'''INSERT INTO TABLE FUNCTION {table_func} SELECT number, concat('name_', toString(number)), 3 from numbers(100)''')
check1 = f"SELECT count() FROM {table}"
assert (node1.query(check1)).rstrip() == '100'
cursor.execute(f'DROP TABLE {table} ')
if __name__ == '__main__':
cluster.start()
input("Cluster created, press any key to destroy...")