Merge branch 'master' into system-metrics

This commit is contained in:
Alexey Milovidov 2021-07-05 03:03:07 +03:00
commit d437ec2e1d
139 changed files with 10310 additions and 843 deletions

View File

@ -0,0 +1,41 @@
#include <functional>
/** Adapt functor to static method where functor passed as context.
* Main use case to convert lambda into function that can be passed into JIT code.
*/
template <typename Functor>
class FunctorToStaticMethodAdaptor : public FunctorToStaticMethodAdaptor<decltype(&Functor::operator())>
{
};
template <typename R, typename C, typename ...Args>
class FunctorToStaticMethodAdaptor<R (C::*)(Args...) const>
{
public:
static R call(C * ptr, Args &&... arguments)
{
return std::invoke(&C::operator(), ptr, std::forward<Args>(arguments)...);
}
static R unsafeCall(char * ptr, Args &&... arguments)
{
C * ptr_typed = reinterpret_cast<C*>(ptr);
return std::invoke(&C::operator(), ptr_typed, std::forward<Args>(arguments)...);
}
};
template <typename R, typename C, typename ...Args>
class FunctorToStaticMethodAdaptor<R (C::*)(Args...)>
{
public:
static R call(C * ptr, Args &&... arguments)
{
return std::invoke(&C::operator(), ptr, std::forward<Args>(arguments)...);
}
static R unsafeCall(char * ptr, Args &&... arguments)
{
C * ptr_typed = static_cast<C*>(ptr);
return std::invoke(&C::operator(), ptr_typed, std::forward<Args>(arguments)...);
}
};

2
contrib/libunwind vendored

@ -1 +1 @@
Subproject commit cdcc3d8c6f6e80a0886082704a0902d61d8d3ffe Subproject commit 6b816d2fba3991f8fd6aaec17d92f68947eab667

View File

@ -373,12 +373,6 @@ function run_tests
# Depends on AWS # Depends on AWS
01801_s3_cluster 01801_s3_cluster
# Depends on LLVM JIT
01072_nullable_jit
01852_jit_if
01865_jit_comparison_constant_result
01871_merge_tree_compile_expressions
# needs psql # needs psql
01889_postgresql_protocol_null_fields 01889_postgresql_protocol_null_fields

View File

@ -11,6 +11,7 @@ services:
interval: 10s interval: 10s
timeout: 5s timeout: 5s
retries: 5 retries: 5
command: [ "postgres", "-c", "wal_level=logical", "-c", "max_replication_slots=2"]
networks: networks:
default: default:
aliases: aliases:
@ -22,4 +23,4 @@ services:
volumes: volumes:
- type: ${POSTGRES_LOGS_FS:-tmpfs} - type: ${POSTGRES_LOGS_FS:-tmpfs}
source: ${POSTGRES_DIR:-} source: ${POSTGRES_DIR:-}
target: /postgres/ target: /postgres/

View File

@ -0,0 +1,71 @@
---
toc_priority: 30
toc_title: MaterializedPostgreSQL
---
# MaterializedPostgreSQL {#materialize-postgresql}
## Creating a Database {#creating-a-database}
``` sql
CREATE DATABASE test_database
ENGINE = MaterializedPostgreSQL('postgres1:5432', 'postgres_database', 'postgres_user', 'postgres_password'
SELECT * FROM test_database.postgres_table;
```
## Settings {#settings}
1. `materialized_postgresql_max_block_size` - Number of rows collected before flushing data into table. Default: `65536`.
2. `materialized_postgresql_tables_list` - List of tables for MaterializedPostgreSQL database engine. Default: `whole database`.
3. `materialized_postgresql_allow_automatic_update` - Allow to reload table in the background, when schema changes are detected. Default: `0` (`false`).
``` sql
CREATE DATABASE test_database
ENGINE = MaterializedPostgreSQL('postgres1:5432', 'postgres_database', 'postgres_user', 'postgres_password'
SETTINGS materialized_postgresql_max_block_size = 65536,
materialized_postgresql_tables_list = 'table1,table2,table3';
SELECT * FROM test_database.table1;
```
## Requirements {#requirements}
- Setting `wal_level`to `logical` and `max_replication_slots` to at least `2` in the postgresql config file.
- Each replicated table must have one of the following **replica identity**:
1. **default** (primary key)
2. **index**
``` bash
postgres# CREATE TABLE postgres_table (a Integer NOT NULL, b Integer, c Integer NOT NULL, d Integer, e Integer NOT NULL);
postgres# CREATE unique INDEX postgres_table_index on postgres_table(a, c, e);
postgres# ALTER TABLE postgres_table REPLICA IDENTITY USING INDEX postgres_table_index;
```
Primary key is always checked first. If it is absent, then index, defined as replica identity index, is checked.
If index is used as replica identity, there has to be only one such index in a table.
You can check what type is used for a specific table with the following command:
``` bash
postgres# SELECT CASE relreplident
WHEN 'd' THEN 'default'
WHEN 'n' THEN 'nothing'
WHEN 'f' THEN 'full'
WHEN 'i' THEN 'index'
END AS replica_identity
FROM pg_class
WHERE oid = 'postgres_table'::regclass;
```
## Warning {#warning}
1. **TOAST** values convertion is not supported. Default value for the data type will be used.

View File

@ -0,0 +1,46 @@
---
toc_priority: 12
toc_title: MateriaziePostgreSQL
---
# MaterializedPostgreSQL {#materialize-postgresql}
## Creating a Table {#creating-a-table}
``` sql
CREATE TABLE test.postgresql_replica (key UInt64, value UInt64)
ENGINE = MaterializedPostgreSQL('postgres1:5432', 'postgres_database', 'postgresql_replica', 'postgres_user', 'postgres_password')
PRIMARY KEY key;
```
## Requirements {#requirements}
- Setting `wal_level`to `logical` and `max_replication_slots` to at least `2` in the postgresql config file.
- A table with engine `MaterializedPostgreSQL` must have a primary key - the same as a replica identity index (default: primary key) of a postgres table (See [details on replica identity index](../../database-engines/materialized-postgresql.md#requirements)).
- Only database `Atomic` is allowed.
## Virtual columns {#creating-a-table}
- `_version` (`UInt64`)
- `_sign` (`Int8`)
These columns do not need to be added, when table is created. They are always accessible in `SELECT` query.
`_version` column equals `LSN` position in `WAL`, so it might be used to check how up-to-date replication is.
``` sql
CREATE TABLE test.postgresql_replica (key UInt64, value UInt64)
ENGINE = MaterializedPostgreSQL('postgres1:5432', 'postgres_database', 'postgresql_replica', 'postgres_user', 'postgres_password')
PRIMARY KEY key;
SELECT key, value, _version FROM test.postgresql_replica;
```
## Warning {#warning}
1. **TOAST** values convertion is not supported. Default value for the data type will be used.

View File

@ -50,7 +50,7 @@
#include <Interpreters/DNSCacheUpdater.h> #include <Interpreters/DNSCacheUpdater.h>
#include <Interpreters/ExternalLoaderXMLConfigRepository.h> #include <Interpreters/ExternalLoaderXMLConfigRepository.h>
#include <Interpreters/InterserverCredentials.h> #include <Interpreters/InterserverCredentials.h>
#include <Interpreters/ExpressionJIT.h> #include <Interpreters/JIT/CompiledExpressionCache.h>
#include <Access/AccessControlManager.h> #include <Access/AccessControlManager.h>
#include <Storages/StorageReplicatedMergeTree.h> #include <Storages/StorageReplicatedMergeTree.h>
#include <Storages/System/attachSystemTables.h> #include <Storages/System/attachSystemTables.h>

View File

@ -44,6 +44,7 @@
--table-header-color: #F8F8F8; --table-header-color: #F8F8F8;
--table-hover-color: #FFF8EF; --table-hover-color: #FFF8EF;
--null-color: #A88; --null-color: #A88;
--link-color: #06D;
} }
[data-theme="dark"] { [data-theme="dark"] {
@ -61,6 +62,7 @@
--table-header-color: #102020; --table-header-color: #102020;
--table-hover-color: #003333; --table-hover-color: #003333;
--null-color: #A88; --null-color: #A88;
--link-color: #4BDAF7;
} }
html, body html, body
@ -275,6 +277,12 @@
font-size: 110%; font-size: 110%;
color: #080; color: #080;
} }
a, a:visited
{
color: var(--link-color);
text-decoration: none;
}
</style> </style>
</head> </head>
@ -482,6 +490,7 @@
let cell = response.data[row_idx][col_idx]; let cell = response.data[row_idx][col_idx];
let is_null = (cell === null); let is_null = (cell === null);
let is_link = false;
/// Test: SELECT number, toString(number) AS str, number % 2 ? number : NULL AS nullable, range(number) AS arr, CAST((['hello', 'world'], [number, number % 2]) AS Map(String, UInt64)) AS map FROM numbers(10) /// Test: SELECT number, toString(number) AS str, number % 2 ? number : NULL AS nullable, range(number) AS arr, CAST((['hello', 'world'], [number, number % 2]) AS Map(String, UInt64)) AS map FROM numbers(10)
let text; let text;
@ -491,9 +500,23 @@
text = JSON.stringify(cell); text = JSON.stringify(cell);
} else { } else {
text = cell; text = cell;
/// If it looks like URL, create a link. This is for convenience.
if (typeof(cell) == 'string' && cell.match(/^https?:\/\/\S+$/)) {
is_link = true;
}
} }
td.appendChild(document.createTextNode(text)); let node = document.createTextNode(text);
if (is_link) {
let link = document.createElement('a');
link.appendChild(node);
link.href = text;
link.setAttribute('target', '_blank');
node = link;
}
td.appendChild(node);
td.className = column_classes[col_idx]; td.className = column_classes[col_idx];
if (is_null) { if (is_null) {
td.className += ' null'; td.className += ' null';

View File

@ -9,6 +9,14 @@
#include <AggregateFunctions/IAggregateFunction.h> #include <AggregateFunctions/IAggregateFunction.h>
#include <Core/DecimalFunctions.h> #include <Core/DecimalFunctions.h>
#if !defined(ARCADIA_BUILD)
# include <Common/config.h>
#endif
#if USE_EMBEDDED_COMPILER
# include <llvm/IR/IRBuilder.h>
# include <DataTypes/Native.h>
#endif
namespace DB namespace DB
{ {
@ -85,13 +93,15 @@ struct AvgFraction
* @tparam Derived When deriving from this class, use the child class name as in CRTP, e.g. * @tparam Derived When deriving from this class, use the child class name as in CRTP, e.g.
* class Self : Agg<char, bool, bool, Self>. * class Self : Agg<char, bool, bool, Self>.
*/ */
template <typename Numerator, typename Denominator, typename Derived> template <typename TNumerator, typename TDenominator, typename Derived>
class AggregateFunctionAvgBase : public class AggregateFunctionAvgBase : public
IAggregateFunctionDataHelper<AvgFraction<Numerator, Denominator>, Derived> IAggregateFunctionDataHelper<AvgFraction<TNumerator, TDenominator>, Derived>
{ {
public: public:
using Base = IAggregateFunctionDataHelper<AvgFraction<TNumerator, TDenominator>, Derived>;
using Numerator = TNumerator;
using Denominator = TDenominator;
using Fraction = AvgFraction<Numerator, Denominator>; using Fraction = AvgFraction<Numerator, Denominator>;
using Base = IAggregateFunctionDataHelper<Fraction, Derived>;
explicit AggregateFunctionAvgBase(const DataTypes & argument_types_, explicit AggregateFunctionAvgBase(const DataTypes & argument_types_,
UInt32 num_scale_ = 0, UInt32 denom_scale_ = 0) UInt32 num_scale_ = 0, UInt32 denom_scale_ = 0)
@ -135,6 +145,77 @@ public:
else else
assert_cast<ColumnVector<Float64> &>(to).getData().push_back(this->data(place).divide()); assert_cast<ColumnVector<Float64> &>(to).getData().push_back(this->data(place).divide());
} }
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
bool can_be_compiled = true;
for (const auto & argument : this->argument_types)
can_be_compiled &= canBeNativeType(*argument);
auto return_type = getReturnType();
can_be_compiled &= canBeNativeType(*return_type);
return can_be_compiled;
}
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
b.CreateMemSet(aggregate_data_ptr, llvm::ConstantInt::get(b.getInt8Ty(), 0), sizeof(Fraction), llvm::assumeAligned(this->alignOfData()));
}
void compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * numerator_type = toNativeType<Numerator>(b);
auto * numerator_dst_ptr = b.CreatePointerCast(aggregate_data_dst_ptr, numerator_type->getPointerTo());
auto * numerator_dst_value = b.CreateLoad(numerator_type, numerator_dst_ptr);
auto * numerator_src_ptr = b.CreatePointerCast(aggregate_data_src_ptr, numerator_type->getPointerTo());
auto * numerator_src_value = b.CreateLoad(numerator_type, numerator_src_ptr);
auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_dst_value, numerator_src_value) : b.CreateFAdd(numerator_dst_value, numerator_src_value);
b.CreateStore(numerator_result_value, numerator_dst_ptr);
auto * denominator_type = toNativeType<Denominator>(b);
static constexpr size_t denominator_offset = offsetof(Fraction, denominator);
auto * denominator_dst_ptr = b.CreatePointerCast(b.CreateConstGEP1_32(nullptr, aggregate_data_dst_ptr, denominator_offset), denominator_type->getPointerTo());
auto * denominator_src_ptr = b.CreatePointerCast(b.CreateConstGEP1_32(nullptr, aggregate_data_src_ptr, denominator_offset), denominator_type->getPointerTo());
auto * denominator_dst_value = b.CreateLoad(denominator_type, denominator_dst_ptr);
auto * denominator_src_value = b.CreateLoad(denominator_type, denominator_src_ptr);
auto * denominator_result_value = denominator_type->isIntegerTy() ? b.CreateAdd(denominator_src_value, denominator_dst_value) : b.CreateFAdd(denominator_src_value, denominator_dst_value);
b.CreateStore(denominator_result_value, denominator_dst_ptr);
}
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * numerator_type = toNativeType<Numerator>(b);
auto * numerator_ptr = b.CreatePointerCast(aggregate_data_ptr, numerator_type->getPointerTo());
auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr);
auto * denominator_type = toNativeType<Denominator>(b);
static constexpr size_t denominator_offset = offsetof(Fraction, denominator);
auto * denominator_ptr = b.CreatePointerCast(b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, denominator_offset), denominator_type->getPointerTo());
auto * denominator_value = b.CreateLoad(denominator_type, denominator_ptr);
auto * double_numerator = nativeCast<Numerator>(b, numerator_value, b.getDoubleTy());
auto * double_denominator = nativeCast<Denominator>(b, denominator_value, b.getDoubleTy());
return b.CreateFDiv(double_numerator, double_denominator);
}
#endif
private: private:
UInt32 num_scale; UInt32 num_scale;
UInt32 denom_scale; UInt32 denom_scale;
@ -149,7 +230,12 @@ template <typename T>
class AggregateFunctionAvg final : public AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionAvg<T>> class AggregateFunctionAvg final : public AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionAvg<T>>
{ {
public: public:
using AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionAvg<T>>::AggregateFunctionAvgBase; using Base = AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionAvg<T>>;
using Base::Base;
using Numerator = typename Base::Numerator;
using Denominator = typename Base::Denominator;
using Fraction = typename Base::Fraction;
void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const final void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const final
{ {
@ -158,5 +244,29 @@ public:
} }
String getName() const final { return "avg"; } String getName() const final { return "avg"; }
#if USE_EMBEDDED_COMPILER
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * numerator_type = toNativeType<Numerator>(b);
auto * numerator_ptr = b.CreatePointerCast(aggregate_data_ptr, numerator_type->getPointerTo());
auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr);
auto * value_cast_to_numerator = nativeCast(b, arguments_types[0], argument_values[0], numerator_type);
auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_value, value_cast_to_numerator) : b.CreateFAdd(numerator_value, value_cast_to_numerator);
b.CreateStore(numerator_result_value, numerator_ptr);
auto * denominator_type = toNativeType<Denominator>(b);
static constexpr size_t denominator_offset = offsetof(Fraction, denominator);
auto * denominator_ptr = b.CreatePointerCast(b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, denominator_offset), denominator_type->getPointerTo());
auto * denominator_value_updated = b.CreateAdd(b.CreateLoad(denominator_type, denominator_ptr), llvm::ConstantInt::get(denominator_type, 1));
b.CreateStore(denominator_value_updated, denominator_ptr);
}
#endif
}; };
} }

View File

@ -28,19 +28,64 @@ public:
MaxFieldType<Value, Weight>, AvgWeightedFieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>; MaxFieldType<Value, Weight>, AvgWeightedFieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>;
using Base::Base; using Base::Base;
using ValueT = MaxFieldType<Value, Weight>; using Numerator = typename Base::Numerator;
using Denominator = typename Base::Denominator;
using Fraction = typename Base::Fraction;
void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{ {
const auto& weights = static_cast<const DecimalOrVectorCol<Weight> &>(*columns[1]); const auto& weights = static_cast<const DecimalOrVectorCol<Weight> &>(*columns[1]);
this->data(place).numerator += static_cast<ValueT>( this->data(place).numerator += static_cast<Numerator>(
static_cast<const DecimalOrVectorCol<Value> &>(*columns[0]).getData()[row_num]) * static_cast<const DecimalOrVectorCol<Value> &>(*columns[0]).getData()[row_num]) *
static_cast<ValueT>(weights.getData()[row_num]); static_cast<Numerator>(weights.getData()[row_num]);
this->data(place).denominator += static_cast<AvgWeightedFieldType<Weight>>(weights.getData()[row_num]); this->data(place).denominator += static_cast<Denominator>(weights.getData()[row_num]);
} }
String getName() const override { return "avgWeighted"; } String getName() const override { return "avgWeighted"; }
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
bool can_be_compiled = Base::isCompilable();
can_be_compiled &= canBeNativeType<Weight>();
return can_be_compiled;
}
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * numerator_type = toNativeType<Numerator>(b);
auto * numerator_ptr = b.CreatePointerCast(aggregate_data_ptr, numerator_type->getPointerTo());
auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr);
auto * argument = nativeCast(b, arguments_types[0], argument_values[0], numerator_type);
auto * weight = nativeCast(b, arguments_types[1], argument_values[1], numerator_type);
llvm::Value * value_weight_multiplication = argument->getType()->isIntegerTy() ? b.CreateMul(argument, weight) : b.CreateFMul(argument, weight);
auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_value, value_weight_multiplication) : b.CreateFAdd(numerator_value, value_weight_multiplication);
b.CreateStore(numerator_result_value, numerator_ptr);
auto * denominator_type = toNativeType<Denominator>(b);
static constexpr size_t denominator_offset = offsetof(Fraction, denominator);
auto * denominator_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, denominator_offset);
auto * denominator_ptr = b.CreatePointerCast(denominator_offset_ptr, denominator_type->getPointerTo());
auto * weight_cast_to_denominator = nativeCast(b, arguments_types[1], argument_values[1], denominator_type);
auto * denominator_value = b.CreateLoad(denominator_type, denominator_ptr);
auto * denominator_value_updated = denominator_type->isIntegerTy() ? b.CreateAdd(denominator_value, weight_cast_to_denominator) : b.CreateFAdd(denominator_value, weight_cast_to_denominator);
b.CreateStore(denominator_value_updated, denominator_ptr);
}
#endif
}; };
} }

View File

@ -10,6 +10,15 @@
#include <AggregateFunctions/IAggregateFunction.h> #include <AggregateFunctions/IAggregateFunction.h>
#include <Common/assert_cast.h> #include <Common/assert_cast.h>
#if !defined(ARCADIA_BUILD)
# include <Common/config.h>
#endif
#if USE_EMBEDDED_COMPILER
# include <llvm/IR/IRBuilder.h>
# include <DataTypes/Native.h>
#endif
namespace DB namespace DB
{ {
@ -107,6 +116,66 @@ public:
AggregateFunctionPtr getOwnNullAdapter( AggregateFunctionPtr getOwnNullAdapter(
const AggregateFunctionPtr &, const DataTypes & types, const Array & params, const AggregateFunctionProperties & /*properties*/) const override; const AggregateFunctionPtr &, const DataTypes & types, const Array & params, const AggregateFunctionProperties & /*properties*/) const override;
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
bool is_compilable = true;
for (const auto & argument_type : argument_types)
is_compilable &= canBeNativeType(*argument_type);
return is_compilable;
}
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
b.CreateMemSet(aggregate_data_ptr, llvm::ConstantInt::get(b.getInt8Ty(), 0), sizeof(AggregateFunctionCountData), llvm::assumeAligned(this->alignOfData()));
}
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes &, const std::vector<llvm::Value *> &) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * count_value_ptr = b.CreatePointerCast(aggregate_data_ptr, return_type->getPointerTo());
auto * count_value = b.CreateLoad(return_type, count_value_ptr);
auto * updated_count_value = b.CreateAdd(count_value, llvm::ConstantInt::get(return_type, 1));
b.CreateStore(updated_count_value, count_value_ptr);
}
void compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * count_value_dst_ptr = b.CreatePointerCast(aggregate_data_dst_ptr, return_type->getPointerTo());
auto * count_value_dst = b.CreateLoad(return_type, count_value_dst_ptr);
auto * count_value_src_ptr = b.CreatePointerCast(aggregate_data_src_ptr, return_type->getPointerTo());
auto * count_value_src = b.CreateLoad(return_type, count_value_src_ptr);
auto * count_value_dst_updated = b.CreateAdd(count_value_dst, count_value_src);
b.CreateStore(count_value_dst_updated, count_value_dst_ptr);
}
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * count_value_ptr = b.CreatePointerCast(aggregate_data_ptr, return_type->getPointerTo());
return b.CreateLoad(return_type, count_value_ptr);
}
#endif
}; };
@ -155,6 +224,71 @@ public:
{ {
assert_cast<ColumnUInt64 &>(to).getData().push_back(data(place).count); assert_cast<ColumnUInt64 &>(to).getData().push_back(data(place).count);
} }
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
bool is_compilable = true;
for (const auto & argument_type : argument_types)
is_compilable &= canBeNativeType(*argument_type);
return is_compilable;
}
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
b.CreateMemSet(aggregate_data_ptr, llvm::ConstantInt::get(b.getInt8Ty(), 0), sizeof(AggregateFunctionCountData), llvm::assumeAligned(this->alignOfData()));
}
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes &, const std::vector<llvm::Value *> & values) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * is_null_value = b.CreateExtractValue(values[0], {1});
auto * increment_value = b.CreateSelect(is_null_value, llvm::ConstantInt::get(return_type, 0), llvm::ConstantInt::get(return_type, 1));
auto * count_value_ptr = b.CreatePointerCast(aggregate_data_ptr, return_type->getPointerTo());
auto * count_value = b.CreateLoad(return_type, count_value_ptr);
auto * updated_count_value = b.CreateAdd(count_value, increment_value);
b.CreateStore(updated_count_value, count_value_ptr);
}
void compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * count_value_dst_ptr = b.CreatePointerCast(aggregate_data_dst_ptr, return_type->getPointerTo());
auto * count_value_dst = b.CreateLoad(return_type, count_value_dst_ptr);
auto * count_value_src_ptr = b.CreatePointerCast(aggregate_data_src_ptr, return_type->getPointerTo());
auto * count_value_src = b.CreateLoad(return_type, count_value_src_ptr);
auto * count_value_dst_updated = b.CreateAdd(count_value_dst, count_value_src);
b.CreateStore(count_value_dst_updated, count_value_dst_ptr);
}
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * count_value_ptr = b.CreatePointerCast(aggregate_data_ptr, return_type->getPointerTo());
return b.CreateLoad(return_type, count_value_ptr);
}
#endif
}; };
} }

View File

@ -106,6 +106,48 @@ public:
this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena); this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);
} }
} }
#if USE_EMBEDDED_COMPILER
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
const auto & nullable_type = arguments_types[0];
const auto & nullable_value = argument_values[0];
auto * wrapped_value = b.CreateExtractValue(nullable_value, {0});
auto * is_null_value = b.CreateExtractValue(nullable_value, {1});
const auto & predicate_type = arguments_types[argument_values.size() - 1];
auto * predicate_value = argument_values[argument_values.size() - 1];
auto * is_predicate_true = nativeBoolCast(b, predicate_type, predicate_value);
auto * head = b.GetInsertBlock();
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
auto * if_null = llvm::BasicBlock::Create(head->getContext(), "if_null", head->getParent());
auto * if_not_null = llvm::BasicBlock::Create(head->getContext(), "if_not_null", head->getParent());
b.CreateCondBr(b.CreateAnd(b.CreateNot(is_null_value), is_predicate_true), if_not_null, if_null);
b.SetInsertPoint(if_null);
b.CreateBr(join_block);
b.SetInsertPoint(if_not_null);
if constexpr (result_is_nullable)
b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr);
auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, this->prefix_size);
this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, { removeNullable(nullable_type) }, { wrapped_value });
b.CreateBr(join_block);
b.SetInsertPoint(join_block);
}
#endif
}; };
template <bool result_is_nullable, bool serialize_flag, bool null_is_skipped> template <bool result_is_nullable, bool serialize_flag, bool null_is_skipped>
@ -168,6 +210,95 @@ public:
} }
} }
#if USE_EMBEDDED_COMPILER
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
{
/// TODO: Check
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
size_t arguments_size = arguments_types.size();
DataTypes non_nullable_types;
std::vector<llvm::Value * > wrapped_values;
std::vector<llvm::Value * > is_null_values;
non_nullable_types.resize(arguments_size);
wrapped_values.resize(arguments_size);
is_null_values.resize(arguments_size);
for (size_t i = 0; i < arguments_size; ++i)
{
const auto & argument_value = argument_values[i];
if (is_nullable[i])
{
auto * wrapped_value = b.CreateExtractValue(argument_value, {0});
if constexpr (null_is_skipped)
is_null_values[i] = b.CreateExtractValue(argument_value, {1});
wrapped_values[i] = wrapped_value;
non_nullable_types[i] = removeNullable(arguments_types[i]);
}
else
{
wrapped_values[i] = argument_value;
non_nullable_types[i] = arguments_types[i];
}
}
auto * head = b.GetInsertBlock();
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
auto * join_block_after_null_checks = llvm::BasicBlock::Create(head->getContext(), "join_block_after_null_checks", head->getParent());
if constexpr (null_is_skipped)
{
auto * values_have_null_ptr = b.CreateAlloca(b.getInt1Ty());
b.CreateStore(b.getInt1(false), values_have_null_ptr);
for (auto * is_null_value : is_null_values)
{
if (!is_null_value)
continue;
auto * values_have_null = b.CreateLoad(b.getInt1Ty(), values_have_null_ptr);
b.CreateStore(b.CreateOr(values_have_null, is_null_value), values_have_null_ptr);
}
b.CreateCondBr(b.CreateLoad(b.getInt1Ty(), values_have_null_ptr), join_block, join_block_after_null_checks);
}
b.SetInsertPoint(join_block_after_null_checks);
const auto & predicate_type = arguments_types[argument_values.size() - 1];
auto * predicate_value = argument_values[argument_values.size() - 1];
auto * is_predicate_true = nativeBoolCast(b, predicate_type, predicate_value);
auto * if_true = llvm::BasicBlock::Create(head->getContext(), "if_true", head->getParent());
auto * if_false = llvm::BasicBlock::Create(head->getContext(), "if_false", head->getParent());
b.CreateCondBr(is_predicate_true, if_true, if_false);
b.SetInsertPoint(if_false);
b.CreateBr(join_block);
b.SetInsertPoint(if_true);
if constexpr (result_is_nullable)
b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr);
auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, this->prefix_size);
this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, non_nullable_types, wrapped_values);
b.CreateBr(join_block);
b.SetInsertPoint(join_block);
}
#endif
private: private:
using Base = AggregateFunctionNullBase<result_is_nullable, serialize_flag, using Base = AggregateFunctionNullBase<result_is_nullable, serialize_flag,
AggregateFunctionIfNullVariadic<result_is_nullable, serialize_flag, null_is_skipped>>; AggregateFunctionIfNullVariadic<result_is_nullable, serialize_flag, null_is_skipped>>;

View File

@ -5,6 +5,14 @@
#include <Common/assert_cast.h> #include <Common/assert_cast.h>
#include <AggregateFunctions/IAggregateFunction.h> #include <AggregateFunctions/IAggregateFunction.h>
#if !defined(ARCADIA_BUILD)
# include <Common/config.h>
#endif
#if USE_EMBEDDED_COMPILER
# include <llvm/IR/IRBuilder.h>
# include <DataTypes/Native.h>
#endif
namespace DB namespace DB
{ {
@ -154,6 +162,76 @@ public:
const Array & params, const AggregateFunctionProperties & properties) const override; const Array & params, const AggregateFunctionProperties & properties) const override;
AggregateFunctionPtr getNestedFunction() const override { return nested_func; } AggregateFunctionPtr getNestedFunction() const override { return nested_func; }
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
return nested_func->isCompilable();
}
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
nested_func->compileCreate(builder, aggregate_data_ptr);
}
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
const auto & predicate_type = arguments_types[argument_values.size() - 1];
auto * predicate_value = argument_values[argument_values.size() - 1];
auto * head = b.GetInsertBlock();
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
auto * if_true = llvm::BasicBlock::Create(head->getContext(), "if_true", head->getParent());
auto * if_false = llvm::BasicBlock::Create(head->getContext(), "if_false", head->getParent());
auto * is_predicate_true = nativeBoolCast(b, predicate_type, predicate_value);
b.CreateCondBr(is_predicate_true, if_true, if_false);
b.SetInsertPoint(if_true);
size_t arguments_size_without_predicate = arguments_types.size() - 1;
DataTypes argument_types_without_predicate;
std::vector<llvm::Value *> argument_values_without_predicate;
argument_types_without_predicate.resize(arguments_size_without_predicate);
argument_values_without_predicate.resize(arguments_size_without_predicate);
for (size_t i = 0; i < arguments_size_without_predicate; ++i)
{
argument_types_without_predicate[i] = arguments_types[i];
argument_values_without_predicate[i] = argument_values[i];
}
nested_func->compileAdd(builder, aggregate_data_ptr, argument_types_without_predicate, argument_values_without_predicate);
b.CreateBr(join_block);
b.SetInsertPoint(if_false);
b.CreateBr(join_block);
b.SetInsertPoint(join_block);
}
void compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
{
nested_func->compileMerge(builder, aggregate_data_dst_ptr, aggregate_data_src_ptr);
}
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
return nested_func->compileGetResult(builder, aggregate_data_ptr);
}
#endif
}; };
} }

View File

@ -7,11 +7,20 @@
#include <Columns/ColumnDecimal.h> #include <Columns/ColumnDecimal.h>
#include <Columns/ColumnString.h> #include <Columns/ColumnString.h>
#include <DataTypes/IDataType.h> #include <DataTypes/IDataType.h>
#include <DataTypes/DataTypesNumber.h>
#include <common/StringRef.h> #include <common/StringRef.h>
#include <Common/assert_cast.h> #include <Common/assert_cast.h>
#include <AggregateFunctions/IAggregateFunction.h> #include <AggregateFunctions/IAggregateFunction.h>
#if !defined(ARCADIA_BUILD)
# include <Common/config.h>
#endif
#if USE_EMBEDDED_COMPILER
# include <llvm/IR/IRBuilder.h>
# include <DataTypes/Native.h>
#endif
namespace DB namespace DB
{ {
@ -20,6 +29,7 @@ struct Settings;
namespace ErrorCodes namespace ErrorCodes
{ {
extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NOT_IMPLEMENTED;
} }
/** Aggregate functions that store one of passed values. /** Aggregate functions that store one of passed values.
@ -177,6 +187,265 @@ public:
{ {
return false; return false;
} }
#if USE_EMBEDDED_COMPILER
static constexpr bool is_compilable = true;
static llvm::Value * getValuePtrFromAggregateDataPtr(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr)
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
static constexpr size_t value_offset_from_structure = offsetof(SingleValueDataFixed<T>, value);
auto * type = toNativeType<T>(builder);
auto * value_ptr_with_offset = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, value_offset_from_structure);
auto * value_ptr = b.CreatePointerCast(value_ptr_with_offset, type->getPointerTo());
return value_ptr;
}
static llvm::Value * getValueFromAggregateDataPtr(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr)
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * type = toNativeType<T>(builder);
auto * value_ptr = getValuePtrFromAggregateDataPtr(builder, aggregate_data_ptr);
return b.CreateLoad(type, value_ptr);
}
static void compileChange(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check)
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * has_value_ptr = b.CreatePointerCast(aggregate_data_ptr, b.getInt1Ty()->getPointerTo());
b.CreateStore(b.getInt1(true), has_value_ptr);
auto * value_ptr = getValuePtrFromAggregateDataPtr(b, aggregate_data_ptr);
b.CreateStore(value_to_check, value_ptr);
}
static void compileChangeMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr)
{
auto * value_src = getValueFromAggregateDataPtr(builder, aggregate_data_src_ptr);
compileChange(builder, aggregate_data_dst_ptr, value_src);
}
static void compileChangeFirstTime(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check)
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * has_value_ptr = b.CreatePointerCast(aggregate_data_ptr, b.getInt1Ty()->getPointerTo());
auto * has_value_value = b.CreateLoad(b.getInt1Ty(), has_value_ptr);
auto * head = b.GetInsertBlock();
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
auto * if_should_change = llvm::BasicBlock::Create(head->getContext(), "if_should_change", head->getParent());
auto * if_should_not_change = llvm::BasicBlock::Create(head->getContext(), "if_should_not_change", head->getParent());
b.CreateCondBr(has_value_value, if_should_not_change, if_should_change);
b.SetInsertPoint(if_should_not_change);
b.CreateBr(join_block);
b.SetInsertPoint(if_should_change);
compileChange(builder, aggregate_data_ptr, value_to_check);
b.CreateBr(join_block);
b.SetInsertPoint(join_block);
}
static void compileChangeFirstTimeMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr)
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * has_value_dst_ptr = b.CreatePointerCast(aggregate_data_dst_ptr, b.getInt1Ty()->getPointerTo());
auto * has_value_dst = b.CreateLoad(b.getInt1Ty(), has_value_dst_ptr);
auto * has_value_src_ptr = b.CreatePointerCast(aggregate_data_src_ptr, b.getInt1Ty()->getPointerTo());
auto * has_value_src = b.CreateLoad(b.getInt1Ty(), has_value_src_ptr);
auto * head = b.GetInsertBlock();
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
auto * if_should_change = llvm::BasicBlock::Create(head->getContext(), "if_should_change", head->getParent());
auto * if_should_not_change = llvm::BasicBlock::Create(head->getContext(), "if_should_not_change", head->getParent());
b.CreateCondBr(b.CreateAnd(b.CreateNot(has_value_dst), has_value_src), if_should_change, if_should_not_change);
b.SetInsertPoint(if_should_change);
compileChangeMerge(builder, aggregate_data_dst_ptr, aggregate_data_src_ptr);
b.CreateBr(join_block);
b.SetInsertPoint(if_should_not_change);
b.CreateBr(join_block);
b.SetInsertPoint(join_block);
}
static void compileChangeEveryTime(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check)
{
compileChange(builder, aggregate_data_ptr, value_to_check);
}
static void compileChangeEveryTimeMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr)
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * has_value_src_ptr = b.CreatePointerCast(aggregate_data_src_ptr, b.getInt1Ty()->getPointerTo());
auto * has_value_src = b.CreateLoad(b.getInt1Ty(), has_value_src_ptr);
auto * head = b.GetInsertBlock();
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
auto * if_should_change = llvm::BasicBlock::Create(head->getContext(), "if_should_change", head->getParent());
auto * if_should_not_change = llvm::BasicBlock::Create(head->getContext(), "if_should_not_change", head->getParent());
b.CreateCondBr(has_value_src, if_should_change, if_should_not_change);
b.SetInsertPoint(if_should_change);
compileChangeMerge(builder, aggregate_data_dst_ptr, aggregate_data_src_ptr);
b.CreateBr(join_block);
b.SetInsertPoint(if_should_not_change);
b.CreateBr(join_block);
b.SetInsertPoint(join_block);
}
template <bool is_less>
static void compileChangeComparison(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check)
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * has_value_ptr = b.CreatePointerCast(aggregate_data_ptr, b.getInt1Ty()->getPointerTo());
auto * has_value_value = b.CreateLoad(b.getInt1Ty(), has_value_ptr);
auto * value = getValueFromAggregateDataPtr(b, aggregate_data_ptr);
auto * head = b.GetInsertBlock();
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
auto * if_should_change = llvm::BasicBlock::Create(head->getContext(), "if_should_change", head->getParent());
auto * if_should_not_change = llvm::BasicBlock::Create(head->getContext(), "if_should_not_change", head->getParent());
auto is_signed = std::numeric_limits<T>::is_signed;
llvm::Value * should_change_after_comparison = nullptr;
if constexpr (is_less)
{
if (value_to_check->getType()->isIntegerTy())
should_change_after_comparison = is_signed ? b.CreateICmpSLT(value_to_check, value) : b.CreateICmpULT(value_to_check, value);
else
should_change_after_comparison = b.CreateFCmpOLT(value_to_check, value);
}
else
{
if (value_to_check->getType()->isIntegerTy())
should_change_after_comparison = is_signed ? b.CreateICmpSGT(value_to_check, value) : b.CreateICmpUGT(value_to_check, value);
else
should_change_after_comparison = b.CreateFCmpOGT(value_to_check, value);
}
b.CreateCondBr(b.CreateOr(b.CreateNot(has_value_value), should_change_after_comparison), if_should_change, if_should_not_change);
b.SetInsertPoint(if_should_change);
compileChange(builder, aggregate_data_ptr, value_to_check);
b.CreateBr(join_block);
b.SetInsertPoint(if_should_not_change);
b.CreateBr(join_block);
b.SetInsertPoint(join_block);
}
template <bool is_less>
static void compileChangeComparisonMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr)
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * has_value_dst_ptr = b.CreatePointerCast(aggregate_data_dst_ptr, b.getInt1Ty()->getPointerTo());
auto * has_value_dst = b.CreateLoad(b.getInt1Ty(), has_value_dst_ptr);
auto * value_dst = getValueFromAggregateDataPtr(b, aggregate_data_dst_ptr);
auto * has_value_src_ptr = b.CreatePointerCast(aggregate_data_src_ptr, b.getInt1Ty()->getPointerTo());
auto * has_value_src = b.CreateLoad(b.getInt1Ty(), has_value_src_ptr);
auto * value_src = getValueFromAggregateDataPtr(b, aggregate_data_src_ptr);
auto * head = b.GetInsertBlock();
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
auto * if_should_change = llvm::BasicBlock::Create(head->getContext(), "if_should_change", head->getParent());
auto * if_should_not_change = llvm::BasicBlock::Create(head->getContext(), "if_should_not_change", head->getParent());
auto is_signed = std::numeric_limits<T>::is_signed;
llvm::Value * should_change_after_comparison = nullptr;
if constexpr (is_less)
{
if (value_src->getType()->isIntegerTy())
should_change_after_comparison = is_signed ? b.CreateICmpSLT(value_src, value_dst) : b.CreateICmpULT(value_src, value_dst);
else
should_change_after_comparison = b.CreateFCmpOLT(value_src, value_dst);
}
else
{
if (value_src->getType()->isIntegerTy())
should_change_after_comparison = is_signed ? b.CreateICmpSGT(value_src, value_dst) : b.CreateICmpUGT(value_src, value_dst);
else
should_change_after_comparison = b.CreateFCmpOGT(value_src, value_dst);
}
b.CreateCondBr(b.CreateAnd(has_value_src, b.CreateOr(b.CreateNot(has_value_dst), should_change_after_comparison)), if_should_change, if_should_not_change);
b.SetInsertPoint(if_should_change);
compileChangeMerge(builder, aggregate_data_dst_ptr, aggregate_data_src_ptr);
b.CreateBr(join_block);
b.SetInsertPoint(if_should_not_change);
b.CreateBr(join_block);
b.SetInsertPoint(join_block);
}
static void compileChangeIfLess(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check)
{
static constexpr bool is_less = true;
compileChangeComparison<is_less>(builder, aggregate_data_ptr, value_to_check);
}
static void compileChangeIfLessMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr)
{
static constexpr bool is_less = true;
compileChangeComparisonMerge<is_less>(builder, aggregate_data_dst_ptr, aggregate_data_src_ptr);
}
static void compileChangeIfGreater(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check)
{
static constexpr bool is_less = false;
compileChangeComparison<is_less>(builder, aggregate_data_ptr, value_to_check);
}
static void compileChangeIfGreaterMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr)
{
static constexpr bool is_less = false;
compileChangeComparisonMerge<is_less>(builder, aggregate_data_dst_ptr, aggregate_data_src_ptr);
}
static llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr)
{
return getValueFromAggregateDataPtr(builder, aggregate_data_ptr);
}
#endif
}; };
@ -400,6 +669,13 @@ public:
{ {
return true; return true;
} }
#if USE_EMBEDDED_COMPILER
static constexpr bool is_compilable = false;
#endif
}; };
static_assert( static_assert(
@ -576,6 +852,13 @@ public:
{ {
return false; return false;
} }
#if USE_EMBEDDED_COMPILER
static constexpr bool is_compilable = false;
#endif
}; };
@ -593,6 +876,22 @@ struct AggregateFunctionMinData : Data
bool changeIfBetter(const Self & to, Arena * arena) { return this->changeIfLess(to, arena); } bool changeIfBetter(const Self & to, Arena * arena) { return this->changeIfLess(to, arena); }
static const char * name() { return "min"; } static const char * name() { return "min"; }
#if USE_EMBEDDED_COMPILER
static constexpr bool is_compilable = Data::is_compilable;
static void compileChangeIfBetter(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check)
{
Data::compileChangeIfLess(builder, aggregate_data_ptr, value_to_check);
}
static void compileChangeIfBetterMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr)
{
Data::compileChangeIfLessMerge(builder, aggregate_data_dst_ptr, aggregate_data_src_ptr);
}
#endif
}; };
template <typename Data> template <typename Data>
@ -604,6 +903,22 @@ struct AggregateFunctionMaxData : Data
bool changeIfBetter(const Self & to, Arena * arena) { return this->changeIfGreater(to, arena); } bool changeIfBetter(const Self & to, Arena * arena) { return this->changeIfGreater(to, arena); }
static const char * name() { return "max"; } static const char * name() { return "max"; }
#if USE_EMBEDDED_COMPILER
static constexpr bool is_compilable = Data::is_compilable;
static void compileChangeIfBetter(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check)
{
Data::compileChangeIfGreater(builder, aggregate_data_ptr, value_to_check);
}
static void compileChangeIfBetterMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr)
{
Data::compileChangeIfGreaterMerge(builder, aggregate_data_dst_ptr, aggregate_data_src_ptr);
}
#endif
}; };
template <typename Data> template <typename Data>
@ -615,6 +930,22 @@ struct AggregateFunctionAnyData : Data
bool changeIfBetter(const Self & to, Arena * arena) { return this->changeFirstTime(to, arena); } bool changeIfBetter(const Self & to, Arena * arena) { return this->changeFirstTime(to, arena); }
static const char * name() { return "any"; } static const char * name() { return "any"; }
#if USE_EMBEDDED_COMPILER
static constexpr bool is_compilable = Data::is_compilable;
static void compileChangeIfBetter(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check)
{
Data::compileChangeFirstTime(builder, aggregate_data_ptr, value_to_check);
}
static void compileChangeIfBetterMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr)
{
Data::compileChangeFirstTimeMerge(builder, aggregate_data_dst_ptr, aggregate_data_src_ptr);
}
#endif
}; };
template <typename Data> template <typename Data>
@ -626,6 +957,22 @@ struct AggregateFunctionAnyLastData : Data
bool changeIfBetter(const Self & to, Arena * arena) { return this->changeEveryTime(to, arena); } bool changeIfBetter(const Self & to, Arena * arena) { return this->changeEveryTime(to, arena); }
static const char * name() { return "anyLast"; } static const char * name() { return "anyLast"; }
#if USE_EMBEDDED_COMPILER
static constexpr bool is_compilable = Data::is_compilable;
static void compileChangeIfBetter(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check)
{
Data::compileChangeEveryTime(builder, aggregate_data_ptr, value_to_check);
}
static void compileChangeIfBetterMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr)
{
Data::compileChangeEveryTimeMerge(builder, aggregate_data_dst_ptr, aggregate_data_src_ptr);
}
#endif
}; };
@ -693,6 +1040,13 @@ struct AggregateFunctionAnyHeavyData : Data
} }
static const char * name() { return "anyHeavy"; } static const char * name() { return "anyHeavy"; }
#if USE_EMBEDDED_COMPILER
static constexpr bool is_compilable = false;
#endif
}; };
@ -752,6 +1106,62 @@ public:
{ {
this->data(place).insertResultInto(to); this->data(place).insertResultInto(to);
} }
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
if constexpr (!Data::is_compilable)
return false;
return canBeNativeType(*this->argument_types[0]);
}
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
b.CreateMemSet(aggregate_data_ptr, llvm::ConstantInt::get(b.getInt8Ty(), 0), this->sizeOfData(), llvm::assumeAligned(this->alignOfData()));
}
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes &, const std::vector<llvm::Value *> & argument_values) const override
{
if constexpr (Data::is_compilable)
{
Data::compileChangeIfBetter(builder, aggregate_data_ptr, argument_values[0]);
}
else
{
throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED);
}
}
void compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
{
if constexpr (Data::is_compilable)
{
Data::compileChangeIfBetterMerge(builder, aggregate_data_dst_ptr, aggregate_data_src_ptr);
}
else
{
throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED);
}
}
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
if constexpr (Data::is_compilable)
{
return Data::compileGetResult(builder, aggregate_data_ptr);
}
else
{
throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED);
}
}
#endif
}; };
} }

View File

@ -6,9 +6,18 @@
#include <Common/assert_cast.h> #include <Common/assert_cast.h>
#include <Columns/ColumnsCommon.h> #include <Columns/ColumnsCommon.h>
#include <DataTypes/DataTypeNullable.h> #include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h>
#include <IO/ReadHelpers.h> #include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h> #include <IO/WriteHelpers.h>
#if !defined(ARCADIA_BUILD)
# include <Common/config.h>
#endif
#if USE_EMBEDDED_COMPILER
# include <llvm/IR/IRBuilder.h>
# include <DataTypes/Native.h>
#endif
namespace DB namespace DB
{ {
@ -183,6 +192,93 @@ public:
} }
AggregateFunctionPtr getNestedFunction() const override { return nested_function; } AggregateFunctionPtr getNestedFunction() const override { return nested_function; }
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
return this->nested_function->isCompilable();
}
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
if constexpr (result_is_nullable)
b.CreateMemSet(aggregate_data_ptr, llvm::ConstantInt::get(b.getInt8Ty(), 0), this->prefix_size, llvm::assumeAligned(this->alignOfData()));
auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, this->prefix_size);
this->nested_function->compileCreate(b, aggregate_data_ptr_with_prefix_size_offset);
}
void compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
if constexpr (result_is_nullable)
{
auto * aggregate_data_is_null_dst_value = b.CreateLoad(aggregate_data_dst_ptr);
auto * aggregate_data_is_null_src_value = b.CreateLoad(aggregate_data_src_ptr);
auto * is_src_null = nativeBoolCast(b, std::make_shared<DataTypeUInt8>(), aggregate_data_is_null_src_value);
auto * is_null_result_value = b.CreateSelect(is_src_null, llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_is_null_dst_value);
b.CreateStore(is_null_result_value, aggregate_data_dst_ptr);
}
auto * aggregate_data_dst_ptr_with_prefix_size_offset = b.CreateConstGEP1_32(nullptr, aggregate_data_dst_ptr, this->prefix_size);
auto * aggregate_data_src_ptr_with_prefix_size_offset = b.CreateConstGEP1_32(nullptr, aggregate_data_src_ptr, this->prefix_size);
this->nested_function->compileMerge(b, aggregate_data_dst_ptr_with_prefix_size_offset, aggregate_data_src_ptr_with_prefix_size_offset);
}
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, this->getReturnType());
llvm::Value * result = nullptr;
if constexpr (result_is_nullable)
{
auto * place = b.CreateLoad(b.getInt8Ty(), aggregate_data_ptr);
auto * head = b.GetInsertBlock();
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
auto * if_null = llvm::BasicBlock::Create(head->getContext(), "if_null", head->getParent());
auto * if_not_null = llvm::BasicBlock::Create(head->getContext(), "if_not_null", head->getParent());
auto * nullable_value_ptr = b.CreateAlloca(return_type);
b.CreateStore(llvm::ConstantInt::getNullValue(return_type), nullable_value_ptr);
auto * nullable_value = b.CreateLoad(return_type, nullable_value_ptr);
b.CreateCondBr(nativeBoolCast(b, std::make_shared<DataTypeUInt8>(), place), if_not_null, if_null);
b.SetInsertPoint(if_null);
b.CreateStore(b.CreateInsertValue(nullable_value, b.getInt1(true), {1}), nullable_value_ptr);
b.CreateBr(join_block);
b.SetInsertPoint(if_not_null);
auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, this->prefix_size);
auto * nested_result = this->nested_function->compileGetResult(builder, aggregate_data_ptr_with_prefix_size_offset);
b.CreateStore(b.CreateInsertValue(nullable_value, nested_result, {0}), nullable_value_ptr);
b.CreateBr(join_block);
b.SetInsertPoint(join_block);
result = b.CreateLoad(return_type, nullable_value_ptr);
}
else
{
result = this->nested_function->compileGetResult(b, aggregate_data_ptr);
}
return result;
}
#endif
}; };
@ -226,6 +322,44 @@ public:
if (!memoryIsByte(null_map, batch_size, 1)) if (!memoryIsByte(null_map, batch_size, 1))
this->setFlag(place); this->setFlag(place);
} }
#if USE_EMBEDDED_COMPILER
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
const auto & nullable_type = arguments_types[0];
const auto & nullable_value = argument_values[0];
auto * wrapped_value = b.CreateExtractValue(nullable_value, {0});
auto * is_null_value = b.CreateExtractValue(nullable_value, {1});
auto * head = b.GetInsertBlock();
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
auto * if_null = llvm::BasicBlock::Create(head->getContext(), "if_null", head->getParent());
auto * if_not_null = llvm::BasicBlock::Create(head->getContext(), "if_not_null", head->getParent());
b.CreateCondBr(is_null_value, if_null, if_not_null);
b.SetInsertPoint(if_null);
b.CreateBr(join_block);
b.SetInsertPoint(if_not_null);
if constexpr (result_is_nullable)
b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr);
auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, this->prefix_size);
this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, { removeNullable(nullable_type) }, { wrapped_value });
b.CreateBr(join_block);
b.SetInsertPoint(join_block);
}
#endif
}; };
@ -277,6 +411,90 @@ public:
this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena); this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena);
} }
#if USE_EMBEDDED_COMPILER
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
size_t arguments_size = arguments_types.size();
DataTypes non_nullable_types;
std::vector<llvm::Value * > wrapped_values;
std::vector<llvm::Value * > is_null_values;
non_nullable_types.resize(arguments_size);
wrapped_values.resize(arguments_size);
is_null_values.resize(arguments_size);
for (size_t i = 0; i < arguments_size; ++i)
{
const auto & argument_value = argument_values[i];
if (is_nullable[i])
{
auto * wrapped_value = b.CreateExtractValue(argument_value, {0});
if constexpr (null_is_skipped)
is_null_values[i] = b.CreateExtractValue(argument_value, {1});
wrapped_values[i] = wrapped_value;
non_nullable_types[i] = removeNullable(arguments_types[i]);
}
else
{
wrapped_values[i] = argument_value;
non_nullable_types[i] = arguments_types[i];
}
}
if constexpr (null_is_skipped)
{
auto * head = b.GetInsertBlock();
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
auto * if_null = llvm::BasicBlock::Create(head->getContext(), "if_null", head->getParent());
auto * if_not_null = llvm::BasicBlock::Create(head->getContext(), "if_not_null", head->getParent());
auto * values_have_null_ptr = b.CreateAlloca(b.getInt1Ty());
b.CreateStore(b.getInt1(false), values_have_null_ptr);
for (auto * is_null_value : is_null_values)
{
if (!is_null_value)
continue;
auto * values_have_null = b.CreateLoad(b.getInt1Ty(), values_have_null_ptr);
b.CreateStore(b.CreateOr(values_have_null, is_null_value), values_have_null_ptr);
}
b.CreateCondBr(b.CreateLoad(b.getInt1Ty(), values_have_null_ptr), if_null, if_not_null);
b.SetInsertPoint(if_null);
b.CreateBr(join_block);
b.SetInsertPoint(if_not_null);
if constexpr (result_is_nullable)
b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr);
auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, this->prefix_size);
this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, arguments_types, wrapped_values);
b.CreateBr(join_block);
b.SetInsertPoint(join_block);
}
else
{
b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr);
auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, this->prefix_size);
this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, non_nullable_types, wrapped_values);
}
}
#endif
private: private:
enum { MAX_ARGS = 8 }; enum { MAX_ARGS = 8 };
size_t number_of_arguments = 0; size_t number_of_arguments = 0;

View File

@ -12,6 +12,14 @@
#include <AggregateFunctions/IAggregateFunction.h> #include <AggregateFunctions/IAggregateFunction.h>
#if !defined(ARCADIA_BUILD)
# include <Common/config.h>
#endif
#if USE_EMBEDDED_COMPILER
# include <llvm/IR/IRBuilder.h>
# include <DataTypes/Native.h>
#endif
namespace DB namespace DB
{ {
@ -385,6 +393,80 @@ public:
column.getData().push_back(this->data(place).get()); column.getData().push_back(this->data(place).get());
} }
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
if constexpr (Type == AggregateFunctionTypeSumKahan)
return false;
bool can_be_compiled = true;
for (const auto & argument_type : this->argument_types)
can_be_compiled &= canBeNativeType(*argument_type);
auto return_type = getReturnType();
can_be_compiled &= canBeNativeType(*return_type);
return can_be_compiled;
}
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * aggregate_sum_ptr = b.CreatePointerCast(aggregate_data_ptr, return_type->getPointerTo());
b.CreateStore(llvm::Constant::getNullValue(return_type), aggregate_sum_ptr);
}
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * sum_value_ptr = b.CreatePointerCast(aggregate_data_ptr, return_type->getPointerTo());
auto * sum_value = b.CreateLoad(return_type, sum_value_ptr);
const auto & argument_type = arguments_types[0];
const auto & argument_value = argument_values[0];
auto * value_cast_to_result = nativeCast(b, argument_type, argument_value, return_type);
auto * sum_result_value = sum_value->getType()->isIntegerTy() ? b.CreateAdd(sum_value, value_cast_to_result) : b.CreateFAdd(sum_value, value_cast_to_result);
b.CreateStore(sum_result_value, sum_value_ptr);
}
void compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * sum_value_dst_ptr = b.CreatePointerCast(aggregate_data_dst_ptr, return_type->getPointerTo());
auto * sum_value_dst = b.CreateLoad(return_type, sum_value_dst_ptr);
auto * sum_value_src_ptr = b.CreatePointerCast(aggregate_data_src_ptr, return_type->getPointerTo());
auto * sum_value_src = b.CreateLoad(return_type, sum_value_src_ptr);
auto * sum_return_value = sum_value_dst->getType()->isIntegerTy() ? b.CreateAdd(sum_value_dst, sum_value_src) : b.CreateFAdd(sum_value_dst, sum_value_src);
b.CreateStore(sum_return_value, sum_value_dst_ptr);
}
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * sum_value_ptr = b.CreatePointerCast(aggregate_data_ptr, return_type->getPointerTo());
return b.CreateLoad(return_type, sum_value_ptr);
}
#endif
private: private:
UInt32 scale; UInt32 scale;
}; };

View File

@ -48,6 +48,15 @@ public:
String getName() const final { return "sumCount"; } String getName() const final { return "sumCount"; }
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
return false;
}
#endif
private: private:
UInt32 scale; UInt32 scale;
}; };

View File

@ -10,4 +10,44 @@ DataTypePtr IAggregateFunction::getStateType() const
return std::make_shared<DataTypeAggregateFunction>(shared_from_this(), argument_types, parameters); return std::make_shared<DataTypeAggregateFunction>(shared_from_this(), argument_types, parameters);
} }
String IAggregateFunction::getDescription() const
{
String description;
description += getName();
description += '(';
for (const auto & parameter : parameters)
{
description += parameter.dump();
description += ", ";
}
if (!parameters.empty())
{
description.pop_back();
description.pop_back();
}
description += ')';
description += '(';
for (const auto & argument_type : argument_types)
{
description += argument_type->getName();
description += ", ";
}
if (!argument_types.empty())
{
description.pop_back();
description.pop_back();
}
description += ')';
return description;
}
} }

View File

@ -9,11 +9,21 @@
#include <Common/Exception.h> #include <Common/Exception.h>
#include <common/types.h> #include <common/types.h>
#if !defined(ARCADIA_BUILD)
# include "config_core.h"
#endif
#include <cstddef> #include <cstddef>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <type_traits> #include <type_traits>
namespace llvm
{
class LLVMContext;
class Value;
class IRBuilderBase;
}
namespace DB namespace DB
{ {
@ -208,6 +218,26 @@ public:
const IColumn ** columns, const IColumn ** columns,
Arena * arena) const = 0; Arena * arena) const = 0;
/** Insert result of aggregate function into result column with batch size.
* If destroy_place_after_insert is true. Then implementation of this method
* must destroy aggregate place if insert state into result column was successful.
* All places that were not inserted must be destroyed if there was exception during insert into result column.
*/
virtual void insertResultIntoBatch(
size_t batch_size,
AggregateDataPtr * places,
size_t place_offset,
IColumn & to,
Arena * arena,
bool destroy_place_after_insert) const = 0;
/** Destroy batch of aggregate places.
*/
virtual void destroyBatch(
size_t batch_size,
AggregateDataPtr * places,
size_t place_offset) const noexcept = 0;
/** By default all NULLs are skipped during aggregation. /** By default all NULLs are skipped during aggregation.
* If it returns nullptr, the default one will be used. * If it returns nullptr, the default one will be used.
* If an aggregate function wants to use something instead of the default one, it overrides this function and returns its own null adapter. * If an aggregate function wants to use something instead of the default one, it overrides this function and returns its own null adapter.
@ -241,6 +271,40 @@ public:
// of true window functions, so this hack-ish interface suffices. // of true window functions, so this hack-ish interface suffices.
virtual bool isOnlyWindowFunction() const { return false; } virtual bool isOnlyWindowFunction() const { return false; }
/// Description of AggregateFunction in form of name(parameters)(argument_types).
String getDescription() const;
#if USE_EMBEDDED_COMPILER
/// Is function JIT compilable
virtual bool isCompilable() const { return false; }
/// compileCreate should generate code for initialization of aggregate function state in aggregate_data_ptr
virtual void compileCreate(llvm::IRBuilderBase & /*builder*/, llvm::Value * /*aggregate_data_ptr*/) const
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
/// compileAdd should generate code for updating aggregate function state stored in aggregate_data_ptr
virtual void compileAdd(llvm::IRBuilderBase & /*builder*/, llvm::Value * /*aggregate_data_ptr*/, const DataTypes & /*arguments_types*/, const std::vector<llvm::Value *> & /*arguments_values*/) const
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
/// compileMerge should generate code for merging aggregate function states stored in aggregate_data_dst_ptr and aggregate_data_src_ptr
virtual void compileMerge(llvm::IRBuilderBase & /*builder*/, llvm::Value * /*aggregate_data_dst_ptr*/, llvm::Value * /*aggregate_data_src_ptr*/) const
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
/// compileGetResult should generate code for getting result value from aggregate function state stored in aggregate_data_ptr
virtual llvm::Value * compileGetResult(llvm::IRBuilderBase & /*builder*/, llvm::Value * /*aggregate_data_ptr*/) const
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
#endif
protected: protected:
DataTypes argument_types; DataTypes argument_types;
Array parameters; Array parameters;
@ -415,6 +479,37 @@ public:
static_cast<const Derived *>(this)->add(place + place_offset, columns, i, arena); static_cast<const Derived *>(this)->add(place + place_offset, columns, i, arena);
} }
} }
void insertResultIntoBatch(size_t batch_size, AggregateDataPtr * places, size_t place_offset, IColumn & to, Arena * arena, bool destroy_place_after_insert) const override
{
size_t batch_index = 0;
try
{
for (; batch_index < batch_size; ++batch_index)
{
static_cast<const Derived *>(this)->insertResultInto(places[batch_index] + place_offset, to, arena);
if (destroy_place_after_insert)
static_cast<const Derived *>(this)->destroy(places[batch_index] + place_offset);
}
}
catch (...)
{
for (size_t destroy_index = batch_index; destroy_index < batch_size; ++destroy_index)
static_cast<const Derived *>(this)->destroy(places[destroy_index] + place_offset);
throw;
}
}
void destroyBatch(size_t batch_size, AggregateDataPtr * places, size_t place_offset) const noexcept override
{
for (size_t i = 0; i < batch_size; ++i)
{
static_cast<const Derived *>(this)->destroy(places[i] + place_offset);
}
}
}; };

View File

@ -85,6 +85,7 @@ if (USE_AMQPCPP)
endif() endif()
if (USE_LIBPQXX) if (USE_LIBPQXX)
add_headers_and_sources(dbms Core/PostgreSQL)
add_headers_and_sources(dbms Databases/PostgreSQL) add_headers_and_sources(dbms Databases/PostgreSQL)
add_headers_and_sources(dbms Storages/PostgreSQL) add_headers_and_sources(dbms Storages/PostgreSQL)
endif() endif()

View File

@ -7,6 +7,11 @@
#pragma clang diagnostic ignored "-Wreserved-id-macro" #pragma clang diagnostic ignored "-Wreserved-id-macro"
#endif #endif
#undef __msan_unpoison
#undef __msan_test_shadow
#undef __msan_print_shadow
#undef __msan_unpoison_string
#define __msan_unpoison(X, Y) #define __msan_unpoison(X, Y)
#define __msan_test_shadow(X, Y) (false) #define __msan_test_shadow(X, Y) (false)
#define __msan_print_shadow(X, Y) #define __msan_print_shadow(X, Y)

View File

@ -0,0 +1,73 @@
#include "Connection.h"
#include <common/logger_useful.h>
namespace postgres
{
Connection::Connection(const ConnectionInfo & connection_info_, bool replication_, size_t num_tries_)
: connection_info(connection_info_), replication(replication_), num_tries(num_tries_)
, log(&Poco::Logger::get("PostgreSQLReplicaConnection"))
{
if (replication)
{
connection_info = std::make_pair(
fmt::format("{} replication=database", connection_info.first), connection_info.second);
}
}
void Connection::execWithRetry(const std::function<void(pqxx::nontransaction &)> & exec)
{
for (size_t try_no = 0; try_no < num_tries; ++try_no)
{
try
{
pqxx::nontransaction tx(getRef());
exec(tx);
}
catch (const pqxx::broken_connection & e)
{
LOG_DEBUG(log, "Cannot execute query due to connection failure, attempt: {}/{}. (Message: {})",
try_no, num_tries, e.what());
if (try_no == num_tries)
throw;
}
}
}
pqxx::connection & Connection::getRef()
{
connect();
assert(connection != nullptr);
return *connection;
}
void Connection::tryUpdateConnection()
{
try
{
updateConnection();
}
catch (const pqxx::broken_connection & e)
{
LOG_ERROR(log, "Unable to update connection: {}", e.what());
}
}
void Connection::updateConnection()
{
if (connection)
connection->close();
/// Always throws if there is no connection.
connection = std::make_unique<pqxx::connection>(connection_info.first);
if (replication)
connection->set_variable("default_transaction_isolation", "'repeatable read'");
LOG_DEBUG(&Poco::Logger::get("PostgreSQLConnection"), "New connection to {}", connection_info.second);
}
void Connection::connect()
{
if (!connection || !connection->is_open())
updateConnection();
}
}

View File

@ -0,0 +1,47 @@
#pragma once
#include <pqxx/pqxx> // Y_IGNORE
#include <Core/Types.h>
#include <boost/noncopyable.hpp>
/* Methods to work with PostgreSQL connection object.
* Should only be used in case there has to be a single connection object, which
* is long-lived and there are no concurrent connection queries.
* Now only use case - for replication handler for replication from PostgreSQL.
* In all other integration engine use pool with failover.
**/
namespace Poco { class Logger; }
namespace postgres
{
using ConnectionInfo = std::pair<String, String>;
using ConnectionPtr = std::unique_ptr<pqxx::connection>;
class Connection : private boost::noncopyable
{
public:
Connection(const ConnectionInfo & connection_info_, bool replication_ = false, size_t num_tries = 3);
void execWithRetry(const std::function<void(pqxx::nontransaction &)> & exec);
pqxx::connection & getRef();
void connect();
void tryUpdateConnection();
const ConnectionInfo & getConnectionInfo() { return connection_info; }
private:
void updateConnection();
ConnectionPtr connection;
ConnectionInfo connection_info;
bool replication;
size_t num_tries;
Poco::Logger * log;
};
}

View File

@ -1,7 +1,7 @@
#include <Storages/PostgreSQL/PoolWithFailover.h> #include "PoolWithFailover.h"
#include "Utils.h"
#include <Common/parseRemoteDescription.h> #include <Common/parseRemoteDescription.h>
#include <Common/Exception.h> #include <Common/Exception.h>
#include <IO/Operators.h>
namespace DB namespace DB
{ {
@ -14,18 +14,6 @@ namespace ErrorCodes
namespace postgres namespace postgres
{ {
String formatConnectionString(String dbname, String host, UInt16 port, String user, String password)
{
DB::WriteBufferFromOwnString out;
out << "dbname=" << DB::quote << dbname
<< " host=" << DB::quote << host
<< " port=" << port
<< " user=" << DB::quote << user
<< " password=" << DB::quote << password
<< " connect_timeout=10";
return out.str();
}
PoolWithFailover::PoolWithFailover( PoolWithFailover::PoolWithFailover(
const Poco::Util::AbstractConfiguration & config, const String & config_prefix, const Poco::Util::AbstractConfiguration & config, const String & config_prefix,
size_t pool_size, size_t pool_wait_timeout_, size_t max_tries_) size_t pool_size, size_t pool_wait_timeout_, size_t max_tries_)
@ -58,14 +46,14 @@ PoolWithFailover::PoolWithFailover(
auto replica_user = config.getString(replica_name + ".user", user); auto replica_user = config.getString(replica_name + ".user", user);
auto replica_password = config.getString(replica_name + ".password", password); auto replica_password = config.getString(replica_name + ".password", password);
auto connection_string = formatConnectionString(db, replica_host, replica_port, replica_user, replica_password); auto connection_string = formatConnectionString(db, replica_host, replica_port, replica_user, replica_password).first;
replicas_with_priority[priority].emplace_back(connection_string, pool_size); replicas_with_priority[priority].emplace_back(connection_string, pool_size);
} }
} }
} }
else else
{ {
auto connection_string = formatConnectionString(db, host, port, user, password); auto connection_string = formatConnectionString(db, host, port, user, password).first;
replicas_with_priority[0].emplace_back(connection_string, pool_size); replicas_with_priority[0].emplace_back(connection_string, pool_size);
} }
} }
@ -85,7 +73,7 @@ PoolWithFailover::PoolWithFailover(
for (const auto & [host, port] : addresses) for (const auto & [host, port] : addresses)
{ {
LOG_DEBUG(&Poco::Logger::get("PostgreSQLPoolWithFailover"), "Adding address host: {}, port: {} to connection pool", host, port); LOG_DEBUG(&Poco::Logger::get("PostgreSQLPoolWithFailover"), "Adding address host: {}, port: {} to connection pool", host, port);
auto connection_string = formatConnectionString(database, host, port, user, password); auto connection_string = formatConnectionString(database, host, port, user, password).first;
replicas_with_priority[0].emplace_back(connection_string, pool_size); replicas_with_priority[0].emplace_back(connection_string, pool_size);
} }
} }

View File

@ -1,16 +1,14 @@
#pragma once #pragma once
#include "ConnectionHolder.h"
#include <mutex> #include <mutex>
#include <Poco/Util/AbstractConfiguration.h> #include <Poco/Util/AbstractConfiguration.h>
#include <Storages/PostgreSQL/ConnectionHolder.h>
#include <common/logger_useful.h> #include <common/logger_useful.h>
namespace postgres namespace postgres
{ {
String formatConnectionString(String dbname, String host, UInt16 port, String user, String password);
class PoolWithFailover class PoolWithFailover
{ {

View File

@ -0,0 +1,19 @@
#include "Utils.h"
#include <IO/Operators.h>
namespace postgres
{
ConnectionInfo formatConnectionString(String dbname, String host, UInt16 port, String user, String password)
{
DB::WriteBufferFromOwnString out;
out << "dbname=" << DB::quote << dbname
<< " host=" << DB::quote << host
<< " port=" << port
<< " user=" << DB::quote << user
<< " password=" << DB::quote << password
<< " connect_timeout=10";
return std::make_pair(out.str(), host + ':' + DB::toString(port));
}
}

View File

@ -0,0 +1,17 @@
#pragma once
#include <pqxx/pqxx> // Y_IGNORE
#include <Core/Types.h>
#include "Connection.h"
#include <Common/Exception.h>
namespace pqxx
{
using ReadTransaction = pqxx::read_transaction;
using ReplicationTransaction = pqxx::transaction<isolation_level::repeatable_read, write_policy::read_only>;
}
namespace postgres
{
ConnectionInfo formatConnectionString(String dbname, String host, UInt16 port, String user, String password);
}

View File

@ -0,0 +1,241 @@
#include "insertPostgreSQLValue.h"
#if USE_LIBPQXX
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnDecimal.h>
#include <DataTypes/IDataType.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Interpreters/convertFieldToType.h>
#include <IO/ReadHelpers.h>
#include <IO/ReadBufferFromString.h>
#include <Common/assert_cast.h>
#include <pqxx/pqxx> // Y_IGNORE
namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
}
void insertDefaultPostgreSQLValue(IColumn & column, const IColumn & sample_column)
{
column.insertFrom(sample_column, 0);
}
void insertPostgreSQLValue(
IColumn & column, std::string_view value,
const ExternalResultDescription::ValueType type, const DataTypePtr data_type,
std::unordered_map<size_t, PostgreSQLArrayInfo> & array_info, size_t idx)
{
switch (type)
{
case ExternalResultDescription::ValueType::vtUInt8:
{
if (value == "t")
assert_cast<ColumnUInt8 &>(column).insertValue(1);
else if (value == "f")
assert_cast<ColumnUInt8 &>(column).insertValue(0);
else
assert_cast<ColumnUInt8 &>(column).insertValue(pqxx::from_string<uint16_t>(value));
break;
}
case ExternalResultDescription::ValueType::vtUInt16:
assert_cast<ColumnUInt16 &>(column).insertValue(pqxx::from_string<uint16_t>(value));
break;
case ExternalResultDescription::ValueType::vtUInt32:
assert_cast<ColumnUInt32 &>(column).insertValue(pqxx::from_string<uint32_t>(value));
break;
case ExternalResultDescription::ValueType::vtUInt64:
assert_cast<ColumnUInt64 &>(column).insertValue(pqxx::from_string<uint64_t>(value));
break;
case ExternalResultDescription::ValueType::vtInt8:
assert_cast<ColumnInt8 &>(column).insertValue(pqxx::from_string<int16_t>(value));
break;
case ExternalResultDescription::ValueType::vtInt16:
assert_cast<ColumnInt16 &>(column).insertValue(pqxx::from_string<int16_t>(value));
break;
case ExternalResultDescription::ValueType::vtInt32:
assert_cast<ColumnInt32 &>(column).insertValue(pqxx::from_string<int32_t>(value));
break;
case ExternalResultDescription::ValueType::vtInt64:
assert_cast<ColumnInt64 &>(column).insertValue(pqxx::from_string<int64_t>(value));
break;
case ExternalResultDescription::ValueType::vtFloat32:
assert_cast<ColumnFloat32 &>(column).insertValue(pqxx::from_string<float>(value));
break;
case ExternalResultDescription::ValueType::vtFloat64:
assert_cast<ColumnFloat64 &>(column).insertValue(pqxx::from_string<double>(value));
break;
case ExternalResultDescription::ValueType::vtEnum8:[[fallthrough]];
case ExternalResultDescription::ValueType::vtEnum16:[[fallthrough]];
case ExternalResultDescription::ValueType::vtFixedString:[[fallthrough]];
case ExternalResultDescription::ValueType::vtString:
assert_cast<ColumnString &>(column).insertData(value.data(), value.size());
break;
case ExternalResultDescription::ValueType::vtUUID:
assert_cast<ColumnUInt128 &>(column).insert(parse<UUID>(value.data(), value.size()));
break;
case ExternalResultDescription::ValueType::vtDate:
assert_cast<ColumnUInt16 &>(column).insertValue(UInt16{LocalDate{std::string(value)}.getDayNum()});
break;
case ExternalResultDescription::ValueType::vtDateTime:
{
ReadBufferFromString in(value);
time_t time = 0;
readDateTimeText(time, in, assert_cast<const DataTypeDateTime *>(data_type.get())->getTimeZone());
if (time < 0)
time = 0;
assert_cast<ColumnUInt32 &>(column).insertValue(time);
break;
}
case ExternalResultDescription::ValueType::vtDateTime64:[[fallthrough]];
case ExternalResultDescription::ValueType::vtDecimal32: [[fallthrough]];
case ExternalResultDescription::ValueType::vtDecimal64: [[fallthrough]];
case ExternalResultDescription::ValueType::vtDecimal128: [[fallthrough]];
case ExternalResultDescription::ValueType::vtDecimal256:
{
ReadBufferFromString istr(value);
data_type->getDefaultSerialization()->deserializeWholeText(column, istr, FormatSettings{});
break;
}
case ExternalResultDescription::ValueType::vtArray:
{
pqxx::array_parser parser{value};
std::pair<pqxx::array_parser::juncture, std::string> parsed = parser.get_next();
size_t dimension = 0, max_dimension = 0, expected_dimensions = array_info[idx].num_dimensions;
const auto parse_value = array_info[idx].pqxx_parser;
std::vector<Row> dimensions(expected_dimensions + 1);
while (parsed.first != pqxx::array_parser::juncture::done)
{
if ((parsed.first == pqxx::array_parser::juncture::row_start) && (++dimension > expected_dimensions))
throw Exception("Got more dimensions than expected", ErrorCodes::BAD_ARGUMENTS);
else if (parsed.first == pqxx::array_parser::juncture::string_value)
dimensions[dimension].emplace_back(parse_value(parsed.second));
else if (parsed.first == pqxx::array_parser::juncture::null_value)
dimensions[dimension].emplace_back(array_info[idx].default_value);
else if (parsed.first == pqxx::array_parser::juncture::row_end)
{
max_dimension = std::max(max_dimension, dimension);
--dimension;
if (dimension == 0)
break;
dimensions[dimension].emplace_back(Array(dimensions[dimension + 1].begin(), dimensions[dimension + 1].end()));
dimensions[dimension + 1].clear();
}
parsed = parser.get_next();
}
if (max_dimension < expected_dimensions)
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"Got less dimensions than expected. ({} instead of {})", max_dimension, expected_dimensions);
assert_cast<ColumnArray &>(column).insert(Array(dimensions[1].begin(), dimensions[1].end()));
break;
}
}
}
void preparePostgreSQLArrayInfo(
std::unordered_map<size_t, PostgreSQLArrayInfo> & array_info, size_t column_idx, const DataTypePtr data_type)
{
const auto * array_type = typeid_cast<const DataTypeArray *>(data_type.get());
auto nested = array_type->getNestedType();
size_t count_dimensions = 1;
while (isArray(nested))
{
++count_dimensions;
nested = typeid_cast<const DataTypeArray *>(nested.get())->getNestedType();
}
Field default_value = nested->getDefault();
if (nested->isNullable())
nested = static_cast<const DataTypeNullable *>(nested.get())->getNestedType();
WhichDataType which(nested);
std::function<Field(std::string & fields)> parser;
if (which.isUInt8() || which.isUInt16())
parser = [](std::string & field) -> Field { return pqxx::from_string<uint16_t>(field); };
else if (which.isInt8() || which.isInt16())
parser = [](std::string & field) -> Field { return pqxx::from_string<int16_t>(field); };
else if (which.isUInt32())
parser = [](std::string & field) -> Field { return pqxx::from_string<uint32_t>(field); };
else if (which.isInt32())
parser = [](std::string & field) -> Field { return pqxx::from_string<int32_t>(field); };
else if (which.isUInt64())
parser = [](std::string & field) -> Field { return pqxx::from_string<uint64_t>(field); };
else if (which.isInt64())
parser = [](std::string & field) -> Field { return pqxx::from_string<int64_t>(field); };
else if (which.isFloat32())
parser = [](std::string & field) -> Field { return pqxx::from_string<float>(field); };
else if (which.isFloat64())
parser = [](std::string & field) -> Field { return pqxx::from_string<double>(field); };
else if (which.isString() || which.isFixedString())
parser = [](std::string & field) -> Field { return field; };
else if (which.isDate())
parser = [](std::string & field) -> Field { return UInt16{LocalDate{field}.getDayNum()}; };
else if (which.isDateTime())
parser = [nested](std::string & field) -> Field
{
ReadBufferFromString in(field);
time_t time = 0;
readDateTimeText(time, in, assert_cast<const DataTypeDateTime *>(nested.get())->getTimeZone());
return time;
};
else if (which.isDecimal32())
parser = [nested](std::string & field) -> Field
{
const auto & type = typeid_cast<const DataTypeDecimal<Decimal32> *>(nested.get());
DataTypeDecimal<Decimal32> res(getDecimalPrecision(*type), getDecimalScale(*type));
return convertFieldToType(field, res);
};
else if (which.isDecimal64())
parser = [nested](std::string & field) -> Field
{
const auto & type = typeid_cast<const DataTypeDecimal<Decimal64> *>(nested.get());
DataTypeDecimal<Decimal64> res(getDecimalPrecision(*type), getDecimalScale(*type));
return convertFieldToType(field, res);
};
else if (which.isDecimal128())
parser = [nested](std::string & field) -> Field
{
const auto & type = typeid_cast<const DataTypeDecimal<Decimal128> *>(nested.get());
DataTypeDecimal<Decimal128> res(getDecimalPrecision(*type), getDecimalScale(*type));
return convertFieldToType(field, res);
};
else if (which.isDecimal256())
parser = [nested](std::string & field) -> Field
{
const auto & type = typeid_cast<const DataTypeDecimal<Decimal256> *>(nested.get());
DataTypeDecimal<Decimal256> res(getDecimalPrecision(*type), getDecimalScale(*type));
return convertFieldToType(field, res);
};
else
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Type conversion to {} is not supported", nested->getName());
array_info[column_idx] = {count_dimensions, default_value, parser};
}
}
#endif

View File

@ -0,0 +1,38 @@
#pragma once
#if !defined(ARCADIA_BUILD)
#include "config_core.h"
#endif
#if USE_LIBPQXX
#include <Core/Block.h>
#include <DataStreams/IBlockInputStream.h>
#include <Core/ExternalResultDescription.h>
#include <Core/Field.h>
namespace DB
{
struct PostgreSQLArrayInfo
{
size_t num_dimensions;
Field default_value;
std::function<Field(std::string & field)> pqxx_parser;
};
void insertPostgreSQLValue(
IColumn & column, std::string_view value,
const ExternalResultDescription::ValueType type, const DataTypePtr data_type,
std::unordered_map<size_t, PostgreSQLArrayInfo> & array_info, size_t idx);
void preparePostgreSQLArrayInfo(
std::unordered_map<size_t, PostgreSQLArrayInfo> & array_info, size_t column_idx, const DataTypePtr data_type);
void insertDefaultPostgreSQLValue(IColumn & column, const IColumn & sample_column);
}
#endif

View File

@ -93,6 +93,7 @@ class IColumn;
M(Bool, distributed_directory_monitor_split_batch_on_failure, false, "Should StorageDistributed DirectoryMonitors try to split batch into smaller in case of failures.", 0) \ M(Bool, distributed_directory_monitor_split_batch_on_failure, false, "Should StorageDistributed DirectoryMonitors try to split batch into smaller in case of failures.", 0) \
\ \
M(Bool, optimize_move_to_prewhere, true, "Allows disabling WHERE to PREWHERE optimization in SELECT queries from MergeTree.", 0) \ M(Bool, optimize_move_to_prewhere, true, "Allows disabling WHERE to PREWHERE optimization in SELECT queries from MergeTree.", 0) \
M(Bool, optimize_move_to_prewhere_if_final, false, "If query has `FINAL`, the optimization `move_to_prewhere` is not always correct and it is enabled only if both settings `optimize_move_to_prewhere` and `optimize_move_to_prewhere_if_final` are turned on", 0) \
\ \
M(UInt64, replication_alter_partitions_sync, 1, "Wait for actions to manipulate the partitions. 0 - do not wait, 1 - wait for execution only of itself, 2 - wait for everyone.", 0) \ M(UInt64, replication_alter_partitions_sync, 1, "Wait for actions to manipulate the partitions. 0 - do not wait, 1 - wait for execution only of itself, 2 - wait for everyone.", 0) \
M(UInt64, replication_alter_columns_timeout, 60, "Wait for actions to change the table structure within the specified number of seconds. 0 - wait unlimited time.", 0) \ M(UInt64, replication_alter_columns_timeout, 60, "Wait for actions to change the table structure within the specified number of seconds. 0 - wait unlimited time.", 0) \
@ -106,6 +107,8 @@ class IColumn;
M(Bool, allow_suspicious_low_cardinality_types, false, "In CREATE TABLE statement allows specifying LowCardinality modifier for types of small fixed size (8 or less). Enabling this may increase merge times and memory consumption.", 0) \ M(Bool, allow_suspicious_low_cardinality_types, false, "In CREATE TABLE statement allows specifying LowCardinality modifier for types of small fixed size (8 or less). Enabling this may increase merge times and memory consumption.", 0) \
M(Bool, compile_expressions, true, "Compile some scalar functions and operators to native code.", 0) \ M(Bool, compile_expressions, true, "Compile some scalar functions and operators to native code.", 0) \
M(UInt64, min_count_to_compile_expression, 3, "The number of identical expressions before they are JIT-compiled", 0) \ M(UInt64, min_count_to_compile_expression, 3, "The number of identical expressions before they are JIT-compiled", 0) \
M(Bool, compile_aggregate_expressions, true, "Compile aggregate functions to native code.", 0) \
M(UInt64, min_count_to_compile_aggregate_expression, 0, "The number of identical aggreagte expressions before they are JIT-compiled", 0) \
M(UInt64, group_by_two_level_threshold, 100000, "From what number of keys, a two-level aggregation starts. 0 - the threshold is not set.", 0) \ M(UInt64, group_by_two_level_threshold, 100000, "From what number of keys, a two-level aggregation starts. 0 - the threshold is not set.", 0) \
M(UInt64, group_by_two_level_threshold_bytes, 50000000, "From what size of the aggregation state in bytes, a two-level aggregation begins to be used. 0 - the threshold is not set. Two-level aggregation is used when at least one of the thresholds is triggered.", 0) \ M(UInt64, group_by_two_level_threshold_bytes, 50000000, "From what size of the aggregation state in bytes, a two-level aggregation begins to be used. 0 - the threshold is not set. Two-level aggregation is used when at least one of the thresholds is triggered.", 0) \
M(Bool, distributed_aggregation_memory_efficient, true, "Is the memory-saving mode of distributed aggregation enabled.", 0) \ M(Bool, distributed_aggregation_memory_efficient, true, "Is the memory-saving mode of distributed aggregation enabled.", 0) \
@ -429,6 +432,7 @@ class IColumn;
M(Bool, cast_keep_nullable, false, "CAST operator keep Nullable for result data type", 0) \ M(Bool, cast_keep_nullable, false, "CAST operator keep Nullable for result data type", 0) \
M(Bool, alter_partition_verbose_result, false, "Output information about affected parts. Currently works only for FREEZE and ATTACH commands.", 0) \ M(Bool, alter_partition_verbose_result, false, "Output information about affected parts. Currently works only for FREEZE and ATTACH commands.", 0) \
M(Bool, allow_experimental_database_materialize_mysql, false, "Allow to create database with Engine=MaterializeMySQL(...).", 0) \ M(Bool, allow_experimental_database_materialize_mysql, false, "Allow to create database with Engine=MaterializeMySQL(...).", 0) \
M(Bool, allow_experimental_database_materialized_postgresql, false, "Allow to create database with Engine=MaterializedPostgreSQL(...).", 0) \
M(Bool, system_events_show_zero_values, false, "Include all metrics, even with zero values", 0) \ M(Bool, system_events_show_zero_values, false, "Include all metrics, even with zero values", 0) \
M(MySQLDataTypesSupport, mysql_datatypes_support_level, 0, "Which MySQL types should be converted to corresponding ClickHouse types (rather than being represented as String). Can be empty or any combination of 'decimal' or 'datetime64'. When empty MySQL's DECIMAL and DATETIME/TIMESTAMP with non-zero precision are seen as String on ClickHouse's side.", 0) \ M(MySQLDataTypesSupport, mysql_datatypes_support_level, 0, "Which MySQL types should be converted to corresponding ClickHouse types (rather than being represented as String). Can be empty or any combination of 'decimal' or 'datetime64'. When empty MySQL's DECIMAL and DATETIME/TIMESTAMP with non-zero precision are seen as String on ClickHouse's side.", 0) \
M(Bool, optimize_trivial_insert_select, true, "Optimize trivial 'INSERT INTO table SELECT ... FROM TABLES' query", 0) \ M(Bool, optimize_trivial_insert_select, true, "Optimize trivial 'INSERT INTO table SELECT ... FROM TABLES' query", 0) \

View File

@ -22,12 +22,9 @@
namespace DB namespace DB
{ {
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
}
PostgreSQLBlockInputStream::PostgreSQLBlockInputStream( template<typename T>
PostgreSQLBlockInputStream<T>::PostgreSQLBlockInputStream(
postgres::ConnectionHolderPtr connection_holder_, postgres::ConnectionHolderPtr connection_holder_,
const std::string & query_str_, const std::string & query_str_,
const Block & sample_block, const Block & sample_block,
@ -35,25 +32,52 @@ PostgreSQLBlockInputStream::PostgreSQLBlockInputStream(
: query_str(query_str_) : query_str(query_str_)
, max_block_size(max_block_size_) , max_block_size(max_block_size_)
, connection_holder(std::move(connection_holder_)) , connection_holder(std::move(connection_holder_))
{
init(sample_block);
}
template<typename T>
PostgreSQLBlockInputStream<T>::PostgreSQLBlockInputStream(
std::shared_ptr<T> tx_,
const std::string & query_str_,
const Block & sample_block,
const UInt64 max_block_size_,
bool auto_commit_)
: query_str(query_str_)
, tx(std::move(tx_))
, max_block_size(max_block_size_)
, auto_commit(auto_commit_)
{
init(sample_block);
}
template<typename T>
void PostgreSQLBlockInputStream<T>::init(const Block & sample_block)
{ {
description.init(sample_block); description.init(sample_block);
for (const auto idx : collections::range(0, description.sample_block.columns())) for (const auto idx : collections::range(0, description.sample_block.columns()))
if (description.types[idx].first == ValueType::vtArray) if (description.types[idx].first == ExternalResultDescription::ValueType::vtArray)
prepareArrayInfo(idx, description.sample_block.getByPosition(idx).type); preparePostgreSQLArrayInfo(array_info, idx, description.sample_block.getByPosition(idx).type);
/// pqxx::stream_from uses COPY command, will get error if ';' is present /// pqxx::stream_from uses COPY command, will get error if ';' is present
if (query_str.ends_with(';')) if (query_str.ends_with(';'))
query_str.resize(query_str.size() - 1); query_str.resize(query_str.size() - 1);
} }
void PostgreSQLBlockInputStream::readPrefix() template<typename T>
void PostgreSQLBlockInputStream<T>::readPrefix()
{ {
tx = std::make_unique<pqxx::read_transaction>(connection_holder->get()); tx = std::make_shared<T>(connection_holder->get());
stream = std::make_unique<pqxx::stream_from>(*tx, pqxx::from_query, std::string_view(query_str)); stream = std::make_unique<pqxx::stream_from>(*tx, pqxx::from_query, std::string_view(query_str));
} }
Block PostgreSQLBlockInputStream::readImpl() template<typename T>
Block PostgreSQLBlockInputStream<T>::readImpl()
{ {
/// Check if pqxx::stream_from is finished /// Check if pqxx::stream_from is finished
if (!stream || !(*stream)) if (!stream || !(*stream))
@ -81,17 +105,22 @@ Block PostgreSQLBlockInputStream::readImpl()
{ {
ColumnNullable & column_nullable = assert_cast<ColumnNullable &>(*columns[idx]); ColumnNullable & column_nullable = assert_cast<ColumnNullable &>(*columns[idx]);
const auto & data_type = assert_cast<const DataTypeNullable &>(*sample.type); const auto & data_type = assert_cast<const DataTypeNullable &>(*sample.type);
insertValue(column_nullable.getNestedColumn(), (*row)[idx], description.types[idx].first, data_type.getNestedType(), idx);
insertPostgreSQLValue(
column_nullable.getNestedColumn(), (*row)[idx],
description.types[idx].first, data_type.getNestedType(), array_info, idx);
column_nullable.getNullMapData().emplace_back(0); column_nullable.getNullMapData().emplace_back(0);
} }
else else
{ {
insertValue(*columns[idx], (*row)[idx], description.types[idx].first, sample.type, idx); insertPostgreSQLValue(
*columns[idx], (*row)[idx], description.types[idx].first, sample.type, array_info, idx);
} }
} }
else else
{ {
insertDefaultValue(*columns[idx], *sample.column); insertDefaultPostgreSQLValue(*columns[idx], *sample.column);
} }
} }
@ -104,216 +133,23 @@ Block PostgreSQLBlockInputStream::readImpl()
} }
void PostgreSQLBlockInputStream::readSuffix() template<typename T>
void PostgreSQLBlockInputStream<T>::readSuffix()
{ {
if (stream) if (stream)
{ {
stream->complete(); stream->complete();
tx->commit();
if (auto_commit)
tx->commit();
} }
} }
template
class PostgreSQLBlockInputStream<pqxx::ReplicationTransaction>;
void PostgreSQLBlockInputStream::insertValue(IColumn & column, std::string_view value, template
const ExternalResultDescription::ValueType type, const DataTypePtr data_type, size_t idx) class PostgreSQLBlockInputStream<pqxx::ReadTransaction>;
{
switch (type)
{
case ValueType::vtUInt8:
{
if (value == "t")
assert_cast<ColumnUInt8 &>(column).insertValue(1);
else if (value == "f")
assert_cast<ColumnUInt8 &>(column).insertValue(0);
else
assert_cast<ColumnUInt8 &>(column).insertValue(pqxx::from_string<uint16_t>(value));
break;
}
case ValueType::vtUInt16:
assert_cast<ColumnUInt16 &>(column).insertValue(pqxx::from_string<uint16_t>(value));
break;
case ValueType::vtUInt32:
assert_cast<ColumnUInt32 &>(column).insertValue(pqxx::from_string<uint32_t>(value));
break;
case ValueType::vtUInt64:
assert_cast<ColumnUInt64 &>(column).insertValue(pqxx::from_string<uint64_t>(value));
break;
case ValueType::vtInt8:
assert_cast<ColumnInt8 &>(column).insertValue(pqxx::from_string<int16_t>(value));
break;
case ValueType::vtInt16:
assert_cast<ColumnInt16 &>(column).insertValue(pqxx::from_string<int16_t>(value));
break;
case ValueType::vtInt32:
assert_cast<ColumnInt32 &>(column).insertValue(pqxx::from_string<int32_t>(value));
break;
case ValueType::vtInt64:
assert_cast<ColumnInt64 &>(column).insertValue(pqxx::from_string<int64_t>(value));
break;
case ValueType::vtFloat32:
assert_cast<ColumnFloat32 &>(column).insertValue(pqxx::from_string<float>(value));
break;
case ValueType::vtFloat64:
assert_cast<ColumnFloat64 &>(column).insertValue(pqxx::from_string<double>(value));
break;
case ValueType::vtFixedString:[[fallthrough]];
case ValueType::vtEnum8:
case ValueType::vtEnum16:
case ValueType::vtString:
assert_cast<ColumnString &>(column).insertData(value.data(), value.size());
break;
case ValueType::vtUUID:
assert_cast<ColumnUUID &>(column).insert(parse<UUID>(value.data(), value.size()));
break;
case ValueType::vtDate:
assert_cast<ColumnUInt16 &>(column).insertValue(UInt16{LocalDate{std::string(value)}.getDayNum()});
break;
case ValueType::vtDateTime:
{
ReadBufferFromString in(value);
time_t time = 0;
readDateTimeText(time, in, assert_cast<const DataTypeDateTime *>(data_type.get())->getTimeZone());
if (time < 0)
time = 0;
assert_cast<ColumnUInt32 &>(column).insertValue(time);
break;
}
case ValueType::vtDateTime64:[[fallthrough]];
case ValueType::vtDecimal32: [[fallthrough]];
case ValueType::vtDecimal64: [[fallthrough]];
case ValueType::vtDecimal128: [[fallthrough]];
case ValueType::vtDecimal256:
{
ReadBufferFromString istr(value);
data_type->getDefaultSerialization()->deserializeWholeText(column, istr, FormatSettings{});
break;
}
case ValueType::vtArray:
{
pqxx::array_parser parser{value};
std::pair<pqxx::array_parser::juncture, std::string> parsed = parser.get_next();
size_t dimension = 0, max_dimension = 0, expected_dimensions = array_info[idx].num_dimensions;
const auto parse_value = array_info[idx].pqxx_parser;
std::vector<Row> dimensions(expected_dimensions + 1);
while (parsed.first != pqxx::array_parser::juncture::done)
{
if ((parsed.first == pqxx::array_parser::juncture::row_start) && (++dimension > expected_dimensions))
throw Exception("Got more dimensions than expected", ErrorCodes::BAD_ARGUMENTS);
else if (parsed.first == pqxx::array_parser::juncture::string_value)
dimensions[dimension].emplace_back(parse_value(parsed.second));
else if (parsed.first == pqxx::array_parser::juncture::null_value)
dimensions[dimension].emplace_back(array_info[idx].default_value);
else if (parsed.first == pqxx::array_parser::juncture::row_end)
{
max_dimension = std::max(max_dimension, dimension);
--dimension;
if (dimension == 0)
break;
dimensions[dimension].emplace_back(Array(dimensions[dimension + 1].begin(), dimensions[dimension + 1].end()));
dimensions[dimension + 1].clear();
}
parsed = parser.get_next();
}
if (max_dimension < expected_dimensions)
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"Got less dimensions than expected. ({} instead of {})", max_dimension, expected_dimensions);
assert_cast<ColumnArray &>(column).insert(Array(dimensions[1].begin(), dimensions[1].end()));
break;
}
}
}
void PostgreSQLBlockInputStream::prepareArrayInfo(size_t column_idx, const DataTypePtr data_type)
{
const auto * array_type = typeid_cast<const DataTypeArray *>(data_type.get());
auto nested = array_type->getNestedType();
size_t count_dimensions = 1;
while (isArray(nested))
{
++count_dimensions;
nested = typeid_cast<const DataTypeArray *>(nested.get())->getNestedType();
}
Field default_value = nested->getDefault();
if (nested->isNullable())
nested = static_cast<const DataTypeNullable *>(nested.get())->getNestedType();
WhichDataType which(nested);
std::function<Field(std::string & fields)> parser;
if (which.isUInt8() || which.isUInt16())
parser = [](std::string & field) -> Field { return pqxx::from_string<uint16_t>(field); };
else if (which.isInt8() || which.isInt16())
parser = [](std::string & field) -> Field { return pqxx::from_string<int16_t>(field); };
else if (which.isUInt32())
parser = [](std::string & field) -> Field { return pqxx::from_string<uint32_t>(field); };
else if (which.isInt32())
parser = [](std::string & field) -> Field { return pqxx::from_string<int32_t>(field); };
else if (which.isUInt64())
parser = [](std::string & field) -> Field { return pqxx::from_string<uint64_t>(field); };
else if (which.isInt64())
parser = [](std::string & field) -> Field { return pqxx::from_string<int64_t>(field); };
else if (which.isFloat32())
parser = [](std::string & field) -> Field { return pqxx::from_string<float>(field); };
else if (which.isFloat64())
parser = [](std::string & field) -> Field { return pqxx::from_string<double>(field); };
else if (which.isString() || which.isFixedString())
parser = [](std::string & field) -> Field { return field; };
else if (which.isDate())
parser = [](std::string & field) -> Field { return UInt16{LocalDate{field}.getDayNum()}; };
else if (which.isDateTime())
parser = [nested](std::string & field) -> Field
{
ReadBufferFromString in(field);
time_t time = 0;
readDateTimeText(time, in, assert_cast<const DataTypeDateTime *>(nested.get())->getTimeZone());
return time;
};
else if (which.isDecimal32())
parser = [nested](std::string & field) -> Field
{
const auto & type = typeid_cast<const DataTypeDecimal<Decimal32> *>(nested.get());
DataTypeDecimal<Decimal32> res(getDecimalPrecision(*type), getDecimalScale(*type));
return convertFieldToType(field, res);
};
else if (which.isDecimal64())
parser = [nested](std::string & field) -> Field
{
const auto & type = typeid_cast<const DataTypeDecimal<Decimal64> *>(nested.get());
DataTypeDecimal<Decimal64> res(getDecimalPrecision(*type), getDecimalScale(*type));
return convertFieldToType(field, res);
};
else if (which.isDecimal128())
parser = [nested](std::string & field) -> Field
{
const auto & type = typeid_cast<const DataTypeDecimal<Decimal128> *>(nested.get());
DataTypeDecimal<Decimal128> res(getDecimalPrecision(*type), getDecimalScale(*type));
return convertFieldToType(field, res);
};
else if (which.isDecimal256())
parser = [nested](std::string & field) -> Field
{
const auto & type = typeid_cast<const DataTypeDecimal<Decimal256> *>(nested.get());
DataTypeDecimal<Decimal256> res(getDecimalPrecision(*type), getDecimalScale(*type));
return convertFieldToType(field, res);
};
else
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Type conversion to {} is not supported", nested->getName());
array_info[column_idx] = {count_dimensions, default_value, parser};
}
} }

View File

@ -9,54 +9,76 @@
#include <DataStreams/IBlockInputStream.h> #include <DataStreams/IBlockInputStream.h>
#include <Core/ExternalResultDescription.h> #include <Core/ExternalResultDescription.h>
#include <Core/Field.h> #include <Core/Field.h>
#include <Storages/PostgreSQL/ConnectionHolder.h> #include <Core/PostgreSQL/insertPostgreSQLValue.h>
#include <Core/PostgreSQL/ConnectionHolder.h>
#include <Core/PostgreSQL/Utils.h>
namespace DB namespace DB
{ {
template <typename T = pqxx::ReadTransaction>
class PostgreSQLBlockInputStream : public IBlockInputStream class PostgreSQLBlockInputStream : public IBlockInputStream
{ {
public: public:
PostgreSQLBlockInputStream( PostgreSQLBlockInputStream(
postgres::ConnectionHolderPtr connection_holder_, postgres::ConnectionHolderPtr connection_holder_,
const std::string & query_str, const String & query_str_,
const Block & sample_block, const Block & sample_block,
const UInt64 max_block_size_); const UInt64 max_block_size_);
String getName() const override { return "PostgreSQL"; } String getName() const override { return "PostgreSQL"; }
Block getHeader() const override { return description.sample_block.cloneEmpty(); } Block getHeader() const override { return description.sample_block.cloneEmpty(); }
private:
using ValueType = ExternalResultDescription::ValueType;
void readPrefix() override; void readPrefix() override;
protected:
PostgreSQLBlockInputStream(
std::shared_ptr<T> tx_,
const std::string & query_str_,
const Block & sample_block,
const UInt64 max_block_size_,
bool auto_commit_);
String query_str;
std::shared_ptr<T> tx;
std::unique_ptr<pqxx::stream_from> stream;
private:
Block readImpl() override; Block readImpl() override;
void readSuffix() override; void readSuffix() override;
void insertValue(IColumn & column, std::string_view value, void init(const Block & sample_block);
const ExternalResultDescription::ValueType type, const DataTypePtr data_type, size_t idx);
void insertDefaultValue(IColumn & column, const IColumn & sample_column)
{
column.insertFrom(sample_column, 0);
}
void prepareArrayInfo(size_t column_idx, const DataTypePtr data_type);
String query_str;
const UInt64 max_block_size; const UInt64 max_block_size;
bool auto_commit = true;
ExternalResultDescription description; ExternalResultDescription description;
postgres::ConnectionHolderPtr connection_holder; postgres::ConnectionHolderPtr connection_holder;
std::unique_ptr<pqxx::read_transaction> tx;
std::unique_ptr<pqxx::stream_from> stream;
struct ArrayInfo std::unordered_map<size_t, PostgreSQLArrayInfo> array_info;
};
/// Passes transaction object into PostgreSQLBlockInputStream and does not close transaction after read is finished.
template <typename T>
class PostgreSQLTransactionBlockInputStream : public PostgreSQLBlockInputStream<T>
{
public:
using Base = PostgreSQLBlockInputStream<T>;
PostgreSQLTransactionBlockInputStream(
std::shared_ptr<T> tx_,
const std::string & query_str_,
const Block & sample_block_,
const UInt64 max_block_size_)
: PostgreSQLBlockInputStream<T>(tx_, query_str_, sample_block_, max_block_size_, false) {}
void readPrefix() override
{ {
size_t num_dimensions; Base::stream = std::make_unique<pqxx::stream_from>(*Base::tx, pqxx::from_query, std::string_view(Base::query_str));
Field default_value; }
std::function<Field(std::string & field)> pqxx_parser;
};
std::unordered_map<size_t, ArrayInfo> array_info;
}; };
} }

View File

@ -33,7 +33,8 @@ TTLAggregationAlgorithm::TTLAggregationAlgorithm(
Aggregator::Params params(header, keys, aggregates, Aggregator::Params params(header, keys, aggregates,
false, settings.max_rows_to_group_by, settings.group_by_overflow_mode, 0, 0, false, settings.max_rows_to_group_by, settings.group_by_overflow_mode, 0, 0,
settings.max_bytes_before_external_group_by, settings.empty_result_for_aggregation_by_empty_set, settings.max_bytes_before_external_group_by, settings.empty_result_for_aggregation_by_empty_set,
storage_.getContext()->getTemporaryVolume(), settings.max_threads, settings.min_free_disk_space_for_temporary_data); storage_.getContext()->getTemporaryVolume(), settings.max_threads, settings.min_free_disk_space_for_temporary_data,
settings.compile_aggregate_expressions, settings.min_count_to_compile_aggregate_expression);
aggregator = std::make_unique<Aggregator>(params); aggregator = std::make_unique<Aggregator>(params);
} }

View File

@ -61,6 +61,44 @@ static inline llvm::Type * toNativeType(llvm::IRBuilderBase & builder, const IDa
return nullptr; return nullptr;
} }
template <typename ToType>
static inline llvm::Type * toNativeType(llvm::IRBuilderBase & builder)
{
if constexpr (std::is_same_v<ToType, Int8> || std::is_same_v<ToType, UInt8>)
return builder.getInt8Ty();
else if constexpr (std::is_same_v<ToType, Int16> || std::is_same_v<ToType, UInt16>)
return builder.getInt16Ty();
else if constexpr (std::is_same_v<ToType, Int32> || std::is_same_v<ToType, UInt32>)
return builder.getInt32Ty();
else if constexpr (std::is_same_v<ToType, Int64> || std::is_same_v<ToType, UInt64>)
return builder.getInt64Ty();
else if constexpr (std::is_same_v<ToType, Float32>)
return builder.getFloatTy();
else if constexpr (std::is_same_v<ToType, Float64>)
return builder.getDoubleTy();
return nullptr;
}
template <typename Type>
static inline bool canBeNativeType()
{
if constexpr (std::is_same_v<Type, Int8> || std::is_same_v<Type, UInt8>)
return true;
else if constexpr (std::is_same_v<Type, Int16> || std::is_same_v<Type, UInt16>)
return true;
else if constexpr (std::is_same_v<Type, Int32> || std::is_same_v<Type, UInt32>)
return true;
else if constexpr (std::is_same_v<Type, Int64> || std::is_same_v<Type, UInt64>)
return true;
else if constexpr (std::is_same_v<Type, Float32>)
return true;
else if constexpr (std::is_same_v<Type, Float64>)
return true;
return false;
}
static inline bool canBeNativeType(const IDataType & type) static inline bool canBeNativeType(const IDataType & type)
{ {
WhichDataType data_type(type); WhichDataType data_type(type);
@ -79,40 +117,62 @@ static inline llvm::Type * toNativeType(llvm::IRBuilderBase & builder, const Dat
return toNativeType(builder, *type); return toNativeType(builder, *type);
} }
static inline llvm::Value * nativeBoolCast(llvm::IRBuilder<> & b, const DataTypePtr & from, llvm::Value * value) static inline llvm::Value * nativeBoolCast(llvm::IRBuilder<> & b, const DataTypePtr & from_type, llvm::Value * value)
{ {
if (from->isNullable()) if (from_type->isNullable())
{ {
auto * inner = nativeBoolCast(b, removeNullable(from), b.CreateExtractValue(value, {0})); auto * inner = nativeBoolCast(b, removeNullable(from_type), b.CreateExtractValue(value, {0}));
return b.CreateAnd(b.CreateNot(b.CreateExtractValue(value, {1})), inner); return b.CreateAnd(b.CreateNot(b.CreateExtractValue(value, {1})), inner);
} }
auto * zero = llvm::Constant::getNullValue(value->getType()); auto * zero = llvm::Constant::getNullValue(value->getType());
if (value->getType()->isIntegerTy()) if (value->getType()->isIntegerTy())
return b.CreateICmpNE(value, zero); return b.CreateICmpNE(value, zero);
if (value->getType()->isFloatingPointTy()) if (value->getType()->isFloatingPointTy())
return b.CreateFCmpONE(value, zero); /// QNaN is false return b.CreateFCmpONE(value, zero); /// QNaN is false
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Cannot cast non-number {} to bool", from->getName()); throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Cannot cast non-number {} to bool", from_type->getName());
} }
static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, const DataTypePtr & from, llvm::Value * value, llvm::Type * to) static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, const DataTypePtr & from, llvm::Value * value, llvm::Type * to_type)
{ {
auto * n_from = value->getType(); auto * from_type = value->getType();
if (n_from == to) if (from_type == to_type)
return value; return value;
else if (n_from->isIntegerTy() && to->isFloatingPointTy()) else if (from_type->isIntegerTy() && to_type->isFloatingPointTy())
return typeIsSigned(*from) ? b.CreateSIToFP(value, to) : b.CreateUIToFP(value, to); return typeIsSigned(*from) ? b.CreateSIToFP(value, to_type) : b.CreateUIToFP(value, to_type);
else if (n_from->isFloatingPointTy() && to->isIntegerTy()) else if (from_type->isFloatingPointTy() && to_type->isIntegerTy())
return typeIsSigned(*from) ? b.CreateFPToSI(value, to) : b.CreateFPToUI(value, to); return typeIsSigned(*from) ? b.CreateFPToSI(value, to_type) : b.CreateFPToUI(value, to_type);
else if (n_from->isIntegerTy() && to->isIntegerTy()) else if (from_type->isIntegerTy() && to_type->isIntegerTy())
return b.CreateIntCast(value, to, typeIsSigned(*from)); return b.CreateIntCast(value, to_type, typeIsSigned(*from));
else if (n_from->isFloatingPointTy() && to->isFloatingPointTy()) else if (from_type->isFloatingPointTy() && to_type->isFloatingPointTy())
return b.CreateFPCast(value, to); return b.CreateFPCast(value, to_type);
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Cannot cast {} to requested type", from->getName()); throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Cannot cast {} to requested type", from->getName());
} }
template <typename FromType>
static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, llvm::Value * value, llvm::Type * to_type)
{
auto * from_type = value->getType();
static constexpr bool from_type_is_signed = std::numeric_limits<FromType>::is_signed;
if (from_type == to_type)
return value;
else if (from_type->isIntegerTy() && to_type->isFloatingPointTy())
return from_type_is_signed ? b.CreateSIToFP(value, to_type) : b.CreateUIToFP(value, to_type);
else if (from_type->isFloatingPointTy() && to_type->isIntegerTy())
return from_type_is_signed ? b.CreateFPToSI(value, to_type) : b.CreateFPToUI(value, to_type);
else if (from_type->isIntegerTy() && to_type->isIntegerTy())
return b.CreateIntCast(value, to_type, from_type_is_signed);
else if (from_type->isFloatingPointTy() && to_type->isFloatingPointTy())
return b.CreateFPCast(value, to_type);
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Cannot cast {} to requested type", TypeName<FromType>);
}
static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, const DataTypePtr & from, llvm::Value * value, const DataTypePtr & to) static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, const DataTypePtr & from, llvm::Value * value, const DataTypePtr & to)
{ {
auto * n_to = toNativeType(b, to); auto * n_to = toNativeType(b, to);
@ -139,6 +199,37 @@ static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, const DataTypePtr
return nativeCast(b, from, value, n_to); return nativeCast(b, from, value, n_to);
} }
static inline std::pair<llvm::Value *, llvm::Value *> nativeCastToCommon(llvm::IRBuilder<> & b, const DataTypePtr & lhs_type, llvm::Value * lhs, const DataTypePtr & rhs_type, llvm::Value * rhs)
{
llvm::Type * common;
bool lhs_is_signed = typeIsSigned(*lhs_type);
bool rhs_is_signed = typeIsSigned(*rhs_type);
if (lhs->getType()->isIntegerTy() && rhs->getType()->isIntegerTy())
{
/// if one integer has a sign bit, make sure the other does as well. llvm generates optimal code
/// (e.g. uses overflow flag on x86) for (word size + 1)-bit integer operations.
size_t lhs_bit_width = lhs->getType()->getIntegerBitWidth() + (!lhs_is_signed && rhs_is_signed);
size_t rhs_bit_width = rhs->getType()->getIntegerBitWidth() + (!rhs_is_signed && lhs_is_signed);
size_t max_bit_width = std::max(lhs_bit_width, rhs_bit_width);
common = b.getIntNTy(max_bit_width);
}
else
{
/// TODO: Check
/// (double, float) or (double, int_N where N <= double's mantissa width) -> double
common = b.getDoubleTy();
}
auto * cast_lhs_to_common = nativeCast(b, lhs_type, lhs, common);
auto * cast_rhs_to_common = nativeCast(b, rhs_type, rhs, common);
return std::make_pair(cast_lhs_to_common, cast_rhs_to_common);
}
static inline llvm::Constant * getColumnNativeValue(llvm::IRBuilderBase & builder, const DataTypePtr & column_type, const IColumn & column, size_t index) static inline llvm::Constant * getColumnNativeValue(llvm::IRBuilderBase & builder, const DataTypePtr & column_type, const IColumn & column, size_t index)
{ {
if (const auto * constant = typeid_cast<const ColumnConst *>(&column)) if (const auto * constant = typeid_cast<const ColumnConst *>(&column))

View File

@ -109,12 +109,11 @@ StoragePtr DatabaseAtomic::detachTable(const String & name)
void DatabaseAtomic::dropTable(ContextPtr local_context, const String & table_name, bool no_delay) void DatabaseAtomic::dropTable(ContextPtr local_context, const String & table_name, bool no_delay)
{ {
if (auto * mv = dynamic_cast<StorageMaterializedView *>(tryGetTable(table_name, local_context).get())) auto storage = tryGetTable(table_name, local_context);
{ /// Remove the inner table (if any) to avoid deadlock
/// Remove the inner table (if any) to avoid deadlock /// (due to attempt to execute DROP from the worker thread)
/// (due to attempt to execute DROP from the worker thread) if (storage)
mv->dropInnerTable(no_delay, local_context); storage->dropInnerTableIfAny(no_delay, local_context);
}
String table_metadata_path = getObjectMetadataPath(table_name); String table_metadata_path = getObjectMetadataPath(table_name);
String table_metadata_path_drop; String table_metadata_path_drop;
@ -568,4 +567,3 @@ void DatabaseAtomic::checkDetachedTableNotInUse(const UUID & uuid)
} }
} }

View File

@ -36,7 +36,8 @@
#if USE_LIBPQXX #if USE_LIBPQXX
#include <Databases/PostgreSQL/DatabasePostgreSQL.h> // Y_IGNORE #include <Databases/PostgreSQL/DatabasePostgreSQL.h> // Y_IGNORE
#include <Storages/PostgreSQL/PoolWithFailover.h> #include <Databases/PostgreSQL/DatabaseMaterializedPostgreSQL.h>
#include <Storages/PostgreSQL/MaterializedPostgreSQLSettings.h>
#endif #endif
namespace fs = std::filesystem; namespace fs = std::filesystem;
@ -99,14 +100,14 @@ DatabasePtr DatabaseFactory::getImpl(const ASTCreateQuery & create, const String
const UUID & uuid = create.uuid; const UUID & uuid = create.uuid;
bool engine_may_have_arguments = engine_name == "MySQL" || engine_name == "MaterializeMySQL" || engine_name == "Lazy" || bool engine_may_have_arguments = engine_name == "MySQL" || engine_name == "MaterializeMySQL" || engine_name == "Lazy" ||
engine_name == "Replicated" || engine_name == "PostgreSQL"; engine_name == "Replicated" || engine_name == "PostgreSQL" || engine_name == "MaterializedPostgreSQL";
if (engine_define->engine->arguments && !engine_may_have_arguments) if (engine_define->engine->arguments && !engine_may_have_arguments)
throw Exception("Database engine " + engine_name + " cannot have arguments", ErrorCodes::BAD_ARGUMENTS); throw Exception("Database engine " + engine_name + " cannot have arguments", ErrorCodes::BAD_ARGUMENTS);
bool has_unexpected_element = engine_define->engine->parameters || engine_define->partition_by || bool has_unexpected_element = engine_define->engine->parameters || engine_define->partition_by ||
engine_define->primary_key || engine_define->order_by || engine_define->primary_key || engine_define->order_by ||
engine_define->sample_by; engine_define->sample_by;
bool may_have_settings = endsWith(engine_name, "MySQL") || engine_name == "Replicated"; bool may_have_settings = endsWith(engine_name, "MySQL") || engine_name == "Replicated" || engine_name == "MaterializedPostgreSQL";
if (has_unexpected_element || (!may_have_settings && engine_define->settings)) if (has_unexpected_element || (!may_have_settings && engine_define->settings))
throw Exception("Database engine " + engine_name + " cannot have parameters, primary_key, order_by, sample_by, settings", throw Exception("Database engine " + engine_name + " cannot have parameters, primary_key, order_by, sample_by, settings",
ErrorCodes::UNKNOWN_ELEMENT_IN_AST); ErrorCodes::UNKNOWN_ELEMENT_IN_AST);
@ -262,6 +263,41 @@ DatabasePtr DatabaseFactory::getImpl(const ASTCreateQuery & create, const String
return std::make_shared<DatabasePostgreSQL>( return std::make_shared<DatabasePostgreSQL>(
context, metadata_path, engine_define, database_name, postgres_database_name, connection_pool, use_table_cache); context, metadata_path, engine_define, database_name, postgres_database_name, connection_pool, use_table_cache);
} }
else if (engine_name == "MaterializedPostgreSQL")
{
const ASTFunction * engine = engine_define->engine;
if (!engine->arguments || engine->arguments->children.size() != 4)
{
throw Exception(
fmt::format("{} Database require host:port, database_name, username, password arguments ", engine_name),
ErrorCodes::BAD_ARGUMENTS);
}
ASTs & engine_args = engine->arguments->children;
for (auto & engine_arg : engine_args)
engine_arg = evaluateConstantExpressionOrIdentifierAsLiteral(engine_arg, context);
const auto & host_port = safeGetLiteralValue<String>(engine_args[0], engine_name);
const auto & postgres_database_name = safeGetLiteralValue<String>(engine_args[1], engine_name);
const auto & username = safeGetLiteralValue<String>(engine_args[2], engine_name);
const auto & password = safeGetLiteralValue<String>(engine_args[3], engine_name);
auto parsed_host_port = parseAddress(host_port, 5432);
auto connection_info = postgres::formatConnectionString(postgres_database_name, parsed_host_port.first, parsed_host_port.second, username, password);
auto postgresql_replica_settings = std::make_unique<MaterializedPostgreSQLSettings>();
if (engine_define->settings)
postgresql_replica_settings->loadFromQuery(*engine_define);
return std::make_shared<DatabaseMaterializedPostgreSQL>(
context, metadata_path, uuid, engine_define,
database_name, postgres_database_name, connection_info,
std::move(postgresql_replica_settings));
}
#endif #endif

View File

@ -0,0 +1,212 @@
#include <Databases/PostgreSQL/DatabaseMaterializedPostgreSQL.h>
#if USE_LIBPQXX
#include <Storages/PostgreSQL/StorageMaterializedPostgreSQL.h>
#include <Databases/PostgreSQL/fetchPostgreSQLTableStructure.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeArray.h>
#include <Databases/DatabaseOrdinary.h>
#include <Databases/DatabaseAtomic.h>
#include <Storages/StoragePostgreSQL.h>
#include <Interpreters/Context.h>
#include <Parsers/ASTCreateQuery.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ParserCreateQuery.h>
#include <Parsers/parseQuery.h>
#include <Parsers/queryToString.h>
#include <Common/escapeForFileName.h>
#include <Poco/DirectoryIterator.h>
#include <Poco/File.h>
#include <Common/Macros.h>
#include <common/logger_useful.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NOT_IMPLEMENTED;
extern const int LOGICAL_ERROR;
}
DatabaseMaterializedPostgreSQL::DatabaseMaterializedPostgreSQL(
ContextPtr context_,
const String & metadata_path_,
UUID uuid_,
const ASTStorage * database_engine_define_,
const String & database_name_,
const String & postgres_database_name,
const postgres::ConnectionInfo & connection_info_,
std::unique_ptr<MaterializedPostgreSQLSettings> settings_)
: DatabaseAtomic(database_name_, metadata_path_, uuid_, "DatabaseMaterializedPostgreSQL (" + database_name_ + ")", context_)
, database_engine_define(database_engine_define_->clone())
, remote_database_name(postgres_database_name)
, connection_info(connection_info_)
, settings(std::move(settings_))
{
}
void DatabaseMaterializedPostgreSQL::startSynchronization()
{
replication_handler = std::make_unique<PostgreSQLReplicationHandler>(
/* replication_identifier */database_name,
remote_database_name,
database_name,
connection_info,
getContext(),
settings->materialized_postgresql_max_block_size.value,
settings->materialized_postgresql_allow_automatic_update,
/* is_materialized_postgresql_database = */ true,
settings->materialized_postgresql_tables_list.value);
postgres::Connection connection(connection_info);
NameSet tables_to_replicate;
try
{
tables_to_replicate = replication_handler->fetchRequiredTables(connection);
}
catch (...)
{
LOG_ERROR(log, "Unable to load replicated tables list");
throw;
}
if (tables_to_replicate.empty())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Got empty list of tables to replicate");
for (const auto & table_name : tables_to_replicate)
{
/// Check nested ReplacingMergeTree table.
auto storage = DatabaseAtomic::tryGetTable(table_name, getContext());
if (storage)
{
/// Nested table was already created and synchronized.
storage = StorageMaterializedPostgreSQL::create(storage, getContext());
}
else
{
/// Nested table does not exist and will be created by replication thread.
storage = StorageMaterializedPostgreSQL::create(StorageID(database_name, table_name), getContext());
}
/// Cache MaterializedPostgreSQL wrapper over nested table.
materialized_tables[table_name] = storage;
/// Let replication thread know, which tables it needs to keep in sync.
replication_handler->addStorage(table_name, storage->as<StorageMaterializedPostgreSQL>());
}
LOG_TRACE(log, "Loaded {} tables. Starting synchronization", materialized_tables.size());
replication_handler->startup();
}
void DatabaseMaterializedPostgreSQL::loadStoredObjects(ContextMutablePtr local_context, bool has_force_restore_data_flag, bool force_attach)
{
DatabaseAtomic::loadStoredObjects(local_context, has_force_restore_data_flag, force_attach);
try
{
startSynchronization();
}
catch (...)
{
tryLogCurrentException(log, "Cannot load nested database objects for PostgreSQL database engine.");
if (!force_attach)
throw;
}
}
StoragePtr DatabaseMaterializedPostgreSQL::tryGetTable(const String & name, ContextPtr local_context) const
{
/// In otder to define which table access is needed - to MaterializedPostgreSQL table (only in case of SELECT queries) or
/// to its nested ReplacingMergeTree table (in all other cases), the context of a query os modified.
/// Also if materialzied_tables set is empty - it means all access is done to ReplacingMergeTree tables - it is a case after
/// replication_handler was shutdown.
if (local_context->isInternalQuery() || materialized_tables.empty())
{
return DatabaseAtomic::tryGetTable(name, local_context);
}
/// Note: In select query we call MaterializedPostgreSQL table and it calls tryGetTable from its nested.
/// So the only point, where synchronization is needed - access to MaterializedPostgreSQL table wrapper over nested table.
std::lock_guard lock(tables_mutex);
auto table = materialized_tables.find(name);
/// Return wrapper over ReplacingMergeTree table. If table synchronization just started, table will not
/// be accessible immediately. Table is considered to exist once its nested table was created.
if (table != materialized_tables.end() && table->second->as <StorageMaterializedPostgreSQL>()->hasNested())
{
return table->second;
}
return StoragePtr{};
}
void DatabaseMaterializedPostgreSQL::createTable(ContextPtr local_context, const String & table_name, const StoragePtr & table, const ASTPtr & query)
{
/// Create table query can only be called from replication thread.
if (local_context->isInternalQuery())
{
DatabaseAtomic::createTable(local_context, table_name, table, query);
return;
}
throw Exception(ErrorCodes::NOT_IMPLEMENTED,
"Create table query allowed only for ReplacingMergeTree engine and from synchronization thread");
}
void DatabaseMaterializedPostgreSQL::shutdown()
{
stopReplication();
DatabaseAtomic::shutdown();
}
void DatabaseMaterializedPostgreSQL::stopReplication()
{
if (replication_handler)
replication_handler->shutdown();
/// Clear wrappers over nested, all access is not done to nested tables directly.
materialized_tables.clear();
}
void DatabaseMaterializedPostgreSQL::dropTable(ContextPtr local_context, const String & table_name, bool no_delay)
{
/// Modify context into nested_context and pass query to Atomic database.
DatabaseAtomic::dropTable(StorageMaterializedPostgreSQL::makeNestedTableContext(local_context), table_name, no_delay);
}
void DatabaseMaterializedPostgreSQL::drop(ContextPtr local_context)
{
if (replication_handler)
replication_handler->shutdownFinal();
DatabaseAtomic::drop(StorageMaterializedPostgreSQL::makeNestedTableContext(local_context));
}
DatabaseTablesIteratorPtr DatabaseMaterializedPostgreSQL::getTablesIterator(
ContextPtr local_context, const DatabaseOnDisk::FilterByNameFunction & filter_by_table_name)
{
/// Modify context into nested_context and pass query to Atomic database.
return DatabaseAtomic::getTablesIterator(StorageMaterializedPostgreSQL::makeNestedTableContext(local_context), filter_by_table_name);
}
}
#endif

View File

@ -0,0 +1,77 @@
#pragma once
#if !defined(ARCADIA_BUILD)
#include "config_core.h"
#endif
#if USE_LIBPQXX
#include <Storages/PostgreSQL/PostgreSQLReplicationHandler.h>
#include <Storages/PostgreSQL/MaterializedPostgreSQLSettings.h>
#include <Databases/DatabasesCommon.h>
#include <Core/BackgroundSchedulePool.h>
#include <Parsers/ASTCreateQuery.h>
#include <Databases/IDatabase.h>
#include <Databases/DatabaseOnDisk.h>
#include <Databases/DatabaseAtomic.h>
namespace DB
{
class PostgreSQLConnection;
using PostgreSQLConnectionPtr = std::shared_ptr<PostgreSQLConnection>;
class DatabaseMaterializedPostgreSQL : public DatabaseAtomic
{
public:
DatabaseMaterializedPostgreSQL(
ContextPtr context_,
const String & metadata_path_,
UUID uuid_,
const ASTStorage * database_engine_define_,
const String & database_name_,
const String & postgres_database_name,
const postgres::ConnectionInfo & connection_info,
std::unique_ptr<MaterializedPostgreSQLSettings> settings_);
String getEngineName() const override { return "MaterializedPostgreSQL"; }
String getMetadataPath() const override { return metadata_path; }
void loadStoredObjects(ContextMutablePtr, bool, bool force_attach) override;
DatabaseTablesIteratorPtr getTablesIterator(
ContextPtr context, const DatabaseOnDisk::FilterByNameFunction & filter_by_table_name) override;
StoragePtr tryGetTable(const String & name, ContextPtr context) const override;
void createTable(ContextPtr context, const String & name, const StoragePtr & table, const ASTPtr & query) override;
void dropTable(ContextPtr local_context, const String & name, bool no_delay) override;
void drop(ContextPtr local_context) override;
void stopReplication();
void shutdown() override;
private:
void startSynchronization();
ASTPtr database_engine_define;
String remote_database_name;
postgres::ConnectionInfo connection_info;
std::unique_ptr<MaterializedPostgreSQLSettings> settings;
std::shared_ptr<PostgreSQLReplicationHandler> replication_handler;
std::map<std::string, StoragePtr> materialized_tables;
mutable std::mutex tables_mutex;
};
}
#endif

View File

@ -40,14 +40,14 @@ DatabasePostgreSQL::DatabasePostgreSQL(
const ASTStorage * database_engine_define_, const ASTStorage * database_engine_define_,
const String & dbname_, const String & dbname_,
const String & postgres_dbname, const String & postgres_dbname,
postgres::PoolWithFailoverPtr connection_pool_, postgres::PoolWithFailoverPtr pool_,
const bool cache_tables_) bool cache_tables_)
: IDatabase(dbname_) : IDatabase(dbname_)
, WithContext(context_->getGlobalContext()) , WithContext(context_->getGlobalContext())
, metadata_path(metadata_path_) , metadata_path(metadata_path_)
, database_engine_define(database_engine_define_->clone()) , database_engine_define(database_engine_define_->clone())
, dbname(postgres_dbname) , dbname(postgres_dbname)
, connection_pool(std::move(connection_pool_)) , pool(std::move(pool_))
, cache_tables(cache_tables_) , cache_tables(cache_tables_)
{ {
cleaner_task = getContext()->getSchedulePool().createTask("PostgreSQLCleanerTask", [this]{ removeOutdatedTables(); }); cleaner_task = getContext()->getSchedulePool().createTask("PostgreSQLCleanerTask", [this]{ removeOutdatedTables(); });
@ -59,7 +59,8 @@ bool DatabasePostgreSQL::empty() const
{ {
std::lock_guard<std::mutex> lock(mutex); std::lock_guard<std::mutex> lock(mutex);
auto tables_list = fetchTablesList(); auto connection_holder = pool->get();
auto tables_list = fetchPostgreSQLTablesList(connection_holder->get());
for (const auto & table_name : tables_list) for (const auto & table_name : tables_list)
if (!detached_or_dropped.count(table_name)) if (!detached_or_dropped.count(table_name))
@ -74,7 +75,8 @@ DatabaseTablesIteratorPtr DatabasePostgreSQL::getTablesIterator(ContextPtr local
std::lock_guard<std::mutex> lock(mutex); std::lock_guard<std::mutex> lock(mutex);
Tables tables; Tables tables;
auto table_names = fetchTablesList(); auto connection_holder = pool->get();
auto table_names = fetchPostgreSQLTablesList(connection_holder->get());
for (const auto & table_name : table_names) for (const auto & table_name : table_names)
if (!detached_or_dropped.count(table_name)) if (!detached_or_dropped.count(table_name))
@ -84,21 +86,6 @@ DatabaseTablesIteratorPtr DatabasePostgreSQL::getTablesIterator(ContextPtr local
} }
std::unordered_set<std::string> DatabasePostgreSQL::fetchTablesList() const
{
std::unordered_set<std::string> tables;
std::string query = "SELECT tablename FROM pg_catalog.pg_tables "
"WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'";
auto connection_holder = connection_pool->get();
pqxx::read_transaction tx(connection_holder->get());
for (auto table_name : tx.stream<std::string>(query))
tables.insert(std::get<0>(table_name));
return tables;
}
bool DatabasePostgreSQL::checkPostgresTable(const String & table_name) const bool DatabasePostgreSQL::checkPostgresTable(const String & table_name) const
{ {
if (table_name.find('\'') != std::string::npos if (table_name.find('\'') != std::string::npos
@ -108,7 +95,7 @@ bool DatabasePostgreSQL::checkPostgresTable(const String & table_name) const
"PostgreSQL table name cannot contain single quote or backslash characters, passed {}", table_name); "PostgreSQL table name cannot contain single quote or backslash characters, passed {}", table_name);
} }
auto connection_holder = connection_pool->get(); auto connection_holder = pool->get();
pqxx::nontransaction tx(connection_holder->get()); pqxx::nontransaction tx(connection_holder->get());
try try
@ -163,20 +150,15 @@ StoragePtr DatabasePostgreSQL::fetchTable(const String & table_name, ContextPtr
if (!table_checked && !checkPostgresTable(table_name)) if (!table_checked && !checkPostgresTable(table_name))
return StoragePtr{}; return StoragePtr{};
auto use_nulls = local_context->getSettingsRef().external_table_functions_use_nulls; auto connection_holder = pool->get();
auto columns = fetchPostgreSQLTableStructure(connection_pool->get(), doubleQuoteString(table_name), use_nulls); auto columns = fetchPostgreSQLTableStructure(connection_holder->get(), doubleQuoteString(table_name)).columns;
if (!columns) if (!columns)
return StoragePtr{}; return StoragePtr{};
auto storage = StoragePostgreSQL::create( auto storage = StoragePostgreSQL::create(
StorageID(database_name, table_name), StorageID(database_name, table_name), pool, table_name,
connection_pool, ColumnsDescription{*columns}, ConstraintsDescription{}, String{}, local_context);
table_name,
ColumnsDescription{*columns},
ConstraintsDescription{},
String{},
local_context);
if (cache_tables) if (cache_tables)
cached_tables[table_name] = storage; cached_tables[table_name] = storage;
@ -298,7 +280,8 @@ void DatabasePostgreSQL::loadStoredObjects(ContextMutablePtr /* context */, bool
void DatabasePostgreSQL::removeOutdatedTables() void DatabasePostgreSQL::removeOutdatedTables()
{ {
std::lock_guard<std::mutex> lock{mutex}; std::lock_guard<std::mutex> lock{mutex};
auto actual_tables = fetchTablesList(); auto connection_holder = pool->get();
auto actual_tables = fetchPostgreSQLTablesList(connection_holder->get());
if (cache_tables) if (cache_tables)
{ {

View File

@ -9,7 +9,7 @@
#include <Databases/DatabasesCommon.h> #include <Databases/DatabasesCommon.h>
#include <Core/BackgroundSchedulePool.h> #include <Core/BackgroundSchedulePool.h>
#include <Parsers/ASTCreateQuery.h> #include <Parsers/ASTCreateQuery.h>
#include <Storages/PostgreSQL/PoolWithFailover.h> #include <Core/PostgreSQL/PoolWithFailover.h>
namespace DB namespace DB
@ -33,7 +33,7 @@ public:
const ASTStorage * database_engine_define, const ASTStorage * database_engine_define,
const String & dbname_, const String & dbname_,
const String & postgres_dbname, const String & postgres_dbname,
postgres::PoolWithFailoverPtr connection_pool_, postgres::PoolWithFailoverPtr pool_,
bool cache_tables_); bool cache_tables_);
String getEngineName() const override { return "PostgreSQL"; } String getEngineName() const override { return "PostgreSQL"; }
@ -70,7 +70,7 @@ private:
String metadata_path; String metadata_path;
ASTPtr database_engine_define; ASTPtr database_engine_define;
String dbname; String dbname;
postgres::PoolWithFailoverPtr connection_pool; postgres::PoolWithFailoverPtr pool;
const bool cache_tables; const bool cache_tables;
mutable Tables cached_tables; mutable Tables cached_tables;
@ -78,9 +78,11 @@ private:
BackgroundSchedulePool::TaskHolder cleaner_task; BackgroundSchedulePool::TaskHolder cleaner_task;
bool checkPostgresTable(const String & table_name) const; bool checkPostgresTable(const String & table_name) const;
std::unordered_set<std::string> fetchTablesList() const;
StoragePtr fetchTable(const String & table_name, ContextPtr context, bool table_checked) const; StoragePtr fetchTable(const String & table_name, ContextPtr context, const bool table_checked) const;
void removeOutdatedTables(); void removeOutdatedTables();
ASTPtr getColumnDeclaration(const DataTypePtr & data_type) const; ASTPtr getColumnDeclaration(const DataTypePtr & data_type) const;
}; };

View File

@ -12,7 +12,8 @@
#include <DataTypes/DataTypeDateTime.h> #include <DataTypes/DataTypeDateTime.h>
#include <boost/algorithm/string/split.hpp> #include <boost/algorithm/string/split.hpp>
#include <boost/algorithm/string/trim.hpp> #include <boost/algorithm/string/trim.hpp>
#include <pqxx/pqxx> #include <Common/quoteString.h>
#include <Core/PostgreSQL/Utils.h>
namespace DB namespace DB
@ -25,7 +26,21 @@ namespace ErrorCodes
} }
static DataTypePtr convertPostgreSQLDataType(String & type, bool is_nullable, uint16_t dimensions, const std::function<void()> & recheck_array) template<typename T>
std::unordered_set<std::string> fetchPostgreSQLTablesList(T & tx)
{
std::unordered_set<std::string> tables;
std::string query = "SELECT tablename FROM pg_catalog.pg_tables "
"WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'";
for (auto table_name : tx.template stream<std::string>(query))
tables.insert(std::get<0>(table_name));
return tables;
}
static DataTypePtr convertPostgreSQLDataType(String & type, const std::function<void()> & recheck_array, bool is_nullable = false, uint16_t dimensions = 0)
{ {
DataTypePtr res; DataTypePtr res;
bool is_array = false; bool is_array = false;
@ -116,52 +131,51 @@ static DataTypePtr convertPostgreSQLDataType(String & type, bool is_nullable, ui
} }
std::shared_ptr<NamesAndTypesList> fetchPostgreSQLTableStructure( template<typename T>
postgres::ConnectionHolderPtr connection_holder, const String & postgres_table_name, bool use_nulls) std::shared_ptr<NamesAndTypesList> readNamesAndTypesList(
T & tx, const String & postgres_table_name, const String & query, bool use_nulls, bool only_names_and_types)
{ {
auto columns = NamesAndTypes(); auto columns = NamesAndTypes();
if (postgres_table_name.find('\'') != std::string::npos
|| postgres_table_name.find('\\') != std::string::npos)
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "PostgreSQL table name cannot contain single quote or backslash characters, passed {}",
postgres_table_name);
}
std::string query = fmt::format(
"SELECT attname AS name, format_type(atttypid, atttypmod) AS type, "
"attnotnull AS not_null, attndims AS dims "
"FROM pg_attribute "
"WHERE attrelid = '{}'::regclass "
"AND NOT attisdropped AND attnum > 0", postgres_table_name);
try try
{ {
std::set<size_t> recheck_arrays_indexes; std::set<size_t> recheck_arrays_indexes;
{ {
pqxx::read_transaction tx(connection_holder->get());
auto stream{pqxx::stream_from::query(tx, query)}; auto stream{pqxx::stream_from::query(tx, query)};
std::tuple<std::string, std::string, std::string, uint16_t> row;
size_t i = 0; size_t i = 0;
auto recheck_array = [&]() { recheck_arrays_indexes.insert(i); }; auto recheck_array = [&]() { recheck_arrays_indexes.insert(i); };
while (stream >> row)
if (only_names_and_types)
{ {
auto data_type = convertPostgreSQLDataType(std::get<1>(row), std::tuple<std::string, std::string> row;
use_nulls && (std::get<2>(row) == "f"), /// 'f' means that postgres `not_null` is false, i.e. value is nullable while (stream >> row)
std::get<3>(row), {
recheck_array); columns.push_back(NameAndTypePair(std::get<0>(row), convertPostgreSQLDataType(std::get<1>(row), recheck_array)));
columns.push_back(NameAndTypePair(std::get<0>(row), data_type)); ++i;
++i; }
} }
else
{
std::tuple<std::string, std::string, std::string, uint16_t> row;
while (stream >> row)
{
auto data_type = convertPostgreSQLDataType(std::get<1>(row),
recheck_array,
use_nulls && (std::get<2>(row) == "f"), /// 'f' means that postgres `not_null` is false, i.e. value is nullable
std::get<3>(row));
columns.push_back(NameAndTypePair(std::get<0>(row), data_type));
++i;
}
}
stream.complete(); stream.complete();
tx.commit();
} }
for (const auto & i : recheck_arrays_indexes) for (const auto & i : recheck_arrays_indexes)
{ {
const auto & name_and_type = columns[i]; const auto & name_and_type = columns[i];
pqxx::nontransaction tx(connection_holder->get());
/// All rows must contain the same number of dimensions, so limit 1 is ok. If number of dimensions in all rows is not the same - /// All rows must contain the same number of dimensions, so limit 1 is ok. If number of dimensions in all rows is not the same -
/// such arrays are not able to be used as ClickHouse Array at all. /// such arrays are not able to be used as ClickHouse Array at all.
pqxx::result result{tx.exec(fmt::format("SELECT array_ndims({}) FROM {} LIMIT 1", name_and_type.name, postgres_table_name))}; pqxx::result result{tx.exec(fmt::format("SELECT array_ndims({}) FROM {} LIMIT 1", name_and_type.name, postgres_table_name))};
@ -178,9 +192,7 @@ std::shared_ptr<NamesAndTypesList> fetchPostgreSQLTableStructure(
catch (const pqxx::undefined_table &) catch (const pqxx::undefined_table &)
{ {
throw Exception(fmt::format( throw Exception(ErrorCodes::UNKNOWN_TABLE, "PostgreSQL table {} does not exist", postgres_table_name);
"PostgreSQL table {}.{} does not exist",
connection_holder->get().dbname(), postgres_table_name), ErrorCodes::UNKNOWN_TABLE);
} }
catch (Exception & e) catch (Exception & e)
{ {
@ -188,12 +200,101 @@ std::shared_ptr<NamesAndTypesList> fetchPostgreSQLTableStructure(
throw; throw;
} }
if (columns.empty()) return !columns.empty() ? std::make_shared<NamesAndTypesList>(columns.begin(), columns.end()) : nullptr;
return nullptr;
return std::make_shared<NamesAndTypesList>(NamesAndTypesList(columns.begin(), columns.end()));
} }
template<typename T>
PostgreSQLTableStructure fetchPostgreSQLTableStructure(
T & tx, const String & postgres_table_name, bool use_nulls, bool with_primary_key, bool with_replica_identity_index)
{
PostgreSQLTableStructure table;
std::string query = fmt::format(
"SELECT attname AS name, format_type(atttypid, atttypmod) AS type, "
"attnotnull AS not_null, attndims AS dims "
"FROM pg_attribute "
"WHERE attrelid = {}::regclass "
"AND NOT attisdropped AND attnum > 0", quoteString(postgres_table_name));
table.columns = readNamesAndTypesList(tx, postgres_table_name, query, use_nulls, false);
if (with_primary_key)
{
/// wiki.postgresql.org/wiki/Retrieve_primary_key_columns
query = fmt::format(
"SELECT a.attname, format_type(a.atttypid, a.atttypmod) AS data_type "
"FROM pg_index i "
"JOIN pg_attribute a ON a.attrelid = i.indrelid "
"AND a.attnum = ANY(i.indkey) "
"WHERE i.indrelid = {}::regclass AND i.indisprimary", quoteString(postgres_table_name));
table.primary_key_columns = readNamesAndTypesList(tx, postgres_table_name, query, use_nulls, true);
}
if (with_replica_identity_index && !table.primary_key_columns)
{
query = fmt::format(
"SELECT "
"a.attname AS column_name, " /// column name
"format_type(a.atttypid, a.atttypmod) as type " /// column type
"FROM "
"pg_class t, "
"pg_class i, "
"pg_index ix, "
"pg_attribute a "
"WHERE "
"t.oid = ix.indrelid "
"and i.oid = ix.indexrelid "
"and a.attrelid = t.oid "
"and a.attnum = ANY(ix.indkey) "
"and t.relkind = 'r' " /// simple tables
"and t.relname = {} " /// Connection is already done to a needed database, only table name is needed.
"and ix.indisreplident = 't' " /// index is is replica identity index
"ORDER BY a.attname", /// column names
quoteString(postgres_table_name));
table.replica_identity_columns = readNamesAndTypesList(tx, postgres_table_name, query, use_nulls, true);
}
return table;
}
PostgreSQLTableStructure fetchPostgreSQLTableStructure(pqxx::connection & connection, const String & postgres_table_name, bool use_nulls)
{
pqxx::ReadTransaction tx(connection);
auto result = fetchPostgreSQLTableStructure(tx, postgres_table_name, use_nulls, false, false);
tx.commit();
return result;
}
std::unordered_set<std::string> fetchPostgreSQLTablesList(pqxx::connection & connection)
{
pqxx::ReadTransaction tx(connection);
auto result = fetchPostgreSQLTablesList(tx);
tx.commit();
return result;
}
template
PostgreSQLTableStructure fetchPostgreSQLTableStructure(
pqxx::ReadTransaction & tx, const String & postgres_table_name, bool use_nulls,
bool with_primary_key, bool with_replica_identity_index);
template
PostgreSQLTableStructure fetchPostgreSQLTableStructure(
pqxx::ReplicationTransaction & tx, const String & postgres_table_name, bool use_nulls,
bool with_primary_key, bool with_replica_identity_index);
template
std::unordered_set<std::string> fetchPostgreSQLTablesList(pqxx::work & tx);
template
std::unordered_set<std::string> fetchPostgreSQLTablesList(pqxx::ReadTransaction & tx);
} }
#endif #endif

View File

@ -5,15 +5,34 @@
#endif #endif
#if USE_LIBPQXX #if USE_LIBPQXX
#include <Storages/PostgreSQL/ConnectionHolder.h> #include <Core/PostgreSQL/ConnectionHolder.h>
#include <Core/NamesAndTypes.h> #include <Core/NamesAndTypes.h>
namespace DB namespace DB
{ {
std::shared_ptr<NamesAndTypesList> fetchPostgreSQLTableStructure( struct PostgreSQLTableStructure
postgres::ConnectionHolderPtr connection_holder, const String & postgres_table_name, bool use_nulls); {
std::shared_ptr<NamesAndTypesList> columns = nullptr;
std::shared_ptr<NamesAndTypesList> primary_key_columns = nullptr;
std::shared_ptr<NamesAndTypesList> replica_identity_columns = nullptr;
};
using PostgreSQLTableStructurePtr = std::unique_ptr<PostgreSQLTableStructure>;
std::unordered_set<std::string> fetchPostgreSQLTablesList(pqxx::connection & connection);
PostgreSQLTableStructure fetchPostgreSQLTableStructure(
pqxx::connection & connection, const String & postgres_table_name, bool use_nulls = true);
template<typename T>
PostgreSQLTableStructure fetchPostgreSQLTableStructure(
T & tx, const String & postgres_table_name, bool use_nulls = true,
bool with_primary_key = false, bool with_replica_identity_index = false);
template<typename T>
std::unordered_set<std::string> fetchPostgreSQLTablesList(T & tx);
} }

View File

@ -107,9 +107,10 @@ BlockInputStreamPtr PostgreSQLDictionarySource::loadKeys(const Columns & key_col
BlockInputStreamPtr PostgreSQLDictionarySource::loadBase(const String & query) BlockInputStreamPtr PostgreSQLDictionarySource::loadBase(const String & query)
{ {
return std::make_shared<PostgreSQLBlockInputStream>(pool->get(), query, sample_block, max_block_size); return std::make_shared<PostgreSQLBlockInputStream<>>(pool->get(), query, sample_block, max_block_size);
} }
bool PostgreSQLDictionarySource::isModified() const bool PostgreSQLDictionarySource::isModified() const
{ {
if (!configuration.invalidate_query.empty()) if (!configuration.invalidate_query.empty())
@ -128,7 +129,7 @@ std::string PostgreSQLDictionarySource::doInvalidateQuery(const std::string & re
Block invalidate_sample_block; Block invalidate_sample_block;
ColumnPtr column(ColumnString::create()); ColumnPtr column(ColumnString::create());
invalidate_sample_block.insert(ColumnWithTypeAndName(column, std::make_shared<DataTypeString>(), "Sample Block")); invalidate_sample_block.insert(ColumnWithTypeAndName(column, std::make_shared<DataTypeString>(), "Sample Block"));
PostgreSQLBlockInputStream block_input_stream(pool->get(), request, invalidate_sample_block, 1); PostgreSQLBlockInputStream<> block_input_stream(pool->get(), request, invalidate_sample_block, 1);
return readInvalidateQuery(block_input_stream); return readInvalidateQuery(block_input_stream);
} }

View File

@ -11,8 +11,7 @@
#include <Core/Block.h> #include <Core/Block.h>
#include <common/LocalDateTime.h> #include <common/LocalDateTime.h>
#include <common/logger_useful.h> #include <common/logger_useful.h>
#include <Storages/PostgreSQL/PoolWithFailover.h> #include <Core/PostgreSQL/PoolWithFailover.h>
#include <pqxx/pqxx>
namespace DB namespace DB

View File

@ -1265,23 +1265,7 @@ public:
assert(2 == types.size() && 2 == values.size()); assert(2 == types.size() && 2 == values.size());
auto & b = static_cast<llvm::IRBuilder<> &>(builder); auto & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * x = values[0]; auto [x, y] = nativeCastToCommon(b, types[0], values[0], types[1], values[1]);
auto * y = values[1];
if (!types[0]->equals(*types[1]))
{
llvm::Type * common;
if (x->getType()->isIntegerTy() && y->getType()->isIntegerTy())
common = b.getIntNTy(std::max(
/// if one integer has a sign bit, make sure the other does as well. llvm generates optimal code
/// (e.g. uses overflow flag on x86) for (word size + 1)-bit integer operations.
x->getType()->getIntegerBitWidth() + (!typeIsSigned(*types[0]) && typeIsSigned(*types[1])),
y->getType()->getIntegerBitWidth() + (!typeIsSigned(*types[1]) && typeIsSigned(*types[0]))));
else
/// (double, float) or (double, int_N where N <= double's mantissa width) -> double
common = b.getDoubleTy();
x = nativeCast(b, types[0], x, common);
y = nativeCast(b, types[1], y, common);
}
auto * result = CompileOp<Op>::compile(b, x, y, typeIsSigned(*types[0]) || typeIsSigned(*types[1])); auto * result = CompileOp<Op>::compile(b, x, y, typeIsSigned(*types[0]) || typeIsSigned(*types[1]));
return b.CreateSelect(result, b.getInt8(1), b.getInt8(0)); return b.CreateSelect(result, b.getInt8(1), b.getInt8(0));
} }

View File

@ -10,14 +10,20 @@
#include <Common/assert_cast.h> #include <Common/assert_cast.h>
#include <Core/Settings.h> #include <Core/Settings.h>
#include <Columns/ColumnConst.h> #include <Columns/ColumnConst.h>
#include <Columns/ColumnLowCardinality.h>
#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnString.h> #include <Columns/ColumnString.h>
#include <Columns/ColumnVector.h> #include <Columns/ColumnVector.h>
#include <Columns/ColumnFixedString.h> #include <Columns/ColumnFixedString.h>
#include <Columns/ColumnNullable.h> #include <Columns/ColumnNullable.h>
#include <Columns/ColumnArray.h> #include <Columns/ColumnArray.h>
#include <Columns/ColumnTuple.h> #include <Columns/ColumnTuple.h>
#include <DataTypes/Serializations/SerializationDecimal.h>
#include <DataTypes/DataTypesNumber.h> #include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <DataTypes/DataTypeString.h> #include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypeUUID.h>
#include <DataTypes/DataTypeEnum.h> #include <DataTypes/DataTypeEnum.h>
#include <DataTypes/DataTypeFactory.h> #include <DataTypes/DataTypeFactory.h>
#include <DataTypes/DataTypeNullable.h> #include <DataTypes/DataTypeNullable.h>
@ -528,6 +534,7 @@ public:
} }
}; };
template <typename JSONParser> template <typename JSONParser>
using JSONExtractInt8Impl = JSONExtractNumericImpl<JSONParser, Int8>; using JSONExtractInt8Impl = JSONExtractNumericImpl<JSONParser, Int8>;
template <typename JSONParser> template <typename JSONParser>
@ -625,6 +632,60 @@ struct JSONExtractTree
} }
}; };
class LowCardinalityNode : public Node
{
public:
LowCardinalityNode(DataTypePtr dictionary_type_, std::unique_ptr<Node> impl_)
: dictionary_type(dictionary_type_), impl(std::move(impl_)) {}
bool insertResultToColumn(IColumn & dest, const Element & element) override
{
auto from_col = dictionary_type->createColumn();
if (impl->insertResultToColumn(*from_col, element))
{
StringRef value = from_col->getDataAt(0);
assert_cast<ColumnLowCardinality &>(dest).insertData(value.data, value.size);
return true;
}
return false;
}
private:
DataTypePtr dictionary_type;
std::unique_ptr<Node> impl;
};
class UUIDNode : public Node
{
public:
bool insertResultToColumn(IColumn & dest, const Element & element) override
{
if (!element.isString())
return false;
auto uuid = parseFromString<UUID>(element.getString());
assert_cast<ColumnUUID &>(dest).insert(uuid);
return true;
}
};
template <typename DecimalType>
class DecimalNode : public Node
{
public:
DecimalNode(DataTypePtr data_type_) : data_type(data_type_) {}
bool insertResultToColumn(IColumn & dest, const Element & element) override
{
if (!element.isDouble())
return false;
const auto * type = assert_cast<const DataTypeDecimal<DecimalType> *>(data_type.get());
auto result = convertToDecimal<DataTypeNumber<Float64>, DataTypeDecimal<DecimalType>>(element.getDouble(), type->getScale());
assert_cast<ColumnDecimal<DecimalType> &>(dest).insert(result);
return true;
}
private:
DataTypePtr data_type;
};
class StringNode : public Node class StringNode : public Node
{ {
public: public:
@ -864,6 +925,17 @@ struct JSONExtractTree
case TypeIndex::Float64: return std::make_unique<NumericNode<Float64>>(); case TypeIndex::Float64: return std::make_unique<NumericNode<Float64>>();
case TypeIndex::String: return std::make_unique<StringNode>(); case TypeIndex::String: return std::make_unique<StringNode>();
case TypeIndex::FixedString: return std::make_unique<FixedStringNode>(); case TypeIndex::FixedString: return std::make_unique<FixedStringNode>();
case TypeIndex::UUID: return std::make_unique<UUIDNode>();
case TypeIndex::LowCardinality:
{
auto dictionary_type = typeid_cast<const DataTypeLowCardinality *>(type.get())->getDictionaryType();
auto impl = build(function_name, dictionary_type);
return std::make_unique<LowCardinalityNode>(dictionary_type, std::move(impl));
}
case TypeIndex::Decimal256: return std::make_unique<DecimalNode<Decimal256>>(type);
case TypeIndex::Decimal128: return std::make_unique<DecimalNode<Decimal128>>(type);
case TypeIndex::Decimal64: return std::make_unique<DecimalNode<Decimal64>>(type);
case TypeIndex::Decimal32: return std::make_unique<DecimalNode<Decimal32>>(type);
case TypeIndex::Enum8: case TypeIndex::Enum8:
return std::make_unique<EnumNode<Int8>>(static_cast<const DataTypeEnum8 &>(*type).getValues()); return std::make_unique<EnumNode<Int8>>(static_cast<const DataTypeEnum8 &>(*type).getValues());
case TypeIndex::Enum16: case TypeIndex::Enum16:

View File

@ -8,7 +8,6 @@
#include <Functions/materialize.h> #include <Functions/materialize.h>
#include <Functions/FunctionsLogical.h> #include <Functions/FunctionsLogical.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Interpreters/ExpressionJIT.h>
#include <IO/WriteBufferFromString.h> #include <IO/WriteBufferFromString.h>
#include <IO/Operators.h> #include <IO/Operators.h>

View File

@ -27,8 +27,6 @@ using FunctionOverloadResolverPtr = std::shared_ptr<IFunctionOverloadResolver>;
class IDataType; class IDataType;
using DataTypePtr = std::shared_ptr<const IDataType>; using DataTypePtr = std::shared_ptr<const IDataType>;
class CompiledExpressionCache;
namespace JSONBuilder namespace JSONBuilder
{ {
class JSONMap; class JSONMap;

View File

@ -1,5 +1,6 @@
#include <future> #include <future>
#include <Poco/Util/Application.h> #include <Poco/Util/Application.h>
#include <Common/Stopwatch.h> #include <Common/Stopwatch.h>
#include <Common/setThreadName.h> #include <Common/setThreadName.h>
#include <Common/formatReadable.h> #include <Common/formatReadable.h>
@ -21,6 +22,8 @@
#include <AggregateFunctions/AggregateFunctionArray.h> #include <AggregateFunctions/AggregateFunctionArray.h>
#include <AggregateFunctions/AggregateFunctionState.h> #include <AggregateFunctions/AggregateFunctionState.h>
#include <IO/Operators.h> #include <IO/Operators.h>
#include <Interpreters/JIT/compileFunction.h>
#include <Interpreters/JIT/CompiledExpressionCache.h>
namespace ProfileEvents namespace ProfileEvents
@ -211,6 +214,32 @@ void Aggregator::Params::explain(JSONBuilder::JSONMap & map) const
} }
} }
#if USE_EMBEDDED_COMPILER
static CHJIT & getJITInstance()
{
static CHJIT jit;
return jit;
}
class CompiledAggregateFunctionsHolder final : public CompiledExpressionCacheEntry
{
public:
explicit CompiledAggregateFunctionsHolder(CompiledAggregateFunctions compiled_function_)
: CompiledExpressionCacheEntry(compiled_function_.compiled_module.size)
, compiled_aggregate_functions(compiled_function_)
{}
~CompiledAggregateFunctionsHolder() override
{
getJITInstance().deleteCompiledModule(compiled_aggregate_functions.compiled_module);
}
CompiledAggregateFunctions compiled_aggregate_functions;
};
#endif
Aggregator::Aggregator(const Params & params_) Aggregator::Aggregator(const Params & params_)
: params(params_) : params(params_)
{ {
@ -262,8 +291,93 @@ Aggregator::Aggregator(const Params & params_)
HashMethodContext::Settings cache_settings; HashMethodContext::Settings cache_settings;
cache_settings.max_threads = params.max_threads; cache_settings.max_threads = params.max_threads;
aggregation_state_cache = AggregatedDataVariants::createCache(method_chosen, cache_settings); aggregation_state_cache = AggregatedDataVariants::createCache(method_chosen, cache_settings);
#if USE_EMBEDDED_COMPILER
compileAggregateFunctions();
#endif
} }
#if USE_EMBEDDED_COMPILER
void Aggregator::compileAggregateFunctions()
{
static std::unordered_map<UInt128, UInt64, UInt128Hash> aggregate_functions_description_to_count;
static std::mutex mtx;
if (!params.compile_aggregate_expressions)
return;
std::vector<AggregateFunctionWithOffset> functions_to_compile;
size_t aggregate_instructions_size = 0;
String functions_description;
is_aggregate_function_compiled.resize(aggregate_functions.size());
/// Add values to the aggregate functions.
for (size_t i = 0; i < aggregate_functions.size(); ++i)
{
const auto * function = aggregate_functions[i];
size_t offset_of_aggregate_function = offsets_of_aggregate_states[i];
if (function->isCompilable())
{
AggregateFunctionWithOffset function_to_compile
{
.function = function,
.aggregate_data_offset = offset_of_aggregate_function
};
functions_to_compile.emplace_back(std::move(function_to_compile));
functions_description += function->getDescription();
functions_description += ' ';
functions_description += std::to_string(offset_of_aggregate_function);
functions_description += ' ';
}
++aggregate_instructions_size;
is_aggregate_function_compiled[i] = function->isCompilable();
}
if (functions_to_compile.empty())
return;
SipHash aggregate_functions_description_hash;
aggregate_functions_description_hash.update(functions_description);
UInt128 aggregate_functions_description_hash_key;
aggregate_functions_description_hash.get128(aggregate_functions_description_hash_key);
{
std::lock_guard<std::mutex> lock(mtx);
if (aggregate_functions_description_to_count[aggregate_functions_description_hash_key]++ < params.min_count_to_compile_aggregate_expression)
return;
if (auto * compilation_cache = CompiledExpressionCacheFactory::instance().tryGetCache())
{
auto [compiled_function_cache_entry, _] = compilation_cache->getOrSet(aggregate_functions_description_hash_key, [&] ()
{
LOG_TRACE(log, "Compile expression {}", functions_description);
auto compiled_aggregate_functions = compileAggregateFunctons(getJITInstance(), functions_to_compile, functions_description);
return std::make_shared<CompiledAggregateFunctionsHolder>(std::move(compiled_aggregate_functions));
});
compiled_aggregate_functions_holder = std::static_pointer_cast<CompiledAggregateFunctionsHolder>(compiled_function_cache_entry);
}
else
{
LOG_TRACE(log, "Compile expression {}", functions_description);
auto compiled_aggregate_functions = compileAggregateFunctons(getJITInstance(), functions_to_compile, functions_description);
compiled_aggregate_functions_holder = std::make_shared<CompiledAggregateFunctionsHolder>(std::move(compiled_aggregate_functions));
}
}
}
#endif
AggregatedDataVariants::Type Aggregator::chooseAggregationMethod() AggregatedDataVariants::Type Aggregator::chooseAggregationMethod()
{ {
@ -431,11 +545,15 @@ AggregatedDataVariants::Type Aggregator::chooseAggregationMethod()
return AggregatedDataVariants::Type::serialized; return AggregatedDataVariants::Type::serialized;
} }
template <bool skip_compiled_aggregate_functions>
void Aggregator::createAggregateStates(AggregateDataPtr & aggregate_data) const void Aggregator::createAggregateStates(AggregateDataPtr & aggregate_data) const
{ {
for (size_t j = 0; j < params.aggregates_size; ++j) for (size_t j = 0; j < params.aggregates_size; ++j)
{ {
if constexpr (skip_compiled_aggregate_functions)
if (is_aggregate_function_compiled[j])
continue;
try try
{ {
/** An exception may occur if there is a shortage of memory. /** An exception may occur if there is a shortage of memory.
@ -447,14 +565,19 @@ void Aggregator::createAggregateStates(AggregateDataPtr & aggregate_data) const
catch (...) catch (...)
{ {
for (size_t rollback_j = 0; rollback_j < j; ++rollback_j) for (size_t rollback_j = 0; rollback_j < j; ++rollback_j)
{
if constexpr (skip_compiled_aggregate_functions)
if (is_aggregate_function_compiled[j])
continue;
aggregate_functions[rollback_j]->destroy(aggregate_data + offsets_of_aggregate_states[rollback_j]); aggregate_functions[rollback_j]->destroy(aggregate_data + offsets_of_aggregate_states[rollback_j]);
}
throw; throw;
} }
} }
} }
/** It's interesting - if you remove `noinline`, then gcc for some reason will inline this function, and the performance decreases (~ 10%). /** It's interesting - if you remove `noinline`, then gcc for some reason will inline this function, and the performance decreases (~ 10%).
* (Probably because after the inline of this function, more internal functions no longer be inlined.) * (Probably because after the inline of this function, more internal functions no longer be inlined.)
* Inline does not make sense, since the inner loop is entirely inside this function. * Inline does not make sense, since the inner loop is entirely inside this function.
@ -472,13 +595,25 @@ void NO_INLINE Aggregator::executeImpl(
typename Method::State state(key_columns, key_sizes, aggregation_state_cache); typename Method::State state(key_columns, key_sizes, aggregation_state_cache);
if (!no_more_keys) if (!no_more_keys)
executeImplBatch<false>(method, state, aggregates_pool, rows, aggregate_instructions, overflow_row); {
#if USE_EMBEDDED_COMPILER
if (compiled_aggregate_functions_holder)
{
executeImplBatch<false, true>(method, state, aggregates_pool, rows, aggregate_instructions, overflow_row);
}
else
#endif
{
executeImplBatch<false, false>(method, state, aggregates_pool, rows, aggregate_instructions, overflow_row);
}
}
else else
executeImplBatch<true>(method, state, aggregates_pool, rows, aggregate_instructions, overflow_row); {
executeImplBatch<true, false>(method, state, aggregates_pool, rows, aggregate_instructions, overflow_row);
}
} }
template <bool no_more_keys, bool use_compiled_functions, typename Method>
template <bool no_more_keys, typename Method>
void NO_INLINE Aggregator::executeImplBatch( void NO_INLINE Aggregator::executeImplBatch(
Method & method, Method & method,
typename Method::State & state, typename Method::State & state,
@ -535,8 +670,6 @@ void NO_INLINE Aggregator::executeImplBatch(
} }
} }
/// Generic case.
std::unique_ptr<AggregateDataPtr[]> places(new AggregateDataPtr[rows]); std::unique_ptr<AggregateDataPtr[]> places(new AggregateDataPtr[rows]);
/// For all rows. /// For all rows.
@ -555,7 +688,37 @@ void NO_INLINE Aggregator::executeImplBatch(
emplace_result.setMapped(nullptr); emplace_result.setMapped(nullptr);
aggregate_data = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states); aggregate_data = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states);
createAggregateStates(aggregate_data);
#if USE_EMBEDDED_COMPILER
if constexpr (use_compiled_functions)
{
const auto & compiled_aggregate_functions = compiled_aggregate_functions_holder->compiled_aggregate_functions;
compiled_aggregate_functions.create_aggregate_states_function(aggregate_data);
if (compiled_aggregate_functions.functions_count != aggregate_functions.size())
{
static constexpr bool skip_compiled_aggregate_functions = true;
createAggregateStates<skip_compiled_aggregate_functions>(aggregate_data);
}
#if defined(MEMORY_SANITIZER)
/// We compile only functions that do not allocate some data in Arena. Only store necessary state in AggregateData place.
for (size_t aggregate_function_index = 0; aggregate_function_index < aggregate_functions.size(); ++aggregate_function_index)
{
if (!is_aggregate_function_compiled[aggregate_function_index])
continue;
auto aggregate_data_with_offset = aggregate_data + offsets_of_aggregate_states[aggregate_function_index];
auto data_size = params.aggregates[aggregate_function_index].function->sizeOfData();
__msan_unpoison(aggregate_data_with_offset, data_size);
}
#endif
}
else
#endif
{
createAggregateStates(aggregate_data);
}
emplace_result.setMapped(aggregate_data); emplace_result.setMapped(aggregate_data);
} }
@ -577,9 +740,39 @@ void NO_INLINE Aggregator::executeImplBatch(
places[i] = aggregate_data; places[i] = aggregate_data;
} }
/// Add values to the aggregate functions. #if USE_EMBEDDED_COMPILER
for (AggregateFunctionInstruction * inst = aggregate_instructions; inst->that; ++inst) if constexpr (use_compiled_functions)
{ {
std::vector<ColumnData> columns_data;
for (size_t i = 0; i < aggregate_functions.size(); ++i)
{
if (!is_aggregate_function_compiled[i])
continue;
AggregateFunctionInstruction * inst = aggregate_instructions + i;
size_t arguments_size = inst->that->getArgumentTypes().size();
for (size_t argument_index = 0; argument_index < arguments_size; ++argument_index)
columns_data.emplace_back(getColumnData(inst->batch_arguments[argument_index]));
}
auto add_into_aggregate_states_function = compiled_aggregate_functions_holder->compiled_aggregate_functions.add_into_aggregate_states_function;
add_into_aggregate_states_function(rows, columns_data.data(), places.get());
}
#endif
/// Add values to the aggregate functions.
for (size_t i = 0; i < aggregate_functions.size(); ++i)
{
#if USE_EMBEDDED_COMPILER
if constexpr (use_compiled_functions)
if (is_aggregate_function_compiled[i])
continue;
#endif
AggregateFunctionInstruction * inst = aggregate_instructions + i;
if (inst->offsets) if (inst->offsets)
inst->batch_that->addBatchArray(rows, places.get(), inst->state_offset, inst->batch_arguments, inst->offsets, aggregates_pool); inst->batch_that->addBatchArray(rows, places.get(), inst->state_offset, inst->batch_arguments, inst->offsets, aggregates_pool);
else else
@ -720,6 +913,7 @@ bool Aggregator::executeOnBlock(Columns columns, UInt64 num_rows, AggregatedData
} }
} }
} }
NestedColumnsHolder nested_columns_holder; NestedColumnsHolder nested_columns_holder;
AggregateFunctionInstructions aggregate_functions_instructions; AggregateFunctionInstructions aggregate_functions_instructions;
prepareAggregateInstructions(columns, aggregate_columns, materialized_columns, aggregate_functions_instructions, nested_columns_holder); prepareAggregateInstructions(columns, aggregate_columns, materialized_columns, aggregate_functions_instructions, nested_columns_holder);
@ -1025,9 +1219,23 @@ void Aggregator::convertToBlockImpl(
raw_key_columns.push_back(column.get()); raw_key_columns.push_back(column.get());
if (final) if (final)
convertToBlockImplFinal(method, data, std::move(raw_key_columns), final_aggregate_columns, arena); {
#if USE_EMBEDDED_COMPILER
if (compiled_aggregate_functions_holder)
{
static constexpr bool use_compiled_functions = !Method::low_cardinality_optimization;
convertToBlockImplFinal<Method, use_compiled_functions>(method, data, std::move(raw_key_columns), final_aggregate_columns, arena);
}
else
#endif
{
convertToBlockImplFinal<Method, false>(method, data, std::move(raw_key_columns), final_aggregate_columns, arena);
}
}
else else
{
convertToBlockImplNotFinal(method, data, std::move(raw_key_columns), aggregate_columns); convertToBlockImplNotFinal(method, data, std::move(raw_key_columns), aggregate_columns);
}
/// In order to release memory early. /// In order to release memory early.
data.clearAndShrink(); data.clearAndShrink();
} }
@ -1101,7 +1309,7 @@ inline void Aggregator::insertAggregatesIntoColumns(
} }
template <typename Method, typename Table> template <typename Method, bool use_compiled_functions, typename Table>
void NO_INLINE Aggregator::convertToBlockImplFinal( void NO_INLINE Aggregator::convertToBlockImplFinal(
Method & method, Method & method,
Table & data, Table & data,
@ -1121,11 +1329,98 @@ void NO_INLINE Aggregator::convertToBlockImplFinal(
auto shuffled_key_sizes = method.shuffleKeyColumns(key_columns, key_sizes); auto shuffled_key_sizes = method.shuffleKeyColumns(key_columns, key_sizes);
const auto & key_sizes_ref = shuffled_key_sizes ? *shuffled_key_sizes : key_sizes; const auto & key_sizes_ref = shuffled_key_sizes ? *shuffled_key_sizes : key_sizes;
PaddedPODArray<AggregateDataPtr> places;
places.reserve(data.size());
data.forEachValue([&](const auto & key, auto & mapped) data.forEachValue([&](const auto & key, auto & mapped)
{ {
method.insertKeyIntoColumns(key, key_columns, key_sizes_ref); method.insertKeyIntoColumns(key, key_columns, key_sizes_ref);
insertAggregatesIntoColumns(mapped, final_aggregate_columns, arena); places.emplace_back(mapped);
/// Mark the cell as destroyed so it will not be destroyed in destructor.
mapped = nullptr;
}); });
std::exception_ptr exception;
size_t aggregate_functions_destroy_index = 0;
try
{
#if USE_EMBEDDED_COMPILER
if constexpr (use_compiled_functions)
{
/** For JIT compiled functions we need to resize columns before pass them into compiled code.
* insert_aggregates_into_columns_function function does not throw exception.
*/
std::vector<ColumnData> columns_data;
auto compiled_functions = compiled_aggregate_functions_holder->compiled_aggregate_functions;
for (size_t i = 0; i < params.aggregates_size; ++i)
{
if (!is_aggregate_function_compiled[i])
continue;
auto & final_aggregate_column = final_aggregate_columns[i];
final_aggregate_column = final_aggregate_column->cloneResized(places.size());
columns_data.emplace_back(getColumnData(final_aggregate_column.get()));
}
auto insert_aggregates_into_columns_function = compiled_functions.insert_aggregates_into_columns_function;
insert_aggregates_into_columns_function(places.size(), columns_data.data(), places.data());
}
#endif
for (; aggregate_functions_destroy_index < params.aggregates_size;)
{
if constexpr (use_compiled_functions)
{
if (is_aggregate_function_compiled[aggregate_functions_destroy_index])
{
++aggregate_functions_destroy_index;
continue;
}
}
auto & final_aggregate_column = final_aggregate_columns[aggregate_functions_destroy_index];
size_t offset = offsets_of_aggregate_states[aggregate_functions_destroy_index];
/** We increase aggregate_functions_destroy_index because by function contract if insertResultIntoBatch
* throws exception, it also must destroy all necessary states.
* Then code need to continue to destroy other aggregate function states with next function index.
*/
size_t destroy_index = aggregate_functions_destroy_index;
++aggregate_functions_destroy_index;
/// For State AggregateFunction ownership of aggregate place is passed to result column after insert
bool is_state = aggregate_functions[destroy_index]->isState();
bool destroy_place_after_insert = !is_state;
aggregate_functions[destroy_index]->insertResultIntoBatch(places.size(), places.data(), offset, *final_aggregate_column, arena, destroy_place_after_insert);
}
}
catch (...)
{
exception = std::current_exception();
}
for (; aggregate_functions_destroy_index < params.aggregates_size; ++aggregate_functions_destroy_index)
{
if constexpr (use_compiled_functions)
{
if (is_aggregate_function_compiled[aggregate_functions_destroy_index])
{
++aggregate_functions_destroy_index;
continue;
}
}
size_t offset = offsets_of_aggregate_states[aggregate_functions_destroy_index];
aggregate_functions[aggregate_functions_destroy_index]->destroyBatch(places.size(), places.data(), offset);
}
if (exception)
std::rethrow_exception(exception);
} }
template <typename Method, typename Table> template <typename Method, typename Table>
@ -1545,7 +1840,7 @@ void NO_INLINE Aggregator::mergeDataNullKey(
} }
template <typename Method, typename Table> template <typename Method, bool use_compiled_functions, typename Table>
void NO_INLINE Aggregator::mergeDataImpl( void NO_INLINE Aggregator::mergeDataImpl(
Table & table_dst, Table & table_dst,
Table & table_src, Table & table_src,
@ -1554,19 +1849,40 @@ void NO_INLINE Aggregator::mergeDataImpl(
if constexpr (Method::low_cardinality_optimization) if constexpr (Method::low_cardinality_optimization)
mergeDataNullKey<Method, Table>(table_dst, table_src, arena); mergeDataNullKey<Method, Table>(table_dst, table_src, arena);
table_src.mergeToViaEmplace(table_dst, table_src.mergeToViaEmplace(table_dst, [&](AggregateDataPtr & __restrict dst, AggregateDataPtr & __restrict src, bool inserted)
[&](AggregateDataPtr & __restrict dst, AggregateDataPtr & __restrict src, bool inserted)
{ {
if (!inserted) if (!inserted)
{ {
for (size_t i = 0; i < params.aggregates_size; ++i) #if USE_EMBEDDED_COMPILER
aggregate_functions[i]->merge( if constexpr (use_compiled_functions)
dst + offsets_of_aggregate_states[i], {
src + offsets_of_aggregate_states[i], const auto & compiled_functions = compiled_aggregate_functions_holder->compiled_aggregate_functions;
arena); compiled_functions.merge_aggregate_states_function(dst, src);
for (size_t i = 0; i < params.aggregates_size; ++i) if (compiled_aggregate_functions_holder->compiled_aggregate_functions.functions_count != params.aggregates_size)
aggregate_functions[i]->destroy(src + offsets_of_aggregate_states[i]); {
for (size_t i = 0; i < params.aggregates_size; ++i)
{
if (!is_aggregate_function_compiled[i])
aggregate_functions[i]->merge(dst + offsets_of_aggregate_states[i], src + offsets_of_aggregate_states[i], arena);
}
for (size_t i = 0; i < params.aggregates_size; ++i)
{
if (!is_aggregate_function_compiled[i])
aggregate_functions[i]->destroy(src + offsets_of_aggregate_states[i]);
}
}
}
else
#endif
{
for (size_t i = 0; i < params.aggregates_size; ++i)
aggregate_functions[i]->merge(dst + offsets_of_aggregate_states[i], src + offsets_of_aggregate_states[i], arena);
for (size_t i = 0; i < params.aggregates_size; ++i)
aggregate_functions[i]->destroy(src + offsets_of_aggregate_states[i]);
}
} }
else else
{ {
@ -1575,6 +1891,7 @@ void NO_INLINE Aggregator::mergeDataImpl(
src = nullptr; src = nullptr;
}); });
table_src.clearAndShrink(); table_src.clearAndShrink();
} }
@ -1677,21 +1994,39 @@ void NO_INLINE Aggregator::mergeSingleLevelDataImpl(
AggregatedDataVariants & current = *non_empty_data[result_num]; AggregatedDataVariants & current = *non_empty_data[result_num];
if (!no_more_keys) if (!no_more_keys)
mergeDataImpl<Method>( {
getDataVariant<Method>(*res).data, #if USE_EMBEDDED_COMPILER
getDataVariant<Method>(current).data, if (compiled_aggregate_functions_holder)
res->aggregates_pool); {
mergeDataImpl<Method, true>(
getDataVariant<Method>(*res).data,
getDataVariant<Method>(current).data,
res->aggregates_pool);
}
else
#endif
{
mergeDataImpl<Method, false>(
getDataVariant<Method>(*res).data,
getDataVariant<Method>(current).data,
res->aggregates_pool);
}
}
else if (res->without_key) else if (res->without_key)
{
mergeDataNoMoreKeysImpl<Method>( mergeDataNoMoreKeysImpl<Method>(
getDataVariant<Method>(*res).data, getDataVariant<Method>(*res).data,
res->without_key, res->without_key,
getDataVariant<Method>(current).data, getDataVariant<Method>(current).data,
res->aggregates_pool); res->aggregates_pool);
}
else else
{
mergeDataOnlyExistingKeysImpl<Method>( mergeDataOnlyExistingKeysImpl<Method>(
getDataVariant<Method>(*res).data, getDataVariant<Method>(*res).data,
getDataVariant<Method>(current).data, getDataVariant<Method>(current).data,
res->aggregates_pool); res->aggregates_pool);
}
/// `current` will not destroy the states of aggregate functions in the destructor /// `current` will not destroy the states of aggregate functions in the destructor
current.aggregator = nullptr; current.aggregator = nullptr;
@ -1716,11 +2051,22 @@ void NO_INLINE Aggregator::mergeBucketImpl(
return; return;
AggregatedDataVariants & current = *data[result_num]; AggregatedDataVariants & current = *data[result_num];
#if USE_EMBEDDED_COMPILER
mergeDataImpl<Method>( if (compiled_aggregate_functions_holder)
getDataVariant<Method>(*res).data.impls[bucket], {
getDataVariant<Method>(current).data.impls[bucket], mergeDataImpl<Method, true>(
arena); getDataVariant<Method>(*res).data.impls[bucket],
getDataVariant<Method>(current).data.impls[bucket],
arena);
}
else
#endif
{
mergeDataImpl<Method, false>(
getDataVariant<Method>(*res).data.impls[bucket],
getDataVariant<Method>(current).data.impls[bucket],
arena);
}
} }
} }

View File

@ -26,6 +26,7 @@
#include <Interpreters/AggregateDescription.h> #include <Interpreters/AggregateDescription.h>
#include <Interpreters/AggregationCommon.h> #include <Interpreters/AggregationCommon.h>
#include <Interpreters/JIT/compileFunction.h>
#include <Columns/ColumnString.h> #include <Columns/ColumnString.h>
#include <Columns/ColumnFixedString.h> #include <Columns/ColumnFixedString.h>
@ -851,6 +852,8 @@ using AggregatedDataVariantsPtr = std::shared_ptr<AggregatedDataVariants>;
using ManyAggregatedDataVariants = std::vector<AggregatedDataVariantsPtr>; using ManyAggregatedDataVariants = std::vector<AggregatedDataVariantsPtr>;
using ManyAggregatedDataVariantsPtr = std::shared_ptr<ManyAggregatedDataVariants>; using ManyAggregatedDataVariantsPtr = std::shared_ptr<ManyAggregatedDataVariants>;
class CompiledAggregateFunctionsHolder;
/** How are "total" values calculated with WITH TOTALS? /** How are "total" values calculated with WITH TOTALS?
* (For more details, see TotalsHavingTransform.) * (For more details, see TotalsHavingTransform.)
* *
@ -907,6 +910,10 @@ public:
size_t max_threads; size_t max_threads;
const size_t min_free_disk_space; const size_t min_free_disk_space;
bool compile_aggregate_expressions;
size_t min_count_to_compile_aggregate_expression;
Params( Params(
const Block & src_header_, const Block & src_header_,
const ColumnNumbers & keys_, const AggregateDescriptions & aggregates_, const ColumnNumbers & keys_, const AggregateDescriptions & aggregates_,
@ -916,6 +923,8 @@ public:
bool empty_result_for_aggregation_by_empty_set_, bool empty_result_for_aggregation_by_empty_set_,
VolumePtr tmp_volume_, size_t max_threads_, VolumePtr tmp_volume_, size_t max_threads_,
size_t min_free_disk_space_, size_t min_free_disk_space_,
bool compile_aggregate_expressions_,
size_t min_count_to_compile_aggregate_expression_,
const Block & intermediate_header_ = {}) const Block & intermediate_header_ = {})
: src_header(src_header_), : src_header(src_header_),
intermediate_header(intermediate_header_), intermediate_header(intermediate_header_),
@ -925,14 +934,16 @@ public:
max_bytes_before_external_group_by(max_bytes_before_external_group_by_), max_bytes_before_external_group_by(max_bytes_before_external_group_by_),
empty_result_for_aggregation_by_empty_set(empty_result_for_aggregation_by_empty_set_), empty_result_for_aggregation_by_empty_set(empty_result_for_aggregation_by_empty_set_),
tmp_volume(tmp_volume_), max_threads(max_threads_), tmp_volume(tmp_volume_), max_threads(max_threads_),
min_free_disk_space(min_free_disk_space_) min_free_disk_space(min_free_disk_space_),
compile_aggregate_expressions(compile_aggregate_expressions_),
min_count_to_compile_aggregate_expression(min_count_to_compile_aggregate_expression_)
{ {
} }
/// Only parameters that matter during merge. /// Only parameters that matter during merge.
Params(const Block & intermediate_header_, Params(const Block & intermediate_header_,
const ColumnNumbers & keys_, const AggregateDescriptions & aggregates_, bool overflow_row_, size_t max_threads_) const ColumnNumbers & keys_, const AggregateDescriptions & aggregates_, bool overflow_row_, size_t max_threads_)
: Params(Block(), keys_, aggregates_, overflow_row_, 0, OverflowMode::THROW, 0, 0, 0, false, nullptr, max_threads_, 0) : Params(Block(), keys_, aggregates_, overflow_row_, 0, OverflowMode::THROW, 0, 0, 0, false, nullptr, max_threads_, 0, false, 0)
{ {
intermediate_header = intermediate_header_; intermediate_header = intermediate_header_;
} }
@ -1074,11 +1085,22 @@ private:
/// For external aggregation. /// For external aggregation.
TemporaryFiles temporary_files; TemporaryFiles temporary_files;
#if USE_EMBEDDED_COMPILER
std::shared_ptr<CompiledAggregateFunctionsHolder> compiled_aggregate_functions_holder;
#endif
std::vector<bool> is_aggregate_function_compiled;
/** Try to compile aggregate functions.
*/
void compileAggregateFunctions();
/** Select the aggregation method based on the number and types of keys. */ /** Select the aggregation method based on the number and types of keys. */
AggregatedDataVariants::Type chooseAggregationMethod(); AggregatedDataVariants::Type chooseAggregationMethod();
/** Create states of aggregate functions for one key. /** Create states of aggregate functions for one key.
*/ */
template <bool skip_compiled_aggregate_functions = false>
void createAggregateStates(AggregateDataPtr & aggregate_data) const; void createAggregateStates(AggregateDataPtr & aggregate_data) const;
/** Call `destroy` methods for states of aggregate functions. /** Call `destroy` methods for states of aggregate functions.
@ -1099,7 +1121,7 @@ private:
AggregateDataPtr overflow_row) const; AggregateDataPtr overflow_row) const;
/// Specialization for a particular value no_more_keys. /// Specialization for a particular value no_more_keys.
template <bool no_more_keys, typename Method> template <bool no_more_keys, bool use_compiled_expressions, typename Method>
void executeImplBatch( void executeImplBatch(
Method & method, Method & method,
typename Method::State & state, typename Method::State & state,
@ -1136,7 +1158,7 @@ private:
Arena * arena) const; Arena * arena) const;
/// Merge data from hash table `src` into `dst`. /// Merge data from hash table `src` into `dst`.
template <typename Method, typename Table> template <typename Method, bool use_compiled_functions, typename Table>
void mergeDataImpl( void mergeDataImpl(
Table & table_dst, Table & table_dst,
Table & table_src, Table & table_src,
@ -1180,7 +1202,7 @@ private:
MutableColumns & final_aggregate_columns, MutableColumns & final_aggregate_columns,
Arena * arena) const; Arena * arena) const;
template <typename Method, typename Table> template <typename Method, bool use_compiled_functions, typename Table>
void convertToBlockImplFinal( void convertToBlockImplFinal(
Method & method, Method & method,
Table & data, Table & data,

View File

@ -1,6 +1,6 @@
#include <Interpreters/AsynchronousMetrics.h> #include <Interpreters/AsynchronousMetrics.h>
#include <Interpreters/AsynchronousMetricLog.h> #include <Interpreters/AsynchronousMetricLog.h>
#include <Interpreters/ExpressionJIT.h> #include <Interpreters/JIT/CompiledExpressionCache.h>
#include <Interpreters/DatabaseCatalog.h> #include <Interpreters/DatabaseCatalog.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Common/Exception.h> #include <Common/Exception.h>

View File

@ -45,7 +45,6 @@
#include <Access/SettingsConstraints.h> #include <Access/SettingsConstraints.h>
#include <Access/ExternalAuthenticators.h> #include <Access/ExternalAuthenticators.h>
#include <Access/GSSAcceptor.h> #include <Access/GSSAcceptor.h>
#include <Interpreters/ExpressionJIT.h>
#include <Dictionaries/Embedded/GeoDictionariesLoader.h> #include <Dictionaries/Embedded/GeoDictionariesLoader.h>
#include <Interpreters/EmbeddedDictionaries.h> #include <Interpreters/EmbeddedDictionaries.h>
#include <Interpreters/ExternalDictionariesLoader.h> #include <Interpreters/ExternalDictionariesLoader.h>
@ -74,6 +73,7 @@
#include <common/logger_useful.h> #include <common/logger_useful.h>
#include <Common/RemoteHostFilter.h> #include <Common/RemoteHostFilter.h>
#include <Interpreters/DatabaseCatalog.h> #include <Interpreters/DatabaseCatalog.h>
#include <Interpreters/JIT/CompiledExpressionCache.h>
#include <Storages/MergeTree/BackgroundJobsExecutor.h> #include <Storages/MergeTree/BackgroundJobsExecutor.h>
#include <Storages/MergeTree/MergeTreeDataPartUUID.h> #include <Storages/MergeTree/MergeTreeDataPartUUID.h>
#include <filesystem> #include <filesystem>

View File

@ -267,6 +267,9 @@ private:
/// XXX: move this stuff to shared part instead. /// XXX: move this stuff to shared part instead.
ContextMutablePtr buffer_context; /// Buffer context. Could be equal to this. ContextMutablePtr buffer_context; /// Buffer context. Could be equal to this.
/// A flag, used to distinguish between user query and internal query to a database engine (MaterializePostgreSQL).
bool is_internal_query = false;
public: public:
// Top-level OpenTelemetry trace context for the query. Makes sense only for a query context. // Top-level OpenTelemetry trace context for the query. Makes sense only for a query context.
OpenTelemetryTraceContext query_trace_context; OpenTelemetryTraceContext query_trace_context;
@ -742,6 +745,9 @@ public:
void shutdown(); void shutdown();
bool isInternalQuery() const { return is_internal_query; }
void setInternalQuery(bool internal) { is_internal_query = internal; }
ActionLocksManagerPtr getActionLocksManager(); ActionLocksManagerPtr getActionLocksManager();
enum class ApplicationType enum class ApplicationType

View File

@ -377,8 +377,8 @@ void DDLWorker::scheduleTasks(bool reinitialized)
/// The following message is too verbose, but it can be useful too debug mysterious test failures in CI /// The following message is too verbose, but it can be useful too debug mysterious test failures in CI
LOG_TRACE(log, "scheduleTasks: initialized={}, size_before_filtering={}, queue_size={}, " LOG_TRACE(log, "scheduleTasks: initialized={}, size_before_filtering={}, queue_size={}, "
"entries={}..{}, " "entries={}..{}, "
"first_failed_task_name={}, current_tasks_size={}," "first_failed_task_name={}, current_tasks_size={}, "
"last_current_task={}," "last_current_task={}, "
"last_skipped_entry_name={}", "last_skipped_entry_name={}",
initialized, size_before_filtering, queue_nodes.size(), initialized, size_before_filtering, queue_nodes.size(),
queue_nodes.empty() ? "none" : queue_nodes.front(), queue_nodes.empty() ? "none" : queue_nodes.back(), queue_nodes.empty() ? "none" : queue_nodes.front(), queue_nodes.empty() ? "none" : queue_nodes.back(),

View File

@ -28,6 +28,10 @@
# include <Storages/StorageMaterializeMySQL.h> # include <Storages/StorageMaterializeMySQL.h>
#endif #endif
#if USE_LIBPQXX
# include <Storages/PostgreSQL/StorageMaterializedPostgreSQL.h>
#endif
namespace fs = std::filesystem; namespace fs = std::filesystem;
namespace CurrentMetrics namespace CurrentMetrics
@ -234,6 +238,13 @@ DatabaseAndTable DatabaseCatalog::getTableImpl(
return {}; return {};
} }
#if USE_LIBPQXX
if (!context_->isInternalQuery() && (db_and_table.first->getEngineName() == "MaterializedPostgreSQL"))
{
db_and_table.second = std::make_shared<StorageMaterializedPostgreSQL>(std::move(db_and_table.second), getContext());
}
#endif
#if USE_MYSQL #if USE_MYSQL
/// It's definitely not the best place for this logic, but behaviour must be consistent with DatabaseMaterializeMySQL::tryGetTable(...) /// It's definitely not the best place for this logic, but behaviour must be consistent with DatabaseMaterializeMySQL::tryGetTable(...)
if (db_and_table.first->getEngineName() == "MaterializeMySQL") if (db_and_table.first->getEngineName() == "MaterializeMySQL")
@ -245,6 +256,7 @@ DatabaseAndTable DatabaseCatalog::getTableImpl(
return db_and_table; return db_and_table;
} }
if (table_id.database_name == TEMPORARY_DATABASE) if (table_id.database_name == TEMPORARY_DATABASE)
{ {
/// For temporary tables UUIDs are set in Context::resolveStorageID(...). /// For temporary tables UUIDs are set in Context::resolveStorageID(...).

View File

@ -1,4 +1,6 @@
#include <Interpreters/ExpressionJIT.h> #if !defined(ARCADIA_BUILD)
# include "config_core.h"
#endif
#if USE_EMBEDDED_COMPILER #if USE_EMBEDDED_COMPILER
@ -20,6 +22,7 @@
#include <Interpreters/JIT/CHJIT.h> #include <Interpreters/JIT/CHJIT.h>
#include <Interpreters/JIT/CompileDAG.h> #include <Interpreters/JIT/CompileDAG.h>
#include <Interpreters/JIT/compileFunction.h> #include <Interpreters/JIT/compileFunction.h>
#include <Interpreters/JIT/CompiledExpressionCache.h>
#include <Interpreters/ActionsDAG.h> #include <Interpreters/ActionsDAG.h>
namespace DB namespace DB
@ -42,36 +45,30 @@ static Poco::Logger * getLogger()
return &logger; return &logger;
} }
class CompiledFunction class CompiledFunctionHolder : public CompiledExpressionCacheEntry
{ {
public: public:
CompiledFunction(void * compiled_function_, CHJIT::CompiledModuleInfo module_info_) explicit CompiledFunctionHolder(CompiledFunction compiled_function_)
: compiled_function(compiled_function_) : CompiledExpressionCacheEntry(compiled_function_.compiled_module.size)
, module_info(std::move(module_info_)) , compiled_function(compiled_function_)
{} {}
void * getCompiledFunction() const { return compiled_function; } ~CompiledFunctionHolder() override
~CompiledFunction()
{ {
getJITInstance().deleteCompiledModule(module_info); getJITInstance().deleteCompiledModule(compiled_function.compiled_module);
} }
private: CompiledFunction compiled_function;
void * compiled_function;
CHJIT::CompiledModuleInfo module_info;
}; };
class LLVMExecutableFunction : public IExecutableFunction class LLVMExecutableFunction : public IExecutableFunction
{ {
public: public:
explicit LLVMExecutableFunction(const std::string & name_, std::shared_ptr<CompiledFunction> compiled_function_) explicit LLVMExecutableFunction(const std::string & name_, std::shared_ptr<CompiledFunctionHolder> compiled_function_holder_)
: name(name_) : name(name_)
, compiled_function(compiled_function_) , compiled_function_holder(compiled_function_holder_)
{ {
} }
@ -104,8 +101,8 @@ public:
columns[arguments.size()] = getColumnData(result_column.get()); columns[arguments.size()] = getColumnData(result_column.get());
JITCompiledFunction jit_compiled_function_typed = reinterpret_cast<JITCompiledFunction>(compiled_function->getCompiledFunction()); auto jit_compiled_function = compiled_function_holder->compiled_function.compiled_function;
jit_compiled_function_typed(input_rows_count, columns.data()); jit_compiled_function(input_rows_count, columns.data());
#if defined(MEMORY_SANITIZER) #if defined(MEMORY_SANITIZER)
/// Memory sanitizer don't know about stores from JIT-ed code. /// Memory sanitizer don't know about stores from JIT-ed code.
@ -135,7 +132,7 @@ public:
private: private:
std::string name; std::string name;
std::shared_ptr<CompiledFunction> compiled_function; std::shared_ptr<CompiledFunctionHolder> compiled_function_holder;
}; };
class LLVMFunction : public IFunctionBase class LLVMFunction : public IFunctionBase
@ -157,9 +154,9 @@ public:
} }
} }
void setCompiledFunction(std::shared_ptr<CompiledFunction> compiled_function_) void setCompiledFunction(std::shared_ptr<CompiledFunctionHolder> compiled_function_holder_)
{ {
compiled_function = compiled_function_; compiled_function_holder = compiled_function_holder_;
} }
bool isCompilable() const override { return true; } bool isCompilable() const override { return true; }
@ -177,10 +174,10 @@ public:
ExecutableFunctionPtr prepare(const ColumnsWithTypeAndName &) const override ExecutableFunctionPtr prepare(const ColumnsWithTypeAndName &) const override
{ {
if (!compiled_function) if (!compiled_function_holder)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Compiled function was not initialized {}", name); throw Exception(ErrorCodes::LOGICAL_ERROR, "Compiled function was not initialized {}", name);
return std::make_unique<LLVMExecutableFunction>(name, compiled_function); return std::make_unique<LLVMExecutableFunction>(name, compiled_function_holder);
} }
bool isDeterministic() const override bool isDeterministic() const override
@ -269,7 +266,7 @@ private:
CompileDAG dag; CompileDAG dag;
DataTypes argument_types; DataTypes argument_types;
std::vector<FunctionBasePtr> nested_functions; std::vector<FunctionBasePtr> nested_functions;
std::shared_ptr<CompiledFunction> compiled_function; std::shared_ptr<CompiledFunctionHolder> compiled_function_holder;
}; };
static FunctionBasePtr compile( static FunctionBasePtr compile(
@ -293,22 +290,19 @@ static FunctionBasePtr compile(
auto [compiled_function_cache_entry, _] = compilation_cache->getOrSet(hash_key, [&] () auto [compiled_function_cache_entry, _] = compilation_cache->getOrSet(hash_key, [&] ()
{ {
LOG_TRACE(getLogger(), "Compile expression {}", llvm_function->getName()); LOG_TRACE(getLogger(), "Compile expression {}", llvm_function->getName());
CHJIT::CompiledModuleInfo compiled_module_info = compileFunction(getJITInstance(), *llvm_function); auto compiled_function = compileFunction(getJITInstance(), *llvm_function);
auto * compiled_jit_function = getJITInstance().findCompiledFunction(compiled_module_info, llvm_function->getName()); return std::make_shared<CompiledFunctionHolder>(compiled_function);
auto compiled_function = std::make_shared<CompiledFunction>(compiled_jit_function, compiled_module_info);
return std::make_shared<CompiledFunctionCacheEntry>(std::move(compiled_function), compiled_module_info.size);
}); });
llvm_function->setCompiledFunction(compiled_function_cache_entry->getCompiledFunction()); std::shared_ptr<CompiledFunctionHolder> compiled_function_holder = std::static_pointer_cast<CompiledFunctionHolder>(compiled_function_cache_entry);
llvm_function->setCompiledFunction(std::move(compiled_function_holder));
} }
else else
{ {
LOG_TRACE(getLogger(), "Compile expression {}", llvm_function->getName()); auto compiled_function = compileFunction(getJITInstance(), *llvm_function);
CHJIT::CompiledModuleInfo compiled_module_info = compileFunction(getJITInstance(), *llvm_function); auto compiled_function_holder = std::make_shared<CompiledFunctionHolder>(compiled_function);
auto * compiled_jit_function = getJITInstance().findCompiledFunction(compiled_module_info, llvm_function->getName());
auto compiled_function = std::make_shared<CompiledFunction>(compiled_jit_function, compiled_module_info); llvm_function->setCompiledFunction(std::move(compiled_function_holder));
llvm_function->setCompiledFunction(compiled_function);
} }
return llvm_function; return llvm_function;
@ -577,25 +571,6 @@ void ActionsDAG::compileFunctions(size_t min_count_to_compile_expression)
} }
} }
CompiledExpressionCacheFactory & CompiledExpressionCacheFactory::instance()
{
static CompiledExpressionCacheFactory factory;
return factory;
}
void CompiledExpressionCacheFactory::init(size_t cache_size)
{
if (cache)
throw Exception(ErrorCodes::LOGICAL_ERROR, "CompiledExpressionCache was already initialized");
cache = std::make_unique<CompiledExpressionCache>(cache_size);
}
CompiledExpressionCache * CompiledExpressionCacheFactory::tryGetCache()
{
return cache.get();
}
} }
#endif #endif

View File

@ -1,66 +0,0 @@
#pragma once
#if !defined(ARCADIA_BUILD)
# include "config_core.h"
#endif
#if USE_EMBEDDED_COMPILER
# include <Common/LRUCache.h>
# include <Common/HashTable/Hash.h>
namespace DB
{
class CompiledFunction;
class CompiledFunctionCacheEntry
{
public:
CompiledFunctionCacheEntry(std::shared_ptr<CompiledFunction> compiled_function_, size_t compiled_function_size_)
: compiled_function(std::move(compiled_function_))
, compiled_function_size(compiled_function_size_)
{}
std::shared_ptr<CompiledFunction> getCompiledFunction() const { return compiled_function; }
size_t getCompiledFunctionSize() const { return compiled_function_size; }
private:
std::shared_ptr<CompiledFunction> compiled_function;
size_t compiled_function_size;
};
struct CompiledFunctionWeightFunction
{
size_t operator()(const CompiledFunctionCacheEntry & compiled_function) const
{
return compiled_function.getCompiledFunctionSize();
}
};
/** This child of LRUCache breaks one of it's invariants: total weight may be changed after insertion.
* We have to do so, because we don't known real memory consumption of generated LLVM code for every function.
*/
class CompiledExpressionCache : public LRUCache<UInt128, CompiledFunctionCacheEntry, UInt128Hash, CompiledFunctionWeightFunction>
{
public:
using Base = LRUCache<UInt128, CompiledFunctionCacheEntry, UInt128Hash, CompiledFunctionWeightFunction>;
using Base::Base;
};
class CompiledExpressionCacheFactory
{
private:
std::unique_ptr<CompiledExpressionCache> cache;
public:
static CompiledExpressionCacheFactory & instance();
void init(size_t cache_size);
CompiledExpressionCache * tryGetCache();
};
}
#endif

View File

@ -151,7 +151,7 @@ BlockIO InterpreterCreateQuery::createDatabase(ASTCreateQuery & create)
throw Exception(ErrorCodes::UNKNOWN_DATABASE_ENGINE, "Unknown database engine: {}", serializeAST(*create.storage)); throw Exception(ErrorCodes::UNKNOWN_DATABASE_ENGINE, "Unknown database engine: {}", serializeAST(*create.storage));
} }
if (create.storage->engine->name == "Atomic" || create.storage->engine->name == "Replicated") if (create.storage->engine->name == "Atomic" || create.storage->engine->name == "Replicated" || create.storage->engine->name == "MaterializedPostgreSQL")
{ {
if (create.attach && create.uuid == UUIDHelpers::Nil) if (create.attach && create.uuid == UUIDHelpers::Nil)
throw Exception(ErrorCodes::INCORRECT_QUERY, "UUID must be specified for ATTACH. " throw Exception(ErrorCodes::INCORRECT_QUERY, "UUID must be specified for ATTACH. "
@ -217,6 +217,12 @@ BlockIO InterpreterCreateQuery::createDatabase(ASTCreateQuery & create)
"Enable allow_experimental_database_replicated to use it.", ErrorCodes::UNKNOWN_DATABASE_ENGINE); "Enable allow_experimental_database_replicated to use it.", ErrorCodes::UNKNOWN_DATABASE_ENGINE);
} }
if (create.storage->engine->name == "MaterializedPostgreSQL" && !getContext()->getSettingsRef().allow_experimental_database_materialized_postgresql && !internal)
{
throw Exception("MaterializedPostgreSQL is an experimental database engine. "
"Enable allow_experimental_database_postgresql_replica to use it.", ErrorCodes::UNKNOWN_DATABASE_ENGINE);
}
DatabasePtr database = DatabaseFactory::get(create, metadata_path / "", getContext()); DatabasePtr database = DatabaseFactory::get(create, metadata_path / "", getContext());
if (create.uuid != UUIDHelpers::Nil) if (create.uuid != UUIDHelpers::Nil)

View File

@ -20,6 +20,9 @@
# include <Databases/MySQL/DatabaseMaterializeMySQL.h> # include <Databases/MySQL/DatabaseMaterializeMySQL.h>
#endif #endif
#if USE_LIBPQXX
# include <Databases/PostgreSQL/DatabaseMaterializedPostgreSQL.h>
#endif
namespace DB namespace DB
{ {
@ -317,6 +320,10 @@ BlockIO InterpreterDropQuery::executeToDatabaseImpl(const ASTDropQuery & query,
#endif #endif
if (auto * replicated = typeid_cast<DatabaseReplicated *>(database.get())) if (auto * replicated = typeid_cast<DatabaseReplicated *>(database.get()))
replicated->stopReplication(); replicated->stopReplication();
#if USE_LIBPQXX
if (auto * materialize_postgresql = typeid_cast<DatabaseMaterializedPostgreSQL *>(database.get()))
materialize_postgresql->stopReplication();
#endif
if (database->shouldBeEmptyOnDetach()) if (database->shouldBeEmptyOnDetach())
{ {
@ -398,4 +405,33 @@ void InterpreterDropQuery::extendQueryLogElemImpl(QueryLogElement & elem, const
elem.query_kind = "Drop"; elem.query_kind = "Drop";
} }
void InterpreterDropQuery::executeDropQuery(ASTDropQuery::Kind kind, ContextPtr global_context, ContextPtr current_context, const StorageID & target_table_id, bool no_delay)
{
if (DatabaseCatalog::instance().tryGetTable(target_table_id, current_context))
{
/// We create and execute `drop` query for internal table.
auto drop_query = std::make_shared<ASTDropQuery>();
drop_query->database = target_table_id.database_name;
drop_query->table = target_table_id.table_name;
drop_query->kind = kind;
drop_query->no_delay = no_delay;
drop_query->if_exists = true;
ASTPtr ast_drop_query = drop_query;
/// FIXME We have to use global context to execute DROP query for inner table
/// to avoid "Not enough privileges" error if current user has only DROP VIEW ON mat_view_name privilege
/// and not allowed to drop inner table explicitly. Allowing to drop inner table without explicit grant
/// looks like expected behaviour and we have tests for it.
auto drop_context = Context::createCopy(global_context);
drop_context->getClientInfo().query_kind = ClientInfo::QueryKind::SECONDARY_QUERY;
if (auto txn = current_context->getZooKeeperMetadataTransaction())
{
/// For Replicated database
drop_context->setQueryContext(std::const_pointer_cast<Context>(current_context));
drop_context->initZooKeeperMetadataTransaction(txn, true);
}
InterpreterDropQuery drop_interpreter(ast_drop_query, drop_context);
drop_interpreter.execute();
}
}
} }

View File

@ -26,6 +26,8 @@ public:
void extendQueryLogElemImpl(QueryLogElement & elem, const ASTPtr &, ContextPtr) const override; void extendQueryLogElemImpl(QueryLogElement & elem, const ASTPtr &, ContextPtr) const override;
static void executeDropQuery(ASTDropQuery::Kind kind, ContextPtr global_context, ContextPtr current_context, const StorageID & target_table_id, bool no_delay);
private: private:
AccessRightsElements getRequiredAccessForDDLOnCluster() const; AccessRightsElements getRequiredAccessForDDLOnCluster() const;
ASTPtr query_ptr; ASTPtr query_ptr;

View File

@ -506,7 +506,7 @@ InterpreterSelectQuery::InterpreterSelectQuery(
result_header = getSampleBlockImpl(); result_header = getSampleBlockImpl();
}; };
analyze(settings.optimize_move_to_prewhere); analyze(shouldMoveToPrewhere());
bool need_analyze_again = false; bool need_analyze_again = false;
if (analysis_result.prewhere_constant_filter_description.always_false || analysis_result.prewhere_constant_filter_description.always_true) if (analysis_result.prewhere_constant_filter_description.always_false || analysis_result.prewhere_constant_filter_description.always_true)
@ -1532,16 +1532,22 @@ void InterpreterSelectQuery::addEmptySourceToQueryPlan(
} }
} }
void InterpreterSelectQuery::addPrewhereAliasActions() bool InterpreterSelectQuery::shouldMoveToPrewhere()
{ {
const Settings & settings = context->getSettingsRef(); const Settings & settings = context->getSettingsRef();
const ASTSelectQuery & query = getSelectQuery();
return settings.optimize_move_to_prewhere && (!query.final() || settings.optimize_move_to_prewhere_if_final);
}
void InterpreterSelectQuery::addPrewhereAliasActions()
{
auto & expressions = analysis_result; auto & expressions = analysis_result;
if (expressions.filter_info) if (expressions.filter_info)
{ {
if (!expressions.prewhere_info) if (!expressions.prewhere_info)
{ {
const bool does_storage_support_prewhere = !input && !input_pipe && storage && storage->supportsPrewhere(); const bool does_storage_support_prewhere = !input && !input_pipe && storage && storage->supportsPrewhere();
if (does_storage_support_prewhere && settings.optimize_move_to_prewhere) if (does_storage_support_prewhere && shouldMoveToPrewhere())
{ {
/// Execute row level filter in prewhere as a part of "move to prewhere" optimization. /// Execute row level filter in prewhere as a part of "move to prewhere" optimization.
expressions.prewhere_info = std::make_shared<PrewhereInfo>( expressions.prewhere_info = std::make_shared<PrewhereInfo>(
@ -2038,7 +2044,9 @@ void InterpreterSelectQuery::executeAggregation(QueryPlan & query_plan, const Ac
settings.empty_result_for_aggregation_by_empty_set, settings.empty_result_for_aggregation_by_empty_set,
context->getTemporaryVolume(), context->getTemporaryVolume(),
settings.max_threads, settings.max_threads,
settings.min_free_disk_space_for_temporary_data); settings.min_free_disk_space_for_temporary_data,
settings.compile_aggregate_expressions,
settings.min_count_to_compile_aggregate_expression);
SortDescription group_by_sort_description; SortDescription group_by_sort_description;
@ -2140,7 +2148,9 @@ void InterpreterSelectQuery::executeRollupOrCube(QueryPlan & query_plan, Modific
settings.empty_result_for_aggregation_by_empty_set, settings.empty_result_for_aggregation_by_empty_set,
context->getTemporaryVolume(), context->getTemporaryVolume(),
settings.max_threads, settings.max_threads,
settings.min_free_disk_space_for_temporary_data); settings.min_free_disk_space_for_temporary_data,
settings.compile_aggregate_expressions,
settings.min_count_to_compile_aggregate_expression);
auto transform_params = std::make_shared<AggregatingTransformParams>(params, true); auto transform_params = std::make_shared<AggregatingTransformParams>(params, true);

View File

@ -118,6 +118,7 @@ private:
ASTSelectQuery & getSelectQuery() { return query_ptr->as<ASTSelectQuery &>(); } ASTSelectQuery & getSelectQuery() { return query_ptr->as<ASTSelectQuery &>(); }
void addPrewhereAliasActions(); void addPrewhereAliasActions();
bool shouldMoveToPrewhere();
Block getSampleBlockImpl(); Block getSampleBlockImpl();

View File

@ -25,7 +25,7 @@
#include <Interpreters/MetricLog.h> #include <Interpreters/MetricLog.h>
#include <Interpreters/AsynchronousMetricLog.h> #include <Interpreters/AsynchronousMetricLog.h>
#include <Interpreters/OpenTelemetrySpanLog.h> #include <Interpreters/OpenTelemetrySpanLog.h>
#include <Interpreters/ExpressionJIT.h> #include <Interpreters/JIT/CompiledExpressionCache.h>
#include <Access/ContextAccess.h> #include <Access/ContextAccess.h>
#include <Access/AllowedClientHosts.h> #include <Access/AllowedClientHosts.h>
#include <Databases/IDatabase.h> #include <Databases/IDatabase.h>

View File

@ -80,6 +80,28 @@ private:
llvm::TargetMachine & target_machine; llvm::TargetMachine & target_machine;
}; };
// class AssemblyPrinter
// {
// public:
// explicit AssemblyPrinter(llvm::TargetMachine &target_machine_)
// : target_machine(target_machine_)
// {
// }
// void print(llvm::Module & module)
// {
// llvm::legacy::PassManager pass_manager;
// target_machine.Options.MCOptions.AsmVerbose = true;
// if (target_machine.addPassesToEmitFile(pass_manager, llvm::errs(), nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile))
// throw Exception(ErrorCodes::CANNOT_COMPILE_CODE, "MachineCode cannot be printed");
// pass_manager.run(module);
// }
// private:
// llvm::TargetMachine & target_machine;
// };
/** MemoryManager for module. /** MemoryManager for module.
* Keep total allocated size during RuntimeDyld linker execution. * Keep total allocated size during RuntimeDyld linker execution.
* Actual compiled code memory is stored in llvm::SectionMemoryManager member, we cannot use ZeroBase optimization here * Actual compiled code memory is stored in llvm::SectionMemoryManager member, we cannot use ZeroBase optimization here
@ -189,7 +211,7 @@ CHJIT::CHJIT()
CHJIT::~CHJIT() = default; CHJIT::~CHJIT() = default;
CHJIT::CompiledModuleInfo CHJIT::compileModule(std::function<void (llvm::Module &)> compile_function) CHJIT::CompiledModule CHJIT::compileModule(std::function<void (llvm::Module &)> compile_function)
{ {
std::lock_guard<std::mutex> lock(jit_lock); std::lock_guard<std::mutex> lock(jit_lock);
@ -210,7 +232,7 @@ std::unique_ptr<llvm::Module> CHJIT::createModuleForCompilation()
return module; return module;
} }
CHJIT::CompiledModuleInfo CHJIT::compileModule(std::unique_ptr<llvm::Module> module) CHJIT::CompiledModule CHJIT::compileModule(std::unique_ptr<llvm::Module> module)
{ {
runOptimizationPassesOnModule(*module); runOptimizationPassesOnModule(*module);
@ -234,7 +256,7 @@ CHJIT::CompiledModuleInfo CHJIT::compileModule(std::unique_ptr<llvm::Module> mod
dynamic_linker.resolveRelocations(); dynamic_linker.resolveRelocations();
module_memory_manager->getManager().finalizeMemory(); module_memory_manager->getManager().finalizeMemory();
CompiledModuleInfo module_info; CompiledModule compiled_module;
for (const auto & function : *module) for (const auto & function : *module)
{ {
@ -250,47 +272,29 @@ CHJIT::CompiledModuleInfo CHJIT::compileModule(std::unique_ptr<llvm::Module> mod
throw Exception(ErrorCodes::CANNOT_COMPILE_CODE, "DynamicLinker could not found symbol {} after compilation", function_name); throw Exception(ErrorCodes::CANNOT_COMPILE_CODE, "DynamicLinker could not found symbol {} after compilation", function_name);
auto * jit_symbol_address = reinterpret_cast<void *>(jit_symbol.getAddress()); auto * jit_symbol_address = reinterpret_cast<void *>(jit_symbol.getAddress());
compiled_module.function_name_to_symbol.emplace(std::move(function_name), jit_symbol_address);
std::string symbol_name = std::to_string(current_module_key) + '_' + function_name;
name_to_symbol[symbol_name] = jit_symbol_address;
module_info.compiled_functions.emplace_back(std::move(function_name));
} }
module_info.size = module_memory_manager->getAllocatedSize(); compiled_module.size = module_memory_manager->getAllocatedSize();
module_info.identifier = current_module_key; compiled_module.identifier = current_module_key;
module_identifier_to_memory_manager[current_module_key] = std::move(module_memory_manager); module_identifier_to_memory_manager[current_module_key] = std::move(module_memory_manager);
compiled_code_size.fetch_add(module_info.size, std::memory_order_relaxed); compiled_code_size.fetch_add(compiled_module.size, std::memory_order_relaxed);
return module_info; return compiled_module;
} }
void CHJIT::deleteCompiledModule(const CHJIT::CompiledModuleInfo & module_info) void CHJIT::deleteCompiledModule(const CHJIT::CompiledModule & module)
{ {
std::lock_guard<std::mutex> lock(jit_lock); std::lock_guard<std::mutex> lock(jit_lock);
auto module_it = module_identifier_to_memory_manager.find(module_info.identifier); auto module_it = module_identifier_to_memory_manager.find(module.identifier);
if (module_it == module_identifier_to_memory_manager.end()) if (module_it == module_identifier_to_memory_manager.end())
throw Exception(ErrorCodes::LOGICAL_ERROR, "There is no compiled module with identifier {}", module_info.identifier); throw Exception(ErrorCodes::LOGICAL_ERROR, "There is no compiled module with identifier {}", module.identifier);
for (const auto & function : module_info.compiled_functions)
name_to_symbol.erase(function);
module_identifier_to_memory_manager.erase(module_it); module_identifier_to_memory_manager.erase(module_it);
compiled_code_size.fetch_sub(module_info.size, std::memory_order_relaxed); compiled_code_size.fetch_sub(module.size, std::memory_order_relaxed);
}
void * CHJIT::findCompiledFunction(const CompiledModuleInfo & module_info, const std::string & function_name) const
{
std::lock_guard<std::mutex> lock(jit_lock);
std::string symbol_name = std::to_string(module_info.identifier) + '_' + function_name;
auto it = name_to_symbol.find(symbol_name);
if (it != name_to_symbol.end())
return it->second;
return nullptr;
} }
void CHJIT::registerExternalSymbol(const std::string & symbol_name, void * address) void CHJIT::registerExternalSymbol(const std::string & symbol_name, void * address)

View File

@ -9,9 +9,9 @@
#include <unordered_map> #include <unordered_map>
#include <atomic> #include <atomic>
#include <llvm/IR/LLVMContext.h> #include <llvm/IR/LLVMContext.h> // Y_IGNORE
#include <llvm/IR/Module.h> #include <llvm/IR/Module.h> // Y_IGNORE
#include <llvm/Target/TargetMachine.h> #include <llvm/Target/TargetMachine.h> // Y_IGNORE
namespace DB namespace DB
{ {
@ -52,32 +52,31 @@ public:
~CHJIT(); ~CHJIT();
struct CompiledModuleInfo struct CompiledModule
{ {
/// Size of compiled module code in bytes /// Size of compiled module code in bytes
size_t size; size_t size;
/// Module identifier. Should not be changed by client /// Module identifier. Should not be changed by client
uint64_t identifier; uint64_t identifier;
/// Vector of compiled function nameds. Should not be changed by client
std::vector<std::string> compiled_functions; /// Vector of compiled functions. Should not be changed by client.
/// It is client responsibility to cast result function to right signature.
/// After call to deleteCompiledModule compiled functions from module become invalid.
std::unordered_map<std::string, void *> function_name_to_symbol;
}; };
/** Compile module. In compile function client responsibility is to fill module with necessary /** Compile module. In compile function client responsibility is to fill module with necessary
* IR code, then it will be compiled by CHJIT instance. * IR code, then it will be compiled by CHJIT instance.
* Return compiled module info. * Return compiled module.
*/ */
CompiledModuleInfo compileModule(std::function<void (llvm::Module &)> compile_function); CompiledModule compileModule(std::function<void (llvm::Module &)> compile_function);
/** Delete compiled module. Pointers to functions from module become invalid after this call. /** Delete compiled module. Pointers to functions from module become invalid after this call.
* It is client responsibility to be sure that there are no pointers to compiled module code. * It is client responsibility to be sure that there are no pointers to compiled module code.
*/ */
void deleteCompiledModule(const CompiledModuleInfo & module_info); void deleteCompiledModule(const CompiledModule & module_info);
/** Find compiled function using module_info, and function_name.
* It is client responsibility to case result function to right signature.
* After call to deleteCompiledModule compiled functions from module become invalid.
*/
void * findCompiledFunction(const CompiledModuleInfo & module_info, const std::string & function_name) const;
/** Register external symbol for CHJIT instance to use, during linking. /** Register external symbol for CHJIT instance to use, during linking.
* It can be function, or global constant. * It can be function, or global constant.
@ -93,7 +92,7 @@ private:
std::unique_ptr<llvm::Module> createModuleForCompilation(); std::unique_ptr<llvm::Module> createModuleForCompilation();
CompiledModuleInfo compileModule(std::unique_ptr<llvm::Module> module); CompiledModule compileModule(std::unique_ptr<llvm::Module> module);
std::string getMangledName(const std::string & name_to_mangle) const; std::string getMangledName(const std::string & name_to_mangle) const;
@ -107,7 +106,6 @@ private:
std::unique_ptr<JITCompiler> compiler; std::unique_ptr<JITCompiler> compiler;
std::unique_ptr<JITSymbolResolver> symbol_resolver; std::unique_ptr<JITSymbolResolver> symbol_resolver;
std::unordered_map<std::string, void *> name_to_symbol;
std::unordered_map<uint64_t, std::unique_ptr<JITModuleMemoryManager>> module_identifier_to_memory_manager; std::unordered_map<uint64_t, std::unique_ptr<JITModuleMemoryManager>> module_identifier_to_memory_manager;
uint64_t current_module_key = 0; uint64_t current_module_key = 0;
std::atomic<size_t> compiled_code_size = 0; std::atomic<size_t> compiled_code_size = 0;

View File

@ -0,0 +1,34 @@
#include "CompiledExpressionCache.h"
#if USE_EMBEDDED_COMPILER
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
CompiledExpressionCacheFactory & CompiledExpressionCacheFactory::instance()
{
static CompiledExpressionCacheFactory factory;
return factory;
}
void CompiledExpressionCacheFactory::init(size_t cache_size)
{
if (cache)
throw Exception(ErrorCodes::LOGICAL_ERROR, "CompiledExpressionCache was already initialized");
cache = std::make_unique<CompiledExpressionCache>(cache_size);
}
CompiledExpressionCache * CompiledExpressionCacheFactory::tryGetCache()
{
return cache.get();
}
}
#endif

View File

@ -0,0 +1,61 @@
#pragma once
#if !defined(ARCADIA_BUILD)
# include "config_core.h"
#endif
#if USE_EMBEDDED_COMPILER
# include <Common/LRUCache.h>
# include <Common/HashTable/Hash.h>
# include <Interpreters/JIT/CHJIT.h>
namespace DB
{
class CompiledExpressionCacheEntry
{
public:
explicit CompiledExpressionCacheEntry(size_t compiled_expression_size_)
: compiled_expression_size(compiled_expression_size_)
{}
size_t getCompiledExpressionSize() const { return compiled_expression_size; }
virtual ~CompiledExpressionCacheEntry() {}
private:
size_t compiled_expression_size = 0;
};
struct CompiledFunctionWeightFunction
{
size_t operator()(const CompiledExpressionCacheEntry & compiled_function) const
{
return compiled_function.getCompiledExpressionSize();
}
};
class CompiledExpressionCache : public LRUCache<UInt128, CompiledExpressionCacheEntry, UInt128Hash, CompiledFunctionWeightFunction>
{
public:
using Base = LRUCache<UInt128, CompiledExpressionCacheEntry, UInt128Hash, CompiledFunctionWeightFunction>;
using Base::Base;
};
class CompiledExpressionCacheFactory
{
private:
std::unique_ptr<CompiledExpressionCache> cache;
public:
static CompiledExpressionCacheFactory & instance();
void init(size_t cache_size);
CompiledExpressionCache * tryGetCache();
};
}
#endif

View File

@ -250,20 +250,356 @@ static void compileFunction(llvm::Module & module, const IFunctionBase & functio
b.CreateRetVoid(); b.CreateRetVoid();
} }
CHJIT::CompiledModuleInfo compileFunction(CHJIT & jit, const IFunctionBase & function) CompiledFunction compileFunction(CHJIT & jit, const IFunctionBase & function)
{ {
Stopwatch watch; Stopwatch watch;
auto compiled_module_info = jit.compileModule([&](llvm::Module & module) auto compiled_module = jit.compileModule([&](llvm::Module & module)
{ {
compileFunction(module, function); compileFunction(module, function);
}); });
ProfileEvents::increment(ProfileEvents::CompileExpressionsMicroseconds, watch.elapsedMicroseconds()); ProfileEvents::increment(ProfileEvents::CompileExpressionsMicroseconds, watch.elapsedMicroseconds());
ProfileEvents::increment(ProfileEvents::CompileExpressionsBytes, compiled_module_info.size); ProfileEvents::increment(ProfileEvents::CompileExpressionsBytes, compiled_module.size);
ProfileEvents::increment(ProfileEvents::CompileFunction); ProfileEvents::increment(ProfileEvents::CompileFunction);
return compiled_module_info; auto compiled_function_ptr = reinterpret_cast<JITCompiledFunction>(compiled_module.function_name_to_symbol[function.getName()]);
assert(compiled_function_ptr);
CompiledFunction result_compiled_function
{
.compiled_function = compiled_function_ptr,
.compiled_module = compiled_module
};
return result_compiled_function;
}
static void compileCreateAggregateStatesFunctions(llvm::Module & module, const std::vector<AggregateFunctionWithOffset> & functions, const std::string & name)
{
auto & context = module.getContext();
llvm::IRBuilder<> b(context);
auto * aggregate_data_places_type = b.getInt8Ty()->getPointerTo();
auto * create_aggregate_states_function_type = llvm::FunctionType::get(b.getVoidTy(), { aggregate_data_places_type }, false);
auto * create_aggregate_states_function = llvm::Function::Create(create_aggregate_states_function_type, llvm::Function::ExternalLinkage, name, module);
auto * arguments = create_aggregate_states_function->args().begin();
llvm::Value * aggregate_data_place_arg = arguments++;
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", create_aggregate_states_function);
b.SetInsertPoint(entry);
std::vector<ColumnDataPlaceholder> columns(functions.size());
for (const auto & function_to_compile : functions)
{
size_t aggregate_function_offset = function_to_compile.aggregate_data_offset;
const auto * aggregate_function = function_to_compile.function;
auto * aggregation_place_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregate_data_place_arg, aggregate_function_offset);
aggregate_function->compileCreate(b, aggregation_place_with_offset);
}
b.CreateRetVoid();
}
static void compileAddIntoAggregateStatesFunctions(llvm::Module & module, const std::vector<AggregateFunctionWithOffset> & functions, const std::string & name)
{
auto & context = module.getContext();
llvm::IRBuilder<> b(context);
auto * size_type = b.getIntNTy(sizeof(size_t) * 8);
auto * places_type = b.getInt8Ty()->getPointerTo()->getPointerTo();
auto * column_data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy());
auto * aggregate_loop_func_declaration = llvm::FunctionType::get(b.getVoidTy(), { size_type, column_data_type->getPointerTo(), places_type }, false);
auto * aggregate_loop_func_definition = llvm::Function::Create(aggregate_loop_func_declaration, llvm::Function::ExternalLinkage, name, module);
auto * arguments = aggregate_loop_func_definition->args().begin();
llvm::Value * rows_count_arg = arguments++;
llvm::Value * columns_arg = arguments++;
llvm::Value * places_arg = arguments++;
/// Initialize ColumnDataPlaceholder llvm representation of ColumnData
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", aggregate_loop_func_definition);
b.SetInsertPoint(entry);
std::vector<ColumnDataPlaceholder> columns;
size_t previous_columns_size = 0;
for (const auto & function : functions)
{
auto argument_types = function.function->getArgumentTypes();
ColumnDataPlaceholder data_placeholder;
size_t function_arguments_size = argument_types.size();
for (size_t column_argument_index = 0; column_argument_index < function_arguments_size; ++column_argument_index)
{
const auto & argument_type = argument_types[column_argument_index];
auto * data = b.CreateLoad(column_data_type, b.CreateConstInBoundsGEP1_32(column_data_type, columns_arg, previous_columns_size + column_argument_index));
data_placeholder.data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(argument_type))->getPointerTo());
data_placeholder.null_init = argument_type->isNullable() ? b.CreateExtractValue(data, {1}) : nullptr;
columns.emplace_back(data_placeholder);
}
previous_columns_size += function_arguments_size;
}
/// Initialize loop
auto * end = llvm::BasicBlock::Create(b.getContext(), "end", aggregate_loop_func_definition);
auto * loop = llvm::BasicBlock::Create(b.getContext(), "loop", aggregate_loop_func_definition);
b.CreateCondBr(b.CreateICmpEQ(rows_count_arg, llvm::ConstantInt::get(size_type, 0)), end, loop);
b.SetInsertPoint(loop);
auto * counter_phi = b.CreatePHI(rows_count_arg->getType(), 2);
counter_phi->addIncoming(llvm::ConstantInt::get(size_type, 0), entry);
auto * places_phi = b.CreatePHI(places_arg->getType(), 2);
places_phi->addIncoming(places_arg, entry);
for (auto & col : columns)
{
col.data = b.CreatePHI(col.data_init->getType(), 2);
col.data->addIncoming(col.data_init, entry);
if (col.null_init)
{
col.null = b.CreatePHI(col.null_init->getType(), 2);
col.null->addIncoming(col.null_init, entry);
}
}
auto * aggregation_place = b.CreateLoad(b.getInt8Ty()->getPointerTo(), places_phi);
previous_columns_size = 0;
for (const auto & function : functions)
{
size_t aggregate_function_offset = function.aggregate_data_offset;
const auto * aggregate_function_ptr = function.function;
auto arguments_types = function.function->getArgumentTypes();
std::vector<llvm::Value *> arguments_values;
size_t function_arguments_size = arguments_types.size();
arguments_values.resize(function_arguments_size);
for (size_t column_argument_index = 0; column_argument_index < function_arguments_size; ++column_argument_index)
{
auto * column_argument_data = columns[previous_columns_size + column_argument_index].data;
auto * column_argument_null_data = columns[previous_columns_size + column_argument_index].null;
auto & argument_type = arguments_types[column_argument_index];
auto * value = b.CreateLoad(toNativeType(b, removeNullable(argument_type)), column_argument_data);
if (!argument_type->isNullable())
{
arguments_values[column_argument_index] = value;
continue;
}
auto * is_null = b.CreateICmpNE(b.CreateLoad(b.getInt8Ty(), column_argument_null_data), b.getInt8(0));
auto * nullable_unitilized = llvm::Constant::getNullValue(toNativeType(b, argument_type));
auto * nullable_value = b.CreateInsertValue(b.CreateInsertValue(nullable_unitilized, value, {0}), is_null, {1});
arguments_values[column_argument_index] = nullable_value;
}
auto * aggregation_place_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregation_place, aggregate_function_offset);
aggregate_function_ptr->compileAdd(b, aggregation_place_with_offset, arguments_types, arguments_values);
previous_columns_size += function_arguments_size;
}
/// End of loop
auto * cur_block = b.GetInsertBlock();
for (auto & col : columns)
{
col.data->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.data, 1), cur_block);
if (col.null)
col.null->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.null, 1), cur_block);
}
places_phi->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, places_phi, 1), cur_block);
auto * value = b.CreateAdd(counter_phi, llvm::ConstantInt::get(size_type, 1));
counter_phi->addIncoming(value, cur_block);
b.CreateCondBr(b.CreateICmpEQ(value, rows_count_arg), end, loop);
b.SetInsertPoint(end);
b.CreateRetVoid();
}
static void compileMergeAggregatesStates(llvm::Module & module, const std::vector<AggregateFunctionWithOffset> & functions, const std::string & name)
{
auto & context = module.getContext();
llvm::IRBuilder<> b(context);
auto * aggregate_data_places_type = b.getInt8Ty()->getPointerTo();
auto * aggregate_loop_func_declaration = llvm::FunctionType::get(b.getVoidTy(), { aggregate_data_places_type, aggregate_data_places_type }, false);
auto * aggregate_loop_func = llvm::Function::Create(aggregate_loop_func_declaration, llvm::Function::ExternalLinkage, name, module);
auto * arguments = aggregate_loop_func->args().begin();
llvm::Value * aggregate_data_place_dst_arg = arguments++;
llvm::Value * aggregate_data_place_src_arg = arguments++;
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", aggregate_loop_func);
b.SetInsertPoint(entry);
for (const auto & function_to_compile : functions)
{
size_t aggregate_function_offset = function_to_compile.aggregate_data_offset;
const auto * aggregate_function_ptr = function_to_compile.function;
auto * aggregate_data_place_merge_dst_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregate_data_place_dst_arg, aggregate_function_offset);
auto * aggregate_data_place_merge_src_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregate_data_place_src_arg, aggregate_function_offset);
aggregate_function_ptr->compileMerge(b, aggregate_data_place_merge_dst_with_offset, aggregate_data_place_merge_src_with_offset);
}
b.CreateRetVoid();
}
static void compileInsertAggregatesIntoResultColumns(llvm::Module & module, const std::vector<AggregateFunctionWithOffset> & functions, const std::string & name)
{
auto & context = module.getContext();
llvm::IRBuilder<> b(context);
auto * size_type = b.getIntNTy(sizeof(size_t) * 8);
auto * column_data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy());
auto * aggregate_data_places_type = b.getInt8Ty()->getPointerTo()->getPointerTo();
auto * aggregate_loop_func_declaration = llvm::FunctionType::get(b.getVoidTy(), { size_type, column_data_type->getPointerTo(), aggregate_data_places_type }, false);
auto * aggregate_loop_func = llvm::Function::Create(aggregate_loop_func_declaration, llvm::Function::ExternalLinkage, name, module);
auto * arguments = aggregate_loop_func->args().begin();
llvm::Value * rows_count_arg = &*arguments++;
llvm::Value * columns_arg = &*arguments++;
llvm::Value * aggregate_data_places_arg = &*arguments++;
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", aggregate_loop_func);
b.SetInsertPoint(entry);
std::vector<ColumnDataPlaceholder> columns(functions.size());
for (size_t i = 0; i < functions.size(); ++i)
{
auto return_type = functions[i].function->getReturnType();
auto * data = b.CreateLoad(column_data_type, b.CreateConstInBoundsGEP1_32(column_data_type, columns_arg, i));
columns[i].data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(return_type))->getPointerTo());
columns[i].null_init = return_type->isNullable() ? b.CreateExtractValue(data, {1}) : nullptr;
}
auto * end = llvm::BasicBlock::Create(b.getContext(), "end", aggregate_loop_func);
auto * loop = llvm::BasicBlock::Create(b.getContext(), "loop", aggregate_loop_func);
b.CreateCondBr(b.CreateICmpEQ(rows_count_arg, llvm::ConstantInt::get(size_type, 0)), end, loop);
b.SetInsertPoint(loop);
auto * counter_phi = b.CreatePHI(rows_count_arg->getType(), 2);
counter_phi->addIncoming(llvm::ConstantInt::get(size_type, 0), entry);
auto * aggregate_data_place_phi = b.CreatePHI(aggregate_data_places_type, 2);
aggregate_data_place_phi->addIncoming(aggregate_data_places_arg, entry);
for (auto & col : columns)
{
col.data = b.CreatePHI(col.data_init->getType(), 2);
col.data->addIncoming(col.data_init, entry);
if (col.null_init)
{
col.null = b.CreatePHI(col.null_init->getType(), 2);
col.null->addIncoming(col.null_init, entry);
}
}
for (size_t i = 0; i < functions.size(); ++i)
{
size_t aggregate_function_offset = functions[i].aggregate_data_offset;
const auto * aggregate_function_ptr = functions[i].function;
auto * aggregate_data_place = b.CreateLoad(b.getInt8Ty()->getPointerTo(), aggregate_data_place_phi);
auto * aggregation_place_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregate_data_place, aggregate_function_offset);
auto * final_value = aggregate_function_ptr->compileGetResult(b, aggregation_place_with_offset);
if (columns[i].null_init)
{
b.CreateStore(b.CreateExtractValue(final_value, {0}), columns[i].data);
b.CreateStore(b.CreateSelect(b.CreateExtractValue(final_value, {1}), b.getInt8(1), b.getInt8(0)), columns[i].null);
}
else
{
b.CreateStore(final_value, columns[i].data);
}
}
/// End of loop
auto * cur_block = b.GetInsertBlock();
for (auto & col : columns)
{
col.data->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.data, 1), cur_block);
if (col.null)
col.null->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.null, 1), cur_block);
}
auto * value = b.CreateAdd(counter_phi, llvm::ConstantInt::get(size_type, 1), "", true, true);
counter_phi->addIncoming(value, cur_block);
aggregate_data_place_phi->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, aggregate_data_place_phi, 1), cur_block);
b.CreateCondBr(b.CreateICmpEQ(value, rows_count_arg), end, loop);
b.SetInsertPoint(end);
b.CreateRetVoid();
}
CompiledAggregateFunctions compileAggregateFunctons(CHJIT & jit, const std::vector<AggregateFunctionWithOffset> & functions, std::string functions_dump_name)
{
std::string create_aggregate_states_functions_name = functions_dump_name + "_create";
std::string add_aggregate_states_functions_name = functions_dump_name + "_add";
std::string merge_aggregate_states_functions_name = functions_dump_name + "_merge";
std::string insert_aggregate_states_functions_name = functions_dump_name + "_insert";
auto compiled_module = jit.compileModule([&](llvm::Module & module)
{
compileCreateAggregateStatesFunctions(module, functions, create_aggregate_states_functions_name);
compileAddIntoAggregateStatesFunctions(module, functions, add_aggregate_states_functions_name);
compileMergeAggregatesStates(module, functions, merge_aggregate_states_functions_name);
compileInsertAggregatesIntoResultColumns(module, functions, insert_aggregate_states_functions_name);
});
auto create_aggregate_states_function = reinterpret_cast<JITCreateAggregateStatesFunction>(compiled_module.function_name_to_symbol[create_aggregate_states_functions_name]);
auto add_into_aggregate_states_function = reinterpret_cast<JITAddIntoAggregateStatesFunction>(compiled_module.function_name_to_symbol[add_aggregate_states_functions_name]);
auto merge_aggregate_states_function = reinterpret_cast<JITMergeAggregateStatesFunction>(compiled_module.function_name_to_symbol[merge_aggregate_states_functions_name]);
auto insert_aggregate_states_function = reinterpret_cast<JITInsertAggregateStatesIntoColumnsFunction>(compiled_module.function_name_to_symbol[insert_aggregate_states_functions_name]);
assert(create_aggregate_states_function);
assert(add_into_aggregate_states_function);
assert(merge_aggregate_states_function);
assert(insert_aggregate_states_function);
CompiledAggregateFunctions compiled_aggregate_functions
{
.create_aggregate_states_function = create_aggregate_states_function,
.add_into_aggregate_states_function = add_into_aggregate_states_function,
.merge_aggregate_states_function = merge_aggregate_states_function,
.insert_aggregates_into_columns_function = insert_aggregate_states_function,
.functions_count = functions.size(),
.compiled_module = std::move(compiled_module)
};
return compiled_aggregate_functions;
} }
} }

View File

@ -7,6 +7,7 @@
#if USE_EMBEDDED_COMPILER #if USE_EMBEDDED_COMPILER
#include <Functions/IFunction.h> #include <Functions/IFunction.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Interpreters/JIT/CHJIT.h> #include <Interpreters/JIT/CHJIT.h>
namespace DB namespace DB
@ -28,18 +29,56 @@ struct ColumnData
ColumnData getColumnData(const IColumn * column); ColumnData getColumnData(const IColumn * column);
using ColumnDataRowsSize = size_t; using ColumnDataRowsSize = size_t;
using JITCompiledFunction = void (*)(ColumnDataRowsSize, ColumnData *); using JITCompiledFunction = void (*)(ColumnDataRowsSize, ColumnData *);
struct CompiledFunction
{
JITCompiledFunction compiled_function;
CHJIT::CompiledModule compiled_module;
};
/** Compile function to native jit code using CHJIT instance. /** Compile function to native jit code using CHJIT instance.
* Function is compiled as single module. * It is client responsibility to match ColumnData arguments size with
* After this function execution, code for function will be compiled and can be queried using * function arguments size and additional ColumnData for result.
* findCompiledFunction with function name.
* Compiled function can be safely casted to JITCompiledFunction type and must be called with
* valid ColumnData and ColumnDataRowsSize.
* It is important that ColumnData parameter of JITCompiledFunction is result column,
* and will be filled by compiled function.
*/ */
CHJIT::CompiledModuleInfo compileFunction(CHJIT & jit, const IFunctionBase & function); CompiledFunction compileFunction(CHJIT & jit, const IFunctionBase & function);
struct AggregateFunctionWithOffset
{
const IAggregateFunction * function;
size_t aggregate_data_offset;
};
using JITCreateAggregateStatesFunction = void (*)(AggregateDataPtr);
using JITAddIntoAggregateStatesFunction = void (*)(ColumnDataRowsSize, ColumnData *, AggregateDataPtr *);
using JITMergeAggregateStatesFunction = void (*)(AggregateDataPtr, AggregateDataPtr);
using JITInsertAggregateStatesIntoColumnsFunction = void (*)(ColumnDataRowsSize, ColumnData *, AggregateDataPtr *);
struct CompiledAggregateFunctions
{
JITCreateAggregateStatesFunction create_aggregate_states_function;
JITAddIntoAggregateStatesFunction add_into_aggregate_states_function;
JITMergeAggregateStatesFunction merge_aggregate_states_function;
JITInsertAggregateStatesIntoColumnsFunction insert_aggregates_into_columns_function;
/// Count of functions that were compiled
size_t functions_count;
/// Compiled module. It is client responsibility to destroy it after functions are no longer required.
CHJIT::CompiledModule compiled_module;
};
/** Compile aggregate function to native jit code using CHJIT instance.
*
* JITCreateAggregateStatesFunction will initialize aggregate data ptr with initial aggregate states values.
* JITAddIntoAggregateStatesFunction will update aggregate states for aggregate functions with specified ColumnData.
* JITMergeAggregateStatesFunction will merge aggregate states for aggregate functions.
* JITInsertAggregateStatesIntoColumnsFunction will insert aggregate states for aggregate functions into result columns.
*/
CompiledAggregateFunctions compileAggregateFunctons(CHJIT & jit, const std::vector<AggregateFunctionWithOffset> & functions, std::string functions_dump_name);
} }

View File

@ -1,5 +1,11 @@
#include <iostream> #include <iostream>
#if !defined(ARCADIA_BUILD)
# include "config_core.h"
#endif
#if USE_EMBEDDED_COMPILER
#include <llvm/IR/IRBuilder.h> #include <llvm/IR/IRBuilder.h>
#include <Interpreters/JIT/CHJIT.h> #include <Interpreters/JIT/CHJIT.h>
@ -18,7 +24,7 @@ int main(int argc, char **argv)
jit.registerExternalSymbol("test_function", reinterpret_cast<void *>(&test_function)); jit.registerExternalSymbol("test_function", reinterpret_cast<void *>(&test_function));
auto compiled_module_info = jit.compileModule([](llvm::Module & module) auto compiled_module = jit.compileModule([](llvm::Module & module)
{ {
auto & context = module.getContext(); auto & context = module.getContext();
llvm::IRBuilder<> b (context); llvm::IRBuilder<> b (context);
@ -43,15 +49,27 @@ int main(int argc, char **argv)
b.CreateRet(value); b.CreateRet(value);
}); });
for (const auto & compiled_function_name : compiled_module_info.compiled_functions) for (const auto & [compiled_function_name, _] : compiled_module.function_name_to_symbol)
{ {
std::cerr << compiled_function_name << std::endl; std::cerr << compiled_function_name << std::endl;
} }
int64_t value = 5; int64_t value = 5;
auto * test_name_function = reinterpret_cast<int64_t (*)(int64_t *)>(jit.findCompiledFunction(compiled_module_info, "test_name")); auto * symbol = compiled_module.function_name_to_symbol["test_name"];
auto * test_name_function = reinterpret_cast<int64_t (*)(int64_t *)>(symbol);
auto result = test_name_function(&value); auto result = test_name_function(&value);
std::cerr << "Result " << result << std::endl; std::cerr << "Result " << result << std::endl;
return 0; return 0;
} }
#else
int main(int argc, char **argv)
{
(void)(argc);
(void)(argv);
return 0;
}
#endif

View File

@ -66,6 +66,7 @@ static const size_t SSL_REQUEST_PAYLOAD_SIZE = 32;
static String selectEmptyReplacementQuery(const String & query); static String selectEmptyReplacementQuery(const String & query);
static String showTableStatusReplacementQuery(const String & query); static String showTableStatusReplacementQuery(const String & query);
static String killConnectionIdReplacementQuery(const String & query); static String killConnectionIdReplacementQuery(const String & query);
static String selectLimitReplacementQuery(const String & query);
MySQLHandler::MySQLHandler(IServer & server_, const Poco::Net::StreamSocket & socket_, MySQLHandler::MySQLHandler(IServer & server_, const Poco::Net::StreamSocket & socket_,
bool ssl_enabled, size_t connection_id_) bool ssl_enabled, size_t connection_id_)
@ -83,6 +84,7 @@ MySQLHandler::MySQLHandler(IServer & server_, const Poco::Net::StreamSocket & so
replacements.emplace("KILL QUERY", killConnectionIdReplacementQuery); replacements.emplace("KILL QUERY", killConnectionIdReplacementQuery);
replacements.emplace("SHOW TABLE STATUS LIKE", showTableStatusReplacementQuery); replacements.emplace("SHOW TABLE STATUS LIKE", showTableStatusReplacementQuery);
replacements.emplace("SHOW VARIABLES", selectEmptyReplacementQuery); replacements.emplace("SHOW VARIABLES", selectEmptyReplacementQuery);
replacements.emplace("SET SQL_SELECT_LIMIT", selectLimitReplacementQuery);
} }
void MySQLHandler::run() void MySQLHandler::run()
@ -461,6 +463,14 @@ static String showTableStatusReplacementQuery(const String & query)
return query; return query;
} }
static String selectLimitReplacementQuery(const String & query)
{
const String prefix = "SET SQL_SELECT_LIMIT";
if (query.starts_with(prefix))
return "SET limit" + std::string(query.data() + prefix.length());
return query;
}
/// Replace "KILL QUERY [connection_id]" into "KILL QUERY WHERE query_id = 'mysql:[connection_id]'". /// Replace "KILL QUERY [connection_id]" into "KILL QUERY WHERE query_id = 'mysql:[connection_id]'".
static String killConnectionIdReplacementQuery(const String & query) static String killConnectionIdReplacementQuery(const String & query)
{ {

View File

@ -51,7 +51,10 @@ TableLockHolder IStorage::lockForShare(const String & query_id, const std::chron
TableLockHolder result = tryLockTimed(drop_lock, RWLockImpl::Read, query_id, acquire_timeout); TableLockHolder result = tryLockTimed(drop_lock, RWLockImpl::Read, query_id, acquire_timeout);
if (is_dropped) if (is_dropped)
throw Exception("Table is dropped", ErrorCodes::TABLE_IS_DROPPED); {
auto table_id = getStorageID();
throw Exception(ErrorCodes::TABLE_IS_DROPPED, "Table {}.{} is dropped", table_id.database_name, table_id.table_name);
}
return result; return result;
} }

View File

@ -351,6 +351,8 @@ public:
*/ */
virtual void drop() {} virtual void drop() {}
virtual void dropInnerTableIfAny(bool /* no_delay */, ContextPtr /* context */) {}
/** Clear the table data and leave it empty. /** Clear the table data and leave it empty.
* Must be called under exclusive lock (lockExclusively). * Must be called under exclusive lock (lockExclusively).
*/ */

View File

@ -1013,7 +1013,7 @@ void MergeTreeData::loadDataParts(bool skip_sanity_checks)
ErrorCodes::TOO_MANY_UNEXPECTED_DATA_PARTS); ErrorCodes::TOO_MANY_UNEXPECTED_DATA_PARTS);
for (auto & part : broken_parts_to_detach) for (auto & part : broken_parts_to_detach)
part->renameToDetached("broken_on_start"); part->renameToDetached("broken-on-start"); /// detached parts must not have '_' in prefixes
/// Delete from the set of current parts those parts that are covered by another part (those parts that /// Delete from the set of current parts those parts that are covered by another part (those parts that

View File

@ -301,6 +301,8 @@ QueryPlanPtr MergeTreeDataSelectExecutor::read(
context->getTemporaryVolume(), context->getTemporaryVolume(),
settings.max_threads, settings.max_threads,
settings.min_free_disk_space_for_temporary_data, settings.min_free_disk_space_for_temporary_data,
settings.compile_expressions,
settings.min_count_to_compile_aggregate_expression,
header_before_aggregation); // The source header is also an intermediate header header_before_aggregation); // The source header is also an intermediate header
transform_params = std::make_shared<AggregatingTransformParams>(std::move(params), query_info.projection->aggregate_final); transform_params = std::make_shared<AggregatingTransformParams>(std::move(params), query_info.projection->aggregate_final);
@ -329,7 +331,9 @@ QueryPlanPtr MergeTreeDataSelectExecutor::read(
settings.empty_result_for_aggregation_by_empty_set, settings.empty_result_for_aggregation_by_empty_set,
context->getTemporaryVolume(), context->getTemporaryVolume(),
settings.max_threads, settings.max_threads,
settings.min_free_disk_space_for_temporary_data); settings.min_free_disk_space_for_temporary_data,
settings.compile_aggregate_expressions,
settings.min_count_to_compile_aggregate_expression);
transform_params = std::make_shared<AggregatingTransformParams>(std::move(params), query_info.projection->aggregate_final); transform_params = std::make_shared<AggregatingTransformParams>(std::move(params), query_info.projection->aggregate_final);
} }

View File

@ -190,18 +190,36 @@ void ReplicatedMergeTreePartCheckThread::searchForMissingPartAndFetchIfPossible(
if (missing_part_search_result == MissingPartSearchResult::LostForever) if (missing_part_search_result == MissingPartSearchResult::LostForever)
{ {
/// Is it in the replication queue? If there is - delete, because the task can not be processed. auto lost_part_info = MergeTreePartInfo::fromPartName(part_name, storage.format_version);
if (!storage.queue.remove(zookeeper, part_name)) if (lost_part_info.level != 0)
{ {
/// The part was not in our queue. Strings source_parts;
LOG_WARNING(log, "Missing part {} is not in our queue, this can happen rarely.", part_name); bool part_in_queue = storage.queue.checkPartInQueueAndGetSourceParts(part_name, source_parts);
/// If it's MERGE/MUTATION etc. we shouldn't replace result part with empty part
/// because some source parts can be lost, but some of them can exist.
if (part_in_queue && !source_parts.empty())
{
LOG_ERROR(log, "Part {} found in queue and some source parts for it was lost. Will check all source parts.", part_name);
for (const String & source_part_name : source_parts)
enqueuePart(source_part_name);
return;
}
} }
/** This situation is possible if on all the replicas where the part was, it deteriorated. if (storage.createEmptyPartInsteadOfLost(zookeeper, part_name))
* For example, a replica that has just written it has power turned off and the data has not been written from cache to disk. {
*/ /** This situation is possible if on all the replicas where the part was, it deteriorated.
LOG_ERROR(log, "Part {} is lost forever.", part_name); * For example, a replica that has just written it has power turned off and the data has not been written from cache to disk.
ProfileEvents::increment(ProfileEvents::ReplicatedDataLoss); */
LOG_ERROR(log, "Part {} is lost forever.", part_name);
ProfileEvents::increment(ProfileEvents::ReplicatedDataLoss);
}
else
{
LOG_WARNING(log, "Cannot create empty part {} instead of lost. Will retry later", part_name);
}
} }
} }
@ -307,11 +325,12 @@ CheckResult ReplicatedMergeTreePartCheckThread::checkPart(const String & part_na
String message = "Part " + part_name + " looks broken. Removing it and will try to fetch."; String message = "Part " + part_name + " looks broken. Removing it and will try to fetch.";
LOG_ERROR(log, message); LOG_ERROR(log, message);
/// Delete part locally.
storage.forgetPartAndMoveToDetached(part, "broken");
/// Part is broken, let's try to find it and fetch. /// Part is broken, let's try to find it and fetch.
searchForMissingPartAndFetchIfPossible(part_name, exists_in_zookeeper); searchForMissingPartAndFetchIfPossible(part_name, exists_in_zookeeper);
/// Delete part locally.
storage.forgetPartAndMoveToDetached(part, "broken");
return {part_name, false, message}; return {part_name, false, message};
} }
} }

View File

@ -64,6 +64,24 @@ bool ReplicatedMergeTreeQueue::isVirtualPart(const MergeTreeData::DataPartPtr &
return !virtual_part_name.empty() && virtual_part_name != data_part->name; return !virtual_part_name.empty() && virtual_part_name != data_part->name;
} }
bool ReplicatedMergeTreeQueue::checkPartInQueueAndGetSourceParts(const String & part_name, Strings & source_parts) const
{
std::lock_guard lock(state_mutex);
bool found = false;
for (const auto & entry : queue)
{
if (entry->new_part_name == part_name && entry->source_parts.size() > source_parts.size())
{
source_parts.clear();
source_parts.insert(source_parts.end(), entry->source_parts.begin(), entry->source_parts.end());
found = true;
}
}
return found;
}
bool ReplicatedMergeTreeQueue::load(zkutil::ZooKeeperPtr zookeeper) bool ReplicatedMergeTreeQueue::load(zkutil::ZooKeeperPtr zookeeper)
{ {
@ -410,62 +428,6 @@ void ReplicatedMergeTreeQueue::removeProcessedEntry(zkutil::ZooKeeperPtr zookeep
updateTimesInZooKeeper(zookeeper, min_unprocessed_insert_time_changed, max_processed_insert_time_changed); updateTimesInZooKeeper(zookeeper, min_unprocessed_insert_time_changed, max_processed_insert_time_changed);
} }
bool ReplicatedMergeTreeQueue::remove(zkutil::ZooKeeperPtr zookeeper, const String & part_name)
{
LogEntryPtr found;
size_t queue_size = 0;
std::optional<time_t> min_unprocessed_insert_time_changed;
std::optional<time_t> max_processed_insert_time_changed;
{
std::unique_lock lock(state_mutex);
bool removed = virtual_parts.remove(part_name);
for (Queue::iterator it = queue.begin(); it != queue.end();)
{
if ((*it)->new_part_name == part_name)
{
found = *it;
if (removed)
{
/// Preserve invariant `virtual_parts` = `current_parts` + `queue`.
/// We remove new_part from virtual parts and add all source parts
/// which present in current_parts.
for (const auto & source_part : found->source_parts)
{
auto part_in_current_parts = current_parts.getContainingPart(source_part);
if (part_in_current_parts == source_part)
virtual_parts.add(source_part, nullptr, log);
}
}
updateStateOnQueueEntryRemoval(
found, /* is_successful = */ false,
min_unprocessed_insert_time_changed, max_processed_insert_time_changed, lock);
queue.erase(it++);
queue_size = queue.size();
break;
}
else
++it;
}
}
if (!found)
return false;
notifySubscribers(queue_size);
zookeeper->tryRemove(fs::path(replica_path) / "queue" / found->znode_name);
updateTimesInZooKeeper(zookeeper, min_unprocessed_insert_time_changed, max_processed_insert_time_changed);
return true;
}
bool ReplicatedMergeTreeQueue::removeFailedQuorumPart(const MergeTreePartInfo & part_info) bool ReplicatedMergeTreeQueue::removeFailedQuorumPart(const MergeTreePartInfo & part_info)
{ {
assert(part_info.level == 0); assert(part_info.level == 0);

View File

@ -281,11 +281,6 @@ public:
*/ */
void insert(zkutil::ZooKeeperPtr zookeeper, LogEntryPtr & entry); void insert(zkutil::ZooKeeperPtr zookeeper, LogEntryPtr & entry);
/** Delete the action with the specified part (as new_part_name) from the queue.
* Called for unreachable actions in the queue - old lost parts.
*/
bool remove(zkutil::ZooKeeperPtr zookeeper, const String & part_name);
/** Load (initialize) a queue from ZooKeeper (/replicas/me/queue/). /** Load (initialize) a queue from ZooKeeper (/replicas/me/queue/).
* If queue was not empty load() would not load duplicate records. * If queue was not empty load() would not load duplicate records.
* return true, if we update queue. * return true, if we update queue.
@ -378,6 +373,11 @@ public:
/// Checks that part is already in virtual parts /// Checks that part is already in virtual parts
bool isVirtualPart(const MergeTreeData::DataPartPtr & data_part) const; bool isVirtualPart(const MergeTreeData::DataPartPtr & data_part) const;
/// Check that part produced by some entry in queue and get source parts for it.
/// If there are several entries return largest source_parts set. This rarely possible
/// for example after replica clone.
bool checkPartInQueueAndGetSourceParts(const String & part_name, Strings & source_parts) const;
/// Check that part isn't in currently generating parts and isn't covered by them and add it to future_parts. /// Check that part isn't in currently generating parts and isn't covered by them and add it to future_parts.
/// Locks queue's mutex. /// Locks queue's mutex.
bool addFuturePartIfNotCoveredByThem(const String & part_name, LogEntry & entry, String & reject_reason); bool addFuturePartIfNotCoveredByThem(const String & part_name, LogEntry & entry, String & reject_reason);

View File

@ -0,0 +1,720 @@
#include "MaterializedPostgreSQLConsumer.h"
#include "StorageMaterializedPostgreSQL.h"
#include <Columns/ColumnNullable.h>
#include <Common/hex.h>
#include <DataStreams/copyData.h>
#include <DataStreams/OneBlockInputStream.h>
#include <DataTypes/DataTypeNullable.h>
#include <Interpreters/Context.h>
#include <Interpreters/InterpreterInsertQuery.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
MaterializedPostgreSQLConsumer::MaterializedPostgreSQLConsumer(
ContextPtr context_,
std::shared_ptr<postgres::Connection> connection_,
const std::string & replication_slot_name_,
const std::string & publication_name_,
const std::string & start_lsn,
const size_t max_block_size_,
bool allow_automatic_update_,
Storages storages_)
: log(&Poco::Logger::get("PostgreSQLReaplicaConsumer"))
, context(context_)
, replication_slot_name(replication_slot_name_)
, publication_name(publication_name_)
, connection(connection_)
, current_lsn(start_lsn)
, lsn_value(getLSNValue(start_lsn))
, max_block_size(max_block_size_)
, allow_automatic_update(allow_automatic_update_)
, storages(storages_)
{
final_lsn = start_lsn;
auto tx = std::make_shared<pqxx::nontransaction>(connection->getRef());
current_lsn = advanceLSN(tx);
LOG_TRACE(log, "Starting replication. LSN: {} (last: {})", getLSNValue(current_lsn), getLSNValue(final_lsn));
tx->commit();
for (const auto & [table_name, storage] : storages)
{
buffers.emplace(table_name, Buffer(storage));
}
}
void MaterializedPostgreSQLConsumer::Buffer::createEmptyBuffer(StoragePtr storage)
{
const auto storage_metadata = storage->getInMemoryMetadataPtr();
const Block sample_block = storage_metadata->getSampleBlock();
/// Need to clear type, because in description.init() the types are appended (emplace_back)
description.types.clear();
description.init(sample_block);
columns = description.sample_block.cloneEmptyColumns();
const auto & storage_columns = storage_metadata->getColumns().getAllPhysical();
auto insert_columns = std::make_shared<ASTExpressionList>();
auto table_id = storage->getStorageID();
LOG_TRACE(&Poco::Logger::get("MaterializedPostgreSQLBuffer"), "New buffer for table {}.{} ({}), structure: {}",
table_id.database_name, table_id.table_name, toString(table_id.uuid), sample_block.dumpStructure());
assert(description.sample_block.columns() == storage_columns.size());
size_t idx = 0;
for (const auto & column : storage_columns)
{
if (description.types[idx].first == ExternalResultDescription::ValueType::vtArray)
preparePostgreSQLArrayInfo(array_info, idx, description.sample_block.getByPosition(idx).type);
idx++;
insert_columns->children.emplace_back(std::make_shared<ASTIdentifier>(column.name));
}
columnsAST = std::move(insert_columns);
}
void MaterializedPostgreSQLConsumer::insertValue(Buffer & buffer, const std::string & value, size_t column_idx)
{
const auto & sample = buffer.description.sample_block.getByPosition(column_idx);
bool is_nullable = buffer.description.types[column_idx].second;
if (is_nullable)
{
ColumnNullable & column_nullable = assert_cast<ColumnNullable &>(*buffer.columns[column_idx]);
const auto & data_type = assert_cast<const DataTypeNullable &>(*sample.type);
insertPostgreSQLValue(
column_nullable.getNestedColumn(), value,
buffer.description.types[column_idx].first, data_type.getNestedType(), buffer.array_info, column_idx);
column_nullable.getNullMapData().emplace_back(0);
}
else
{
insertPostgreSQLValue(
*buffer.columns[column_idx], value,
buffer.description.types[column_idx].first, sample.type,
buffer.array_info, column_idx);
}
}
void MaterializedPostgreSQLConsumer::insertDefaultValue(Buffer & buffer, size_t column_idx)
{
const auto & sample = buffer.description.sample_block.getByPosition(column_idx);
insertDefaultPostgreSQLValue(*buffer.columns[column_idx], *sample.column);
}
void MaterializedPostgreSQLConsumer::readString(const char * message, size_t & pos, size_t size, String & result)
{
assert(size > pos + 2);
char current = unhex2(message + pos);
pos += 2;
while (pos < size && current != '\0')
{
result += current;
current = unhex2(message + pos);
pos += 2;
}
}
template<typename T>
T MaterializedPostgreSQLConsumer::unhexN(const char * message, size_t pos, size_t n)
{
T result = 0;
for (size_t i = 0; i < n; ++i)
{
if (i) result <<= 8;
result |= UInt32(unhex2(message + pos + 2 * i));
}
return result;
}
Int64 MaterializedPostgreSQLConsumer::readInt64(const char * message, size_t & pos, [[maybe_unused]] size_t size)
{
assert(size >= pos + 16);
Int64 result = unhexN<Int64>(message, pos, 8);
pos += 16;
return result;
}
Int32 MaterializedPostgreSQLConsumer::readInt32(const char * message, size_t & pos, [[maybe_unused]] size_t size)
{
assert(size >= pos + 8);
Int32 result = unhexN<Int32>(message, pos, 4);
pos += 8;
return result;
}
Int16 MaterializedPostgreSQLConsumer::readInt16(const char * message, size_t & pos, [[maybe_unused]] size_t size)
{
assert(size >= pos + 4);
Int16 result = unhexN<Int16>(message, pos, 2);
pos += 4;
return result;
}
Int8 MaterializedPostgreSQLConsumer::readInt8(const char * message, size_t & pos, [[maybe_unused]] size_t size)
{
assert(size >= pos + 2);
Int8 result = unhex2(message + pos);
pos += 2;
return result;
}
void MaterializedPostgreSQLConsumer::readTupleData(
Buffer & buffer, const char * message, size_t & pos, [[maybe_unused]] size_t size, PostgreSQLQuery type, bool old_value)
{
Int16 num_columns = readInt16(message, pos, size);
auto proccess_column_value = [&](Int8 identifier, Int16 column_idx)
{
switch (identifier)
{
case 'n': /// NULL
{
insertDefaultValue(buffer, column_idx);
break;
}
case 't': /// Text formatted value
{
Int32 col_len = readInt32(message, pos, size);
String value;
for (Int32 i = 0; i < col_len; ++i)
value += readInt8(message, pos, size);
insertValue(buffer, value, column_idx);
break;
}
case 'u': /// TOAST value && unchanged at the same time. Actual value is not sent.
{
/// TOAST values are not supported. (TOAST values are values that are considered in postgres
/// to be too large to be stored directly)
LOG_WARNING(log, "Got TOAST value, which is not supported, default value will be used instead.");
insertDefaultValue(buffer, column_idx);
break;
}
}
};
for (int column_idx = 0; column_idx < num_columns; ++column_idx)
proccess_column_value(readInt8(message, pos, size), column_idx);
switch (type)
{
case PostgreSQLQuery::INSERT:
{
buffer.columns[num_columns]->insert(Int8(1));
buffer.columns[num_columns + 1]->insert(lsn_value);
break;
}
case PostgreSQLQuery::DELETE:
{
buffer.columns[num_columns]->insert(Int8(-1));
buffer.columns[num_columns + 1]->insert(lsn_value);
break;
}
case PostgreSQLQuery::UPDATE:
{
/// Process old value in case changed value is a primary key.
if (old_value)
buffer.columns[num_columns]->insert(Int8(-1));
else
buffer.columns[num_columns]->insert(Int8(1));
buffer.columns[num_columns + 1]->insert(lsn_value);
break;
}
}
}
/// https://www.postgresql.org/docs/13/protocol-logicalrep-message-formats.html
void MaterializedPostgreSQLConsumer::processReplicationMessage(const char * replication_message, size_t size)
{
/// Skip '\x'
size_t pos = 2;
char type = readInt8(replication_message, pos, size);
// LOG_DEBUG(log, "Message type: {}, lsn string: {}, lsn value {}", type, current_lsn, lsn_value);
switch (type)
{
case 'B': // Begin
{
readInt64(replication_message, pos, size); /// Int64 transaction end lsn
readInt64(replication_message, pos, size); /// Int64 transaction commit timestamp
break;
}
case 'I': // Insert
{
Int32 relation_id = readInt32(replication_message, pos, size);
if (!isSyncAllowed(relation_id))
return;
Int8 new_tuple = readInt8(replication_message, pos, size);
const auto & table_name = relation_id_to_name[relation_id];
auto buffer = buffers.find(table_name);
assert(buffer != buffers.end());
if (new_tuple)
readTupleData(buffer->second, replication_message, pos, size, PostgreSQLQuery::INSERT);
break;
}
case 'U': // Update
{
Int32 relation_id = readInt32(replication_message, pos, size);
if (!isSyncAllowed(relation_id))
return;
const auto & table_name = relation_id_to_name[relation_id];
auto buffer = buffers.find(table_name);
assert(buffer != buffers.end());
auto proccess_identifier = [&](Int8 identifier) -> bool
{
bool read_next = true;
switch (identifier)
{
/// Only if changed column(s) are part of replica identity index (or primary keys if they are used instead).
/// In this case, first comes a tuple with old replica identity indexes and all other values will come as
/// nulls. Then comes a full new row.
case 'K': [[fallthrough]];
/// Old row. Only if replica identity is set to full. Does not really make sense to use it as
/// it is much more efficient to use replica identity index, but support all possible cases.
case 'O':
{
readTupleData(buffer->second, replication_message, pos, size, PostgreSQLQuery::UPDATE, true);
break;
}
case 'N':
{
/// New row.
readTupleData(buffer->second, replication_message, pos, size, PostgreSQLQuery::UPDATE);
read_next = false;
break;
}
}
return read_next;
};
/// Read either 'K' or 'O'. Never both of them. Also possible not to get both of them.
bool read_next = proccess_identifier(readInt8(replication_message, pos, size));
/// 'N'. Always present, but could come in place of 'K' and 'O'.
if (read_next)
proccess_identifier(readInt8(replication_message, pos, size));
break;
}
case 'D': // Delete
{
Int32 relation_id = readInt32(replication_message, pos, size);
if (!isSyncAllowed(relation_id))
return;
/// 0 or 1 if replica identity is set to full. For now only default replica identity is supported (with primary keys).
readInt8(replication_message, pos, size);
const auto & table_name = relation_id_to_name[relation_id];
auto buffer = buffers.find(table_name);
assert(buffer != buffers.end());
readTupleData(buffer->second, replication_message, pos, size, PostgreSQLQuery::DELETE);
break;
}
case 'C': // Commit
{
constexpr size_t unused_flags_len = 1;
constexpr size_t commit_lsn_len = 8;
constexpr size_t transaction_end_lsn_len = 8;
constexpr size_t transaction_commit_timestamp_len = 8;
pos += unused_flags_len + commit_lsn_len + transaction_end_lsn_len + transaction_commit_timestamp_len;
LOG_DEBUG(log, "Current lsn: {} = {}", current_lsn, getLSNValue(current_lsn)); /// Will be removed
final_lsn = current_lsn;
break;
}
case 'R': // Relation
{
Int32 relation_id = readInt32(replication_message, pos, size);
String relation_namespace, relation_name;
readString(replication_message, pos, size, relation_namespace);
readString(replication_message, pos, size, relation_name);
if (!isSyncAllowed(relation_id))
return;
if (storages.find(relation_name) == storages.end())
{
markTableAsSkipped(relation_id, relation_name);
LOG_ERROR(log,
"Storage for table {} does not exist, but is included in replication stream. (Storages number: {})",
relation_name, storages.size());
return;
}
assert(buffers.count(relation_name));
/// 'd' - default (primary key if any)
/// 'n' - nothing
/// 'f' - all columns (set replica identity full)
/// 'i' - user defined index with indisreplident set
/// Only 'd' and 'i' - are supported.
char replica_identity = readInt8(replication_message, pos, size);
if (replica_identity != 'd' && replica_identity != 'i')
{
LOG_WARNING(log,
"Table has replica identity {} - not supported. A table must have a primary key or a replica identity index");
markTableAsSkipped(relation_id, relation_name);
return;
}
Int16 num_columns = readInt16(replication_message, pos, size);
Int32 data_type_id;
Int32 type_modifier; /// For example, n in varchar(n)
bool new_relation_definition = false;
if (schema_data.find(relation_id) == schema_data.end())
{
relation_id_to_name[relation_id] = relation_name;
schema_data.emplace(relation_id, SchemaData(num_columns));
new_relation_definition = true;
}
auto & current_schema_data = schema_data.find(relation_id)->second;
if (current_schema_data.number_of_columns != num_columns)
{
markTableAsSkipped(relation_id, relation_name);
return;
}
for (uint16_t i = 0; i < num_columns; ++i)
{
String column_name;
readInt8(replication_message, pos, size); /// Marks column as part of replica identity index
readString(replication_message, pos, size, column_name);
data_type_id = readInt32(replication_message, pos, size);
type_modifier = readInt32(replication_message, pos, size);
if (new_relation_definition)
{
current_schema_data.column_identifiers.emplace_back(std::make_tuple(data_type_id, type_modifier));
}
else
{
if (current_schema_data.column_identifiers[i].first != data_type_id
|| current_schema_data.column_identifiers[i].second != type_modifier)
{
markTableAsSkipped(relation_id, relation_name);
return;
}
}
}
tables_to_sync.insert(relation_name);
break;
}
case 'O': // Origin
break;
case 'Y': // Type
break;
case 'T': // Truncate
break;
default:
throw Exception(ErrorCodes::LOGICAL_ERROR,
"Unexpected byte1 value {} while parsing replication message", type);
}
}
void MaterializedPostgreSQLConsumer::syncTables()
{
try
{
for (const auto & table_name : tables_to_sync)
{
auto & buffer = buffers.find(table_name)->second;
Block result_rows = buffer.description.sample_block.cloneWithColumns(std::move(buffer.columns));
if (result_rows.rows())
{
auto storage = storages[table_name];
auto insert_context = Context::createCopy(context);
insert_context->setInternalQuery(true);
auto insert = std::make_shared<ASTInsertQuery>();
insert->table_id = storage->getStorageID();
insert->columns = buffer.columnsAST;
InterpreterInsertQuery interpreter(insert, insert_context, true);
auto block_io = interpreter.execute();
OneBlockInputStream input(result_rows);
assertBlocksHaveEqualStructure(input.getHeader(), block_io.out->getHeader(), "postgresql replica table sync");
copyData(input, *block_io.out);
buffer.columns = buffer.description.sample_block.cloneEmptyColumns();
}
}
LOG_DEBUG(log, "Table sync end for {} tables, last lsn: {} = {}, (attempted lsn {})", tables_to_sync.size(), current_lsn, getLSNValue(current_lsn), getLSNValue(final_lsn));
auto tx = std::make_shared<pqxx::nontransaction>(connection->getRef());
current_lsn = advanceLSN(tx);
tables_to_sync.clear();
tx->commit();
}
catch (...)
{
tryLogCurrentException(__PRETTY_FUNCTION__);
}
}
String MaterializedPostgreSQLConsumer::advanceLSN(std::shared_ptr<pqxx::nontransaction> tx)
{
std::string query_str = fmt::format("SELECT end_lsn FROM pg_replication_slot_advance('{}', '{}')", replication_slot_name, final_lsn);
pqxx::result result{tx->exec(query_str)};
final_lsn = result[0][0].as<std::string>();
LOG_TRACE(log, "Advanced LSN up to: {}", getLSNValue(final_lsn));
return final_lsn;
}
/// Sync for some table might not be allowed if:
/// 1. Table schema changed and might break synchronization.
/// 2. There is no storage for this table. (As a result of some exception or incorrect pg_publication)
bool MaterializedPostgreSQLConsumer::isSyncAllowed(Int32 relation_id)
{
auto table_with_lsn = skip_list.find(relation_id);
/// Table is not present in a skip list - allow synchronization.
if (table_with_lsn == skip_list.end())
return true;
const auto & table_start_lsn = table_with_lsn->second;
/// Table is in a skip list and has not yet received a valid lsn == it has not been reloaded.
if (table_start_lsn.empty())
return false;
/// Table has received a valid lsn, but it is not yet at a position, from which synchronization is
/// allowed. It is allowed only after lsn position, returned with snapshot, from which
/// table was reloaded.
if (getLSNValue(current_lsn) >= getLSNValue(table_start_lsn))
{
LOG_TRACE(log, "Synchronization is resumed for table: {} (start_lsn: {})",
relation_id_to_name[relation_id], table_start_lsn);
skip_list.erase(table_with_lsn);
return true;
}
return false;
}
void MaterializedPostgreSQLConsumer::markTableAsSkipped(Int32 relation_id, const String & relation_name)
{
/// Empty lsn string means - continue waiting for valid lsn.
skip_list.insert({relation_id, ""});
if (storages.count(relation_name))
{
/// Erase cached schema identifiers. It will be updated again once table is allowed back into replication stream
/// and it receives first data after update.
schema_data.erase(relation_id);
/// Clear table buffer.
auto & buffer = buffers.find(relation_name)->second;
buffer.columns = buffer.description.sample_block.cloneEmptyColumns();
if (allow_automatic_update)
LOG_TRACE(log, "Table {} (relation_id: {}) is skipped temporarily. It will be reloaded in the background", relation_name, relation_id);
else
LOG_WARNING(log, "Table {} (relation_id: {}) is skipped, because table schema has changed", relation_name, relation_id);
}
}
/// Read binary changes from replication slot via COPY command (starting from current lsn in a slot).
bool MaterializedPostgreSQLConsumer::readFromReplicationSlot()
{
bool slot_empty = true;
try
{
auto tx = std::make_shared<pqxx::nontransaction>(connection->getRef());
/// Read up to max_block_size rows changes (upto_n_changes parameter). It might return larger number as the limit
/// is checked only after each transaction block.
/// Returns less than max_block_changes, if reached end of wal. Sync to table in this case.
std::string query_str = fmt::format(
"select lsn, data FROM pg_logical_slot_peek_binary_changes("
"'{}', NULL, {}, 'publication_names', '{}', 'proto_version', '1')",
replication_slot_name, max_block_size, publication_name);
auto stream{pqxx::stream_from::query(*tx, query_str)};
while (true)
{
const std::vector<pqxx::zview> * row{stream.read_row()};
if (!row)
{
stream.complete();
if (slot_empty)
{
tx->commit();
return false;
}
break;
}
slot_empty = false;
current_lsn = (*row)[0];
lsn_value = getLSNValue(current_lsn);
processReplicationMessage((*row)[1].c_str(), (*row)[1].size());
}
}
catch (const Exception &)
{
tryLogCurrentException(__PRETTY_FUNCTION__);
return false;
}
catch (const pqxx::broken_connection & e)
{
LOG_ERROR(log, "Connection error: {}", e.what());
connection->tryUpdateConnection();
return false;
}
catch (const pqxx::sql_error & e)
{
/// For now sql replication interface is used and it has the problem that it registers relcache
/// callbacks on each pg_logical_slot_get_changes and there is no way to invalidate them:
/// https://github.com/postgres/postgres/blob/master/src/backend/replication/pgoutput/pgoutput.c#L1128
/// So at some point will get out of limit and then they will be cleaned.
std::string error_message = e.what();
if (error_message.find("out of relcache_callback_list slots") == std::string::npos)
tryLogCurrentException(__PRETTY_FUNCTION__);
return false;
}
catch (const pqxx::conversion_error & e)
{
LOG_ERROR(log, "Conversion error: {}", e.what());
return false;
}
catch (const pqxx::in_doubt_error & e)
{
LOG_ERROR(log, "PostgreSQL library has some doubts: {}", e.what());
return false;
}
catch (const pqxx::internal_error & e)
{
LOG_ERROR(log, "PostgreSQL library internal error: {}", e.what());
return false;
}
catch (...)
{
/// Since reading is done from a background task, it is important to catch any possible error
/// in order to understand why something does not work.
try
{
std::rethrow_exception(std::current_exception());
}
catch (const std::exception& e)
{
LOG_ERROR(log, "Unexpected error: {}", e.what());
}
}
if (!tables_to_sync.empty())
syncTables();
return true;
}
bool MaterializedPostgreSQLConsumer::consume(std::vector<std::pair<Int32, String>> & skipped_tables)
{
/// Check if there are tables, which are skipped from being updated by changes from replication stream,
/// because schema changes were detected. Update them, if it is allowed.
if (allow_automatic_update && !skip_list.empty())
{
for (const auto & [relation_id, lsn] : skip_list)
{
/// Non-empty lsn in this place means that table was already updated, but no changes for that table were
/// received in a previous stream. A table is removed from skip list only when there came
/// changes for table with lsn higher than lsn of snapshot, from which table was reloaded. Since table
/// reaload and reading from replication stream are done in the same thread, no lsn will be skipped
/// between these two events.
if (lsn.empty())
skipped_tables.emplace_back(std::make_pair(relation_id, relation_id_to_name[relation_id]));
}
}
/// Read up to max_block_size changed (approximately - in same cases might be more).
/// false: no data was read, reschedule.
/// true: some data was read, schedule as soon as possible.
return readFromReplicationSlot();
}
void MaterializedPostgreSQLConsumer::updateNested(const String & table_name, StoragePtr nested_storage, Int32 table_id, const String & table_start_lsn)
{
/// Cache new pointer to replacingMergeTree table.
storages[table_name] = nested_storage;
/// Create a new empty buffer (with updated metadata), where data is first loaded before syncing into actual table.
auto & buffer = buffers.find(table_name)->second;
buffer.createEmptyBuffer(nested_storage);
/// Set start position to valid lsn. Before it was an empty string. Further read for table allowed, if it has a valid lsn.
skip_list[table_id] = table_start_lsn;
}
}

View File

@ -0,0 +1,146 @@
#pragma once
#include <Core/PostgreSQL/Connection.h>
#include <Core/PostgreSQL/insertPostgreSQLValue.h>
#include <Core/BackgroundSchedulePool.h>
#include <Core/Names.h>
#include <common/logger_useful.h>
#include <Storages/IStorage.h>
#include <DataStreams/OneBlockInputStream.h>
#include <Parsers/ASTExpressionList.h>
namespace DB
{
class MaterializedPostgreSQLConsumer
{
public:
using Storages = std::unordered_map<String, StoragePtr>;
MaterializedPostgreSQLConsumer(
ContextPtr context_,
std::shared_ptr<postgres::Connection> connection_,
const String & replication_slot_name_,
const String & publication_name_,
const String & start_lsn,
const size_t max_block_size_,
bool allow_automatic_update_,
Storages storages_);
bool consume(std::vector<std::pair<Int32, String>> & skipped_tables);
/// Called from reloadFromSnapshot by replication handler. This method is needed to move a table back into synchronization
/// process if it was skipped due to schema changes.
void updateNested(const String & table_name, StoragePtr nested_storage, Int32 table_id, const String & table_start_lsn);
private:
/// Read approximarely up to max_block_size changes from WAL.
bool readFromReplicationSlot();
void syncTables();
String advanceLSN(std::shared_ptr<pqxx::nontransaction> ntx);
void processReplicationMessage(const char * replication_message, size_t size);
bool isSyncAllowed(Int32 relation_id);
struct Buffer
{
ExternalResultDescription description;
MutableColumns columns;
/// Needed to pass to insert query columns list in syncTables().
std::shared_ptr<ASTExpressionList> columnsAST;
/// Needed for insertPostgreSQLValue() method to parse array
std::unordered_map<size_t, PostgreSQLArrayInfo> array_info;
Buffer(StoragePtr storage) { createEmptyBuffer(storage); }
void createEmptyBuffer(StoragePtr storage);
};
using Buffers = std::unordered_map<String, Buffer>;
static void insertDefaultValue(Buffer & buffer, size_t column_idx);
static void insertValue(Buffer & buffer, const std::string & value, size_t column_idx);
enum class PostgreSQLQuery
{
INSERT,
UPDATE,
DELETE
};
void readTupleData(Buffer & buffer, const char * message, size_t & pos, size_t size, PostgreSQLQuery type, bool old_value = false);
template<typename T>
static T unhexN(const char * message, size_t pos, size_t n);
static void readString(const char * message, size_t & pos, size_t size, String & result);
static Int64 readInt64(const char * message, size_t & pos, size_t size);
static Int32 readInt32(const char * message, size_t & pos, size_t size);
static Int16 readInt16(const char * message, size_t & pos, size_t size);
static Int8 readInt8(const char * message, size_t & pos, size_t size);
void markTableAsSkipped(Int32 relation_id, const String & relation_name);
/// lsn - log sequnce nuumber, like wal offset (64 bit).
Int64 getLSNValue(const std::string & lsn)
{
UInt32 upper_half, lower_half;
std::sscanf(lsn.data(), "%X/%X", &upper_half, &lower_half);
return (static_cast<Int64>(upper_half) << 32) + lower_half;
}
Poco::Logger * log;
ContextPtr context;
const std::string replication_slot_name, publication_name;
std::shared_ptr<postgres::Connection> connection;
std::string current_lsn, final_lsn;
/// current_lsn converted from String to Int64 via getLSNValue().
UInt64 lsn_value;
const size_t max_block_size;
bool allow_automatic_update;
String table_to_insert;
/// List of tables which need to be synced after last replication stream.
std::unordered_set<std::string> tables_to_sync;
Storages storages;
Buffers buffers;
std::unordered_map<Int32, String> relation_id_to_name;
struct SchemaData
{
Int16 number_of_columns;
/// data_type_id and type_modifier
std::vector<std::pair<Int32, Int32>> column_identifiers;
SchemaData(Int16 number_of_columns_) : number_of_columns(number_of_columns_) {}
};
/// Cache for table schema data to be able to detect schema changes, because ddl is not
/// replicated with postgresql logical replication protocol, but some table schema info
/// is received if it is the first time we received dml message for given relation in current session or
/// if relation definition has changed since the last relation definition message.
std::unordered_map<Int32, SchemaData> schema_data;
/// skip_list contains relation ids for tables on which ddl was performed, which can break synchronization.
/// This breaking changes are detected in replication stream in according replication message and table is added to skip list.
/// After it is finished, a temporary replication slot is created with 'export snapshot' option, and start_lsn is returned.
/// Skipped tables are reloaded from snapshot (nested tables are also updated). Afterwards, if a replication message is
/// related to a table in a skip_list, we compare current lsn with start_lsn, which was returned with according snapshot.
/// If current_lsn >= table_start_lsn, we can safely remove table from skip list and continue its synchronization.
/// No needed message, related to reloaded table will be missed, because messages are not consumed in the meantime,
/// i.e. we will not miss the first start_lsn position for reloaded table.
std::unordered_map<Int32, String> skip_list;
};
}

View File

@ -0,0 +1,45 @@
#include "MaterializedPostgreSQLSettings.h"
#if USE_LIBPQXX
#include <Parsers/ASTCreateQuery.h>
#include <Parsers/ASTSetQuery.h>
#include <Parsers/ASTFunction.h>
#include <Common/Exception.h>
namespace DB
{
namespace ErrorCodes
{
extern const int UNKNOWN_SETTING;
}
IMPLEMENT_SETTINGS_TRAITS(MaterializedPostgreSQLSettingsTraits, LIST_OF_MATERIALIZED_POSTGRESQL_SETTINGS)
void MaterializedPostgreSQLSettings::loadFromQuery(ASTStorage & storage_def)
{
if (storage_def.settings)
{
try
{
applyChanges(storage_def.settings->changes);
}
catch (Exception & e)
{
if (e.code() == ErrorCodes::UNKNOWN_SETTING)
e.addMessage("for storage " + storage_def.engine->name);
throw;
}
}
else
{
auto settings_ast = std::make_shared<ASTSetQuery>();
settings_ast->is_standalone = false;
storage_def.set(storage_def.settings, settings_ast);
}
}
}
#endif

View File

@ -0,0 +1,30 @@
#pragma once
#if !defined(ARCADIA_BUILD)
#include "config_core.h"
#endif
#if USE_LIBPQXX
#include <Core/BaseSettings.h>
namespace DB
{
class ASTStorage;
#define LIST_OF_MATERIALIZED_POSTGRESQL_SETTINGS(M) \
M(UInt64, materialized_postgresql_max_block_size, 65536, "Number of row collected before flushing data into table.", 0) \
M(String, materialized_postgresql_tables_list, "", "List of tables for MaterializedPostgreSQL database engine", 0) \
M(Bool, materialized_postgresql_allow_automatic_update, 0, "Allow to reload table in the background, when schema changes are detected", 0) \
DECLARE_SETTINGS_TRAITS(MaterializedPostgreSQLSettingsTraits, LIST_OF_MATERIALIZED_POSTGRESQL_SETTINGS)
struct MaterializedPostgreSQLSettings : public BaseSettings<MaterializedPostgreSQLSettingsTraits>
{
void loadFromQuery(ASTStorage & storage_def);
};
}
#endif

View File

@ -0,0 +1,629 @@
#include "PostgreSQLReplicationHandler.h"
#include <DataStreams/PostgreSQLBlockInputStream.h>
#include <Databases/PostgreSQL/fetchPostgreSQLTableStructure.h>
#include <Storages/PostgreSQL/StorageMaterializedPostgreSQL.h>
#include <Interpreters/InterpreterDropQuery.h>
#include <Interpreters/InterpreterInsertQuery.h>
#include <Interpreters/InterpreterRenameQuery.h>
#include <Common/setThreadName.h>
#include <Interpreters/Context.h>
#include <DataStreams/copyData.h>
namespace DB
{
static const auto RESCHEDULE_MS = 500;
static const auto BACKOFF_TRESHOLD_MS = 10000;
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int BAD_ARGUMENTS;
}
PostgreSQLReplicationHandler::PostgreSQLReplicationHandler(
const String & replication_identifier,
const String & remote_database_name_,
const String & current_database_name_,
const postgres::ConnectionInfo & connection_info_,
ContextPtr context_,
const size_t max_block_size_,
bool allow_automatic_update_,
bool is_materialized_postgresql_database_,
const String tables_list_)
: log(&Poco::Logger::get("PostgreSQLReplicationHandler"))
, context(context_)
, remote_database_name(remote_database_name_)
, current_database_name(current_database_name_)
, connection_info(connection_info_)
, max_block_size(max_block_size_)
, allow_automatic_update(allow_automatic_update_)
, is_materialized_postgresql_database(is_materialized_postgresql_database_)
, tables_list(tables_list_)
, connection(std::make_shared<postgres::Connection>(connection_info_))
, milliseconds_to_wait(RESCHEDULE_MS)
{
replication_slot = fmt::format("{}_ch_replication_slot", replication_identifier);
publication_name = fmt::format("{}_ch_publication", replication_identifier);
startup_task = context->getSchedulePool().createTask("PostgreSQLReplicaStartup", [this]{ waitConnectionAndStart(); });
consumer_task = context->getSchedulePool().createTask("PostgreSQLReplicaStartup", [this]{ consumerFunc(); });
}
void PostgreSQLReplicationHandler::addStorage(const std::string & table_name, StorageMaterializedPostgreSQL * storage)
{
materialized_storages[table_name] = storage;
}
void PostgreSQLReplicationHandler::startup()
{
startup_task->activateAndSchedule();
}
void PostgreSQLReplicationHandler::waitConnectionAndStart()
{
try
{
connection->connect(); /// Will throw pqxx::broken_connection if no connection at the moment
startSynchronization(false);
}
catch (const pqxx::broken_connection & pqxx_error)
{
LOG_ERROR(log, "Unable to set up connection. Reconnection attempt will continue. Error message: {}", pqxx_error.what());
startup_task->scheduleAfter(RESCHEDULE_MS);
}
catch (...)
{
tryLogCurrentException(__PRETTY_FUNCTION__);
}
}
void PostgreSQLReplicationHandler::shutdown()
{
stop_synchronization.store(true);
startup_task->deactivate();
consumer_task->deactivate();
}
void PostgreSQLReplicationHandler::startSynchronization(bool throw_on_error)
{
{
pqxx::work tx(connection->getRef());
createPublicationIfNeeded(tx);
tx.commit();
}
postgres::Connection replication_connection(connection_info, /* replication */true);
pqxx::nontransaction tx(replication_connection.getRef());
/// List of nested tables (table_name -> nested_storage), which is passed to replication consumer.
std::unordered_map<String, StoragePtr> nested_storages;
/// snapshot_name is initialized only if a new replication slot is created.
/// start_lsn is initialized in two places:
/// 1. if replication slot does not exist, start_lsn will be returned with its creation return parameters;
/// 2. if replication slot already exist, start_lsn is read from pg_replication_slots as
/// `confirmed_flush_lsn` - the address (LSN) up to which the logical slot's consumer has confirmed receiving data.
/// Data older than this is not available anymore.
/// TODO: more tests
String snapshot_name, start_lsn;
auto initial_sync = [&]()
{
createReplicationSlot(tx, start_lsn, snapshot_name);
for (const auto & [table_name, storage] : materialized_storages)
{
try
{
nested_storages[table_name] = loadFromSnapshot(snapshot_name, table_name, storage->as <StorageMaterializedPostgreSQL>());
}
catch (Exception & e)
{
e.addMessage("while loading table {}.{}", remote_database_name, table_name);
tryLogCurrentException(__PRETTY_FUNCTION__);
/// Throw in case of single MaterializedPostgreSQL storage, because initial setup is done immediately
/// (unlike database engine where it is done in a separate thread).
if (throw_on_error)
throw;
}
}
};
/// There is one replication slot for each replication handler. In case of MaterializedPostgreSQL database engine,
/// there is one replication slot per database. Its lifetime must be equal to the lifetime of replication handler.
/// Recreation of a replication slot imposes reloading of all tables.
if (!isReplicationSlotExist(tx, start_lsn, /* temporary */false))
{
initial_sync();
}
/// Replication slot depends on publication, so if replication slot exists and new
/// publication was just created - drop that replication slot and start from scratch.
/// TODO: tests
else if (new_publication_created)
{
dropReplicationSlot(tx);
initial_sync();
}
/// Synchronization and initial load already took place - do not create any new tables, just fetch StoragePtr's
/// and pass them to replication consumer.
else
{
for (const auto & [table_name, storage] : materialized_storages)
{
auto * materialized_storage = storage->as <StorageMaterializedPostgreSQL>();
try
{
/// Try load nested table, set materialized table metadata.
nested_storages[table_name] = materialized_storage->prepare();
}
catch (Exception & e)
{
e.addMessage("while loading table {}.{}", remote_database_name, table_name);
tryLogCurrentException(__PRETTY_FUNCTION__);
if (throw_on_error)
throw;
}
}
LOG_TRACE(log, "Loaded {} tables", nested_storages.size());
}
tx.commit();
/// Pass current connection to consumer. It is not std::moved implicitly, but a shared_ptr is passed.
/// Consumer and replication handler are always executed one after another (not concurrently) and share the same connection.
/// (Apart from the case, when shutdownFinal is called).
/// Handler uses it only for loadFromSnapshot and shutdown methods.
consumer = std::make_shared<MaterializedPostgreSQLConsumer>(
context,
connection,
replication_slot,
publication_name,
start_lsn,
max_block_size,
allow_automatic_update,
nested_storages);
consumer_task->activateAndSchedule();
/// Do not rely anymore on saved storage pointers.
materialized_storages.clear();
}
StoragePtr PostgreSQLReplicationHandler::loadFromSnapshot(String & snapshot_name, const String & table_name,
StorageMaterializedPostgreSQL * materialized_storage)
{
auto tx = std::make_shared<pqxx::ReplicationTransaction>(connection->getRef());
std::string query_str = fmt::format("SET TRANSACTION SNAPSHOT '{}'", snapshot_name);
tx->exec(query_str);
/// Load from snapshot, which will show table state before creation of replication slot.
/// Already connected to needed database, no need to add it to query.
query_str = fmt::format("SELECT * FROM {}", table_name);
materialized_storage->createNestedIfNeeded(fetchTableStructure(*tx, table_name));
auto nested_storage = materialized_storage->getNested();
auto insert = std::make_shared<ASTInsertQuery>();
insert->table_id = nested_storage->getStorageID();
auto insert_context = materialized_storage->getNestedTableContext();
InterpreterInsertQuery interpreter(insert, insert_context);
auto block_io = interpreter.execute();
const StorageInMemoryMetadata & storage_metadata = nested_storage->getInMemoryMetadata();
auto sample_block = storage_metadata.getSampleBlockNonMaterialized();
PostgreSQLTransactionBlockInputStream<pqxx::ReplicationTransaction> input(tx, query_str, sample_block, DEFAULT_BLOCK_SIZE);
assertBlocksHaveEqualStructure(input.getHeader(), block_io.out->getHeader(), "postgresql replica load from snapshot");
copyData(input, *block_io.out);
nested_storage = materialized_storage->prepare();
auto nested_table_id = nested_storage->getStorageID();
LOG_TRACE(log, "Loaded table {}.{} (uuid: {})", nested_table_id.database_name, nested_table_id.table_name, toString(nested_table_id.uuid));
return nested_storage;
}
void PostgreSQLReplicationHandler::consumerFunc()
{
std::vector<std::pair<Int32, String>> skipped_tables;
bool schedule_now = consumer->consume(skipped_tables);
if (!skipped_tables.empty())
{
try
{
reloadFromSnapshot(skipped_tables);
}
catch (...)
{
tryLogCurrentException(__PRETTY_FUNCTION__);
}
}
if (stop_synchronization)
{
LOG_TRACE(log, "Replication thread is stopped");
return;
}
if (schedule_now)
{
milliseconds_to_wait = RESCHEDULE_MS;
consumer_task->schedule();
LOG_DEBUG(log, "Scheduling replication thread: now");
}
else
{
consumer_task->scheduleAfter(milliseconds_to_wait);
if (milliseconds_to_wait < BACKOFF_TRESHOLD_MS)
milliseconds_to_wait *= 2;
LOG_TRACE(log, "Scheduling replication thread: after {} ms", milliseconds_to_wait);
}
}
bool PostgreSQLReplicationHandler::isPublicationExist(pqxx::work & tx)
{
std::string query_str = fmt::format("SELECT exists (SELECT 1 FROM pg_publication WHERE pubname = '{}')", publication_name);
pqxx::result result{tx.exec(query_str)};
assert(!result.empty());
bool publication_exists = (result[0][0].as<std::string>() == "t");
if (publication_exists)
LOG_INFO(log, "Publication {} already exists. Using existing version", publication_name);
return publication_exists;
}
void PostgreSQLReplicationHandler::createPublicationIfNeeded(pqxx::work & tx, bool create_without_check)
{
/// For database engine a publication can be created earlier than in startReplication().
if (new_publication_created)
return;
if (create_without_check || !isPublicationExist(tx))
{
if (tables_list.empty())
{
for (const auto & storage_data : materialized_storages)
{
if (!tables_list.empty())
tables_list += ", ";
tables_list += storage_data.first;
}
}
if (tables_list.empty())
throw Exception(ErrorCodes::LOGICAL_ERROR, "No table found to be replicated");
/// 'ONLY' means just a table, without descendants.
std::string query_str = fmt::format("CREATE PUBLICATION {} FOR TABLE ONLY {}", publication_name, tables_list);
try
{
tx.exec(query_str);
new_publication_created = true;
LOG_TRACE(log, "Created publication {} with tables list: {}", publication_name, tables_list);
}
catch (Exception & e)
{
e.addMessage("while creating pg_publication");
throw;
}
}
}
bool PostgreSQLReplicationHandler::isReplicationSlotExist(pqxx::nontransaction & tx, String & start_lsn, bool temporary)
{
String slot_name;
if (temporary)
slot_name = replication_slot + "_tmp";
else
slot_name = replication_slot;
String query_str = fmt::format("SELECT active, restart_lsn, confirmed_flush_lsn FROM pg_replication_slots WHERE slot_name = '{}'", slot_name);
pqxx::result result{tx.exec(query_str)};
/// Replication slot does not exist
if (result.empty())
return false;
start_lsn = result[0][2].as<std::string>();
LOG_TRACE(log, "Replication slot {} already exists (active: {}). Restart lsn position: {}, confirmed flush lsn: {}",
slot_name, result[0][0].as<bool>(), result[0][1].as<std::string>(), start_lsn);
return true;
}
void PostgreSQLReplicationHandler::createReplicationSlot(
pqxx::nontransaction & tx, String & start_lsn, String & snapshot_name, bool temporary)
{
String query_str, slot_name;
if (temporary)
slot_name = replication_slot + "_tmp";
else
slot_name = replication_slot;
query_str = fmt::format("CREATE_REPLICATION_SLOT {} LOGICAL pgoutput EXPORT_SNAPSHOT", slot_name);
try
{
pqxx::result result{tx.exec(query_str)};
start_lsn = result[0][1].as<std::string>();
snapshot_name = result[0][2].as<std::string>();
LOG_TRACE(log, "Created replication slot: {}, start lsn: {}", replication_slot, start_lsn);
}
catch (Exception & e)
{
e.addMessage("while creating PostgreSQL replication slot {}", slot_name);
throw;
}
}
void PostgreSQLReplicationHandler::dropReplicationSlot(pqxx::nontransaction & tx, bool temporary)
{
std::string slot_name;
if (temporary)
slot_name = replication_slot + "_tmp";
else
slot_name = replication_slot;
std::string query_str = fmt::format("SELECT pg_drop_replication_slot('{}')", slot_name);
tx.exec(query_str);
LOG_TRACE(log, "Dropped replication slot: {}", slot_name);
}
void PostgreSQLReplicationHandler::dropPublication(pqxx::nontransaction & tx)
{
std::string query_str = fmt::format("DROP PUBLICATION IF EXISTS {}", publication_name);
tx.exec(query_str);
}
void PostgreSQLReplicationHandler::shutdownFinal()
{
try
{
shutdown();
connection->execWithRetry([&](pqxx::nontransaction & tx){ dropPublication(tx); });
String last_committed_lsn;
connection->execWithRetry([&](pqxx::nontransaction & tx)
{
if (isReplicationSlotExist(tx, last_committed_lsn, /* temporary */false))
dropReplicationSlot(tx, /* temporary */false);
});
connection->execWithRetry([&](pqxx::nontransaction & tx)
{
if (isReplicationSlotExist(tx, last_committed_lsn, /* temporary */true))
dropReplicationSlot(tx, /* temporary */true);
});
}
catch (Exception & e)
{
e.addMessage("while dropping replication slot: {}", replication_slot);
LOG_ERROR(log, "Failed to drop replication slot: {}. It must be dropped manually.", replication_slot);
throw;
}
}
/// Used by MaterializedPostgreSQL database engine.
NameSet PostgreSQLReplicationHandler::fetchRequiredTables(postgres::Connection & connection_)
{
pqxx::work tx(connection_.getRef());
bool publication_exists_before_startup = isPublicationExist(tx);
NameSet result_tables;
Strings expected_tables;
if (!tables_list.empty())
{
splitInto<','>(expected_tables, tables_list);
if (expected_tables.empty())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Cannot parse tables list: {}", tables_list);
for (auto & table_name : expected_tables)
boost::trim(table_name);
}
if (publication_exists_before_startup)
{
if (tables_list.empty())
{
/// There is no tables list, but publication already exists, then the expected behaviour
/// is to replicate the whole database. But it could be a server restart, so we can't drop it.
LOG_WARNING(log,
"Publication {} already exists and tables list is empty. Assuming publication is correct",
publication_name);
result_tables = fetchPostgreSQLTablesList(tx);
}
/// Check tables list from publication is the same as expected tables list.
/// If not - drop publication and return expected tables list.
else
{
result_tables = fetchTablesFromPublication(tx);
NameSet diff;
std::set_symmetric_difference(expected_tables.begin(), expected_tables.end(),
result_tables.begin(), result_tables.end(),
std::inserter(diff, diff.begin()));
if (!diff.empty())
{
String diff_tables;
for (const auto & table_name : diff)
{
if (!diff_tables.empty())
diff_tables += ", ";
diff_tables += table_name;
}
LOG_WARNING(log,
"Publication {} already exists, but specified tables list differs from publication tables list in tables: {}",
publication_name, diff_tables);
connection->execWithRetry([&](pqxx::nontransaction & tx_){ dropPublication(tx_); });
}
}
}
else
{
if (!tables_list.empty())
{
tx.commit();
return NameSet(expected_tables.begin(), expected_tables.end());
}
else
{
/// Fetch all tables list from database. Publication does not exist yet, which means
/// that no replication took place. Publication will be created in
/// startSynchronization method.
result_tables = fetchPostgreSQLTablesList(tx);
}
}
tx.commit();
return result_tables;
}
NameSet PostgreSQLReplicationHandler::fetchTablesFromPublication(pqxx::work & tx)
{
std::string query = fmt::format("SELECT tablename FROM pg_publication_tables WHERE pubname = '{}'", publication_name);
std::unordered_set<std::string> tables;
for (auto table_name : tx.stream<std::string>(query))
tables.insert(std::get<0>(table_name));
return tables;
}
PostgreSQLTableStructurePtr PostgreSQLReplicationHandler::fetchTableStructure(
pqxx::ReplicationTransaction & tx, const std::string & table_name) const
{
if (!is_materialized_postgresql_database)
return nullptr;
return std::make_unique<PostgreSQLTableStructure>(fetchPostgreSQLTableStructure(tx, table_name, true, true, true));
}
void PostgreSQLReplicationHandler::reloadFromSnapshot(const std::vector<std::pair<Int32, String>> & relation_data)
{
/// If table schema has changed, the table stops consuming changes from replication stream.
/// If `allow_automatic_update` is true, create a new table in the background, load new table schema
/// and all data from scratch. Then execute REPLACE query.
/// This is only allowed for MaterializedPostgreSQL database engine.
try
{
postgres::Connection replication_connection(connection_info, /* replication */true);
pqxx::nontransaction tx(replication_connection.getRef());
String snapshot_name, start_lsn;
if (isReplicationSlotExist(tx, start_lsn, /* temporary */true))
dropReplicationSlot(tx, /* temporary */true);
createReplicationSlot(tx, start_lsn, snapshot_name, /* temporary */true);
for (const auto & [relation_id, table_name] : relation_data)
{
auto storage = DatabaseCatalog::instance().getTable(StorageID(current_database_name, table_name), context);
auto * materialized_storage = storage->as <StorageMaterializedPostgreSQL>();
/// If for some reason this temporary table already exists - also drop it.
auto temp_materialized_storage = materialized_storage->createTemporary();
/// This snapshot is valid up to the end of the transaction, which exported it.
StoragePtr temp_nested_storage = loadFromSnapshot(snapshot_name, table_name,
temp_materialized_storage->as <StorageMaterializedPostgreSQL>());
auto table_id = materialized_storage->getNestedStorageID();
auto temp_table_id = temp_nested_storage->getStorageID();
LOG_TRACE(log, "Starting background update of table {} with table {}",
table_id.getNameForLogs(), temp_table_id.getNameForLogs());
auto ast_rename = std::make_shared<ASTRenameQuery>();
ASTRenameQuery::Element elem
{
ASTRenameQuery::Table{table_id.database_name, table_id.table_name},
ASTRenameQuery::Table{temp_table_id.database_name, temp_table_id.table_name}
};
ast_rename->elements.push_back(std::move(elem));
ast_rename->exchange = true;
auto nested_context = materialized_storage->getNestedTableContext();
try
{
auto materialized_table_lock = materialized_storage->lockForShare(String(), context->getSettingsRef().lock_acquire_timeout);
InterpreterRenameQuery(ast_rename, nested_context).execute();
{
auto nested_storage = DatabaseCatalog::instance().getTable(StorageID(table_id.database_name, table_id.table_name),
nested_context);
auto nested_table_lock = nested_storage->lockForShare(String(), context->getSettingsRef().lock_acquire_timeout);
auto nested_table_id = nested_storage->getStorageID();
materialized_storage->setNestedStorageID(nested_table_id);
nested_storage = materialized_storage->prepare();
auto nested_storage_metadata = nested_storage->getInMemoryMetadataPtr();
auto nested_sample_block = nested_storage_metadata->getSampleBlock();
LOG_TRACE(log, "Updated table {}. New structure: {}",
nested_table_id.getNameForLogs(), nested_sample_block.dumpStructure());
auto materialized_storage_metadata = nested_storage->getInMemoryMetadataPtr();
auto materialized_sample_block = materialized_storage_metadata->getSampleBlock();
assertBlocksHaveEqualStructure(nested_sample_block, materialized_sample_block, "while reloading table in the background");
/// Pass pointer to new nested table into replication consumer, remove current table from skip list and set start lsn position.
consumer->updateNested(table_name, nested_storage, relation_id, start_lsn);
}
LOG_DEBUG(log, "Dropping table {}", temp_table_id.getNameForLogs());
InterpreterDropQuery::executeDropQuery(ASTDropQuery::Kind::Drop, nested_context, nested_context, temp_table_id, true);
}
catch (...)
{
tryLogCurrentException(__PRETTY_FUNCTION__);
}
}
dropReplicationSlot(tx, /* temporary */true);
tx.commit();
}
catch (...)
{
tryLogCurrentException(__PRETTY_FUNCTION__);
}
}
}

View File

@ -0,0 +1,127 @@
#pragma once
#include "MaterializedPostgreSQLConsumer.h"
#include <Databases/PostgreSQL/fetchPostgreSQLTableStructure.h>
#include <Core/PostgreSQL/Utils.h>
namespace DB
{
/// IDEA: There is ALTER PUBLICATION command to dynamically add and remove tables for replicating (the command is transactional).
/// (Probably, if in a replication stream comes a relation name, which does not currently
/// exist in CH, it can be loaded via snapshot while stream is stopped and then comparing wal positions with
/// current lsn and table start lsn.
class StorageMaterializedPostgreSQL;
class PostgreSQLReplicationHandler
{
public:
PostgreSQLReplicationHandler(
const String & replication_identifier,
const String & remote_database_name_,
const String & current_database_name_,
const postgres::ConnectionInfo & connection_info_,
ContextPtr context_,
const size_t max_block_size_,
bool allow_automatic_update_,
bool is_materialized_postgresql_database_,
const String tables_list = "");
/// Activate task to be run from a separate thread: wait until connection is available and call startReplication().
void startup();
/// Stop replication without cleanup.
void shutdown();
/// Clean up replication: remove publication and replication slots.
void shutdownFinal();
/// Add storage pointer to let handler know which tables it needs to keep in sync.
void addStorage(const std::string & table_name, StorageMaterializedPostgreSQL * storage);
/// Fetch list of tables which are going to be replicated. Used for database engine.
NameSet fetchRequiredTables(postgres::Connection & connection_);
/// Start replication setup immediately.
void startSynchronization(bool throw_on_error);
private:
using MaterializedStorages = std::unordered_map<String, StorageMaterializedPostgreSQL *>;
/// Methods to manage Publication.
bool isPublicationExist(pqxx::work & tx);
void createPublicationIfNeeded(pqxx::work & tx, bool create_without_check = false);
NameSet fetchTablesFromPublication(pqxx::work & tx);
void dropPublication(pqxx::nontransaction & ntx);
/// Methods to manage Replication Slots.
bool isReplicationSlotExist(pqxx::nontransaction & tx, String & start_lsn, bool temporary = false);
void createReplicationSlot(pqxx::nontransaction & tx, String & start_lsn, String & snapshot_name, bool temporary = false);
void dropReplicationSlot(pqxx::nontransaction & tx, bool temporary = false);
/// Methods to manage replication.
void waitConnectionAndStart();
void consumerFunc();
StoragePtr loadFromSnapshot(std::string & snapshot_name, const String & table_name, StorageMaterializedPostgreSQL * materialized_storage);
void reloadFromSnapshot(const std::vector<std::pair<Int32, String>> & relation_data);
PostgreSQLTableStructurePtr fetchTableStructure(pqxx::ReplicationTransaction & tx, const String & table_name) const;
Poco::Logger * log;
ContextPtr context;
const String remote_database_name, current_database_name;
/// Connection string and address for logs.
postgres::ConnectionInfo connection_info;
/// max_block_size for replication stream.
const size_t max_block_size;
/// Table structure changes are always tracked. By default, table with changed schema will get into a skip list.
/// This setting allows to reloas table in the background.
bool allow_automatic_update = false;
/// To distinguish whether current replication handler belongs to a MaterializedPostgreSQL database engine or single storage.
bool is_materialized_postgresql_database;
/// A coma-separated list of tables, which are going to be replicated for database engine. By default, a whole database is replicated.
String tables_list;
String replication_slot, publication_name;
/// Shared between replication_consumer and replication_handler, but never accessed concurrently.
std::shared_ptr<postgres::Connection> connection;
/// Replication consumer. Manages decoding of replication stream and syncing into tables.
std::shared_ptr<MaterializedPostgreSQLConsumer> consumer;
BackgroundSchedulePool::TaskHolder startup_task, consumer_task;
std::atomic<bool> stop_synchronization = false;
/// For database engine there are 2 places where it is checked for publication:
/// 1. to fetch tables list from already created publication when database is loaded
/// 2. at replication startup
bool new_publication_created = false;
/// MaterializedPostgreSQL tables. Used for managing all operations with its internal nested tables.
MaterializedStorages materialized_storages;
UInt64 milliseconds_to_wait;
};
}

View File

@ -0,0 +1,501 @@
#include "StorageMaterializedPostgreSQL.h"
#if USE_LIBPQXX
#include <Common/Macros.h>
#include <Core/Settings.h>
#include <Common/parseAddress.h>
#include <Common/assert_cast.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataStreams/ConvertingBlockInputStream.h>
#include <Formats/FormatFactory.h>
#include <Formats/FormatSettings.h>
#include <Processors/Transforms/FilterTransform.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Processors/Sources/SourceFromInputStream.h>
#include <Processors/Pipe.h>
#include <Interpreters/executeQuery.h>
#include <Interpreters/InterpreterSelectQuery.h>
#include <Interpreters/InterpreterDropQuery.h>
#include <Storages/StorageFactory.h>
#include <common/logger_useful.h>
#include <Storages/ReadFinalForExternalReplicaStorage.h>
#include <Core/PostgreSQL/Connection.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int LOGICAL_ERROR;
extern const int BAD_ARGUMENTS;
}
static const auto NESTED_TABLE_SUFFIX = "_nested";
static const auto TMP_SUFFIX = "_tmp";
/// For the case of single storage.
StorageMaterializedPostgreSQL::StorageMaterializedPostgreSQL(
const StorageID & table_id_,
bool is_attach_,
const String & remote_database_name,
const String & remote_table_name_,
const postgres::ConnectionInfo & connection_info,
const StorageInMemoryMetadata & storage_metadata,
ContextPtr context_,
std::unique_ptr<MaterializedPostgreSQLSettings> replication_settings)
: IStorage(table_id_)
, WithContext(context_->getGlobalContext())
, is_materialized_postgresql_database(false)
, has_nested(false)
, nested_context(makeNestedTableContext(context_->getGlobalContext()))
, nested_table_id(StorageID(table_id_.database_name, getNestedTableName()))
, remote_table_name(remote_table_name_)
, is_attach(is_attach_)
{
if (table_id_.uuid == UUIDHelpers::Nil)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Storage MaterializedPostgreSQL is allowed only for Atomic database");
setInMemoryMetadata(storage_metadata);
String replication_identifier = remote_database_name + "_" + remote_table_name_;
replication_handler = std::make_unique<PostgreSQLReplicationHandler>(
replication_identifier,
remote_database_name,
table_id_.database_name,
connection_info,
getContext(),
replication_settings->materialized_postgresql_max_block_size.value,
/* allow_automatic_update */ false, /* is_materialized_postgresql_database */false);
}
/// For the case of MaterializePosgreSQL database engine.
/// It is used when nested ReplacingMergeeTree table has not yet be created by replication thread.
/// In this case this storage can't be used for read queries.
StorageMaterializedPostgreSQL::StorageMaterializedPostgreSQL(const StorageID & table_id_, ContextPtr context_)
: IStorage(table_id_)
, WithContext(context_->getGlobalContext())
, is_materialized_postgresql_database(true)
, has_nested(false)
, nested_context(makeNestedTableContext(context_->getGlobalContext()))
{
}
/// Constructor for MaterializedPostgreSQL table engine - for the case of MaterializePosgreSQL database engine.
/// It is used when nested ReplacingMergeeTree table has already been created by replication thread.
/// This storage is ready to handle read queries.
StorageMaterializedPostgreSQL::StorageMaterializedPostgreSQL(StoragePtr nested_storage_, ContextPtr context_)
: IStorage(nested_storage_->getStorageID())
, WithContext(context_->getGlobalContext())
, is_materialized_postgresql_database(true)
, has_nested(true)
, nested_context(makeNestedTableContext(context_->getGlobalContext()))
, nested_table_id(nested_storage_->getStorageID())
{
setInMemoryMetadata(nested_storage_->getInMemoryMetadata());
}
/// A temporary clone table might be created for current table in order to update its schema and reload
/// all data in the background while current table will still handle read requests.
StoragePtr StorageMaterializedPostgreSQL::createTemporary() const
{
auto table_id = getStorageID();
auto tmp_table_id = StorageID(table_id.database_name, table_id.table_name + TMP_SUFFIX);
/// If for some reason it already exists - drop it.
auto tmp_storage = DatabaseCatalog::instance().tryGetTable(tmp_table_id, nested_context);
if (tmp_storage)
{
LOG_TRACE(&Poco::Logger::get("MaterializedPostgreSQLStorage"), "Temporary table {} already exists, dropping", tmp_table_id.getNameForLogs());
InterpreterDropQuery::executeDropQuery(ASTDropQuery::Kind::Drop, getContext(), getContext(), tmp_table_id, /* no delay */true);
}
auto new_context = Context::createCopy(context);
return StorageMaterializedPostgreSQL::create(tmp_table_id, new_context);
}
StoragePtr StorageMaterializedPostgreSQL::getNested() const
{
return DatabaseCatalog::instance().getTable(getNestedStorageID(), nested_context);
}
StoragePtr StorageMaterializedPostgreSQL::tryGetNested() const
{
return DatabaseCatalog::instance().tryGetTable(getNestedStorageID(), nested_context);
}
String StorageMaterializedPostgreSQL::getNestedTableName() const
{
auto table_id = getStorageID();
if (is_materialized_postgresql_database)
return table_id.table_name;
return toString(table_id.uuid) + NESTED_TABLE_SUFFIX;
}
StorageID StorageMaterializedPostgreSQL::getNestedStorageID() const
{
if (nested_table_id.has_value())
return nested_table_id.value();
auto table_id = getStorageID();
throw Exception(ErrorCodes::LOGICAL_ERROR,
"No storageID found for inner table. ({})", table_id.getNameForLogs());
}
void StorageMaterializedPostgreSQL::createNestedIfNeeded(PostgreSQLTableStructurePtr table_structure)
{
const auto ast_create = getCreateNestedTableQuery(std::move(table_structure));
auto table_id = getStorageID();
auto tmp_nested_table_id = StorageID(table_id.database_name, getNestedTableName());
try
{
InterpreterCreateQuery interpreter(ast_create, nested_context);
interpreter.execute();
auto nested_storage = DatabaseCatalog::instance().getTable(tmp_nested_table_id, nested_context);
/// Save storage_id with correct uuid.
nested_table_id = nested_storage->getStorageID();
}
catch (Exception & e)
{
e.addMessage("while creating nested table: {}", tmp_nested_table_id.getNameForLogs());
tryLogCurrentException(__PRETTY_FUNCTION__);
}
}
std::shared_ptr<Context> StorageMaterializedPostgreSQL::makeNestedTableContext(ContextPtr from_context)
{
auto new_context = Context::createCopy(from_context);
new_context->setInternalQuery(true);
return new_context;
}
StoragePtr StorageMaterializedPostgreSQL::prepare()
{
auto nested_table = getNested();
setInMemoryMetadata(nested_table->getInMemoryMetadata());
has_nested.store(true);
return nested_table;
}
void StorageMaterializedPostgreSQL::startup()
{
/// replication_handler != nullptr only in case of single table engine MaterializedPostgreSQL.
if (replication_handler)
{
replication_handler->addStorage(remote_table_name, this);
if (is_attach)
{
/// In case of attach table use background startup in a separate thread. First wait until connection is reachable,
/// then check for nested table -- it should already be created.
replication_handler->startup();
}
else
{
/// Start synchronization preliminary setup immediately and throw in case of failure.
/// It should be guaranteed that if MaterializedPostgreSQL table was created successfully, then
/// its nested table was also created.
replication_handler->startSynchronization(/* throw_on_error */ true);
}
}
}
void StorageMaterializedPostgreSQL::shutdown()
{
if (replication_handler)
replication_handler->shutdown();
auto nested = getNested();
if (nested)
nested->shutdown();
}
void StorageMaterializedPostgreSQL::dropInnerTableIfAny(bool no_delay, ContextPtr local_context)
{
/// If it is a table with database engine MaterializedPostgreSQL - return, because delition of
/// internal tables is managed there.
if (is_materialized_postgresql_database)
return;
replication_handler->shutdownFinal();
auto nested_table = getNested();
if (nested_table)
InterpreterDropQuery::executeDropQuery(ASTDropQuery::Kind::Drop, getContext(), local_context, getNestedStorageID(), no_delay);
}
NamesAndTypesList StorageMaterializedPostgreSQL::getVirtuals() const
{
return NamesAndTypesList{
{"_sign", std::make_shared<DataTypeInt8>()},
{"_version", std::make_shared<DataTypeUInt64>()}
};
}
Pipe StorageMaterializedPostgreSQL::read(
const Names & column_names,
const StorageMetadataPtr & metadata_snapshot,
SelectQueryInfo & query_info,
ContextPtr context_,
QueryProcessingStage::Enum processed_stage,
size_t max_block_size,
unsigned num_streams)
{
auto materialized_table_lock = lockForShare(String(), context_->getSettingsRef().lock_acquire_timeout);
auto nested_table = getNested();
return readFinalFromNestedStorage(nested_table, column_names, metadata_snapshot,
query_info, context_, processed_stage, max_block_size, num_streams);
}
std::shared_ptr<ASTColumnDeclaration> StorageMaterializedPostgreSQL::getMaterializedColumnsDeclaration(
const String name, const String type, UInt64 default_value)
{
auto column_declaration = std::make_shared<ASTColumnDeclaration>();
column_declaration->name = name;
column_declaration->type = makeASTFunction(type);
column_declaration->default_specifier = "MATERIALIZED";
column_declaration->default_expression = std::make_shared<ASTLiteral>(default_value);
column_declaration->children.emplace_back(column_declaration->type);
column_declaration->children.emplace_back(column_declaration->default_expression);
return column_declaration;
}
ASTPtr StorageMaterializedPostgreSQL::getColumnDeclaration(const DataTypePtr & data_type) const
{
WhichDataType which(data_type);
if (which.isNullable())
return makeASTFunction("Nullable", getColumnDeclaration(typeid_cast<const DataTypeNullable *>(data_type.get())->getNestedType()));
if (which.isArray())
return makeASTFunction("Array", getColumnDeclaration(typeid_cast<const DataTypeArray *>(data_type.get())->getNestedType()));
/// getName() for decimal returns 'Decimal(precision, scale)', will get an error with it
if (which.isDecimal())
{
auto make_decimal_expression = [&](std::string type_name)
{
auto ast_expression = std::make_shared<ASTFunction>();
ast_expression->name = type_name;
ast_expression->arguments = std::make_shared<ASTExpressionList>();
ast_expression->arguments->children.emplace_back(std::make_shared<ASTLiteral>(getDecimalScale(*data_type)));
return ast_expression;
};
if (which.isDecimal32())
return make_decimal_expression("Decimal32");
if (which.isDecimal64())
return make_decimal_expression("Decimal64");
if (which.isDecimal128())
return make_decimal_expression("Decimal128");
if (which.isDecimal256())
return make_decimal_expression("Decimal256");
}
return std::make_shared<ASTIdentifier>(data_type->getName());
}
/// For single storage MaterializedPostgreSQL get columns and primary key columns from storage definition.
/// For database engine MaterializedPostgreSQL get columns and primary key columns by fetching from PostgreSQL, also using the same
/// transaction with snapshot, which is used for initial tables dump.
ASTPtr StorageMaterializedPostgreSQL::getCreateNestedTableQuery(PostgreSQLTableStructurePtr table_structure)
{
auto create_table_query = std::make_shared<ASTCreateQuery>();
auto table_id = getStorageID();
create_table_query->table = getNestedTableName();
create_table_query->database = table_id.database_name;
if (is_materialized_postgresql_database)
create_table_query->uuid = table_id.uuid;
auto columns_declare_list = std::make_shared<ASTColumns>();
auto columns_expression_list = std::make_shared<ASTExpressionList>();
auto order_by_expression = std::make_shared<ASTFunction>();
auto metadata_snapshot = getInMemoryMetadataPtr();
const auto & columns = metadata_snapshot->getColumns();
NamesAndTypesList ordinary_columns_and_types;
if (!is_materialized_postgresql_database)
{
ordinary_columns_and_types = columns.getOrdinary();
}
else
{
if (!table_structure)
{
throw Exception(ErrorCodes::LOGICAL_ERROR,
"No table structure returned for table {}.{}", table_id.database_name, table_id.table_name);
}
if (!table_structure->columns)
{
throw Exception(ErrorCodes::LOGICAL_ERROR,
"No columns returned for table {}.{}", table_id.database_name, table_id.table_name);
}
ordinary_columns_and_types = *table_structure->columns;
if (!table_structure->primary_key_columns && !table_structure->replica_identity_columns)
{
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"Table {}.{} has no primary key and no replica identity index", table_id.database_name, table_id.table_name);
}
NamesAndTypesList merging_columns;
if (table_structure->primary_key_columns)
merging_columns = *table_structure->primary_key_columns;
else
merging_columns = *table_structure->replica_identity_columns;
order_by_expression->name = "tuple";
order_by_expression->arguments = std::make_shared<ASTExpressionList>();
for (const auto & column : merging_columns)
order_by_expression->arguments->children.emplace_back(std::make_shared<ASTIdentifier>(column.name));
}
for (const auto & [name, type] : ordinary_columns_and_types)
{
const auto & column_declaration = std::make_shared<ASTColumnDeclaration>();
column_declaration->name = name;
column_declaration->type = getColumnDeclaration(type);
columns_expression_list->children.emplace_back(column_declaration);
}
columns_declare_list->set(columns_declare_list->columns, columns_expression_list);
columns_declare_list->columns->children.emplace_back(getMaterializedColumnsDeclaration("_sign", "Int8", 1));
columns_declare_list->columns->children.emplace_back(getMaterializedColumnsDeclaration("_version", "UInt64", 1));
create_table_query->set(create_table_query->columns_list, columns_declare_list);
/// Not nullptr for single storage (because throws exception if not specified), nullptr otherwise.
auto primary_key_ast = getInMemoryMetadataPtr()->getPrimaryKeyAST();
auto storage = std::make_shared<ASTStorage>();
storage->set(storage->engine, makeASTFunction("ReplacingMergeTree", std::make_shared<ASTIdentifier>("_version")));
if (primary_key_ast)
storage->set(storage->order_by, primary_key_ast);
else
storage->set(storage->order_by, order_by_expression);
create_table_query->set(create_table_query->storage, storage);
/// Add columns _sign and _version, so that they can be accessed from nested ReplacingMergeTree table if needed.
ordinary_columns_and_types.push_back({"_sign", std::make_shared<DataTypeInt8>()});
ordinary_columns_and_types.push_back({"_version", std::make_shared<DataTypeUInt64>()});
StorageInMemoryMetadata storage_metadata;
storage_metadata.setColumns(ColumnsDescription(ordinary_columns_and_types));
storage_metadata.setConstraints(metadata_snapshot->getConstraints());
setInMemoryMetadata(storage_metadata);
return create_table_query;
}
void registerStorageMaterializedPostgreSQL(StorageFactory & factory)
{
auto creator_fn = [](const StorageFactory::Arguments & args)
{
ASTs & engine_args = args.engine_args;
bool has_settings = args.storage_def->settings;
auto postgresql_replication_settings = std::make_unique<MaterializedPostgreSQLSettings>();
if (has_settings)
postgresql_replication_settings->loadFromQuery(*args.storage_def);
if (engine_args.size() != 5)
throw Exception("Storage MaterializedPostgreSQL requires 5 parameters: "
"PostgreSQL('host:port', 'database', 'table', 'username', 'password'",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
for (auto & engine_arg : engine_args)
engine_arg = evaluateConstantExpressionOrIdentifierAsLiteral(engine_arg, args.getContext());
StorageInMemoryMetadata metadata;
metadata.setColumns(args.columns);
metadata.setConstraints(args.constraints);
if (!args.storage_def->order_by && args.storage_def->primary_key)
args.storage_def->set(args.storage_def->order_by, args.storage_def->primary_key->clone());
if (!args.storage_def->order_by)
throw Exception("Storage MaterializedPostgreSQL needs order by key or primary key", ErrorCodes::BAD_ARGUMENTS);
if (args.storage_def->primary_key)
metadata.primary_key = KeyDescription::getKeyFromAST(args.storage_def->primary_key->ptr(), metadata.columns, args.getContext());
else
metadata.primary_key = KeyDescription::getKeyFromAST(args.storage_def->order_by->ptr(), metadata.columns, args.getContext());
auto parsed_host_port = parseAddress(engine_args[0]->as<ASTLiteral &>().value.safeGet<String>(), 5432);
const String & remote_table = engine_args[2]->as<ASTLiteral &>().value.safeGet<String>();
const String & remote_database = engine_args[1]->as<ASTLiteral &>().value.safeGet<String>();
/// No connection is made here, see Storages/PostgreSQL/PostgreSQLConnection.cpp
auto connection_info = postgres::formatConnectionString(
remote_database,
parsed_host_port.first,
parsed_host_port.second,
engine_args[3]->as<ASTLiteral &>().value.safeGet<String>(),
engine_args[4]->as<ASTLiteral &>().value.safeGet<String>());
return StorageMaterializedPostgreSQL::create(
args.table_id, args.attach, remote_database, remote_table, connection_info,
metadata, args.getContext(),
std::move(postgresql_replication_settings));
};
factory.registerStorage(
"MaterializedPostgreSQL",
creator_fn,
StorageFactory::StorageFeatures{
.supports_settings = true,
.supports_sort_order = true,
.source_access_type = AccessType::POSTGRES,
});
}
}
#endif

View File

@ -0,0 +1,180 @@
#pragma once
#if !defined(ARCADIA_BUILD)
#include "config_core.h"
#endif
#if USE_LIBPQXX
#include "PostgreSQLReplicationHandler.h"
#include "MaterializedPostgreSQLSettings.h"
#include <Parsers/IAST.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTCreateQuery.h>
#include <Parsers/ASTColumnDeclaration.h>
#include <Interpreters/evaluateConstantExpression.h>
#include <Interpreters/InterpreterCreateQuery.h>
#include <Interpreters/ExpressionAnalyzer.h>
#include <common/shared_ptr_helper.h>
#include <memory>
namespace DB
{
/** Case of single MaterializedPostgreSQL table engine.
*
* A user creates a table with engine MaterializedPostgreSQL. Order by expression must be specified (needed for
* nested ReplacingMergeTree table). This storage owns its own replication handler, which loads table data
* from PostgreSQL into nested ReplacingMergeTree table. If table is not created, but attached, replication handler
* will not start loading-from-snapshot procedure, instead it will continue from last committed lsn.
*
* Main point: Both tables exist on disk; database engine interacts only with the main table and main table takes
* total ownershot over nested table. Nested table has name `main_table_uuid` + NESTED_SUFFIX.
*
**/
/** Case of MaterializedPostgreSQL database engine.
*
* MaterializedPostgreSQL table exists only in memory and acts as a wrapper for nested table, i.e. only provides an
* interface to work with nested table. Both tables share the same StorageID.
*
* Main table is never created or dropped via database method. The only way database engine interacts with
* MaterializedPostgreSQL table - in tryGetTable() method, a MaterializedPostgreSQL table is returned in order to wrap
* and redirect read requests. Set of such wrapper-tables is cached inside database engine. All other methods in
* regard to materializePostgreSQL table are handled by replication handler.
*
* All database methods, apart from tryGetTable(), are devoted only to nested table.
* NOTE: It makes sense to allow rename method for MaterializedPostgreSQL table via database method.
* TODO: Make sure replication-to-table data channel is done only by relation_id.
*
* Also main table has the same InMemoryMetadata as its nested table, so if metadata of nested table changes - main table also has
* to update its metadata, because all read requests are passed to MaterializedPostgreSQL table and then it redirects read
* into nested table.
*
* When there is a need to update table structure, there will be created a new MaterializedPostgreSQL table with its own nested table,
* it will have updated table schema and all data will be loaded from scratch in the background, while previous table with outadted table
* structure will still serve read requests. When data is loaded, nested tables will be swapped, metadata of metarialzied table will be
* updated according to nested table.
*
**/
class StorageMaterializedPostgreSQL final : public shared_ptr_helper<StorageMaterializedPostgreSQL>, public IStorage, WithContext
{
friend struct shared_ptr_helper<StorageMaterializedPostgreSQL>;
public:
StorageMaterializedPostgreSQL(const StorageID & table_id_, ContextPtr context_);
StorageMaterializedPostgreSQL(StoragePtr nested_storage_, ContextPtr context_);
String getName() const override { return "MaterializedPostgreSQL"; }
void startup() override;
void shutdown() override;
/// Used only for single MaterializedPostgreSQL storage.
void dropInnerTableIfAny(bool no_delay, ContextPtr local_context) override;
NamesAndTypesList getVirtuals() const override;
Pipe read(
const Names & column_names,
const StorageMetadataPtr & metadata_snapshot,
SelectQueryInfo & query_info,
ContextPtr context_,
QueryProcessingStage::Enum processed_stage,
size_t max_block_size,
unsigned num_streams) override;
/// This method is called only from MateriaizePostgreSQL database engine, because it needs to maintain
/// an invariant: a table exists only if its nested table exists. This atomic variable is set to _true_
/// only once - when nested table is successfully created and is never changed afterwards.
bool hasNested() { return has_nested.load(); }
void createNestedIfNeeded(PostgreSQLTableStructurePtr table_structure);
StoragePtr getNested() const;
StoragePtr tryGetNested() const;
/// Create a temporary MaterializedPostgreSQL table with current_table_name + TMP_SUFFIX.
/// An empty wrapper is returned - it does not have inMemory metadata, just acts as an empty wrapper over
/// temporary nested, which will be created shortly after.
StoragePtr createTemporary() const;
ContextPtr getNestedTableContext() const { return nested_context; }
StorageID getNestedStorageID() const;
void setNestedStorageID(const StorageID & id) { nested_table_id.emplace(id); }
static std::shared_ptr<Context> makeNestedTableContext(ContextPtr from_context);
/// Get nested table (or throw if it does not exist), set in-memory metadata (taken from nested table)
/// for current table, set has_nested = true.
StoragePtr prepare();
protected:
StorageMaterializedPostgreSQL(
const StorageID & table_id_,
bool is_attach_,
const String & remote_database_name,
const String & remote_table_name,
const postgres::ConnectionInfo & connection_info,
const StorageInMemoryMetadata & storage_metadata,
ContextPtr context_,
std::unique_ptr<MaterializedPostgreSQLSettings> replication_settings);
private:
static std::shared_ptr<ASTColumnDeclaration> getMaterializedColumnsDeclaration(
const String name, const String type, UInt64 default_value);
ASTPtr getColumnDeclaration(const DataTypePtr & data_type) const;
ASTPtr getCreateNestedTableQuery(PostgreSQLTableStructurePtr table_structure);
String getNestedTableName() const;
/// Not nullptr only for single MaterializedPostgreSQL storage, because for MaterializedPostgreSQL
/// database engine there is one replication handler for all tables.
std::unique_ptr<PostgreSQLReplicationHandler> replication_handler;
/// Distinguish between single MaterilizePostgreSQL table engine and MaterializedPostgreSQL database engine,
/// because table with engine MaterilizePostgreSQL acts differently in each case.
bool is_materialized_postgresql_database = false;
/// Will be set to `true` only once - when nested table was loaded by replication thread.
/// After that, it will never be changed. Needed for MaterializedPostgreSQL database engine
/// because there is an invariant - table exists only if its nested table exists, but nested
/// table is not loaded immediately. It is made atomic, because it is accessed only by database engine,
/// and updated by replication handler (only once).
std::atomic<bool> has_nested = false;
/// Nested table context is a copy of global context, but modified to answer isInternalQuery() == true.
/// This is needed to let database engine know whether to access nested table or a wrapper over nested (materialized table).
ContextMutablePtr nested_context;
/// Save nested storageID to be able to fetch it. It is set once nested is created and will be
/// updated only when nested is reloaded or renamed.
std::optional<StorageID> nested_table_id;
/// Needed only for the case of single MaterializedPostgreSQL storage - in order to make
/// delayed storage forwarding into replication handler.
String remote_table_name;
/// Needed only for the case of single MaterializedPostgreSQL storage, because in case of create
/// query (not attach) initial setup will be done immediately and error message is thrown at once.
/// It results in the fact: single MaterializedPostgreSQL storage is created only if its nested table is created.
/// In case of attach - this setup will be done in a separate thread in the background. It will also
/// be checked for nested table and attempted to load it if it does not exist for some reason.
bool is_attach = true;
};
}
#endif

View File

@ -0,0 +1,86 @@
#include <Storages/ReadFinalForExternalReplicaStorage.h>
#if USE_MYSQL || USE_LIBPQXX
#include <Interpreters/ExpressionAnalyzer.h>
#include <Interpreters/TreeRewriter.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTIdentifier.h>
#include <Processors/Transforms/FilterTransform.h>
#include <Interpreters/Context.h>
namespace DB
{
Pipe readFinalFromNestedStorage(
StoragePtr nested_storage,
const Names & column_names,
const StorageMetadataPtr & /*metadata_snapshot*/,
SelectQueryInfo & query_info,
ContextPtr context,
QueryProcessingStage::Enum processed_stage,
size_t max_block_size,
unsigned int num_streams)
{
NameSet column_names_set = NameSet(column_names.begin(), column_names.end());
auto lock = nested_storage->lockForShare(context->getCurrentQueryId(), context->getSettingsRef().lock_acquire_timeout);
const StorageMetadataPtr & nested_metadata = nested_storage->getInMemoryMetadataPtr();
Block nested_header = nested_metadata->getSampleBlock();
ColumnWithTypeAndName & sign_column = nested_header.getByPosition(nested_header.columns() - 2);
ColumnWithTypeAndName & version_column = nested_header.getByPosition(nested_header.columns() - 1);
if (ASTSelectQuery * select_query = query_info.query->as<ASTSelectQuery>(); select_query && !column_names_set.count(version_column.name))
{
auto & tables_in_select_query = select_query->tables()->as<ASTTablesInSelectQuery &>();
if (!tables_in_select_query.children.empty())
{
auto & tables_element = tables_in_select_query.children[0]->as<ASTTablesInSelectQueryElement &>();
if (tables_element.table_expression)
tables_element.table_expression->as<ASTTableExpression &>().final = true;
}
}
String filter_column_name;
Names require_columns_name = column_names;
ASTPtr expressions = std::make_shared<ASTExpressionList>();
if (column_names_set.empty() || !column_names_set.count(sign_column.name))
{
require_columns_name.emplace_back(sign_column.name);
const auto & sign_column_name = std::make_shared<ASTIdentifier>(sign_column.name);
const auto & fetch_sign_value = std::make_shared<ASTLiteral>(Field(Int8(1)));
expressions->children.emplace_back(makeASTFunction("equals", sign_column_name, fetch_sign_value));
filter_column_name = expressions->children.back()->getColumnName();
for (const auto & column_name : column_names)
expressions->children.emplace_back(std::make_shared<ASTIdentifier>(column_name));
}
Pipe pipe = nested_storage->read(require_columns_name, nested_metadata, query_info, context, processed_stage, max_block_size, num_streams);
pipe.addTableLock(lock);
if (!expressions->children.empty() && !pipe.empty())
{
Block pipe_header = pipe.getHeader();
auto syntax = TreeRewriter(context).analyze(expressions, pipe_header.getNamesAndTypesList());
ExpressionActionsPtr expression_actions = ExpressionAnalyzer(expressions, syntax, context).getActions(true /* add_aliases */, false /* project_result */);
pipe.addSimpleTransform([&](const Block & header)
{
return std::make_shared<FilterTransform>(header, expression_actions, filter_column_name, false);
});
}
return pipe;
}
}
#endif

View File

@ -0,0 +1,28 @@
#pragma once
#if !defined(ARCADIA_BUILD)
# include "config_core.h"
#endif
#if USE_MYSQL || USE_LIBPQXX
#include <Storages/StorageProxy.h>
#include <Processors/Pipe.h>
namespace DB
{
Pipe readFinalFromNestedStorage(
StoragePtr nested_storage,
const Names & column_names,
const StorageMetadataPtr & /*metadata_snapshot*/,
SelectQueryInfo & query_info,
ContextPtr context,
QueryProcessingStage::Enum processed_stage,
size_t max_block_size,
unsigned int num_streams);
}
#endif

View File

@ -22,6 +22,7 @@
#include <Processors/Transforms/FilterTransform.h> #include <Processors/Transforms/FilterTransform.h>
#include <Databases/MySQL/DatabaseMaterializeMySQL.h> #include <Databases/MySQL/DatabaseMaterializeMySQL.h>
#include <Storages/ReadFinalForExternalReplicaStorage.h>
#include <Storages/SelectQueryInfo.h> #include <Storages/SelectQueryInfo.h>
namespace DB namespace DB
@ -37,7 +38,7 @@ StorageMaterializeMySQL::StorageMaterializeMySQL(const StoragePtr & nested_stora
Pipe StorageMaterializeMySQL::read( Pipe StorageMaterializeMySQL::read(
const Names & column_names, const Names & column_names,
const StorageMetadataPtr & /*metadata_snapshot*/, const StorageMetadataPtr & metadata_snapshot,
SelectQueryInfo & query_info, SelectQueryInfo & query_info,
ContextPtr context, ContextPtr context,
QueryProcessingStage::Enum processed_stage, QueryProcessingStage::Enum processed_stage,
@ -46,61 +47,8 @@ Pipe StorageMaterializeMySQL::read(
{ {
/// If the background synchronization thread has exception. /// If the background synchronization thread has exception.
rethrowSyncExceptionIfNeed(database); rethrowSyncExceptionIfNeed(database);
return readFinalFromNestedStorage(nested_storage, column_names, metadata_snapshot,
NameSet column_names_set = NameSet(column_names.begin(), column_names.end()); query_info, context, processed_stage, max_block_size, num_streams);
auto lock = nested_storage->lockForShare(context->getCurrentQueryId(), context->getSettingsRef().lock_acquire_timeout);
const StorageMetadataPtr & nested_metadata = nested_storage->getInMemoryMetadataPtr();
Block nested_header = nested_metadata->getSampleBlock();
ColumnWithTypeAndName & sign_column = nested_header.getByPosition(nested_header.columns() - 2);
ColumnWithTypeAndName & version_column = nested_header.getByPosition(nested_header.columns() - 1);
if (ASTSelectQuery * select_query = query_info.query->as<ASTSelectQuery>(); select_query && !column_names_set.count(version_column.name))
{
auto & tables_in_select_query = select_query->tables()->as<ASTTablesInSelectQuery &>();
if (!tables_in_select_query.children.empty())
{
auto & tables_element = tables_in_select_query.children[0]->as<ASTTablesInSelectQueryElement &>();
if (tables_element.table_expression)
tables_element.table_expression->as<ASTTableExpression &>().final = true;
}
}
String filter_column_name;
Names require_columns_name = column_names;
ASTPtr expressions = std::make_shared<ASTExpressionList>();
if (column_names_set.empty() || !column_names_set.count(sign_column.name))
{
require_columns_name.emplace_back(sign_column.name);
const auto & sign_column_name = std::make_shared<ASTIdentifier>(sign_column.name);
const auto & fetch_sign_value = std::make_shared<ASTLiteral>(Field(Int8(1)));
expressions->children.emplace_back(makeASTFunction("equals", sign_column_name, fetch_sign_value));
filter_column_name = expressions->children.back()->getColumnName();
for (const auto & column_name : column_names)
expressions->children.emplace_back(std::make_shared<ASTIdentifier>(column_name));
}
Pipe pipe = nested_storage->read(require_columns_name, nested_metadata, query_info, context, processed_stage, max_block_size, num_streams);
pipe.addTableLock(lock);
if (!expressions->children.empty() && !pipe.empty())
{
Block pipe_header = pipe.getHeader();
auto syntax = TreeRewriter(context).analyze(expressions, pipe_header.getNamesAndTypesList());
ExpressionActionsPtr expression_actions = ExpressionAnalyzer(expressions, syntax, context).getActions(true /* add_aliases */, false /* project_result */);
pipe.addSimpleTransform([&](const Block & header)
{
return std::make_shared<FilterTransform>(header, expression_actions, filter_column_name, false);
});
}
return pipe;
} }
NamesAndTypesList StorageMaterializeMySQL::getVirtuals() const NamesAndTypesList StorageMaterializeMySQL::getVirtuals() const

View File

@ -3,7 +3,6 @@
#include <Parsers/ASTSelectQuery.h> #include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTSelectWithUnionQuery.h> #include <Parsers/ASTSelectWithUnionQuery.h>
#include <Parsers/ASTCreateQuery.h> #include <Parsers/ASTCreateQuery.h>
#include <Parsers/ASTDropQuery.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Interpreters/InterpreterCreateQuery.h> #include <Interpreters/InterpreterCreateQuery.h>
@ -229,36 +228,6 @@ BlockOutputStreamPtr StorageMaterializedView::write(const ASTPtr & query, const
} }
static void executeDropQuery(ASTDropQuery::Kind kind, ContextPtr global_context, ContextPtr current_context, const StorageID & target_table_id, bool no_delay)
{
if (DatabaseCatalog::instance().tryGetTable(target_table_id, current_context))
{
/// We create and execute `drop` query for internal table.
auto drop_query = std::make_shared<ASTDropQuery>();
drop_query->database = target_table_id.database_name;
drop_query->table = target_table_id.table_name;
drop_query->kind = kind;
drop_query->no_delay = no_delay;
drop_query->if_exists = true;
ASTPtr ast_drop_query = drop_query;
/// FIXME We have to use global context to execute DROP query for inner table
/// to avoid "Not enough privileges" error if current user has only DROP VIEW ON mat_view_name privilege
/// and not allowed to drop inner table explicitly. Allowing to drop inner table without explicit grant
/// looks like expected behaviour and we have tests for it.
auto drop_context = Context::createCopy(global_context);
drop_context->getClientInfo().query_kind = ClientInfo::QueryKind::SECONDARY_QUERY;
if (auto txn = current_context->getZooKeeperMetadataTransaction())
{
/// For Replicated database
drop_context->setQueryContext(std::const_pointer_cast<Context>(current_context));
drop_context->initZooKeeperMetadataTransaction(txn, true);
}
InterpreterDropQuery drop_interpreter(ast_drop_query, drop_context);
drop_interpreter.execute();
}
}
void StorageMaterializedView::drop() void StorageMaterializedView::drop()
{ {
auto table_id = getStorageID(); auto table_id = getStorageID();
@ -266,19 +235,19 @@ void StorageMaterializedView::drop()
if (!select_query.select_table_id.empty()) if (!select_query.select_table_id.empty())
DatabaseCatalog::instance().removeDependency(select_query.select_table_id, table_id); DatabaseCatalog::instance().removeDependency(select_query.select_table_id, table_id);
dropInnerTable(true, getContext()); dropInnerTableIfAny(true, getContext());
} }
void StorageMaterializedView::dropInnerTable(bool no_delay, ContextPtr local_context) void StorageMaterializedView::dropInnerTableIfAny(bool no_delay, ContextPtr local_context)
{ {
if (has_inner_table && tryGetTargetTable()) if (has_inner_table && tryGetTargetTable())
executeDropQuery(ASTDropQuery::Kind::Drop, getContext(), local_context, target_table_id, no_delay); InterpreterDropQuery::executeDropQuery(ASTDropQuery::Kind::Drop, getContext(), local_context, target_table_id, no_delay);
} }
void StorageMaterializedView::truncate(const ASTPtr &, const StorageMetadataPtr &, ContextPtr local_context, TableExclusiveLockHolder &) void StorageMaterializedView::truncate(const ASTPtr &, const StorageMetadataPtr &, ContextPtr local_context, TableExclusiveLockHolder &)
{ {
if (has_inner_table) if (has_inner_table)
executeDropQuery(ASTDropQuery::Kind::Truncate, getContext(), local_context, target_table_id, true); InterpreterDropQuery::executeDropQuery(ASTDropQuery::Kind::Truncate, getContext(), local_context, target_table_id, true);
} }
void StorageMaterializedView::checkStatementCanBeForwarded() const void StorageMaterializedView::checkStatementCanBeForwarded() const

View File

@ -37,7 +37,7 @@ public:
BlockOutputStreamPtr write(const ASTPtr & query, const StorageMetadataPtr & /*metadata_snapshot*/, ContextPtr context) override; BlockOutputStreamPtr write(const ASTPtr & query, const StorageMetadataPtr & /*metadata_snapshot*/, ContextPtr context) override;
void drop() override; void drop() override;
void dropInnerTable(bool no_delay, ContextPtr context); void dropInnerTableIfAny(bool no_delay, ContextPtr local_context) override;
void truncate(const ASTPtr &, const StorageMetadataPtr &, ContextPtr, TableExclusiveLockHolder &) override; void truncate(const ASTPtr &, const StorageMetadataPtr &, ContextPtr, TableExclusiveLockHolder &) override;

View File

@ -47,6 +47,7 @@ namespace ErrorCodes
extern const int TIMEOUT_EXCEEDED; extern const int TIMEOUT_EXCEEDED;
extern const int UNKNOWN_POLICY; extern const int UNKNOWN_POLICY;
extern const int NO_SUCH_DATA_PART; extern const int NO_SUCH_DATA_PART;
extern const int ABORTED;
} }
namespace ActionLocks namespace ActionLocks
@ -676,9 +677,16 @@ void StorageMergeTree::loadMutations()
} }
std::shared_ptr<StorageMergeTree::MergeMutateSelectedEntry> StorageMergeTree::selectPartsToMerge( std::shared_ptr<StorageMergeTree::MergeMutateSelectedEntry> StorageMergeTree::selectPartsToMerge(
const StorageMetadataPtr & metadata_snapshot, bool aggressive, const String & partition_id, bool final, String * out_disable_reason, TableLockHolder & /* table_lock_holder */, bool optimize_skip_merged_partitions, SelectPartsDecision * select_decision_out) const StorageMetadataPtr & metadata_snapshot,
bool aggressive,
const String & partition_id,
bool final,
String * out_disable_reason,
TableLockHolder & /* table_lock_holder */,
std::unique_lock<std::mutex> & lock,
bool optimize_skip_merged_partitions,
SelectPartsDecision * select_decision_out)
{ {
std::unique_lock lock(currently_processing_in_background_mutex);
auto data_settings = getSettings(); auto data_settings = getSettings();
FutureMergedMutatedPart future_part; FutureMergedMutatedPart future_part;
@ -795,7 +803,24 @@ bool StorageMergeTree::merge(
SelectPartsDecision select_decision; SelectPartsDecision select_decision;
auto merge_mutate_entry = selectPartsToMerge(metadata_snapshot, aggressive, partition_id, final, out_disable_reason, table_lock_holder, optimize_skip_merged_partitions, &select_decision); std::shared_ptr<MergeMutateSelectedEntry> merge_mutate_entry;
{
std::unique_lock lock(currently_processing_in_background_mutex);
if (merger_mutator.merges_blocker.isCancelled())
throw Exception("Cancelled merging parts", ErrorCodes::ABORTED);
merge_mutate_entry = selectPartsToMerge(
metadata_snapshot,
aggressive,
partition_id,
final,
out_disable_reason,
table_lock_holder,
lock,
optimize_skip_merged_partitions,
&select_decision);
}
/// If there is nothing to merge then we treat this merge as successful (needed for optimize final optimization) /// If there is nothing to merge then we treat this merge as successful (needed for optimize final optimization)
if (select_decision == SelectPartsDecision::NOTHING_TO_MERGE) if (select_decision == SelectPartsDecision::NOTHING_TO_MERGE)
@ -867,7 +892,6 @@ bool StorageMergeTree::partIsAssignedToBackgroundOperation(const DataPartPtr & p
std::shared_ptr<StorageMergeTree::MergeMutateSelectedEntry> StorageMergeTree::selectPartsToMutate( std::shared_ptr<StorageMergeTree::MergeMutateSelectedEntry> StorageMergeTree::selectPartsToMutate(
const StorageMetadataPtr & metadata_snapshot, String * /* disable_reason */, TableLockHolder & /* table_lock_holder */) const StorageMetadataPtr & metadata_snapshot, String * /* disable_reason */, TableLockHolder & /* table_lock_holder */)
{ {
std::lock_guard lock(currently_processing_in_background_mutex);
size_t max_ast_elements = getContext()->getSettingsRef().max_expanded_ast_elements; size_t max_ast_elements = getContext()->getSettingsRef().max_expanded_ast_elements;
FutureMergedMutatedPart future_part; FutureMergedMutatedPart future_part;
@ -1006,16 +1030,20 @@ bool StorageMergeTree::scheduleDataProcessingJob(IBackgroundJobExecutor & execut
if (shutdown_called) if (shutdown_called)
return false; return false;
if (merger_mutator.merges_blocker.isCancelled())
return false;
auto metadata_snapshot = getInMemoryMetadataPtr(); auto metadata_snapshot = getInMemoryMetadataPtr();
std::shared_ptr<MergeMutateSelectedEntry> merge_entry, mutate_entry; std::shared_ptr<MergeMutateSelectedEntry> merge_entry, mutate_entry;
auto share_lock = lockForShare(RWLockImpl::NO_QUERY, getSettings()->lock_acquire_timeout_for_background_operations); auto share_lock = lockForShare(RWLockImpl::NO_QUERY, getSettings()->lock_acquire_timeout_for_background_operations);
merge_entry = selectPartsToMerge(metadata_snapshot, false, {}, false, nullptr, share_lock);
if (!merge_entry) {
mutate_entry = selectPartsToMutate(metadata_snapshot, nullptr, share_lock); std::unique_lock lock(currently_processing_in_background_mutex);
if (merger_mutator.merges_blocker.isCancelled())
return false;
merge_entry = selectPartsToMerge(metadata_snapshot, false, {}, false, nullptr, share_lock, lock);
if (!merge_entry)
mutate_entry = selectPartsToMutate(metadata_snapshot, nullptr, share_lock);
}
if (merge_entry) if (merge_entry)
{ {
@ -1033,7 +1061,7 @@ bool StorageMergeTree::scheduleDataProcessingJob(IBackgroundJobExecutor & execut
}, PoolType::MERGE_MUTATE}); }, PoolType::MERGE_MUTATE});
return true; return true;
} }
else if (auto lock = time_after_previous_cleanup.compareAndRestartDeferred(1)) else if (auto cmp_lock = time_after_previous_cleanup.compareAndRestartDeferred(1))
{ {
executor.execute({[this, share_lock] () executor.execute({[this, share_lock] ()
{ {
@ -1186,22 +1214,21 @@ bool StorageMergeTree::optimize(
ActionLock StorageMergeTree::stopMergesAndWait() ActionLock StorageMergeTree::stopMergesAndWait()
{ {
std::unique_lock lock(currently_processing_in_background_mutex);
/// Asks to complete merges and does not allow them to start. /// Asks to complete merges and does not allow them to start.
/// This protects against "revival" of data for a removed partition after completion of merge. /// This protects against "revival" of data for a removed partition after completion of merge.
auto merge_blocker = merger_mutator.merges_blocker.cancel(); auto merge_blocker = merger_mutator.merges_blocker.cancel();
while (!currently_merging_mutating_parts.empty())
{ {
std::unique_lock lock(currently_processing_in_background_mutex); LOG_DEBUG(log, "Waiting for currently running merges ({} parts are merging right now)",
while (!currently_merging_mutating_parts.empty()) currently_merging_mutating_parts.size());
{
LOG_DEBUG(log, "Waiting for currently running merges ({} parts are merging right now)",
currently_merging_mutating_parts.size());
if (std::cv_status::timeout == currently_processing_in_background_condition.wait_for( if (std::cv_status::timeout == currently_processing_in_background_condition.wait_for(
lock, std::chrono::seconds(DBMS_DEFAULT_LOCK_ACQUIRE_TIMEOUT_SEC))) lock, std::chrono::seconds(DBMS_DEFAULT_LOCK_ACQUIRE_TIMEOUT_SEC)))
{ {
throw Exception("Timeout while waiting for already running merges", ErrorCodes::TIMEOUT_EXCEEDED); throw Exception("Timeout while waiting for already running merges", ErrorCodes::TIMEOUT_EXCEEDED);
}
} }
} }

View File

@ -196,6 +196,7 @@ private:
bool final, bool final,
String * disable_reason, String * disable_reason,
TableLockHolder & table_lock_holder, TableLockHolder & table_lock_holder,
std::unique_lock<std::mutex> & lock,
bool optimize_skip_merged_partitions = false, bool optimize_skip_merged_partitions = false,
SelectPartsDecision * select_decision_out = nullptr); SelectPartsDecision * select_decision_out = nullptr);

View File

@ -1,6 +1,7 @@
#include "StoragePostgreSQL.h" #include "StoragePostgreSQL.h"
#if USE_LIBPQXX #if USE_LIBPQXX
#include <DataStreams/PostgreSQLBlockInputStream.h>
#include <Storages/StorageFactory.h> #include <Storages/StorageFactory.h>
#include <Storages/transformQueryForExternalDatabase.h> #include <Storages/transformQueryForExternalDatabase.h>
@ -16,7 +17,6 @@
#include <Columns/ColumnArray.h> #include <Columns/ColumnArray.h>
#include <Columns/ColumnsNumber.h> #include <Columns/ColumnsNumber.h>
#include <Columns/ColumnDecimal.h> #include <Columns/ColumnDecimal.h>
#include <DataStreams/PostgreSQLBlockInputStream.h>
#include <Core/Settings.h> #include <Core/Settings.h>
#include <Common/parseAddress.h> #include <Common/parseAddress.h>
#include <Common/assert_cast.h> #include <Common/assert_cast.h>
@ -90,7 +90,7 @@ Pipe StoragePostgreSQL::read(
} }
return Pipe(std::make_shared<SourceFromInputStream>( return Pipe(std::make_shared<SourceFromInputStream>(
std::make_shared<PostgreSQLBlockInputStream>(pool->get(), query, sample_block, max_block_size_))); std::make_shared<PostgreSQLBlockInputStream<>>(pool->get(), query, sample_block, max_block_size_)));
} }

View File

@ -9,7 +9,7 @@
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Storages/IStorage.h> #include <Storages/IStorage.h>
#include <DataStreams/IBlockOutputStream.h> #include <DataStreams/IBlockOutputStream.h>
#include <Storages/PostgreSQL/PoolWithFailover.h> #include <Core/PostgreSQL/PoolWithFailover.h>
namespace DB namespace DB

View File

@ -17,6 +17,7 @@
#include <Storages/StorageReplicatedMergeTree.h> #include <Storages/StorageReplicatedMergeTree.h>
#include <Storages/MergeTree/IMergeTreeDataPart.h> #include <Storages/MergeTree/IMergeTreeDataPart.h>
#include <Storages/MergeTree/MergeList.h> #include <Storages/MergeTree/MergeList.h>
#include <Storages/MergeTree/MergedBlockOutputStream.h>
#include <Storages/MergeTree/PinnedPartUUIDs.h> #include <Storages/MergeTree/PinnedPartUUIDs.h>
#include <Storages/MergeTree/PartitionPruner.h> #include <Storages/MergeTree/PartitionPruner.h>
#include <Storages/MergeTree/ReplicatedMergeTreeTableMetadata.h> #include <Storages/MergeTree/ReplicatedMergeTreeTableMetadata.h>
@ -1215,29 +1216,37 @@ void StorageReplicatedMergeTree::checkParts(bool skip_sanity_checks)
for (size_t i = 0; i < parts_to_fetch.size(); ++i) for (size_t i = 0; i < parts_to_fetch.size(); ++i)
{ {
const String & part_name = parts_to_fetch[i]; const String & part_name = parts_to_fetch[i];
LOG_ERROR(log, "Removing locally missing part from ZooKeeper and queueing a fetch: {}", part_name);
Coordination::Requests ops; Coordination::Requests ops;
time_t part_create_time = 0; String has_replica = findReplicaHavingPart(part_name, true);
Coordination::ExistsResponse exists_resp = exists_futures[i].get(); if (!has_replica.empty())
if (exists_resp.error == Coordination::Error::ZOK)
{ {
part_create_time = exists_resp.stat.ctime / 1000; LOG_ERROR(log, "Removing locally missing part from ZooKeeper and queueing a fetch: {}", part_name);
removePartFromZooKeeper(part_name, ops, exists_resp.stat.numChildren > 0); time_t part_create_time = 0;
Coordination::ExistsResponse exists_resp = exists_futures[i].get();
if (exists_resp.error == Coordination::Error::ZOK)
{
part_create_time = exists_resp.stat.ctime / 1000;
removePartFromZooKeeper(part_name, ops, exists_resp.stat.numChildren > 0);
}
LogEntry log_entry;
log_entry.type = LogEntry::GET_PART;
log_entry.source_replica = "";
log_entry.new_part_name = part_name;
log_entry.create_time = part_create_time;
/// We assume that this occurs before the queue is loaded (queue.initialize).
ops.emplace_back(zkutil::makeCreateRequest(
fs::path(replica_path) / "queue/queue-", log_entry.toString(), zkutil::CreateMode::PersistentSequential));
enqueue_futures.emplace_back(zookeeper->asyncMulti(ops));
}
else
{
LOG_ERROR(log, "Not found active replica having part {}", part_name);
enqueuePartForCheck(part_name);
} }
LogEntry log_entry;
log_entry.type = LogEntry::GET_PART;
log_entry.source_replica = "";
log_entry.new_part_name = part_name;
log_entry.create_time = part_create_time;
/// We assume that this occurs before the queue is loaded (queue.initialize).
ops.emplace_back(zkutil::makeCreateRequest(
fs::path(replica_path) / "queue/queue-", log_entry.toString(), zkutil::CreateMode::PersistentSequential));
enqueue_futures.emplace_back(zookeeper->asyncMulti(ops));
} }
for (auto & future : enqueue_futures) for (auto & future : enqueue_futures)
@ -1272,7 +1281,6 @@ void StorageReplicatedMergeTree::syncPinnedPartUUIDs()
} }
} }
void StorageReplicatedMergeTree::checkPartChecksumsAndAddCommitOps(const zkutil::ZooKeeperPtr & zookeeper, void StorageReplicatedMergeTree::checkPartChecksumsAndAddCommitOps(const zkutil::ZooKeeperPtr & zookeeper,
const DataPartPtr & part, Coordination::Requests & ops, String part_name, NameSet * absent_replicas_paths) const DataPartPtr & part, Coordination::Requests & ops, String part_name, NameSet * absent_replicas_paths)
{ {
@ -7393,4 +7401,164 @@ bool StorageReplicatedMergeTree::checkIfDetachedPartitionExists(const String & p
} }
return false; return false;
} }
bool StorageReplicatedMergeTree::createEmptyPartInsteadOfLost(zkutil::ZooKeeperPtr zookeeper, const String & lost_part_name)
{
LOG_INFO(log, "Going to replace lost part {} with empty part", lost_part_name);
auto metadata_snapshot = getInMemoryMetadataPtr();
auto settings = getSettings();
constexpr static auto TMP_PREFIX = "tmp_empty_";
auto new_part_info = MergeTreePartInfo::fromPartName(lost_part_name, format_version);
auto block = metadata_snapshot->getSampleBlock();
DB::IMergeTreeDataPart::TTLInfos move_ttl_infos;
NamesAndTypesList columns = metadata_snapshot->getColumns().getAllPhysical().filter(block.getNames());
ReservationPtr reservation = reserveSpacePreferringTTLRules(metadata_snapshot, 0, move_ttl_infos, time(nullptr), 0, true);
VolumePtr volume = getStoragePolicy()->getVolume(0);
IMergeTreeDataPart::MinMaxIndex minmax_idx;
minmax_idx.update(block, getMinMaxColumnsNames(metadata_snapshot->getPartitionKey()));
auto new_data_part = createPart(
lost_part_name,
choosePartType(0, block.rows()),
new_part_info,
createVolumeFromReservation(reservation, volume),
TMP_PREFIX + lost_part_name);
if (settings->assign_part_uuids)
new_data_part->uuid = UUIDHelpers::generateV4();
new_data_part->setColumns(columns);
new_data_part->rows_count = block.rows();
{
auto lock = lockParts();
auto parts_in_partition = getDataPartsPartitionRange(new_part_info.partition_id);
if (parts_in_partition.empty())
{
LOG_WARNING(log, "Empty part {} is not created instead of lost part because there are no parts in partition {} (it's empty), resolve this manually using DROP PARTITION.", lost_part_name, new_part_info.partition_id);
return false;
}
new_data_part->partition = (*parts_in_partition.begin())->partition;
}
new_data_part->minmax_idx = std::move(minmax_idx);
new_data_part->is_temp = true;
SyncGuardPtr sync_guard;
if (new_data_part->isStoredOnDisk())
{
/// The name could be non-unique in case of stale files from previous runs.
String full_path = new_data_part->getFullRelativePath();
if (new_data_part->volume->getDisk()->exists(full_path))
{
LOG_WARNING(log, "Removing old temporary directory {}", fullPath(new_data_part->volume->getDisk(), full_path));
new_data_part->volume->getDisk()->removeRecursive(full_path);
}
const auto disk = new_data_part->volume->getDisk();
disk->createDirectories(full_path);
if (getSettings()->fsync_part_directory)
sync_guard = disk->getDirectorySyncGuard(full_path);
}
/// This effectively chooses minimal compression method:
/// either default lz4 or compression method with zero thresholds on absolute and relative part size.
auto compression_codec = getContext()->chooseCompressionCodec(0, 0);
const auto & index_factory = MergeTreeIndexFactory::instance();
MergedBlockOutputStream out(new_data_part, metadata_snapshot, columns, index_factory.getMany(metadata_snapshot->getSecondaryIndices()), compression_codec);
bool sync_on_insert = settings->fsync_after_insert;
out.writePrefix();
out.write(block);
out.writeSuffixAndFinalizePart(new_data_part, sync_on_insert);
try
{
MergeTreeData::Transaction transaction(*this);
auto replaced_parts = renameTempPartAndReplace(new_data_part, nullptr, &transaction);
if (!replaced_parts.empty())
{
Strings part_names;
for (const auto & part : replaced_parts)
part_names.emplace_back(part->name);
/// Why this exception is not a LOGICAL_ERROR? Because it's possible
/// to have some source parts for the lost part if replica currently
/// cloning from another replica, but source replica lost covering
/// part and finished MERGE_PARTS before clone. It's an extremely
/// rare case and it's unclear how to resolve it better. Eventually
/// source replica will replace lost part with empty part and we
/// will fetch this empty part instead of our source parts. This
/// will make replicas consistent, but some data will be lost.
throw Exception(ErrorCodes::INCORRECT_DATA, "Tried to create empty part {}, but it replaces existing parts {}.", lost_part_name, fmt::join(part_names, ", "));
}
while (true)
{
Coordination::Requests ops;
Coordination::Stat replicas_stat;
auto replicas_path = fs::path(zookeeper_path) / "replicas";
Strings replicas = zookeeper->getChildren(replicas_path, &replicas_stat);
/// In rare cases new replica can appear during check
ops.emplace_back(zkutil::makeCheckRequest(replicas_path, replicas_stat.version));
for (const String & replica : replicas)
{
String current_part_path = fs::path(zookeeper_path) / "replicas" / replica / "parts" / lost_part_name;
/// We must be sure that this part doesn't exist on other replicas
if (!zookeeper->exists(current_part_path))
{
ops.emplace_back(zkutil::makeCreateRequest(current_part_path, "", zkutil::CreateMode::Persistent));
ops.emplace_back(zkutil::makeRemoveRequest(current_part_path, -1));
}
else
{
throw Exception(ErrorCodes::DUPLICATE_DATA_PART, "Part {} already exists on replica {} on path {}", lost_part_name, replica, current_part_path);
}
}
getCommitPartOps(ops, new_data_part);
Coordination::Responses responses;
if (auto code = zookeeper->tryMulti(ops, responses); code == Coordination::Error::ZOK)
{
transaction.commit();
break;
}
else if (code == Coordination::Error::ZBADVERSION)
{
LOG_INFO(log, "Looks like new replica appearead while creating new empty part, will retry");
}
else
{
zkutil::KeeperMultiException::check(code, ops, responses);
}
}
}
catch (const Exception & ex)
{
LOG_WARNING(log, "Cannot commit empty part {} with error {}", lost_part_name, ex.displayText());
return false;
}
LOG_INFO(log, "Created empty part {} instead of lost part", lost_part_name);
return true;
}
} }

View File

@ -258,6 +258,8 @@ public:
return replicated_sends_throttler; return replicated_sends_throttler;
} }
bool createEmptyPartInsteadOfLost(zkutil::ZooKeeperPtr zookeeper, const String & lost_part_name);
private: private:
std::atomic_bool are_restoring_replica {false}; std::atomic_bool are_restoring_replica {false};

View File

@ -60,6 +60,7 @@ void registerStorageEmbeddedRocksDB(StorageFactory & factory);
#if USE_LIBPQXX #if USE_LIBPQXX
void registerStoragePostgreSQL(StorageFactory & factory); void registerStoragePostgreSQL(StorageFactory & factory);
void registerStorageMaterializedPostgreSQL(StorageFactory & factory);
#endif #endif
#if USE_MYSQL || USE_LIBPQXX #if USE_MYSQL || USE_LIBPQXX
@ -121,6 +122,7 @@ void registerStorages()
#if USE_LIBPQXX #if USE_LIBPQXX
registerStoragePostgreSQL(factory); registerStoragePostgreSQL(factory);
registerStorageMaterializedPostgreSQL(factory);
#endif #endif
#if USE_MYSQL || USE_LIBPQXX #if USE_MYSQL || USE_LIBPQXX

View File

@ -1,6 +1,9 @@
#include <TableFunctions/TableFunctionPostgreSQL.h> #include <TableFunctions/TableFunctionPostgreSQL.h>
#if USE_LIBPQXX #if USE_LIBPQXX
#include <Databases/PostgreSQL/fetchPostgreSQLTableStructure.h>
#include <Storages/StoragePostgreSQL.h>
#include <Interpreters/evaluateConstantExpression.h> #include <Interpreters/evaluateConstantExpression.h>
#include <Parsers/ASTFunction.h> #include <Parsers/ASTFunction.h>
#include <Parsers/ASTLiteral.h> #include <Parsers/ASTLiteral.h>
@ -9,10 +12,8 @@
#include <Common/Exception.h> #include <Common/Exception.h>
#include <Common/parseAddress.h> #include <Common/parseAddress.h>
#include "registerTableFunctions.h" #include "registerTableFunctions.h"
#include <Databases/PostgreSQL/fetchPostgreSQLTableStructure.h>
#include <Common/quoteString.h> #include <Common/quoteString.h>
#include <Common/parseRemoteDescription.h> #include <Common/parseRemoteDescription.h>
#include <Storages/StoragePostgreSQL.h>
namespace DB namespace DB
@ -47,11 +48,12 @@ StoragePtr TableFunctionPostgreSQL::executeImpl(const ASTPtr & /*ast_function*/,
ColumnsDescription TableFunctionPostgreSQL::getActualTableStructure(ContextPtr context) const ColumnsDescription TableFunctionPostgreSQL::getActualTableStructure(ContextPtr context) const
{ {
const bool use_nulls = context->getSettingsRef().external_table_functions_use_nulls; const bool use_nulls = context->getSettingsRef().external_table_functions_use_nulls;
auto connection_holder = connection_pool->get();
auto columns = fetchPostgreSQLTableStructure( auto columns = fetchPostgreSQLTableStructure(
connection_pool->get(), connection_holder->get(),
remote_table_schema.empty() ? doubleQuoteString(remote_table_name) remote_table_schema.empty() ? doubleQuoteString(remote_table_name)
: doubleQuoteString(remote_table_schema) + '.' + doubleQuoteString(remote_table_name), : doubleQuoteString(remote_table_schema) + '.' + doubleQuoteString(remote_table_name),
use_nulls); use_nulls).columns;
return ColumnsDescription{*columns}; return ColumnsDescription{*columns};
} }

Some files were not shown because too many files have changed in this diff Show More