Merge remote-tracking branch 'blessed/master' into parallel_replicas_cte_2

This commit is contained in:
Raúl Marín 2024-03-01 12:48:23 +01:00
commit 77752a63c7
189 changed files with 5627 additions and 3072 deletions

View File

@ -157,7 +157,7 @@ if (TARGET ch_contrib::zlib)
endif()
if (TARGET ch_contrib::zstd)
target_compile_definitions(_libarchive PUBLIC HAVE_ZSTD_H=1 HAVE_LIBZSTD=1)
target_compile_definitions(_libarchive PUBLIC HAVE_ZSTD_H=1 HAVE_LIBZSTD=1 HAVE_LIBZSTD_COMPRESSOR=1)
target_link_libraries(_libarchive PRIVATE ch_contrib::zstd)
endif()

View File

@ -25,21 +25,21 @@ public:
static const uint32_t bits = 128;
// Constructor initializes the same as Initialize()
MetroHash128(const uint64_t seed=0);
explicit MetroHash128(const uint64_t seed=0);
// Initializes internal state for new hash with optional seed
void Initialize(const uint64_t seed=0);
// Update the hash state with a string of bytes. If the length
// is sufficiently long, the implementation switches to a bulk
// hashing algorithm directly on the argument buffer for speed.
void Update(const uint8_t * buffer, const uint64_t length);
// Constructs the final hash and writes it to the argument buffer.
// After a hash is finalized, this instance must be Initialized()-ed
// again or the behavior of Update() and Finalize() is undefined.
void Finalize(uint8_t * const hash);
// A non-incremental function implementation. This can be significantly
// faster than the incremental implementation for some usage patterns.
static void Hash(const uint8_t * buffer, const uint64_t length, uint8_t * const hash, const uint64_t seed=0);
@ -57,7 +57,7 @@ private:
static const uint64_t k1 = 0x8648DBDB;
static const uint64_t k2 = 0x7BDEC03B;
static const uint64_t k3 = 0x2F5870A5;
struct { uint64_t v[4]; } state;
struct { uint8_t b[32]; } input;
uint64_t bytes;

View File

@ -25,21 +25,21 @@ public:
static const uint32_t bits = 64;
// Constructor initializes the same as Initialize()
MetroHash64(const uint64_t seed=0);
explicit MetroHash64(const uint64_t seed=0);
// Initializes internal state for new hash with optional seed
void Initialize(const uint64_t seed=0);
// Update the hash state with a string of bytes. If the length
// is sufficiently long, the implementation switches to a bulk
// hashing algorithm directly on the argument buffer for speed.
void Update(const uint8_t * buffer, const uint64_t length);
// Constructs the final hash and writes it to the argument buffer.
// After a hash is finalized, this instance must be Initialized()-ed
// again or the behavior of Update() and Finalize() is undefined.
void Finalize(uint8_t * const hash);
// A non-incremental function implementation. This can be significantly
// faster than the incremental implementation for some usage patterns.
static void Hash(const uint8_t * buffer, const uint64_t length, uint8_t * const hash, const uint64_t seed=0);
@ -57,7 +57,7 @@ private:
static const uint64_t k1 = 0xA2AA033B;
static const uint64_t k2 = 0x62992FC1;
static const uint64_t k3 = 0x30BC5B29;
struct { uint64_t v[4]; } state;
struct { uint8_t b[32]; } input;
uint64_t bytes;

View File

@ -19,6 +19,8 @@ CREATE TABLE azure_blob_storage_table (name String, value UInt32)
### Engine parameters
- `endpoint` — AzureBlobStorage endpoint URL with container & prefix. Optionally can contain account_name if the authentication method used needs it. (http://azurite1:{port}/[account_name]{container_name}/{data_prefix}) or these parameters can be provided separately using storage_account_url, account_name & container. For specifying prefix, endpoint should be used.
- `endpoint_contains_account_name` - This flag is used to specify if endpoint contains account_name as it is only needed for certain authentication methods. (Default : true)
- `connection_string|storage_account_url` — connection_string includes account name & key ([Create connection string](https://learn.microsoft.com/en-us/azure/storage/common/storage-configure-connection-string?toc=%2Fazure%2Fstorage%2Fblobs%2Ftoc.json&bc=%2Fazure%2Fstorage%2Fblobs%2Fbreadcrumb%2Ftoc.json#configure-a-connection-string-for-an-azure-storage-account)) or you could also provide the storage account url here and account name & account key as separate parameters (see parameters account_name & account_key)
- `container_name` - Container name
- `blobpath` - file path. Supports following wildcards in readonly mode: `*`, `**`, `?`, `{abc,def}` and `{N..M}` where `N`, `M` — numbers, `'abc'`, `'def'` — strings.

View File

@ -1242,7 +1242,9 @@ Configuration markup:
```
Connection parameters:
* `storage_account_url` - **Required**, Azure Blob Storage account URL, like `http://account.blob.core.windows.net` or `http://azurite1:10000/devstoreaccount1`.
* `endpoint` — AzureBlobStorage endpoint URL with container & prefix. Optionally can contain account_name if the authentication method used needs it. (`http://account.blob.core.windows.net:{port}/[account_name]{container_name}/{data_prefix}`) or these parameters can be provided separately using storage_account_url, account_name & container. For specifying prefix, endpoint should be used.
* `endpoint_contains_account_name` - This flag is used to specify if endpoint contains account_name as it is only needed for certain authentication methods. (Default : true)
* `storage_account_url` - Required if endpoint is not specified, Azure Blob Storage account URL, like `http://account.blob.core.windows.net` or `http://azurite1:10000/devstoreaccount1`.
* `container_name` - Target container name, defaults to `default-container`.
* `container_already_exists` - If set to `false`, a new container `container_name` is created in the storage account, if set to `true`, disk connects to the container directly, and if left unset, disk connects to the account, checks if the container `container_name` exists, and creates it if it doesn't exist yet.

View File

@ -168,6 +168,28 @@ RESTORE TABLE test.table PARTITIONS '2', '3'
FROM Disk('backups', 'filename.zip')
```
### Backups as tar archives
Backups can also be stored as tar archives. The functionality is the same as for zip, except that a password is not supported.
Write a backup as a tar:
```
BACKUP TABLE test.table TO Disk('backups', '1.tar')
```
Corresponding restore:
```
RESTORE TABLE test.table FROM Disk('backups', '1.tar')
```
To change the compression method, the correct file suffix should be appended to the backup name. I.E to compress the tar archive using gzip:
```
BACKUP TABLE test.table TO Disk('backups', '1.tar.gz')
```
The supported compression file suffixes are `tar.gz`, `.tgz` `tar.bz2`, `tar.lzma`, `.tar.zst`, `.tzst` and `.tar.xz`.
### Check the status of backups
The backup command returns an `id` and `status`, and that `id` can be used to get the status of the backup. This is very useful to check the progress of long ASYNC backups. The example below shows a failure that happened when trying to overwrite an existing backup file:

View File

@ -200,17 +200,13 @@ Type: Bool
Default: 0
## dns_cache_max_size
## dns_cache_max_entries
Internal DNS cache max size in bytes.
:::note
ClickHouse also has a reverse cache, so the actual memory usage could be twice as much.
:::
Internal DNS cache max entries.
Type: UInt64
Default: 1024
Default: 10000
## dns_cache_update_period

View File

@ -33,6 +33,6 @@ Result:
**See also**
- [disable_internal_dns_cache setting](../../operations/server-configuration-parameters/settings.md#disable_internal_dns_cache)
- [dns_cache_max_size setting](../../operations/server-configuration-parameters/settings.md#dns_cache_max_size)
- [dns_cache_max_entries setting](../../operations/server-configuration-parameters/settings.md#dns_cache_max_entries)
- [dns_cache_update_period setting](../../operations/server-configuration-parameters/settings.md#dns_cache_update_period)
- [dns_max_consecutive_failures setting](../../operations/server-configuration-parameters/settings.md#dns_max_consecutive_failures)

View File

@ -167,6 +167,10 @@ Result:
└──────────────────────────────────────────┴───────────────────────────────┘
```
## byteSlice(s, offset, length)
See function [substring](string-functions.md#substring).
## bitTest
Takes any integer and converts it into [binary form](https://en.wikipedia.org/wiki/Binary_number), returns the value of a bit at specified position. The countdown starts from 0 from the right to the left.

View File

@ -394,8 +394,7 @@ Result:
## toYear
Converts a date or date with time to the year number (AD) as `UInt16` value.
Returns the year component (AD) of a date or date with time.
**Syntax**
@ -431,7 +430,7 @@ Result:
## toQuarter
Converts a date or date with time to the quarter number (1-4) as `UInt8` value.
Returns the quarter (1-4) of a date or date with time.
**Syntax**
@ -465,10 +464,9 @@ Result:
└──────────────────────────────────────────────┘
```
## toMonth
Converts a date or date with time to the month number (1-12) as `UInt8` value.
Returns the month component (1-12) of a date or date with time.
**Syntax**
@ -504,7 +502,7 @@ Result:
## toDayOfYear
Converts a date or date with time to the number of the day of the year (1-366) as `UInt16` value.
Returns the number of the day within the year (1-366) of a date or date with time.
**Syntax**
@ -540,7 +538,7 @@ Result:
## toDayOfMonth
Converts a date or date with time to the number of the day in the month (1-31) as `UInt8` value.
Returns the number of the day within the month (1-31) of a date or date with time.
**Syntax**
@ -576,7 +574,7 @@ Result:
## toDayOfWeek
Converts a date or date with time to the number of the day in the week as `UInt8` value.
Returns the number of the day within the week of a date or date with time.
The two-argument form of `toDayOfWeek()` enables you to specify whether the week starts on Monday or Sunday, and whether the return value should be in the range from 0 to 6 or 1 to 7. If the mode argument is omitted, the default mode is 0. The time zone of the date can be specified as the third argument.
@ -627,7 +625,7 @@ Result:
## toHour
Converts a date with time to the number of the hour in 24-hour time (0-23) as `UInt8` value.
Returns the hour component (0-24) of a date with time.
Assumes that if clocks are moved ahead, it is by one hour and occurs at 2 a.m., and if clocks are moved back, it is by one hour and occurs at 3 a.m. (which is not always exactly when it occurs - it depends on the timezone).
@ -641,7 +639,7 @@ Alias: `HOUR`
**Arguments**
- `value` - a [Date](../data-types/date.md), [Date32](../data-types/date32.md), [DateTime](../data-types/datetime.md) or [DateTime64](../data-types/datetime64.md)
- `value` - a [DateTime](../data-types/datetime.md) or [DateTime64](../data-types/datetime64.md)
**Returned value**
@ -665,7 +663,7 @@ Result:
## toMinute
Converts a date with time to the number of the minute of the hour (0-59) as `UInt8` value.
Returns the minute component (0-59) a date with time.
**Syntax**
@ -677,7 +675,7 @@ Alias: `MINUTE`
**Arguments**
- `value` - a [Date](../data-types/date.md), [Date32](../data-types/date32.md), [DateTime](../data-types/datetime.md) or [DateTime64](../data-types/datetime64.md)
- `value` - a [DateTime](../data-types/datetime.md) or [DateTime64](../data-types/datetime64.md)
**Returned value**
@ -701,7 +699,7 @@ Result:
## toSecond
Converts a date with time to the second in the minute (0-59) as `UInt8` value. Leap seconds are not considered.
Returns the second component (0-59) of a date with time. Leap seconds are not considered.
**Syntax**
@ -713,7 +711,7 @@ Alias: `SECOND`
**Arguments**
- `value` - a [Date](../data-types/date.md), [Date32](../data-types/date32.md), [DateTime](../data-types/datetime.md) or [DateTime64](../data-types/datetime64.md)
- `value` - a [DateTime](../data-types/datetime.md) or [DateTime64](../data-types/datetime64.md)
**Returned value**
@ -735,6 +733,40 @@ Result:
└─────────────────────────────────────────────┘
```
## toMillisecond
Returns the millisecond component (0-999) of a date with time.
**Syntax**
```sql
toMillisecond(value)
```
*Arguments**
- `value` - [DateTime](../data-types/datetime.md) or [DateTime64](../data-types/datetime64.md)
Alias: `MILLISECOND`
```sql
SELECT toMillisecond(toDateTime64('2023-04-21 10:20:30.456', 3))
```
Result:
```response
┌──toMillisecond(toDateTime64('2023-04-21 10:20:30.456', 3))─┐
│ 456 │
└────────────────────────────────────────────────────────────┘
```
**Returned value**
- The millisecond in the minute (0 - 59) of the given date/time
Type: `UInt16`
## toUnixTimestamp
Converts a string, a date or a date with time to the [Unix Timestamp](https://en.wikipedia.org/wiki/Unix_time) in `UInt32` representation.

View File

@ -558,6 +558,7 @@ substring(s, offset[, length])
Alias:
- `substr`
- `mid`
- `byteSlice`
**Arguments**

View File

@ -68,7 +68,7 @@ RELOAD FUNCTION [ON CLUSTER cluster_name] function_name
Clears ClickHouses internal DNS cache. Sometimes (for old ClickHouse versions) it is necessary to use this command when changing the infrastructure (changing the IP address of another ClickHouse server or the server used by dictionaries).
For more convenient (automatic) cache management, see disable_internal_dns_cache, dns_cache_max_size, dns_cache_update_period parameters.
For more convenient (automatic) cache management, see disable_internal_dns_cache, dns_cache_max_entries, dns_cache_update_period parameters.
## DROP MARK CACHE

View File

@ -1774,7 +1774,7 @@ try
}
else
{
DNSResolver::instance().setCacheMaxSize(server_settings.dns_cache_max_size);
DNSResolver::instance().setCacheMaxEntries(server_settings.dns_cache_max_entries);
/// Initialize a watcher periodically updating DNS cache
dns_cache_updater = std::make_unique<DNSCacheUpdater>(

View File

@ -24,7 +24,7 @@ class HTTPAuthClient
public:
using Result = TResponseParser::Result;
HTTPAuthClient(const HTTPAuthClientParams & params, const TResponseParser & parser_ = TResponseParser{})
explicit HTTPAuthClient(const HTTPAuthClientParams & params, const TResponseParser & parser_ = TResponseParser{})
: timeouts{params.timeouts}
, max_tries{params.max_tries}
, retry_initial_backoff_ms{params.retry_initial_backoff_ms}

View File

@ -1,5 +1,5 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/HelpersMinMaxAny.h>
#include <AggregateFunctions/SingleValueData.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <base/defines.h>
@ -11,219 +11,347 @@ struct Settings;
namespace ErrorCodes
{
extern const int INCORRECT_DATA;
extern const int LOGICAL_ERROR;
extern const int NOT_IMPLEMENTED;
}
namespace
{
struct AggregateFunctionAnyRespectNullsData
template <typename Data>
class AggregateFunctionAny final : public IAggregateFunctionDataHelper<Data, AggregateFunctionAny<Data>>
{
enum Status : UInt8
{
NotSet = 1,
SetNull = 2,
SetOther = 3
};
Status status = Status::NotSet;
Field value;
bool isSet() const { return status != Status::NotSet; }
void setNull() { status = Status::SetNull; }
void setOther() { status = Status::SetOther; }
};
template <bool First>
class AggregateFunctionAnyRespectNulls final
: public IAggregateFunctionDataHelper<AggregateFunctionAnyRespectNullsData, AggregateFunctionAnyRespectNulls<First>>
{
public:
using Data = AggregateFunctionAnyRespectNullsData;
private:
SerializationPtr serialization;
const bool returns_nullable_type = false;
explicit AggregateFunctionAnyRespectNulls(const DataTypePtr & type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionAnyRespectNulls<First>>({type}, {}, type)
, serialization(type->getDefaultSerialization())
, returns_nullable_type(type->isNullable())
public:
explicit AggregateFunctionAny(const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionAny<Data>>(argument_types_, {}, argument_types_[0])
, serialization(this->result_type->getDefaultSerialization())
{
}
String getName() const override
{
if constexpr (First)
return "any_respect_nulls";
else
return "anyLast_respect_nulls";
}
String getName() const override { return "any"; }
bool allocatesMemoryInArena() const override { return false; }
void addNull(AggregateDataPtr __restrict place) const
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
chassert(returns_nullable_type);
auto & d = this->data(place);
if (First && d.isSet())
return;
d.setNull();
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{
if (columns[0]->isNullable())
{
if (columns[0]->isNullAt(row_num))
return addNull(place);
}
auto & d = this->data(place);
if (First && d.isSet())
return;
d.setOther();
columns[0]->get(row_num, d.value);
}
void addManyDefaults(AggregateDataPtr __restrict place, const IColumn ** columns, size_t, Arena * arena) const override
{
if (columns[0]->isNullable())
addNull(place);
else
add(place, columns, 0, arena);
if (!this->data(place).has())
this->data(place).set(*columns[0], row_num, arena);
}
void addBatchSinglePlace(
size_t row_begin, size_t row_end, AggregateDataPtr place, const IColumn ** columns, Arena * arena, ssize_t if_argument_pos)
const override
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
Arena * arena,
ssize_t if_argument_pos) const override
{
if (this->data(place).has() || row_begin >= row_end)
return;
if (if_argument_pos >= 0)
{
const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
size_t size = row_end - row_begin;
for (size_t i = 0; i < size; ++i)
const auto & if_map = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
for (size_t i = row_begin; i < row_end; i++)
{
size_t pos = First ? row_begin + i : row_end - 1 - i;
if (flags[pos])
if (if_map.data()[i] != 0)
{
add(place, columns, pos, arena);
break;
this->data(place).set(*columns[0], i, arena);
return;
}
}
}
else if (row_begin < row_end)
else
{
size_t pos = First ? row_begin : row_end - 1;
add(place, columns, pos, arena);
this->data(place).set(*columns[0], row_begin, arena);
}
}
void addBatchSinglePlaceNotNull(
size_t, size_t, AggregateDataPtr __restrict, const IColumn **, const UInt8 *, Arena *, ssize_t) const override
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
const UInt8 * __restrict null_map,
Arena * arena,
ssize_t if_argument_pos) const override
{
/// This should not happen since it means somebody else has preprocessed the data (NULLs or IFs) and might
/// have discarded values that we need (NULLs)
throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "AggregateFunctionAnyRespectNulls::addBatchSinglePlaceNotNull called");
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
{
auto & d = this->data(place);
if (First && d.isSet())
if (this->data(place).has() || row_begin >= row_end)
return;
auto & other = this->data(rhs);
if (other.isSet())
if (if_argument_pos >= 0)
{
d.status = other.status;
d.value = other.value;
const auto & if_map = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
for (size_t i = row_begin; i < row_end; i++)
{
if (if_map.data()[i] != 0 && null_map[i] == 0)
{
this->data(place).set(*columns[0], i, arena);
return;
}
}
}
else
{
for (size_t i = row_begin; i < row_end; i++)
{
if (null_map[i] == 0)
{
this->data(place).set(*columns[0], i, arena);
return;
}
}
}
}
void addManyDefaults(AggregateDataPtr __restrict place, const IColumn ** columns, size_t, Arena * arena) const override
{
if (!this->data(place).has())
this->data(place).set(*columns[0], 0, arena);
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
if (!this->data(place).has())
this->data(place).set(this->data(rhs), arena);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
auto & d = this->data(place);
UInt8 k = d.status;
writeBinaryLittleEndian<UInt8>(k, buf);
if (k == Data::Status::SetOther)
serialization->serializeBinary(d.value, buf, {});
this->data(place).write(buf, *serialization);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
auto & d = this->data(place);
UInt8 k = Data::Status::NotSet;
readBinaryLittleEndian<UInt8>(k, buf);
d.status = static_cast<Data::Status>(k);
if (d.status == Data::Status::NotSet)
return;
else if (d.status == Data::Status::SetNull)
{
if (!returns_nullable_type)
throw Exception(ErrorCodes::INCORRECT_DATA, "Incorrect type (NULL) in non-nullable {}State", getName());
return;
}
else if (d.status == Data::Status::SetOther)
serialization->deserializeBinary(d.value, buf, {});
else
throw Exception(ErrorCodes::INCORRECT_DATA, "Incorrect type ({}) in {}State", static_cast<Int8>(k), getName());
this->data(place).read(buf, *serialization, arena);
}
bool allocatesMemoryInArena() const override { return Data::allocatesMemoryInArena(); }
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
auto & d = this->data(place);
if (d.status == Data::Status::SetOther)
to.insert(d.value);
else
to.insertDefault();
this->data(place).insertResultInto(to);
}
AggregateFunctionPtr getOwnNullAdapter(
const AggregateFunctionPtr & original_function,
const DataTypes & /*arguments*/,
const Array & /*params*/,
const AggregateFunctionProperties & /*properties*/) const override
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
return original_function;
if constexpr (!Data::is_compilable)
return false;
else
return Data::isCompilable(*this->argument_types[0]);
}
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
if constexpr (Data::is_compilable)
Data::compileCreate(builder, aggregate_data_ptr);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override
{
if constexpr (Data::is_compilable)
Data::compileAny(builder, aggregate_data_ptr, arguments[0].value);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
void
compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
{
if constexpr (Data::is_compilable)
Data::compileAnyMerge(builder, aggregate_data_dst_ptr, aggregate_data_src_ptr);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
if constexpr (Data::is_compilable)
return Data::compileGetResult(builder, aggregate_data_ptr);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
#endif
};
template <bool First>
IAggregateFunction * createAggregateFunctionSingleValueRespectNulls(
const String & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
AggregateFunctionPtr
createAggregateFunctionAny(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
return new AggregateFunctionAnyRespectNulls<First>(argument_types[0]);
return AggregateFunctionPtr(
createAggregateFunctionSingleValue<AggregateFunctionAny, /* unary */ true>(name, argument_types, parameters, settings));
}
AggregateFunctionPtr createAggregateFunctionAny(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValue<AggregateFunctionsSingleValue, AggregateFunctionAnyData>(name, argument_types, parameters, settings));
}
AggregateFunctionPtr createAggregateFunctionAnyRespectNulls(
template <typename Data>
class AggregateFunctionAnyLast final : public IAggregateFunctionDataHelper<Data, AggregateFunctionAnyLast<Data>>
{
private:
SerializationPtr serialization;
public:
explicit AggregateFunctionAnyLast(const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionAnyLast<Data>>(argument_types_, {}, argument_types_[0])
, serialization(this->result_type->getDefaultSerialization())
{
}
String getName() const override { return "anyLast"; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
this->data(place).set(*columns[0], row_num, arena);
}
void addBatchSinglePlace(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
Arena * arena,
ssize_t if_argument_pos) const override
{
if (row_begin >= row_end)
return;
size_t batch_size = row_end - row_begin;
if (if_argument_pos >= 0)
{
const auto & if_map = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
for (size_t i = 0; i < batch_size; i++)
{
size_t pos = (row_end - 1) - i;
if (if_map.data()[pos] != 0)
{
this->data(place).set(*columns[0], pos, arena);
return;
}
}
}
else
{
this->data(place).set(*columns[0], row_end - 1, arena);
}
}
void addBatchSinglePlaceNotNull(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
const UInt8 * __restrict null_map,
Arena * arena,
ssize_t if_argument_pos) const override
{
if (row_begin >= row_end)
return;
size_t batch_size = row_end - row_begin;
if (if_argument_pos >= 0)
{
const auto & if_map = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
for (size_t i = 0; i < batch_size; i++)
{
size_t pos = (row_end - 1) - i;
if (if_map.data()[pos] != 0 && null_map[pos] == 0)
{
this->data(place).set(*columns[0], pos, arena);
return;
}
}
}
else
{
for (size_t i = 0; i < batch_size; i++)
{
size_t pos = (row_end - 1) - i;
if (null_map[pos] == 0)
{
this->data(place).set(*columns[0], pos, arena);
return;
}
}
}
}
void addManyDefaults(AggregateDataPtr __restrict place, const IColumn ** columns, size_t, Arena * arena) const override
{
this->data(place).set(*columns[0], 0, arena);
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
this->data(place).set(this->data(rhs), arena);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
this->data(place).write(buf, *serialization);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
this->data(place).read(buf, *serialization, arena);
}
bool allocatesMemoryInArena() const override { return Data::allocatesMemoryInArena(); }
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
this->data(place).insertResultInto(to);
}
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
if constexpr (!Data::is_compilable)
return false;
else
return Data::isCompilable(*this->argument_types[0]);
}
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
if constexpr (Data::is_compilable)
Data::compileCreate(builder, aggregate_data_ptr);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override
{
if constexpr (Data::is_compilable)
Data::compileAnyLast(builder, aggregate_data_ptr, arguments[0].value);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
void
compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
{
if constexpr (Data::is_compilable)
Data::compileAnyLastMerge(builder, aggregate_data_dst_ptr, aggregate_data_src_ptr);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
if constexpr (Data::is_compilable)
return Data::compileGetResult(builder, aggregate_data_ptr);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
#endif
};
AggregateFunctionPtr createAggregateFunctionAnyLast(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValueRespectNulls<true>(name, argument_types, parameters, settings));
}
AggregateFunctionPtr createAggregateFunctionAnyLast(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValue<AggregateFunctionsSingleValue, AggregateFunctionAnyLastData>(name, argument_types, parameters, settings));
}
AggregateFunctionPtr createAggregateFunctionAnyLastRespectNulls(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValueRespectNulls<false>(name, argument_types, parameters, settings));
}
AggregateFunctionPtr createAggregateFunctionAnyHeavy(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValue<AggregateFunctionsSingleValue, AggregateFunctionAnyHeavyData>(name, argument_types, parameters, settings));
return AggregateFunctionPtr(
createAggregateFunctionSingleValue<AggregateFunctionAnyLast, /* unary */ true>(name, argument_types, parameters, settings));
}
}
@ -231,27 +359,11 @@ AggregateFunctionPtr createAggregateFunctionAnyHeavy(const std::string & name, c
void registerAggregateFunctionsAny(AggregateFunctionFactory & factory)
{
AggregateFunctionProperties default_properties = {.returns_default_when_only_null = false, .is_order_dependent = true};
AggregateFunctionProperties default_properties_for_respect_nulls
= {.returns_default_when_only_null = false, .is_order_dependent = true, .is_window_function = true};
factory.registerFunction("any", {createAggregateFunctionAny, default_properties});
factory.registerAlias("any_value", "any", AggregateFunctionFactory::CaseInsensitive);
factory.registerAlias("first_value", "any", AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("any_respect_nulls", {createAggregateFunctionAnyRespectNulls, default_properties_for_respect_nulls});
factory.registerAlias("any_value_respect_nulls", "any_respect_nulls", AggregateFunctionFactory::CaseInsensitive);
factory.registerAlias("first_value_respect_nulls", "any_respect_nulls", AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("anyLast", {createAggregateFunctionAnyLast, default_properties});
factory.registerAlias("last_value", "anyLast", AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("anyLast_respect_nulls", {createAggregateFunctionAnyLastRespectNulls, default_properties_for_respect_nulls});
factory.registerAlias("last_value_respect_nulls", "anyLast_respect_nulls", AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("anyHeavy", {createAggregateFunctionAnyHeavy, default_properties});
factory.registerNullsActionTransformation("any", "any_respect_nulls");
factory.registerNullsActionTransformation("anyLast", "anyLast_respect_nulls");
}
}

View File

@ -0,0 +1,168 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/SingleValueData.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <base/defines.h>
namespace DB
{
struct Settings;
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
namespace
{
/** Implement 'heavy hitters' algorithm.
* Selects most frequent value if its frequency is more than 50% in each thread of execution.
* Otherwise, selects some arbitrary value.
* http://www.cs.umd.edu/~samir/498/karp.pdf
*/
struct AggregateFunctionAnyHeavyData
{
using Self = AggregateFunctionAnyHeavyData;
private:
SingleValueDataBaseMemoryBlock v_data;
UInt64 counter = 0;
public:
[[noreturn]] explicit AggregateFunctionAnyHeavyData()
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "AggregateFunctionAnyHeavyData initialized empty");
}
explicit AggregateFunctionAnyHeavyData(TypeIndex value_type) { generateSingleValueFromTypeIndex(value_type, v_data); }
~AggregateFunctionAnyHeavyData() { data().~SingleValueDataBase(); }
SingleValueDataBase & data() { return v_data.get(); }
const SingleValueDataBase & data() const { return v_data.get(); }
void add(const IColumn & column, size_t row_num, Arena * arena)
{
if (data().isEqualTo(column, row_num))
{
++counter;
}
else if (counter == 0)
{
data().set(column, row_num, arena);
++counter;
}
else
{
--counter;
}
}
void add(const Self & to, Arena * arena)
{
if (!to.data().has())
return;
if (data().isEqualTo(to.data()))
counter += to.counter;
else if (!data().has() || counter < to.counter)
data().set(to.data(), arena);
else
counter -= to.counter;
}
void addManyDefaults(const IColumn & column, size_t length, Arena * arena)
{
for (size_t i = 0; i < length; ++i)
add(column, 0, arena);
}
void write(WriteBuffer & buf, const ISerialization & serialization) const
{
data().write(buf, serialization);
writeBinaryLittleEndian(counter, buf);
}
void read(ReadBuffer & buf, const ISerialization & serialization, Arena * arena)
{
data().read(buf, serialization, arena);
readBinaryLittleEndian(counter, buf);
}
void insertResultInto(IColumn & to) const { data().insertResultInto(to); }
};
class AggregateFunctionAnyHeavy final : public IAggregateFunctionDataHelper<AggregateFunctionAnyHeavyData, AggregateFunctionAnyHeavy>
{
private:
SerializationPtr serialization;
const TypeIndex value_type_index;
public:
explicit AggregateFunctionAnyHeavy(const DataTypePtr & type)
: IAggregateFunctionDataHelper<AggregateFunctionAnyHeavyData, AggregateFunctionAnyHeavy>({type}, {}, type)
, serialization(type->getDefaultSerialization())
, value_type_index(WhichDataType(type).idx)
{
}
void create(AggregateDataPtr __restrict place) const override { new (place) AggregateFunctionAnyHeavyData(value_type_index); }
String getName() const override { return "anyHeavy"; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
this->data(place).add(*columns[0], row_num, arena);
}
void addManyDefaults(AggregateDataPtr __restrict place, const IColumn ** columns, size_t, Arena * arena) const override
{
this->data(place).addManyDefaults(*columns[0], 0, arena);
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
this->data(place).add(this->data(rhs), arena);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
this->data(place).write(buf, *serialization);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
this->data(place).read(buf, *serialization, arena);
}
bool allocatesMemoryInArena() const override { return singleValueTypeAllocatesMemoryInArena(value_type_index); }
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
this->data(place).insertResultInto(to);
}
};
AggregateFunctionPtr
createAggregateFunctionAnyHeavy(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
const DataTypePtr & res_type = argument_types[0];
return AggregateFunctionPtr(new AggregateFunctionAnyHeavy(res_type));
}
}
void registerAggregateFunctionAnyHeavy(AggregateFunctionFactory & factory)
{
AggregateFunctionProperties default_properties = {.returns_default_when_only_null = false, .is_order_dependent = true};
factory.registerFunction("anyHeavy", {createAggregateFunctionAnyHeavy, default_properties});
}
}

View File

@ -0,0 +1,235 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/SingleValueData.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <base/defines.h>
namespace DB
{
struct Settings;
namespace ErrorCodes
{
extern const int INCORRECT_DATA;
extern const int LOGICAL_ERROR;
}
namespace
{
struct AggregateFunctionAnyRespectNullsData
{
enum class Status : UInt8
{
NotSet = 1,
SetNull = 2,
SetOther = 3
};
Status status = Status::NotSet;
Field value;
bool isSet() const { return status != Status::NotSet; }
void setNull() { status = Status::SetNull; }
void setOther() { status = Status::SetOther; }
};
template <bool First>
class AggregateFunctionAnyRespectNulls final
: public IAggregateFunctionDataHelper<AggregateFunctionAnyRespectNullsData, AggregateFunctionAnyRespectNulls<First>>
{
public:
using Data = AggregateFunctionAnyRespectNullsData;
SerializationPtr serialization;
const bool returns_nullable_type = false;
explicit AggregateFunctionAnyRespectNulls(const DataTypePtr & type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionAnyRespectNulls<First>>({type}, {}, type)
, serialization(type->getDefaultSerialization())
, returns_nullable_type(type->isNullable())
{
}
String getName() const override
{
if constexpr (First)
return "any_respect_nulls";
else
return "anyLast_respect_nulls";
}
bool allocatesMemoryInArena() const override { return false; }
void addNull(AggregateDataPtr __restrict place) const
{
chassert(returns_nullable_type);
auto & d = this->data(place);
if (First && d.isSet())
return;
d.setNull();
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{
if (columns[0]->isNullable())
{
if (columns[0]->isNullAt(row_num))
return addNull(place);
}
auto & d = this->data(place);
if (First && d.isSet())
return;
d.setOther();
columns[0]->get(row_num, d.value);
}
void addManyDefaults(AggregateDataPtr __restrict place, const IColumn ** columns, size_t, Arena * arena) const override
{
if (columns[0]->isNullable())
addNull(place);
else
add(place, columns, 0, arena);
}
void addBatchSinglePlace(
size_t row_begin, size_t row_end, AggregateDataPtr place, const IColumn ** columns, Arena * arena, ssize_t if_argument_pos)
const override
{
if (if_argument_pos >= 0)
{
const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
size_t size = row_end - row_begin;
for (size_t i = 0; i < size; ++i)
{
size_t pos = First ? row_begin + i : row_end - 1 - i;
if (flags[pos])
{
add(place, columns, pos, arena);
break;
}
}
}
else if (row_begin < row_end)
{
size_t pos = First ? row_begin : row_end - 1;
add(place, columns, pos, arena);
}
}
void addBatchSinglePlaceNotNull(
size_t, size_t, AggregateDataPtr __restrict, const IColumn **, const UInt8 *, Arena *, ssize_t) const override
{
/// This should not happen since it means somebody else has preprocessed the data (NULLs or IFs) and might
/// have discarded values that we need (NULLs)
throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "AggregateFunctionAnyRespectNulls::addBatchSinglePlaceNotNull called");
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
{
auto & d = this->data(place);
if (First && d.isSet())
return;
auto & other = this->data(rhs);
if (other.isSet())
{
d.status = other.status;
d.value = other.value;
}
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
auto & d = this->data(place);
UInt8 k = static_cast<UInt8>(d.status);
writeBinaryLittleEndian<UInt8>(k, buf);
if (d.status == Data::Status::SetOther)
serialization->serializeBinary(d.value, buf, {});
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
{
auto & d = this->data(place);
UInt8 k = 0;
readBinaryLittleEndian<UInt8>(k, buf);
d.status = static_cast<Data::Status>(k);
if (d.status == Data::Status::NotSet)
return;
else if (d.status == Data::Status::SetNull)
{
if (!returns_nullable_type)
throw Exception(ErrorCodes::INCORRECT_DATA, "Incorrect type (NULL) in non-nullable {}State", getName());
return;
}
else if (d.status == Data::Status::SetOther)
{
serialization->deserializeBinary(d.value, buf, {});
return;
}
throw Exception(ErrorCodes::INCORRECT_DATA, "Incorrect type ({}) in {}State", static_cast<Int8>(k), getName());
}
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
auto & d = this->data(place);
if (d.status == Data::Status::SetOther)
to.insert(d.value);
else
to.insertDefault();
}
AggregateFunctionPtr getOwnNullAdapter(
const AggregateFunctionPtr & original_function,
const DataTypes & /*arguments*/,
const Array & /*params*/,
const AggregateFunctionProperties & /*properties*/) const override
{
return original_function;
}
};
template <bool First>
IAggregateFunction * createAggregateFunctionSingleValueRespectNulls(
const String & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
return new AggregateFunctionAnyRespectNulls<First>(argument_types[0]);
}
AggregateFunctionPtr createAggregateFunctionAnyRespectNulls(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValueRespectNulls<true>(name, argument_types, parameters, settings));
}
AggregateFunctionPtr createAggregateFunctionAnyLastRespectNulls(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValueRespectNulls<false>(name, argument_types, parameters, settings));
}
}
void registerAggregateFunctionsAnyRespectNulls(AggregateFunctionFactory & factory)
{
AggregateFunctionProperties default_properties_for_respect_nulls
= {.returns_default_when_only_null = false, .is_order_dependent = true, .is_window_function = true};
factory.registerFunction("any_respect_nulls", {createAggregateFunctionAnyRespectNulls, default_properties_for_respect_nulls});
factory.registerAlias("any_value_respect_nulls", "any_respect_nulls", AggregateFunctionFactory::CaseInsensitive);
factory.registerAlias("first_value_respect_nulls", "any_respect_nulls", AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("anyLast_respect_nulls", {createAggregateFunctionAnyLastRespectNulls, default_properties_for_respect_nulls});
factory.registerAlias("last_value_respect_nulls", "anyLast_respect_nulls", AggregateFunctionFactory::CaseInsensitive);
/// Must happen after registering any and anyLast
factory.registerNullsActionTransformation("any", "any_respect_nulls");
factory.registerNullsActionTransformation("anyLast", "anyLast_respect_nulls");
}
}

View File

@ -1,107 +0,0 @@
#pragma once
#include <base/StringRef.h>
#include <DataTypes/IDataType.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <AggregateFunctions/AggregateFunctionMinMaxAny.h> // SingleValueDataString used in embedded compiler
namespace DB
{
struct Settings;
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int CORRUPTED_DATA;
}
/// For possible values for template parameters, see 'AggregateFunctionMinMaxAny.h'.
template <typename ResultData, typename ValueData>
struct AggregateFunctionArgMinMaxData
{
using ResultData_t = ResultData;
using ValueData_t = ValueData;
ResultData result; // the argument at which the minimum/maximum value is reached.
ValueData value; // value for which the minimum/maximum is calculated.
static bool allocatesMemoryInArena()
{
return ResultData::allocatesMemoryInArena() || ValueData::allocatesMemoryInArena();
}
};
/// Returns the first arg value found for the minimum/maximum value. Example: argMax(arg, value).
template <typename Data>
class AggregateFunctionArgMinMax final : public IAggregateFunctionDataHelper<Data, AggregateFunctionArgMinMax<Data>>
{
private:
const DataTypePtr & type_val;
const SerializationPtr serialization_res;
const SerializationPtr serialization_val;
using Base = IAggregateFunctionDataHelper<Data, AggregateFunctionArgMinMax<Data>>;
public:
AggregateFunctionArgMinMax(const DataTypePtr & type_res_, const DataTypePtr & type_val_)
: Base({type_res_, type_val_}, {}, type_res_)
, type_val(this->argument_types[1])
, serialization_res(type_res_->getDefaultSerialization())
, serialization_val(type_val->getDefaultSerialization())
{
if (!type_val->isComparable())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of second argument of "
"aggregate function {} because the values of that data type are not comparable",
type_val->getName(), getName());
}
String getName() const override
{
return StringRef(Data::ValueData_t::name()) == StringRef("min") ? "argMin" : "argMax";
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
if (this->data(place).value.changeIfBetter(*columns[1], row_num, arena))
this->data(place).result.change(*columns[0], row_num, arena);
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
if (this->data(place).value.changeIfBetter(this->data(rhs).value, arena))
this->data(place).result.change(this->data(rhs).result, arena);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
this->data(place).result.write(buf, *serialization_res);
this->data(place).value.write(buf, *serialization_val);
}
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
this->data(place).result.read(buf, *serialization_res, arena);
this->data(place).value.read(buf, *serialization_val, arena);
if (unlikely(this->data(place).value.has() != this->data(place).result.has()))
throw Exception(
ErrorCodes::CORRUPTED_DATA,
"Invalid state of the aggregate function {}: has_value ({}) != has_result ({})",
getName(),
this->data(place).value.has(),
this->data(place).result.has());
}
bool allocatesMemoryInArena() const override
{
return Data::allocatesMemoryInArena();
}
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
this->data(place).result.insertResultInto(to);
}
};
}

View File

@ -204,7 +204,7 @@ private:
class Adam : public IWeightsUpdater
{
public:
Adam(size_t num_params)
explicit Adam(size_t num_params)
{
beta1_powered = beta1;
beta2_powered = beta2;

View File

@ -1,238 +0,0 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/HelpersMinMaxAny.h>
#include <Common/Concepts.h>
#include <Common/findExtreme.h>
namespace DB
{
struct Settings;
namespace
{
template <typename Data>
class AggregateFunctionsSingleValueMax final : public AggregateFunctionsSingleValue<Data>
{
using Parent = AggregateFunctionsSingleValue<Data>;
public:
explicit AggregateFunctionsSingleValueMax(const DataTypePtr & type) : Parent(type) { }
/// Specializations for native numeric types
void addBatchSinglePlace(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
Arena * arena,
ssize_t if_argument_pos) const override;
void addBatchSinglePlaceNotNull(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
const UInt8 * __restrict null_map,
Arena * arena,
ssize_t if_argument_pos) const override;
};
// NOLINTBEGIN(bugprone-macro-parentheses)
#define SPECIALIZE(TYPE) \
template <> \
void AggregateFunctionsSingleValueMax<typename DB::AggregateFunctionMaxData<SingleValueDataFixed<TYPE>>>::addBatchSinglePlace( \
size_t row_begin, \
size_t row_end, \
AggregateDataPtr __restrict place, \
const IColumn ** __restrict columns, \
Arena *, \
ssize_t if_argument_pos) const \
{ \
const auto & column = assert_cast<const DB::AggregateFunctionMaxData<SingleValueDataFixed<TYPE>>::ColVecType &>(*columns[0]); \
std::optional<TYPE> opt; \
if (if_argument_pos >= 0) \
{ \
const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData(); \
opt = findExtremeMaxIf(column.getData().data(), flags.data(), row_begin, row_end); \
} \
else \
opt = findExtremeMax(column.getData().data(), row_begin, row_end); \
if (opt.has_value()) \
this->data(place).changeIfGreater(opt.value()); \
}
// NOLINTEND(bugprone-macro-parentheses)
FOR_BASIC_NUMERIC_TYPES(SPECIALIZE)
#undef SPECIALIZE
template <typename Data>
void AggregateFunctionsSingleValueMax<Data>::addBatchSinglePlace(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
Arena * arena,
ssize_t if_argument_pos) const
{
if constexpr (!is_any_of<typename Data::Impl, SingleValueDataString, SingleValueDataGeneric>)
{
/// Leave other numeric types (large integers, decimals, etc) to keep doing the comparison as it's
/// faster than doing a permutation
return Parent::addBatchSinglePlace(row_begin, row_end, place, columns, arena, if_argument_pos);
}
constexpr int nan_null_direction_hint = -1;
auto const & column = *columns[0];
if (if_argument_pos >= 0)
{
size_t index = row_begin;
const auto & if_flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
while (if_flags[index] == 0 && index < row_end)
index++;
if (index >= row_end)
return;
for (size_t i = index + 1; i < row_end; i++)
{
if ((if_flags[i] != 0) && (column.compareAt(i, index, column, nan_null_direction_hint) > 0))
index = i;
}
this->data(place).changeIfGreater(column, index, arena);
}
else
{
if (row_begin >= row_end)
return;
/// TODO: Introduce row_begin and row_end to getPermutation
if (row_begin != 0 || row_end != column.size())
{
size_t index = row_begin;
for (size_t i = index + 1; i < row_end; i++)
{
if (column.compareAt(i, index, column, nan_null_direction_hint) > 0)
index = i;
}
this->data(place).changeIfGreater(column, index, arena);
}
else
{
constexpr IColumn::PermutationSortDirection direction = IColumn::PermutationSortDirection::Descending;
constexpr IColumn::PermutationSortStability stability = IColumn::PermutationSortStability::Unstable;
IColumn::Permutation permutation;
constexpr UInt64 limit = 1;
column.getPermutation(direction, stability, limit, nan_null_direction_hint, permutation);
this->data(place).changeIfGreater(column, permutation[0], arena);
}
}
}
// NOLINTBEGIN(bugprone-macro-parentheses)
#define SPECIALIZE(TYPE) \
template <> \
void AggregateFunctionsSingleValueMax<typename DB::AggregateFunctionMaxData<SingleValueDataFixed<TYPE>>>::addBatchSinglePlaceNotNull( \
size_t row_begin, \
size_t row_end, \
AggregateDataPtr __restrict place, \
const IColumn ** __restrict columns, \
const UInt8 * __restrict null_map, \
Arena *, \
ssize_t if_argument_pos) const \
{ \
const auto & column = assert_cast<const DB::AggregateFunctionMaxData<SingleValueDataFixed<TYPE>>::ColVecType &>(*columns[0]); \
std::optional<TYPE> opt; \
if (if_argument_pos >= 0) \
{ \
const auto * if_flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData().data(); \
auto final_flags = std::make_unique<UInt8[]>(row_end); \
for (size_t i = row_begin; i < row_end; ++i) \
final_flags[i] = (!null_map[i]) & !!if_flags[i]; \
opt = findExtremeMaxIf(column.getData().data(), final_flags.get(), row_begin, row_end); \
} \
else \
opt = findExtremeMaxNotNull(column.getData().data(), null_map, row_begin, row_end); \
if (opt.has_value()) \
this->data(place).changeIfGreater(opt.value()); \
}
// NOLINTEND(bugprone-macro-parentheses)
FOR_BASIC_NUMERIC_TYPES(SPECIALIZE)
#undef SPECIALIZE
template <typename Data>
void AggregateFunctionsSingleValueMax<Data>::addBatchSinglePlaceNotNull(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
const UInt8 * __restrict null_map,
Arena * arena,
ssize_t if_argument_pos) const
{
if constexpr (!is_any_of<typename Data::Impl, SingleValueDataString, SingleValueDataGeneric>)
{
/// Leave other numeric types (large integers, decimals, etc) to keep doing the comparison as it's
/// faster than doing a permutation
return Parent::addBatchSinglePlaceNotNull(row_begin, row_end, place, columns, null_map, arena, if_argument_pos);
}
constexpr int nan_null_direction_hint = -1;
auto const & column = *columns[0];
if (if_argument_pos >= 0)
{
size_t index = row_begin;
const auto & if_flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
while ((if_flags[index] == 0 || null_map[index] != 0) && (index < row_end))
index++;
if (index >= row_end)
return;
for (size_t i = index + 1; i < row_end; i++)
{
if ((if_flags[i] != 0) && (null_map[i] == 0) && (column.compareAt(i, index, column, nan_null_direction_hint) > 0))
index = i;
}
this->data(place).changeIfGreater(column, index, arena);
}
else
{
size_t index = row_begin;
while ((null_map[index] != 0) && (index < row_end))
index++;
if (index >= row_end)
return;
for (size_t i = index + 1; i < row_end; i++)
{
if ((null_map[i] == 0) && (column.compareAt(i, index, column, nan_null_direction_hint) > 0))
index = i;
}
this->data(place).changeIfGreater(column, index, arena);
}
}
AggregateFunctionPtr createAggregateFunctionMax(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValue<AggregateFunctionsSingleValueMax, AggregateFunctionMaxData>(name, argument_types, parameters, settings));
}
AggregateFunctionPtr createAggregateFunctionArgMax(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionArgMinMax<AggregateFunctionMaxData>(name, argument_types, parameters, settings));
}
}
void registerAggregateFunctionsMax(AggregateFunctionFactory & factory)
{
factory.registerFunction("max", createAggregateFunctionMax, AggregateFunctionFactory::CaseInsensitive);
/// The functions below depend on the order of data.
AggregateFunctionProperties properties = { .returns_default_when_only_null = false, .is_order_dependent = true };
factory.registerFunction("argMax", { createAggregateFunctionArgMax, properties });
}
}

View File

@ -1,240 +0,0 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/HelpersMinMaxAny.h>
#include <Common/Concepts.h>
#include <Common/findExtreme.h>
namespace DB
{
struct Settings;
namespace
{
template <typename Data>
class AggregateFunctionsSingleValueMin final : public AggregateFunctionsSingleValue<Data>
{
using Parent = AggregateFunctionsSingleValue<Data>;
public:
explicit AggregateFunctionsSingleValueMin(const DataTypePtr & type) : Parent(type) { }
/// Specializations for native numeric types
void addBatchSinglePlace(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
Arena * arena,
ssize_t if_argument_pos) const override;
void addBatchSinglePlaceNotNull(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
const UInt8 * __restrict null_map,
Arena * arena,
ssize_t if_argument_pos) const override;
};
// NOLINTBEGIN(bugprone-macro-parentheses)
#define SPECIALIZE(TYPE) \
template <> \
void AggregateFunctionsSingleValueMin<typename DB::AggregateFunctionMinData<SingleValueDataFixed<TYPE>>>::addBatchSinglePlace( \
size_t row_begin, \
size_t row_end, \
AggregateDataPtr __restrict place, \
const IColumn ** __restrict columns, \
Arena *, \
ssize_t if_argument_pos) const \
{ \
const auto & column = assert_cast<const DB::AggregateFunctionMinData<SingleValueDataFixed<TYPE>>::ColVecType &>(*columns[0]); \
std::optional<TYPE> opt; \
if (if_argument_pos >= 0) \
{ \
const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData(); \
opt = findExtremeMinIf(column.getData().data(), flags.data(), row_begin, row_end); \
} \
else \
opt = findExtremeMin(column.getData().data(), row_begin, row_end); \
if (opt.has_value()) \
this->data(place).changeIfLess(opt.value()); \
}
// NOLINTEND(bugprone-macro-parentheses)
FOR_BASIC_NUMERIC_TYPES(SPECIALIZE)
#undef SPECIALIZE
template <typename Data>
void AggregateFunctionsSingleValueMin<Data>::addBatchSinglePlace(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
Arena * arena,
ssize_t if_argument_pos) const
{
if constexpr (!is_any_of<typename Data::Impl, SingleValueDataString, SingleValueDataGeneric>)
{
/// Leave other numeric types (large integers, decimals, etc) to keep doing the comparison as it's
/// faster than doing a permutation
return Parent::addBatchSinglePlace(row_begin, row_end, place, columns, arena, if_argument_pos);
}
constexpr int nan_null_direction_hint = 1;
auto const & column = *columns[0];
if (if_argument_pos >= 0)
{
size_t index = row_begin;
const auto & if_flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
while (if_flags[index] == 0 && index < row_end)
index++;
if (index >= row_end)
return;
for (size_t i = index + 1; i < row_end; i++)
{
if ((if_flags[i] != 0) && (column.compareAt(i, index, column, nan_null_direction_hint) < 0))
index = i;
}
this->data(place).changeIfLess(column, index, arena);
}
else
{
if (row_begin >= row_end)
return;
/// TODO: Introduce row_begin and row_end to getPermutation
if (row_begin != 0 || row_end != column.size())
{
size_t index = row_begin;
for (size_t i = index + 1; i < row_end; i++)
{
if (column.compareAt(i, index, column, nan_null_direction_hint) < 0)
index = i;
}
this->data(place).changeIfLess(column, index, arena);
}
else
{
constexpr IColumn::PermutationSortDirection direction = IColumn::PermutationSortDirection::Ascending;
constexpr IColumn::PermutationSortStability stability = IColumn::PermutationSortStability::Unstable;
IColumn::Permutation permutation;
constexpr UInt64 limit = 1;
column.getPermutation(direction, stability, limit, nan_null_direction_hint, permutation);
this->data(place).changeIfLess(column, permutation[0], arena);
}
}
}
// NOLINTBEGIN(bugprone-macro-parentheses)
#define SPECIALIZE(TYPE) \
template <> \
void AggregateFunctionsSingleValueMin<typename DB::AggregateFunctionMinData<SingleValueDataFixed<TYPE>>>::addBatchSinglePlaceNotNull( \
size_t row_begin, \
size_t row_end, \
AggregateDataPtr __restrict place, \
const IColumn ** __restrict columns, \
const UInt8 * __restrict null_map, \
Arena *, \
ssize_t if_argument_pos) const \
{ \
const auto & column = assert_cast<const DB::AggregateFunctionMinData<SingleValueDataFixed<TYPE>>::ColVecType &>(*columns[0]); \
std::optional<TYPE> opt; \
if (if_argument_pos >= 0) \
{ \
const auto * if_flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData().data(); \
auto final_flags = std::make_unique<UInt8[]>(row_end); \
for (size_t i = row_begin; i < row_end; ++i) \
final_flags[i] = (!null_map[i]) & !!if_flags[i]; \
opt = findExtremeMinIf(column.getData().data(), final_flags.get(), row_begin, row_end); \
} \
else \
opt = findExtremeMinNotNull(column.getData().data(), null_map, row_begin, row_end); \
if (opt.has_value()) \
this->data(place).changeIfLess(opt.value()); \
}
// NOLINTEND(bugprone-macro-parentheses)
FOR_BASIC_NUMERIC_TYPES(SPECIALIZE)
#undef SPECIALIZE
template <typename Data>
void AggregateFunctionsSingleValueMin<Data>::addBatchSinglePlaceNotNull(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
const UInt8 * __restrict null_map,
Arena * arena,
ssize_t if_argument_pos) const
{
if constexpr (!is_any_of<typename Data::Impl, SingleValueDataString, SingleValueDataGeneric>)
{
/// Leave other numeric types (large integers, decimals, etc) to keep doing the comparison as it's
/// faster than doing a permutation
return Parent::addBatchSinglePlaceNotNull(row_begin, row_end, place, columns, null_map, arena, if_argument_pos);
}
constexpr int nan_null_direction_hint = 1;
auto const & column = *columns[0];
if (if_argument_pos >= 0)
{
size_t index = row_begin;
const auto & if_flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
while ((if_flags[index] == 0 || null_map[index] != 0) && (index < row_end))
index++;
if (index >= row_end)
return;
for (size_t i = index + 1; i < row_end; i++)
{
if ((if_flags[i] != 0) && (null_map[index] == 0) && (column.compareAt(i, index, column, nan_null_direction_hint) < 0))
index = i;
}
this->data(place).changeIfLess(column, index, arena);
}
else
{
size_t index = row_begin;
while ((null_map[index] != 0) && (index < row_end))
index++;
if (index >= row_end)
return;
for (size_t i = index + 1; i < row_end; i++)
{
if ((null_map[i] == 0) && (column.compareAt(i, index, column, nan_null_direction_hint) < 0))
index = i;
}
this->data(place).changeIfLess(column, index, arena);
}
}
AggregateFunctionPtr createAggregateFunctionMin(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValue<AggregateFunctionsSingleValueMin, AggregateFunctionMinData>(
name, argument_types, parameters, settings));
}
AggregateFunctionPtr createAggregateFunctionArgMin(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionArgMinMax<AggregateFunctionMinData>(name, argument_types, parameters, settings));
}
}
void registerAggregateFunctionsMin(AggregateFunctionFactory & factory)
{
factory.registerFunction("min", createAggregateFunctionMin, AggregateFunctionFactory::CaseInsensitive);
/// The functions below depend on the order of data.
AggregateFunctionProperties properties = { .returns_default_when_only_null = false, .is_order_dependent = true };
factory.registerFunction("argMin", { createAggregateFunctionArgMin, properties });
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,19 +1,193 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/HelpersMinMaxAny.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include "registerAggregateFunctions.h"
#include <AggregateFunctions/SingleValueData.h>
#include <Columns/ColumnNullable.h>
#include <DataTypes/DataTypeNullable.h>
namespace DB
{
struct Settings;
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
namespace
{
/** The aggregate function 'singleValueOrNull' is used to implement subquery operators,
* such as x = ALL (SELECT ...)
* It checks if there is only one unique non-NULL value in the data.
* If there is only one unique value - returns it.
* If there are zero or at least two distinct values - returns NULL.
*/
AggregateFunctionPtr createAggregateFunctionSingleValueOrNull(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
struct AggregateFunctionSingleValueOrNullData
{
return AggregateFunctionPtr(createAggregateFunctionSingleValue<AggregateFunctionsSingleValue, AggregateFunctionSingleValueOrNullData>(name, argument_types, parameters, settings));
using Self = AggregateFunctionSingleValueOrNullData;
private:
SingleValueDataBaseMemoryBlock v_data;
bool first_value = true;
bool is_null = false;
public:
[[noreturn]] explicit AggregateFunctionSingleValueOrNullData()
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "AggregateFunctionSingleValueOrNullData initialized empty");
}
explicit AggregateFunctionSingleValueOrNullData(TypeIndex value_type) { generateSingleValueFromTypeIndex(value_type, v_data); }
~AggregateFunctionSingleValueOrNullData() { data().~SingleValueDataBase(); }
SingleValueDataBase & data() { return v_data.get(); }
const SingleValueDataBase & data() const { return v_data.get(); }
bool isNull() const { return is_null; }
void add(const IColumn & column, size_t row_num, Arena * arena)
{
if (first_value)
{
first_value = false;
data().set(column, row_num, arena);
}
else if (!data().isEqualTo(column, row_num))
{
is_null = true;
}
}
void add(const Self & to, Arena * arena)
{
if (!to.data().has())
return;
if (first_value && !to.first_value)
{
first_value = false;
data().set(to.data(), arena);
}
else if (!data().isEqualTo(to.data()))
{
is_null = true;
}
}
/// TODO: Methods write and read lose data (first_value and is_null)
/// Fixing it requires a breaking change (but it's probably necessary)
void write(WriteBuffer & buf, const ISerialization & serialization) const { data().write(buf, serialization); }
void read(ReadBuffer & buf, const ISerialization & serialization, Arena * arena) { data().read(buf, serialization, arena); }
void insertResultInto(IColumn & to) const
{
if (is_null || first_value)
{
to.insertDefault();
}
else
{
ColumnNullable & col = typeid_cast<ColumnNullable &>(to);
col.getNullMapColumn().insertDefault();
data().insertResultInto(col.getNestedColumn());
}
}
};
class AggregateFunctionSingleValueOrNull final
: public IAggregateFunctionDataHelper<AggregateFunctionSingleValueOrNullData, AggregateFunctionSingleValueOrNull>
{
private:
SerializationPtr serialization;
const TypeIndex value_type_index;
public:
explicit AggregateFunctionSingleValueOrNull(const DataTypePtr & type)
: IAggregateFunctionDataHelper<AggregateFunctionSingleValueOrNullData, AggregateFunctionSingleValueOrNull>(
{type}, {}, makeNullable(type))
, serialization(type->getDefaultSerialization())
, value_type_index(WhichDataType(type).idx)
{
}
void create(AggregateDataPtr __restrict place) const override { new (place) AggregateFunctionSingleValueOrNullData(value_type_index); }
String getName() const override { return "singleValueOrNull"; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
this->data(place).add(*columns[0], row_num, arena);
}
void addBatchSinglePlace(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
Arena * arena,
ssize_t if_argument_pos) const override
{
if (this->data(place).isNull())
return;
IAggregateFunctionDataHelper<Data, AggregateFunctionSingleValueOrNull>::addBatchSinglePlace(
row_begin, row_end, place, columns, arena, if_argument_pos);
}
void addBatchSinglePlaceNotNull(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
const UInt8 * __restrict null_map,
Arena * arena,
ssize_t if_argument_pos) const override
{
if (this->data(place).isNull())
return;
IAggregateFunctionDataHelper<Data, AggregateFunctionSingleValueOrNull>::addBatchSinglePlaceNotNull(
row_begin, row_end, place, columns, null_map, arena, if_argument_pos);
}
void addManyDefaults(AggregateDataPtr __restrict place, const IColumn ** columns, size_t, Arena * arena) const override
{
this->data(place).add(*columns[0], 0, arena);
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
this->data(place).add(this->data(rhs), arena);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
this->data(place).write(buf, *serialization);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
this->data(place).read(buf, *serialization, arena);
}
bool allocatesMemoryInArena() const override { return singleValueTypeAllocatesMemoryInArena(value_type_index); }
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
this->data(place).insertResultInto(to);
}
};
AggregateFunctionPtr createAggregateFunctionSingleValueOrNull(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
const DataTypePtr & res_type = argument_types[0];
return AggregateFunctionPtr(new AggregateFunctionSingleValueOrNull(res_type));
}
}
@ -22,6 +196,4 @@ void registerAggregateFunctionSingleValueOrNull(AggregateFunctionFactory & facto
{
factory.registerFunction("singleValueOrNull", createAggregateFunctionSingleValueOrNull);
}
}

View File

@ -0,0 +1,236 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <AggregateFunctions/SingleValueData.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/IDataType.h>
namespace DB
{
struct Settings;
namespace ErrorCodes
{
extern const int CORRUPTED_DATA;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int LOGICAL_ERROR;
}
namespace
{
template <class ValueType>
struct AggregateFunctionArgMinMaxData
{
private:
SingleValueDataBaseMemoryBlock result_data;
ValueType value_data;
public:
SingleValueDataBase & result() { return result_data.get(); }
const SingleValueDataBase & result() const { return result_data.get(); }
ValueType & value() { return value_data; }
const ValueType & value() const { return value_data; }
[[noreturn]] explicit AggregateFunctionArgMinMaxData()
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "AggregateFunctionArgMinMaxData initialized empty");
}
explicit AggregateFunctionArgMinMaxData(TypeIndex result_type) : value_data()
{
generateSingleValueFromTypeIndex(result_type, result_data);
}
~AggregateFunctionArgMinMaxData() { result().~SingleValueDataBase(); }
};
static_assert(
sizeof(AggregateFunctionArgMinMaxData<Int8>) <= 2 * SingleValueDataBase::MAX_STORAGE_SIZE,
"Incorrect size of AggregateFunctionArgMinMaxData struct");
/// Returns the first arg value found for the minimum/maximum value. Example: argMin(arg, value).
template <typename ValueData, bool isMin>
class AggregateFunctionArgMinMax final
: public IAggregateFunctionDataHelper<AggregateFunctionArgMinMaxData<ValueData>, AggregateFunctionArgMinMax<ValueData, isMin>>
{
private:
const DataTypePtr & type_val;
const SerializationPtr serialization_res;
const SerializationPtr serialization_val;
const TypeIndex result_type_index;
using Base = IAggregateFunctionDataHelper<AggregateFunctionArgMinMaxData<ValueData>, AggregateFunctionArgMinMax<ValueData, isMin>>;
public:
explicit AggregateFunctionArgMinMax(const DataTypes & argument_types_)
: Base(argument_types_, {}, argument_types_[0])
, type_val(this->argument_types[1])
, serialization_res(this->argument_types[0]->getDefaultSerialization())
, serialization_val(this->argument_types[1]->getDefaultSerialization())
, result_type_index(WhichDataType(this->argument_types[0]).idx)
{
if (!type_val->isComparable())
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of second argument of aggregate function {} because the values of that data type are not comparable",
type_val->getName(),
getName());
}
void create(AggregateDataPtr __restrict place) const override /// NOLINT
{
new (place) AggregateFunctionArgMinMaxData<ValueData>(result_type_index);
}
String getName() const override
{
if constexpr (isMin)
return "argMin";
else
return "argMax";
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
if constexpr (isMin)
{
if (this->data(place).value().setIfSmaller(*columns[1], row_num, arena))
this->data(place).result().set(*columns[0], row_num, arena);
}
else
{
if (this->data(place).value().setIfGreater(*columns[1], row_num, arena))
this->data(place).result().set(*columns[0], row_num, arena);
}
}
void addManyDefaults(AggregateDataPtr __restrict place, const IColumn ** columns, size_t, Arena * arena) const override
{
add(place, columns, 0, arena);
}
void addBatchSinglePlace(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
Arena * arena,
ssize_t if_argument_pos) const override
{
std::optional<size_t> idx;
if (if_argument_pos >= 0)
{
const auto & if_map = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
if constexpr (isMin)
idx = this->data(place).value().getSmallestIndexNotNullIf(*columns[1], nullptr, if_map.data(), row_begin, row_end);
else
idx = this->data(place).value().getGreatestIndexNotNullIf(*columns[1], nullptr, if_map.data(), row_begin, row_end);
}
else
{
if constexpr (isMin)
idx = this->data(place).value().getSmallestIndex(*columns[1], row_begin, row_end);
else
idx = this->data(place).value().getGreatestIndex(*columns[1], row_begin, row_end);
}
if (idx)
add(place, columns, *idx, arena);
}
void addBatchSinglePlaceNotNull(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
const UInt8 * __restrict null_map,
Arena * arena,
ssize_t if_argument_pos) const override
{
std::optional<size_t> idx;
if (if_argument_pos >= 0)
{
const auto & if_map = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
if constexpr (isMin)
idx = this->data(place).value().getSmallestIndexNotNullIf(*columns[1], null_map, if_map.data(), row_begin, row_end);
else
idx = this->data(place).value().getGreatestIndexNotNullIf(*columns[1], null_map, if_map.data(), row_begin, row_end);
}
else
{
if constexpr (isMin)
idx = this->data(place).value().getSmallestIndexNotNullIf(*columns[1], null_map, nullptr, row_begin, row_end);
else
idx = this->data(place).value().getGreatestIndexNotNullIf(*columns[1], null_map, nullptr, row_begin, row_end);
}
if (idx)
add(place, columns, *idx, arena);
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
if constexpr (isMin)
{
if (this->data(place).value().setIfSmaller(this->data(rhs).value(), arena))
this->data(place).result().set(this->data(rhs).result(), arena);
}
else
{
if (this->data(place).value().setIfGreater(this->data(rhs).value(), arena))
this->data(place).result().set(this->data(rhs).result(), arena);
}
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
this->data(place).result().write(buf, *serialization_res);
this->data(place).value().write(buf, *serialization_val);
}
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
this->data(place).result().read(buf, *serialization_res, arena);
this->data(place).value().read(buf, *serialization_val, arena);
if (unlikely(this->data(place).value().has() != this->data(place).result().has()))
throw Exception(
ErrorCodes::CORRUPTED_DATA,
"Invalid state of the aggregate function {}: has_value ({}) != has_result ({})",
getName(),
this->data(place).value().has(),
this->data(place).result().has());
}
bool allocatesMemoryInArena() const override
{
return singleValueTypeAllocatesMemoryInArena(result_type_index) || ValueData::allocatesMemoryInArena();
}
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
this->data(place).result().insertResultInto(to);
}
};
template <bool isMin>
AggregateFunctionPtr createAggregateFunctionArgMinMax(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValue<AggregateFunctionArgMinMax, /* unary */ false, isMin>(
name, argument_types, parameters, settings));
}
}
void registerAggregateFunctionsArgMinArgMax(AggregateFunctionFactory & factory)
{
AggregateFunctionProperties properties = {.returns_default_when_only_null = false, .is_order_dependent = true};
factory.registerFunction("argMin", {createAggregateFunctionArgMinMax<true>, properties});
factory.registerFunction("argMax", {createAggregateFunctionArgMinMax<false>, properties});
}
}

View File

@ -0,0 +1,202 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/SingleValueData.h>
#include <Common/Concepts.h>
#include <Common/findExtreme.h>
namespace DB
{
struct Settings;
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NOT_IMPLEMENTED;
}
namespace
{
template <typename Data, bool isMin>
class AggregateFunctionMinMax final : public IAggregateFunctionDataHelper<Data, AggregateFunctionMinMax<Data, isMin>>
{
private:
SerializationPtr serialization;
public:
explicit AggregateFunctionMinMax(const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionMinMax<Data, isMin>>(argument_types_, {}, argument_types_[0])
, serialization(this->result_type->getDefaultSerialization())
{
if (!this->result_type->isComparable())
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument of aggregate function {} because the values of that data type are not comparable",
this->result_type->getName(),
getName());
}
String getName() const override
{
if constexpr (isMin)
return "min";
else
return "max";
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
if constexpr (isMin)
this->data(place).setIfSmaller(*columns[0], row_num, arena);
else
this->data(place).setIfGreater(*columns[0], row_num, arena);
}
void addManyDefaults(AggregateDataPtr __restrict place, const IColumn ** columns, size_t, Arena * arena) const override
{
add(place, columns, 0, arena);
}
void addBatchSinglePlace(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
Arena * arena,
ssize_t if_argument_pos) const override
{
if (if_argument_pos >= 0)
{
const auto & if_map = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
if constexpr (isMin)
this->data(place).setSmallestNotNullIf(*columns[0], nullptr, if_map.data(), row_begin, row_end, arena);
else
this->data(place).setGreatestNotNullIf(*columns[0], nullptr, if_map.data(), row_begin, row_end, arena);
}
else
{
if constexpr (isMin)
this->data(place).setSmallest(*columns[0], row_begin, row_end, arena);
else
this->data(place).setGreatest(*columns[0], row_begin, row_end, arena);
}
}
void addBatchSinglePlaceNotNull(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
const UInt8 * __restrict null_map,
Arena * arena,
ssize_t if_argument_pos) const override
{
if (if_argument_pos >= 0)
{
const auto & if_map = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
if constexpr (isMin)
this->data(place).setSmallestNotNullIf(*columns[0], null_map, if_map.data(), row_begin, row_end, arena);
else
this->data(place).setGreatestNotNullIf(*columns[0], null_map, if_map.data(), row_begin, row_end, arena);
}
else
{
if constexpr (isMin)
this->data(place).setSmallestNotNullIf(*columns[0], null_map, nullptr, row_begin, row_end, arena);
else
this->data(place).setGreatestNotNullIf(*columns[0], null_map, nullptr, row_begin, row_end, arena);
}
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
if constexpr (isMin)
this->data(place).setIfSmaller(this->data(rhs), arena);
else
this->data(place).setIfGreater(this->data(rhs), arena);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
this->data(place).write(buf, *serialization);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
this->data(place).read(buf, *serialization, arena);
}
bool allocatesMemoryInArena() const override { return Data::allocatesMemoryInArena(); }
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
this->data(place).insertResultInto(to);
}
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
if constexpr (!Data::is_compilable)
return false;
else
return Data::isCompilable(*this->argument_types[0]);
}
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
if constexpr (Data::is_compilable)
Data::compileCreate(builder, aggregate_data_ptr);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override
{
if constexpr (Data::is_compilable)
if constexpr (isMin)
Data::compileMin(builder, aggregate_data_ptr, arguments[0].value);
else
Data::compileMax(builder, aggregate_data_ptr, arguments[0].value);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
void
compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
{
if constexpr (Data::is_compilable)
if constexpr (isMin)
Data::compileMinMerge(builder, aggregate_data_dst_ptr, aggregate_data_src_ptr);
else
Data::compileMaxMerge(builder, aggregate_data_dst_ptr, aggregate_data_src_ptr);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
if constexpr (Data::is_compilable)
return Data::compileGetResult(builder, aggregate_data_ptr);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
#endif
};
template <bool isMin>
AggregateFunctionPtr createAggregateFunctionMinMax(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(
createAggregateFunctionSingleValue<AggregateFunctionMinMax, /* unary */ true, isMin>(name, argument_types, parameters, settings));
}
}
void registerAggregateFunctionsMinMax(AggregateFunctionFactory & factory)
{
factory.registerFunction("min", createAggregateFunctionMinMax<true>, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("max", createAggregateFunctionMinMax<false>, AggregateFunctionFactory::CaseInsensitive);
}
}

View File

@ -1,93 +0,0 @@
#include "AggregateFunctionArgMinMax.h"
#include "AggregateFunctionCombinatorFactory.h"
#include <AggregateFunctions/AggregateFunctionMinMaxAny.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeString.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
namespace
{
template <template <typename> class Data>
class AggregateFunctionCombinatorArgMinMax final : public IAggregateFunctionCombinator
{
public:
String getName() const override { return Data<SingleValueDataGeneric>::name(); }
DataTypes transformArguments(const DataTypes & arguments) const override
{
if (arguments.empty())
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Incorrect number of arguments for aggregate function with {} suffix",
getName());
return DataTypes(arguments.begin(), arguments.end() - 1);
}
AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array & params) const override
{
const DataTypePtr & argument_type = arguments.back();
WhichDataType which(argument_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<TYPE>>>>(nested_function, arguments, params); /// NOLINT
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Date)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<DataTypeDate::FieldType>>>>(
nested_function, arguments, params);
if (which.idx == TypeIndex::DateTime)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<DataTypeDateTime::FieldType>>>>(
nested_function, arguments, params);
if (which.idx == TypeIndex::DateTime64)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<DateTime64>>>>(nested_function, arguments, params);
if (which.idx == TypeIndex::Decimal32)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<Decimal32>>>>(nested_function, arguments, params);
if (which.idx == TypeIndex::Decimal64)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<Decimal64>>>>(nested_function, arguments, params);
if (which.idx == TypeIndex::Decimal128)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<Decimal128>>>>(nested_function, arguments, params);
if (which.idx == TypeIndex::Decimal256)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<Decimal256>>>>(nested_function, arguments, params);
if (which.idx == TypeIndex::String)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataString>>>(nested_function, arguments, params);
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataGeneric>>>(nested_function, arguments, params);
}
};
template <typename Data>
struct AggregateFunctionArgMinDataCapitalized : AggregateFunctionMinData<Data>
{
static const char * name() { return "ArgMin"; }
};
template <typename Data>
struct AggregateFunctionArgMaxDataCapitalized : AggregateFunctionMaxData<Data>
{
static const char * name() { return "ArgMax"; }
};
}
void registerAggregateFunctionCombinatorMinMax(AggregateFunctionCombinatorFactory & factory)
{
factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorArgMinMax<AggregateFunctionArgMinDataCapitalized>>());
factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorArgMinMax<AggregateFunctionArgMaxDataCapitalized>>());
}
}

View File

@ -1,111 +0,0 @@
#pragma once
#include <AggregateFunctions/IAggregateFunction.h>
namespace DB
{
template <typename Key>
class AggregateFunctionArgMinMax final : public IAggregateFunctionHelper<AggregateFunctionArgMinMax<Key>>
{
private:
AggregateFunctionPtr nested_function;
SerializationPtr serialization;
size_t key_col;
size_t key_offset;
Key & key(AggregateDataPtr __restrict place) const { return *reinterpret_cast<Key *>(place + key_offset); }
const Key & key(ConstAggregateDataPtr __restrict place) const { return *reinterpret_cast<const Key *>(place + key_offset); }
public:
AggregateFunctionArgMinMax(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: IAggregateFunctionHelper<AggregateFunctionArgMinMax<Key>>{arguments, params, nested_function_->getResultType()}
, nested_function{nested_function_}
, serialization(arguments.back()->getDefaultSerialization())
, key_col{arguments.size() - 1}
, key_offset{(nested_function->sizeOfData() + alignof(Key) - 1) / alignof(Key) * alignof(Key)}
{
}
String getName() const override { return nested_function->getName() + Key::name(); }
bool isState() const override { return nested_function->isState(); }
bool isVersioned() const override { return nested_function->isVersioned(); }
size_t getVersionFromRevision(size_t revision) const override { return nested_function->getVersionFromRevision(revision); }
size_t getDefaultVersion() const override { return nested_function->getDefaultVersion(); }
bool allocatesMemoryInArena() const override { return nested_function->allocatesMemoryInArena() || Key::allocatesMemoryInArena(); }
bool hasTrivialDestructor() const override { return nested_function->hasTrivialDestructor(); }
size_t sizeOfData() const override { return key_offset + sizeof(Key); }
size_t alignOfData() const override { return nested_function->alignOfData(); }
void create(AggregateDataPtr __restrict place) const override
{
nested_function->create(place);
new (place + key_offset) Key;
}
void destroy(AggregateDataPtr __restrict place) const noexcept override { nested_function->destroy(place); }
void destroyUpToState(AggregateDataPtr __restrict place) const noexcept override { nested_function->destroyUpToState(place); }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
if (key(place).changeIfBetter(*columns[key_col], row_num, arena))
{
nested_function->destroy(place);
nested_function->create(place);
nested_function->add(place, columns, row_num, arena);
}
else if (key(place).isEqualTo(*columns[key_col], row_num))
{
nested_function->add(place, columns, row_num, arena);
}
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
if (key(place).changeIfBetter(key(rhs), arena))
{
nested_function->destroy(place);
nested_function->create(place);
nested_function->merge(place, rhs, arena);
}
else if (key(place).isEqualTo(key(rhs)))
{
nested_function->merge(place, rhs, arena);
}
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> version) const override
{
nested_function->serialize(place, buf, version);
key(place).write(buf, *serialization);
}
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> version, Arena * arena) const override
{
nested_function->deserialize(place, buf, version, arena);
key(place).read(buf, *serialization, arena);
}
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
{
nested_function->insertResultInto(place, to, arena);
}
void insertMergeResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
{
nested_function->insertMergeResultInto(place, to, arena);
}
AggregateFunctionPtr getNestedFunction() const override { return nested_function; }
};
}

View File

@ -0,0 +1,212 @@
#include <AggregateFunctions/Combinators/AggregateFunctionCombinatorFactory.h>
#include <AggregateFunctions/SingleValueData.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
namespace
{
struct AggregateFunctionCombinatorArgMinArgMaxData
{
private:
SingleValueDataBaseMemoryBlock v_data;
public:
explicit AggregateFunctionCombinatorArgMinArgMaxData(TypeIndex value_type) { generateSingleValueFromTypeIndex(value_type, v_data); }
~AggregateFunctionCombinatorArgMinArgMaxData() { data().~SingleValueDataBase(); }
SingleValueDataBase & data() { return v_data.get(); }
const SingleValueDataBase & data() const { return v_data.get(); }
};
template <bool isMin>
class AggregateFunctionCombinatorArgMinArgMax final : public IAggregateFunctionHelper<AggregateFunctionCombinatorArgMinArgMax<isMin>>
{
using Key = AggregateFunctionCombinatorArgMinArgMaxData;
private:
AggregateFunctionPtr nested_function;
SerializationPtr serialization;
const size_t key_col;
const size_t key_offset;
const TypeIndex key_type_index;
AggregateFunctionCombinatorArgMinArgMaxData & data(AggregateDataPtr __restrict place) const /// NOLINT
{
return *reinterpret_cast<Key *>(place + key_offset);
}
const AggregateFunctionCombinatorArgMinArgMaxData & data(ConstAggregateDataPtr __restrict place) const
{
return *reinterpret_cast<const Key *>(place + key_offset);
}
public:
AggregateFunctionCombinatorArgMinArgMax(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: IAggregateFunctionHelper<AggregateFunctionCombinatorArgMinArgMax<isMin>>{arguments, params, nested_function_->getResultType()}
, nested_function{nested_function_}
, serialization(arguments.back()->getDefaultSerialization())
, key_col{arguments.size() - 1}
, key_offset{((nested_function->sizeOfData() + alignof(Key) - 1) / alignof(Key)) * alignof(Key)}
, key_type_index(WhichDataType(arguments[key_col]).idx)
{
if (!arguments[key_col]->isComparable())
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} for combinator {} because the values of that data type are not comparable",
arguments[key_col]->getName(),
getName());
}
String getName() const override
{
if constexpr (isMin)
return "ArgMin";
else
return "ArgMax";
}
bool isState() const override { return nested_function->isState(); }
bool isVersioned() const override { return nested_function->isVersioned(); }
size_t getVersionFromRevision(size_t revision) const override { return nested_function->getVersionFromRevision(revision); }
size_t getDefaultVersion() const override { return nested_function->getDefaultVersion(); }
bool allocatesMemoryInArena() const override
{
return nested_function->allocatesMemoryInArena() || singleValueTypeAllocatesMemoryInArena(key_type_index);
}
bool hasTrivialDestructor() const override
{
return nested_function->hasTrivialDestructor() && /*false*/ std::is_trivially_destructible_v<SingleValueDataBase>;
}
size_t sizeOfData() const override { return key_offset + sizeof(Key); }
size_t alignOfData() const override { return std::max(nested_function->alignOfData(), alignof(SingleValueDataBaseMemoryBlock)); }
void create(AggregateDataPtr __restrict place) const override
{
nested_function->create(place);
new (place + key_offset) Key(key_type_index);
}
void destroy(AggregateDataPtr __restrict place) const noexcept override
{
data(place).~Key();
nested_function->destroy(place);
}
void destroyUpToState(AggregateDataPtr __restrict place) const noexcept override
{
data(place).~Key();
nested_function->destroyUpToState(place);
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
if ((isMin && data(place).data().setIfSmaller(*columns[key_col], row_num, arena))
|| (!isMin && data(place).data().setIfGreater(*columns[key_col], row_num, arena)))
{
nested_function->destroy(place);
nested_function->create(place);
nested_function->add(place, columns, row_num, arena);
}
else if (data(place).data().isEqualTo(*columns[key_col], row_num))
{
nested_function->add(place, columns, row_num, arena);
}
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
if ((isMin && data(place).data().setIfSmaller(data(rhs).data(), arena))
|| (!isMin && data(place).data().setIfGreater(data(rhs).data(), arena)))
{
nested_function->destroy(place);
nested_function->create(place);
nested_function->merge(place, rhs, arena);
}
else if (data(place).data().isEqualTo(data(rhs).data()))
{
nested_function->merge(place, rhs, arena);
}
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> version) const override
{
nested_function->serialize(place, buf, version);
data(place).data().write(buf, *serialization);
}
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> version, Arena * arena) const override
{
nested_function->deserialize(place, buf, version, arena);
data(place).data().read(buf, *serialization, arena);
}
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
{
nested_function->insertResultInto(place, to, arena);
}
void insertMergeResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
{
nested_function->insertMergeResultInto(place, to, arena);
}
AggregateFunctionPtr getNestedFunction() const override { return nested_function; }
};
template <bool isMin>
class CombinatorArgMinArgMax final : public IAggregateFunctionCombinator
{
public:
String getName() const override
{
if constexpr (isMin)
return "ArgMin";
else
return "ArgMax";
}
DataTypes transformArguments(const DataTypes & arguments) const override
{
if (arguments.empty())
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Incorrect number of arguments for aggregate function with {} suffix",
getName());
return DataTypes(arguments.begin(), arguments.end() - 1);
}
AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array & params) const override
{
return std::make_shared<AggregateFunctionCombinatorArgMinArgMax<isMin>>(nested_function, arguments, params);
}
};
}
void registerAggregateFunctionCombinatorsArgMinArgMax(AggregateFunctionCombinatorFactory & factory)
{
factory.registerCombinator(std::make_shared<CombinatorArgMinArgMax<true>>());
factory.registerCombinator(std::make_shared<CombinatorArgMinArgMax<false>>());
}
}

View File

@ -43,8 +43,8 @@ template <bool result_is_nullable, bool serialize_flag, typename Derived>
class AggregateFunctionNullBase : public IAggregateFunctionHelper<Derived>
{
protected:
AggregateFunctionPtr nested_function;
size_t prefix_size;
const AggregateFunctionPtr nested_function;
const size_t prefix_size;
/** In addition to data for nested aggregate function, we keep a flag
* indicating - was there at least one non-NULL value accumulated.
@ -55,12 +55,18 @@ protected:
AggregateDataPtr nestedPlace(AggregateDataPtr __restrict place) const noexcept
{
return place + prefix_size;
if constexpr (result_is_nullable)
return place + prefix_size;
else
return place;
}
ConstAggregateDataPtr nestedPlace(ConstAggregateDataPtr __restrict place) const noexcept
{
return place + prefix_size;
if constexpr (result_is_nullable)
return place + prefix_size;
else
return place;
}
static void initFlag(AggregateDataPtr __restrict place) noexcept
@ -87,11 +93,8 @@ public:
AggregateFunctionNullBase(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: IAggregateFunctionHelper<Derived>(arguments, params, createResultType(nested_function_))
, nested_function{nested_function_}
, prefix_size(result_is_nullable ? nested_function->alignOfData() : 0)
{
if constexpr (result_is_nullable)
prefix_size = nested_function->alignOfData();
else
prefix_size = 0;
}
String getName() const override

View File

@ -1,119 +0,0 @@
#pragma once
#include <AggregateFunctions/AggregateFunctionMinMaxAny.h>
#include <AggregateFunctions/AggregateFunctionArgMinMax.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/Helpers.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeString.h>
namespace DB
{
struct Settings;
/// min, max, any, anyLast, anyHeavy, etc...
template <template <typename> class AggregateFunctionTemplate, template <typename, bool...> class Data>
static IAggregateFunction *
createAggregateFunctionSingleValue(const String & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
const DataTypePtr & argument_type = argument_types[0];
WhichDataType which(argument_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<TYPE>>>(argument_type); /// NOLINT
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Date)
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<DataTypeDate::FieldType>>>(argument_type);
if (which.idx == TypeIndex::DateTime)
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<DataTypeDateTime::FieldType>>>(argument_type);
if (which.idx == TypeIndex::DateTime64)
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<DateTime64>>>(argument_type);
if (which.idx == TypeIndex::Decimal32)
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Decimal32>>>(argument_type);
if (which.idx == TypeIndex::Decimal64)
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Decimal64>>>(argument_type);
if (which.idx == TypeIndex::Decimal128)
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Decimal128>>>(argument_type);
if (which.idx == TypeIndex::Decimal256)
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Decimal256>>>(argument_type);
if (which.idx == TypeIndex::String)
return new AggregateFunctionTemplate<Data<SingleValueDataString>>(argument_type);
return new AggregateFunctionTemplate<Data<SingleValueDataGeneric>>(argument_type);
}
/// argMin, argMax
template <template <typename> class MinMaxData, typename ResData>
static IAggregateFunction * createAggregateFunctionArgMinMaxSecond(const DataTypePtr & res_type, const DataTypePtr & val_type)
{
WhichDataType which(val_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<TYPE>>>>(res_type, val_type); /// NOLINT
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Date)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<DataTypeDate::FieldType>>>>(res_type, val_type);
if (which.idx == TypeIndex::DateTime)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<DataTypeDateTime::FieldType>>>>(res_type, val_type);
if (which.idx == TypeIndex::DateTime64)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<DateTime64>>>>(res_type, val_type);
if (which.idx == TypeIndex::Decimal32)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<Decimal32>>>>(res_type, val_type);
if (which.idx == TypeIndex::Decimal64)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<Decimal64>>>>(res_type, val_type);
if (which.idx == TypeIndex::Decimal128)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<Decimal128>>>>(res_type, val_type);
if (which.idx == TypeIndex::Decimal256)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<Decimal256>>>>(res_type, val_type);
if (which.idx == TypeIndex::String)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataString>>>(res_type, val_type);
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataGeneric>>>(res_type, val_type);
}
template <template <typename> class MinMaxData>
static IAggregateFunction * createAggregateFunctionArgMinMax(const String & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
assertNoParameters(name, parameters);
assertBinary(name, argument_types);
const DataTypePtr & res_type = argument_types[0];
const DataTypePtr & val_type = argument_types[1];
WhichDataType which(res_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<TYPE>>(res_type, val_type); /// NOLINT
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Date)
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<DataTypeDate::FieldType>>(res_type, val_type);
if (which.idx == TypeIndex::DateTime)
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<DataTypeDateTime::FieldType>>(res_type, val_type);
if (which.idx == TypeIndex::DateTime64)
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<DateTime64>>(res_type, val_type);
if (which.idx == TypeIndex::Decimal32)
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<Decimal32>>(res_type, val_type);
if (which.idx == TypeIndex::Decimal64)
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<Decimal64>>(res_type, val_type);
if (which.idx == TypeIndex::Decimal128)
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<Decimal128>>(res_type, val_type);
if (which.idx == TypeIndex::Decimal256)
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<Decimal256>>(res_type, val_type);
if (which.idx == TypeIndex::String)
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataString>(res_type, val_type);
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataGeneric>(res_type, val_type);
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,394 @@
#pragma once
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Columns/ColumnDecimal.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <base/StringRef.h>
namespace DB
{
class Arena;
class ReadBuffer;
struct Settings;
class WriteBuffer;
/// Base class for Aggregation data that stores one of passed values: min, any, argMax...
/// It's setup as a virtual class so we can avoid templates when we need to extend them (argMax, SingleValueOrNull)
struct SingleValueDataBase
{
/// Any subclass (numeric, string, generic) must be smaller than MAX_STORAGE_SIZE
/// We use this knowledge to create composite data classes that use them directly by reserving a 'memory_block'
/// For example argMin holds 1 of these (for the result), while keeping a template for the value
static constexpr UInt32 MAX_STORAGE_SIZE = 64;
virtual ~SingleValueDataBase() { }
virtual bool has() const = 0;
virtual void insertResultInto(IColumn &) const = 0;
virtual void write(WriteBuffer &, const ISerialization &) const = 0;
virtual void read(ReadBuffer &, const ISerialization &, Arena *) = 0;
virtual bool isEqualTo(const IColumn & column, size_t row_num) const = 0;
virtual bool isEqualTo(const SingleValueDataBase &) const = 0;
virtual void set(const IColumn &, size_t row_num, Arena *) = 0;
virtual void set(const SingleValueDataBase &, Arena *) = 0;
virtual bool setIfSmaller(const IColumn &, size_t row_num, Arena *) = 0;
virtual bool setIfSmaller(const SingleValueDataBase &, Arena *) = 0;
virtual bool setIfGreater(const IColumn &, size_t row_num, Arena *) = 0;
virtual bool setIfGreater(const SingleValueDataBase &, Arena *) = 0;
/// Given a column, sets the internal value to the smallest or greatest value from the column
/// Used to implement batch min/max
virtual void setSmallest(const IColumn & column, size_t row_begin, size_t row_end, Arena * arena);
virtual void setGreatest(const IColumn & column, size_t row_begin, size_t row_end, Arena * arena);
virtual void setSmallestNotNullIf(const IColumn &, const UInt8 * __restrict, const UInt8 * __restrict, size_t, size_t, Arena *);
virtual void setGreatestNotNullIf(const IColumn &, const UInt8 * __restrict, const UInt8 * __restrict, size_t, size_t, Arena *);
/// Given a column returns the index of the smallest or greatest value in it
/// Doesn't return anything if the column is empty
/// There are used to implement argMin / argMax
virtual std::optional<size_t> getSmallestIndex(const IColumn & column, size_t row_begin, size_t row_end) const;
virtual std::optional<size_t> getGreatestIndex(const IColumn & column, size_t row_begin, size_t row_end) const;
virtual std::optional<size_t> getSmallestIndexNotNullIf(
const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end) const;
virtual std::optional<size_t> getGreatestIndexNotNullIf(
const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end) const;
};
#define FOR_SINGLE_VALUE_NUMERIC_TYPES(M) \
M(UInt8) \
M(UInt16) \
M(UInt32) \
M(UInt64) \
M(UInt128) \
M(UInt256) \
M(Int8) \
M(Int16) \
M(Int32) \
M(Int64) \
M(Int128) \
M(Int256) \
M(Float32) \
M(Float64) \
M(Decimal32) \
M(Decimal64) \
M(Decimal128) \
M(Decimal256) \
M(DateTime64)
/// For numeric values (without inheritance, for performance sensitive functions and JIT)
template <typename T>
struct SingleValueDataFixed
{
static constexpr bool is_compilable = true;
using Self = SingleValueDataFixed;
using ColVecType = ColumnVectorOrDecimal<T>;
T value = T{};
/// We need to remember if at least one value has been passed.
/// This is necessary for AggregateFunctionIf, merging states, JIT (where simple add is used), etc
bool has_value = false;
bool has() const { return has_value; }
void insertResultInto(IColumn & to) const;
void write(WriteBuffer & buf, const ISerialization &) const;
void read(ReadBuffer & buf, const ISerialization &, Arena *);
bool isEqualTo(const IColumn & column, size_t index) const;
bool isEqualTo(const Self & to) const;
void set(const IColumn & column, size_t row_num, Arena *);
void set(const Self & to, Arena *);
bool setIfSmaller(const T & to);
bool setIfGreater(const T & to);
bool setIfSmaller(const Self & to, Arena * arena);
bool setIfGreater(const Self & to, Arena * arena);
bool setIfSmaller(const IColumn & column, size_t row_num, Arena * arena);
bool setIfGreater(const IColumn & column, size_t row_num, Arena * arena);
void setSmallest(const IColumn & column, size_t row_begin, size_t row_end, Arena *);
void setGreatest(const IColumn & column, size_t row_begin, size_t row_end, Arena *);
void setSmallestNotNullIf(
const IColumn & column,
const UInt8 * __restrict null_map,
const UInt8 * __restrict if_map,
size_t row_begin,
size_t row_end,
Arena *);
void setGreatestNotNullIf(
const IColumn & column,
const UInt8 * __restrict null_map,
const UInt8 * __restrict if_map,
size_t row_begin,
size_t row_end,
Arena *);
std::optional<size_t> getSmallestIndex(const IColumn & column, size_t row_begin, size_t row_end) const;
std::optional<size_t> getGreatestIndex(const IColumn & column, size_t row_begin, size_t row_end) const;
std::optional<size_t> getSmallestIndexNotNullIf(
const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end) const;
std::optional<size_t> getGreatestIndexNotNullIf(
const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end) const;
static bool allocatesMemoryInArena() { return false; }
#if USE_EMBEDDED_COMPILER
static constexpr size_t has_value_offset = offsetof(Self, has_value);
static constexpr size_t value_offset = offsetof(Self, value);
static bool isCompilable(const IDataType & type);
static llvm::Value * getValuePtrFromAggregateDataPtr(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr);
static llvm::Value * getValueFromAggregateDataPtr(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr);
static llvm::Value * getHasValuePtrFromAggregateDataPtr(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr);
static llvm::Value * getHasValueFromAggregateDataPtr(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr);
static void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr);
static llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr);
static void compileSetValueFromNumber(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check);
static void
compileSetValueFromAggregation(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * aggregate_data_src_ptr);
static void compileAny(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check);
static void compileAnyMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr);
static void compileAnyLast(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check);
static void
compileAnyLastMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr);
template <bool isMin>
static void compileMinMax(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check);
template <bool isMin>
static void
compileMinMaxMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr);
static void compileMin(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check);
static void compileMinMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr);
static void compileMax(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check);
static void compileMaxMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr);
#endif
};
#define DISPATCH(TYPE) \
extern template struct SingleValueDataFixed<TYPE>; \
static_assert( \
sizeof(SingleValueDataFixed<TYPE>) <= SingleValueDataBase::MAX_STORAGE_SIZE, "Incorrect size of SingleValueDataFixed struct");
FOR_SINGLE_VALUE_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
/// For numeric values inheriting from SingleValueDataBase
template <typename T>
struct SingleValueDataNumeric final : public SingleValueDataBase
{
using Self = SingleValueDataNumeric<T>;
using Base = SingleValueDataFixed<T>;
private:
/// 32 bytes for types of 256 bits, + 8 bytes for the virtual table pointer.
static constexpr size_t base_memory_reserved_size = 40;
struct alignas(alignof(Base)) PrivateMemory
{
char memory[base_memory_reserved_size];
Base & get() { return *reinterpret_cast<Base *>(memory); }
const Base & get() const { return *reinterpret_cast<const Base *>(memory); }
};
static_assert(sizeof(Base) <= base_memory_reserved_size);
PrivateMemory memory;
public:
static constexpr bool is_compilable = false;
SingleValueDataNumeric();
~SingleValueDataNumeric() override;
bool has() const override;
void insertResultInto(IColumn & to) const override;
void write(WriteBuffer & buf, const ISerialization & serialization) const override;
void read(ReadBuffer & buf, const ISerialization & serialization, Arena * arena) override;
bool isEqualTo(const IColumn & column, size_t index) const override;
bool isEqualTo(const SingleValueDataBase & to) const override;
void set(const IColumn & column, size_t row_num, Arena * arena) override;
void set(const SingleValueDataBase & to, Arena * arena) override;
bool setIfSmaller(const SingleValueDataBase & to, Arena * arena) override;
bool setIfGreater(const SingleValueDataBase & to, Arena * arena) override;
bool setIfSmaller(const IColumn & column, size_t row_num, Arena * arena) override;
bool setIfGreater(const IColumn & column, size_t row_num, Arena * arena) override;
void setSmallest(const IColumn & column, size_t row_begin, size_t row_end, Arena * arena) override;
void setGreatest(const IColumn & column, size_t row_begin, size_t row_end, Arena * arena) override;
void setSmallestNotNullIf(
const IColumn & column,
const UInt8 * __restrict null_map,
const UInt8 * __restrict if_map,
size_t row_begin,
size_t row_end,
Arena * arena) override;
void setGreatestNotNullIf(
const IColumn & column,
const UInt8 * __restrict null_map,
const UInt8 * __restrict if_map,
size_t row_begin,
size_t row_end,
Arena * arena) override;
std::optional<size_t> getSmallestIndex(const IColumn & column, size_t row_begin, size_t row_end) const override;
std::optional<size_t> getGreatestIndex(const IColumn & column, size_t row_begin, size_t row_end) const override;
std::optional<size_t> getSmallestIndexNotNullIf(
const IColumn & column,
const UInt8 * __restrict null_map,
const UInt8 * __restrict if_map,
size_t row_begin,
size_t row_end) const override;
std::optional<size_t> getGreatestIndexNotNullIf(
const IColumn & column,
const UInt8 * __restrict null_map,
const UInt8 * __restrict if_map,
size_t row_begin,
size_t row_end) const override;
static bool allocatesMemoryInArena() { return false; }
};
#define DISPATCH(TYPE) \
extern template struct SingleValueDataNumeric<TYPE>; \
static_assert( \
sizeof(SingleValueDataNumeric<TYPE>) <= SingleValueDataBase::MAX_STORAGE_SIZE, "Incorrect size of SingleValueDataNumeric struct");
FOR_SINGLE_VALUE_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
/** For strings. Short strings are stored in the object itself, and long strings are allocated separately.
* NOTE It could also be suitable for arrays of numbers.
// */
struct SingleValueDataString final : public SingleValueDataBase
{
static constexpr bool is_compilable = false;
using Self = SingleValueDataString;
/// 0 size indicates that there is no value. Empty string must have terminating '\0' and, therefore, size of empty string is 1
UInt32 size = 0;
UInt32 capacity = 0; /// power of two or zero
char * large_data; /// Always allocated in an arena
//// TODO: Maybe instead of a virtual class we need to go with a std::variant of the 3 to avoid reserving space for the vtable
static constexpr UInt32 MAX_SMALL_STRING_SIZE
= SingleValueDataBase::MAX_STORAGE_SIZE - sizeof(size) - sizeof(capacity) - sizeof(large_data) - sizeof(SingleValueDataBase);
static constexpr UInt32 MAX_STRING_SIZE = std::numeric_limits<Int32>::max();
private:
char small_data[MAX_SMALL_STRING_SIZE]; /// Including the terminating zero.
char * getDataMutable();
const char * getData() const;
StringRef getStringRef() const;
void allocateLargeDataIfNeeded(UInt32 size_to_reserve, Arena * arena);
void changeImpl(StringRef value, Arena * arena);
public:
bool has() const override { return size != 0; }
void insertResultInto(IColumn & to) const override;
void write(WriteBuffer & buf, const ISerialization & /*serialization*/) const override;
void read(ReadBuffer & buf, const ISerialization & /*serialization*/, Arena * arena) override;
bool isEqualTo(const IColumn & column, size_t row_num) const override;
bool isEqualTo(const SingleValueDataBase &) const override;
void set(const IColumn & column, size_t row_num, Arena * arena) override;
void set(const SingleValueDataBase &, Arena * arena) override;
bool setIfSmaller(const IColumn & column, size_t row_num, Arena * arena) override;
bool setIfSmaller(const SingleValueDataBase &, Arena * arena) override;
bool setIfGreater(const IColumn & column, size_t row_num, Arena * arena) override;
bool setIfGreater(const SingleValueDataBase &, Arena * arena) override;
static bool allocatesMemoryInArena() { return true; }
};
static_assert(sizeof(SingleValueDataString) == SingleValueDataBase::MAX_STORAGE_SIZE, "Incorrect size of SingleValueDataString struct");
/// For any other value types.
struct SingleValueDataGeneric final : public SingleValueDataBase
{
static constexpr bool is_compilable = false;
private:
using Self = SingleValueDataGeneric;
Field value;
public:
bool has() const override { return !value.isNull(); }
void insertResultInto(IColumn & to) const override;
void write(WriteBuffer & buf, const ISerialization & serialization) const override;
void read(ReadBuffer & buf, const ISerialization & serialization, Arena *) override;
bool isEqualTo(const IColumn & column, size_t row_num) const override;
bool isEqualTo(const SingleValueDataBase & other) const override;
void set(const IColumn & column, size_t row_num, Arena *) override;
void set(const SingleValueDataBase & other, Arena *) override;
bool setIfSmaller(const IColumn & column, size_t row_num, Arena * arena) override;
bool setIfSmaller(const SingleValueDataBase & other, Arena *) override;
bool setIfGreater(const IColumn & column, size_t row_num, Arena * arena) override;
bool setIfGreater(const SingleValueDataBase & other, Arena *) override;
static bool allocatesMemoryInArena() { return false; }
};
static_assert(sizeof(SingleValueDataGeneric) <= SingleValueDataBase::MAX_STORAGE_SIZE, "Incorrect size of SingleValueDataGeneric struct");
/// min, max, any, anyLast, anyHeavy, etc...
template <template <typename, bool...> class AggregateFunctionTemplate, bool unary, bool... isMin>
static IAggregateFunction *
createAggregateFunctionSingleValue(const String & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
assertNoParameters(name, parameters);
if constexpr (unary)
assertUnary(name, argument_types);
else
assertBinary(name, argument_types);
const DataTypePtr & value_type = unary ? argument_types[0] : argument_types[1];
WhichDataType which(value_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionTemplate<SingleValueDataFixed<TYPE>, isMin...>(argument_types); /// NOLINT
FOR_SINGLE_VALUE_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Date)
return new AggregateFunctionTemplate<SingleValueDataFixed<DataTypeDate::FieldType>, isMin...>(argument_types);
if (which.idx == TypeIndex::DateTime)
return new AggregateFunctionTemplate<SingleValueDataFixed<DataTypeDateTime::FieldType>, isMin...>(argument_types);
if (which.idx == TypeIndex::String)
return new AggregateFunctionTemplate<SingleValueDataString, isMin...>(argument_types);
return new AggregateFunctionTemplate<SingleValueDataGeneric, isMin...>(argument_types);
}
/// Helper to allocate enough memory to store any derived class
struct SingleValueDataBaseMemoryBlock
{
std::aligned_union_t<
SingleValueDataBase::MAX_STORAGE_SIZE,
SingleValueDataNumeric<Decimal256>, /// We check all types in generateSingleValueFromTypeIndex
SingleValueDataString,
SingleValueDataGeneric>
memory;
SingleValueDataBase & get() { return *reinterpret_cast<SingleValueDataBase *>(&memory); }
const SingleValueDataBase & get() const { return *reinterpret_cast<const SingleValueDataBase *>(&memory); }
};
static_assert(alignof(SingleValueDataBaseMemoryBlock) == 8);
/// For Data classes that want to compose on top of SingleValueDataBase values, like argMax or singleValueOrNull
/// It will build the object based on the type idx on the memory block provided
void generateSingleValueFromTypeIndex(TypeIndex idx, SingleValueDataBaseMemoryBlock & data);
bool singleValueTypeAllocatesMemoryInArena(TypeIndex idx);
}

View File

@ -39,9 +39,11 @@ void registerAggregateFunctionsQuantileApprox(AggregateFunctionFactory &);
void registerAggregateFunctionsSequenceMatch(AggregateFunctionFactory &);
void registerAggregateFunctionWindowFunnel(AggregateFunctionFactory &);
void registerAggregateFunctionRate(AggregateFunctionFactory &);
void registerAggregateFunctionsMin(AggregateFunctionFactory &);
void registerAggregateFunctionsMax(AggregateFunctionFactory &);
void registerAggregateFunctionsMinMax(AggregateFunctionFactory &);
void registerAggregateFunctionsArgMinArgMax(AggregateFunctionFactory &);
void registerAggregateFunctionsAny(AggregateFunctionFactory &);
void registerAggregateFunctionAnyHeavy(AggregateFunctionFactory &);
void registerAggregateFunctionsAnyRespectNulls(AggregateFunctionFactory &);
void registerAggregateFunctionsStatisticsStable(AggregateFunctionFactory &);
void registerAggregateFunctionsStatisticsSecondMoment(AggregateFunctionFactory &);
void registerAggregateFunctionsStatisticsThirdMoment(AggregateFunctionFactory &);
@ -99,7 +101,7 @@ void registerAggregateFunctionCombinatorOrFill(AggregateFunctionCombinatorFactor
void registerAggregateFunctionCombinatorResample(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorDistinct(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorMap(AggregateFunctionCombinatorFactory & factory);
void registerAggregateFunctionCombinatorMinMax(AggregateFunctionCombinatorFactory & factory);
void registerAggregateFunctionCombinatorsArgMinArgMax(AggregateFunctionCombinatorFactory & factory);
void registerWindowFunctions(AggregateFunctionFactory & factory);
@ -138,9 +140,11 @@ void registerAggregateFunctions()
registerAggregateFunctionsSequenceMatch(factory);
registerAggregateFunctionWindowFunnel(factory);
registerAggregateFunctionRate(factory);
registerAggregateFunctionsMin(factory);
registerAggregateFunctionsMax(factory);
registerAggregateFunctionsMinMax(factory);
registerAggregateFunctionsArgMinArgMax(factory);
registerAggregateFunctionsAny(factory);
registerAggregateFunctionAnyHeavy(factory);
registerAggregateFunctionsAnyRespectNulls(factory);
registerAggregateFunctionsStatisticsStable(factory);
registerAggregateFunctionsStatisticsSecondMoment(factory);
registerAggregateFunctionsStatisticsThirdMoment(factory);
@ -203,7 +207,7 @@ void registerAggregateFunctions()
registerAggregateFunctionCombinatorResample(factory);
registerAggregateFunctionCombinatorDistinct(factory);
registerAggregateFunctionCombinatorMap(factory);
registerAggregateFunctionCombinatorMinMax(factory);
registerAggregateFunctionCombinatorsArgMinArgMax(factory);
}
}

View File

@ -19,7 +19,7 @@ class BackupCoordinationFileInfos
public:
/// plain_backup sets that we're writing a plain backup, which means all duplicates are written as is, and empty files are written as is.
/// (For normal backups only the first file amongst duplicates is actually stored, and empty files are not stored).
BackupCoordinationFileInfos(bool plain_backup_) : plain_backup(plain_backup_) {}
explicit BackupCoordinationFileInfos(bool plain_backup_) : plain_backup(plain_backup_) {}
/// Adds file infos for the specified host.
void addFileInfos(BackupFileInfos && file_infos, const String & host_id);

View File

@ -21,7 +21,7 @@ namespace DB
class BackupCoordinationLocal : public IBackupCoordination
{
public:
BackupCoordinationLocal(bool plain_backup_);
explicit BackupCoordinationLocal(bool plain_backup_);
~BackupCoordinationLocal() override;
void setStage(const String & new_stage, const String & message) override;

View File

@ -927,7 +927,7 @@ void BackupImpl::writeFile(const BackupFileInfo & info, BackupEntryPtr entry)
const auto write_info_to_archive = [&](const auto & file_name)
{
auto out = archive_writer->writeFile(file_name);
auto out = archive_writer->writeFile(file_name, info.size);
auto read_buffer = entry->getReadBuffer(writer->getReadSettings());
if (info.base_size != 0)
read_buffer->seek(info.base_size, SEEK_SET);

View File

@ -52,7 +52,7 @@ private:
struct Task : public AsyncTask
{
Task(PacketReceiver & receiver_) : receiver(receiver_) {}
explicit Task(PacketReceiver & receiver_) : receiver(receiver_) {}
PacketReceiver & receiver;

View File

@ -53,7 +53,7 @@ class TestHint
{
public:
using ErrorVector = std::vector<int>;
TestHint(const String & query_);
explicit TestHint(const String & query_);
const auto & serverErrors() const { return server_errors; }
const auto & clientErrors() const { return client_errors; }

View File

@ -20,6 +20,7 @@
#include <Common/TargetSpecific.h>
#include <Common/WeakHash.h>
#include <Common/assert_cast.h>
#include <Common/findExtreme.h>
#include <Common/iota.h>
#include <bit>
@ -248,6 +249,26 @@ void ColumnVector<T>::getPermutation(IColumn::PermutationSortDirection direction
iota(res.data(), data_size, IColumn::Permutation::value_type(0));
if constexpr (has_find_extreme_implementation<T> && !std::is_floating_point_v<T>)
{
/// Disabled for:floating point
/// * floating point: We don't deal with nan_direction_hint
/// * stability::Stable: We might return any value, not the first
if ((limit == 1) && (stability == IColumn::PermutationSortStability::Unstable))
{
std::optional<size_t> index;
if (direction == IColumn::PermutationSortDirection::Ascending)
index = findExtremeMinIndex(data.data(), 0, data.size());
else
index = findExtremeMaxIndex(data.data(), 0, data.size());
if (index)
{
res.data()[0] = *index;
return;
}
}
}
if constexpr (is_arithmetic_v<T> && !is_big_int_v<T>)
{
if (!limit)

View File

@ -627,7 +627,7 @@ struct IsMutableColumns;
template <typename Arg, typename ... Args>
struct IsMutableColumns<Arg, Args ...>
{
static const bool value = std::is_assignable<MutableColumnPtr &&, Arg>::value && IsMutableColumns<Args ...>::value;
static const bool value = std::is_assignable_v<MutableColumnPtr &&, Arg> && IsMutableColumns<Args ...>::value;
};
template <>

View File

@ -70,7 +70,7 @@ private:
return *this;
}
MemoryChunk(size_t size_)
explicit MemoryChunk(size_t size_)
{
ProfileEvents::increment(ProfileEvents::ArenaAllocChunks);
ProfileEvents::increment(ProfileEvents::ArenaAllocBytes, size_);

View File

@ -46,7 +46,7 @@ public:
class AsyncTaskExecutor
{
public:
AsyncTaskExecutor(std::unique_ptr<AsyncTask> task_);
explicit AsyncTaskExecutor(std::unique_ptr<AsyncTask> task_);
/// Resume task execution. This method returns when task is completed or suspended.
void resume();

View File

@ -289,10 +289,10 @@ void DNSResolver::setDisableCacheFlag(bool is_disabled)
impl->disable_cache = is_disabled;
}
void DNSResolver::setCacheMaxSize(const UInt64 cache_max_size)
void DNSResolver::setCacheMaxEntries(const UInt64 cache_max_entries)
{
impl->cache_address.setMaxSizeInBytes(cache_max_size);
impl->cache_host.setMaxSizeInBytes(cache_max_size);
impl->cache_address.setMaxSizeInBytes(cache_max_entries);
impl->cache_host.setMaxSizeInBytes(cache_max_entries);
}
String DNSResolver::getHostName()

View File

@ -52,8 +52,8 @@ public:
/// Disables caching
void setDisableCacheFlag(bool is_disabled = true);
/// Set a limit of cache size in bytes
void setCacheMaxSize(const UInt64 cache_max_size);
/// Set a limit of entries in cache
void setCacheMaxEntries(const UInt64 cache_max_entries);
/// Drops all caches
void dropCache();

View File

@ -3,13 +3,13 @@
#include <base/DayNum.h>
#include <base/defines.h>
#include <base/types.h>
#include <Core/DecimalFunctions.h>
#include <ctime>
#include <cassert>
#include <string>
#include <type_traits>
#define DATE_SECONDS_PER_DAY 86400 /// Number of seconds in a day, 60 * 60 * 24
#define DATE_LUT_MIN_YEAR 1900 /// 1900 since majority of financial organizations consider 1900 as an initial year.
@ -280,9 +280,9 @@ private:
static_assert(std::is_integral_v<DateOrTime> && std::is_integral_v<Divisor>);
assert(divisor > 0);
if (likely(offset_is_whole_number_of_hours_during_epoch))
if (offset_is_whole_number_of_hours_during_epoch) [[likely]]
{
if (likely(x >= 0))
if (x >= 0) [[likely]]
return static_cast<DateOrTime>(x / divisor * divisor);
/// Integer division for negative numbers rounds them towards zero (up).
@ -576,10 +576,10 @@ public:
unsigned toSecond(Time t) const
{
if (likely(offset_is_whole_number_of_minutes_during_epoch))
if (offset_is_whole_number_of_minutes_during_epoch) [[likely]]
{
Time res = t % 60;
if (likely(res >= 0))
if (res >= 0) [[likely]]
return static_cast<unsigned>(res);
return static_cast<unsigned>(res) + 60;
}
@ -593,6 +593,30 @@ public:
return time % 60;
}
template <typename DateOrTime>
unsigned toMillisecond(const DateOrTime & datetime, Int64 scale_multiplier) const
{
constexpr Int64 millisecond_multiplier = 1'000;
constexpr Int64 microsecond_multiplier = 1'000 * millisecond_multiplier;
constexpr Int64 divider = microsecond_multiplier / millisecond_multiplier;
auto components = DB::DecimalUtils::splitWithScaleMultiplier(datetime, scale_multiplier);
if (datetime.value < 0 && components.fractional)
{
components.fractional = scale_multiplier + (components.whole ? Int64(-1) : Int64(1)) * components.fractional;
--components.whole;
}
Int64 fractional = components.fractional;
if (scale_multiplier > microsecond_multiplier)
fractional = fractional / (scale_multiplier / microsecond_multiplier);
else if (scale_multiplier < microsecond_multiplier)
fractional = fractional * (microsecond_multiplier / scale_multiplier);
UInt16 millisecond = static_cast<UInt16>(fractional / divider);
return millisecond;
}
unsigned toMinute(Time t) const
{
if (t >= 0 && offset_is_whole_number_of_hours_during_epoch)
@ -1122,9 +1146,9 @@ public:
DateOrTime toStartOfMinuteInterval(DateOrTime t, UInt64 minutes) const
{
Int64 divisor = 60 * minutes;
if (likely(offset_is_whole_number_of_minutes_during_epoch))
if (offset_is_whole_number_of_minutes_during_epoch) [[likely]]
{
if (likely(t >= 0))
if (t >= 0) [[likely]]
return static_cast<DateOrTime>(t / divisor * divisor);
return static_cast<DateOrTime>((t + 1 - divisor) / divisor * divisor);
}
@ -1339,7 +1363,7 @@ public:
UInt8 saturateDayOfMonth(Int16 year, UInt8 month, UInt8 day_of_month) const
{
if (likely(day_of_month <= 28))
if (day_of_month <= 28) [[likely]]
return day_of_month;
UInt8 days_in_month = daysInMonth(year, month);

View File

@ -11,7 +11,7 @@ namespace DB
class EnvironmentProxyConfigurationResolver : public ProxyConfigurationResolver
{
public:
EnvironmentProxyConfigurationResolver(Protocol request_protocol, bool disable_tunneling_for_https_requests_over_http_proxy_ = false);
explicit EnvironmentProxyConfigurationResolver(Protocol request_protocol, bool disable_tunneling_for_https_requests_over_http_proxy_ = false);
ProxyConfiguration resolve() override;
void errorReport(const ProxyConfiguration &) override {}

View File

@ -203,7 +203,7 @@ public:
{
auto e = ErrnoException(fmt::format(fmt.fmt_str, std::forward<Args>(args)...), code, with_errno);
e.message_format_string = fmt.message_format_string;
throw e;
throw e; /// NOLINT
}
template <typename... Args>
@ -212,7 +212,7 @@ public:
auto e = ErrnoException(fmt::format(fmt.fmt_str, std::forward<Args>(args)...), code, errno);
e.message_format_string = fmt.message_format_string;
e.path = path;
throw e;
throw e; /// NOLINT
}
template <typename... Args>
@ -222,7 +222,7 @@ public:
auto e = ErrnoException(fmt::format(fmt.fmt_str, std::forward<Args>(args)...), code, with_errno);
e.message_format_string = fmt.message_format_string;
e.path = path;
throw e;
throw e; /// NOLINT
}
ErrnoException * clone() const override { return new ErrnoException(*this); }

View File

@ -19,7 +19,7 @@ using DiskPtr = std::shared_ptr<IDisk>;
class FileChecker
{
public:
FileChecker(const String & file_info_path_);
explicit FileChecker(const String & file_info_path_);
FileChecker(DiskPtr disk_, const String & file_info_path_);
void setPath(const String & file_info_path_);
@ -50,7 +50,7 @@ public:
struct DataValidationTasks
{
DataValidationTasks(const std::map<String, size_t> & map_)
explicit DataValidationTasks(const std::map<String, size_t> & map_)
: map(map_), it(map.begin())
{}

View File

@ -25,7 +25,7 @@ class FileRenamer
public:
FileRenamer();
FileRenamer(const String & renaming_rule);
explicit FileRenamer(const String & renaming_rule);
String generateNewFilename(const String & filename) const;

View File

@ -21,12 +21,12 @@ struct FormatStringHelperImpl
std::string_view message_format_string;
fmt::format_string<Args...> fmt_str;
template<typename T>
consteval FormatStringHelperImpl(T && str) : message_format_string(tryGetStaticFormatString(str)), fmt_str(std::forward<T>(str))
consteval FormatStringHelperImpl(T && str) : message_format_string(tryGetStaticFormatString(str)), fmt_str(std::forward<T>(str)) /// NOLINT
{
formatStringCheckArgsNumImpl(message_format_string, sizeof...(Args));
}
template<typename T>
FormatStringHelperImpl(fmt::basic_runtime<T> && str) : message_format_string(), fmt_str(std::forward<fmt::basic_runtime<T>>(str)) {}
FormatStringHelperImpl(fmt::basic_runtime<T> && str) : fmt_str(std::forward<fmt::basic_runtime<T>>(str)) {} /// NOLINT
PreformattedMessage format(Args && ...args) const;
};
@ -43,9 +43,9 @@ struct PreformattedMessage
template <typename... Args>
static PreformattedMessage create(FormatStringHelper<Args...> fmt, Args &&... args);
operator const std::string & () const { return text; }
operator std::string () && { return std::move(text); }
operator fmt::format_string<> () const { UNREACHABLE(); }
operator const std::string & () const { return text; } /// NOLINT
operator std::string () && { return std::move(text); } /// NOLINT
operator fmt::format_string<> () const { UNREACHABLE(); } /// NOLINT
void apply(std::string & out_text, std::string_view & out_format_string) const &
{

View File

@ -181,7 +181,7 @@ using TracingContextHolderPtr = std::unique_ptr<TracingContextHolder>;
/// Once it's created or destructed, it automatically maitains the tracing context on the thread that it lives.
struct SpanHolder : public Span
{
SpanHolder(std::string_view, SpanKind _kind = INTERNAL);
explicit SpanHolder(std::string_view, SpanKind _kind = INTERNAL);
~SpanHolder();
/// Finish a span explicitly if needed.

View File

@ -21,7 +21,7 @@ namespace DB
class IClassifier : private boost::noncopyable
{
public:
virtual ~IClassifier() {}
virtual ~IClassifier() = default;
/// Returns ResourceLink that should be used to access resource.
/// Returned link is valid until classifier destruction.
@ -38,7 +38,7 @@ using ClassifierPtr = std::shared_ptr<IClassifier>;
class IResourceManager : private boost::noncopyable
{
public:
virtual ~IResourceManager() {}
virtual ~IResourceManager() = default;
/// Initialize or reconfigure manager.
virtual void updateConfiguration(const Poco::Util::AbstractConfiguration & config) = 0;

View File

@ -35,7 +35,7 @@ class PriorityPolicy : public ISchedulerNode
};
public:
PriorityPolicy(EventQueue * event_queue_, const Poco::Util::AbstractConfiguration & config = emptyConfig(), const String & config_prefix = {})
explicit PriorityPolicy(EventQueue * event_queue_, const Poco::Util::AbstractConfiguration & config = emptyConfig(), const String & config_prefix = {})
: ISchedulerNode(event_queue_, config, config_prefix)
{}

View File

@ -18,7 +18,7 @@ class SemaphoreConstraint : public ISchedulerConstraint
static constexpr Int64 default_max_requests = std::numeric_limits<Int64>::max();
static constexpr Int64 default_max_cost = std::numeric_limits<Int64>::max();
public:
SemaphoreConstraint(EventQueue * event_queue_, const Poco::Util::AbstractConfiguration & config = emptyConfig(), const String & config_prefix = {})
explicit SemaphoreConstraint(EventQueue * event_queue_, const Poco::Util::AbstractConfiguration & config = emptyConfig(), const String & config_prefix = {})
: ISchedulerConstraint(event_queue_, config, config_prefix)
, max_requests(config.getInt64(config_prefix + ".max_requests", default_max_requests))
, max_cost(config.getInt64(config_prefix + ".max_cost", config.getInt64(config_prefix + ".max_bytes", default_max_cost)))

View File

@ -20,7 +20,7 @@ class ThrottlerConstraint : public ISchedulerConstraint
public:
static constexpr double default_burst_seconds = 1.0;
ThrottlerConstraint(EventQueue * event_queue_, const Poco::Util::AbstractConfiguration & config = emptyConfig(), const String & config_prefix = {})
explicit ThrottlerConstraint(EventQueue * event_queue_, const Poco::Util::AbstractConfiguration & config = emptyConfig(), const String & config_prefix = {})
: ISchedulerConstraint(event_queue_, config, config_prefix)
, max_speed(config.getDouble(config_prefix + ".max_speed", 0))
, max_burst(config.getDouble(config_prefix + ".max_burst", default_burst_seconds * max_speed))

View File

@ -75,7 +75,7 @@ struct ResourceTestBase
struct ConstraintTest : public SemaphoreConstraint
{
ConstraintTest(EventQueue * event_queue_, const Poco::Util::AbstractConfiguration & config = emptyConfig(), const String & config_prefix = {})
explicit ConstraintTest(EventQueue * event_queue_, const Poco::Util::AbstractConfiguration & config = emptyConfig(), const String & config_prefix = {})
: SemaphoreConstraint(event_queue_, config, config_prefix)
{}

View File

@ -101,7 +101,7 @@ class SystemLogQueue
using Index = uint64_t;
public:
SystemLogQueue(const SystemLogQueueSettings & settings_);
explicit SystemLogQueue(const SystemLogQueueSettings & settings_);
void shutdown();
@ -153,7 +153,7 @@ class SystemLogBase : public ISystemLog
public:
using Self = SystemLogBase;
SystemLogBase(
explicit SystemLogBase(
const SystemLogQueueSettings & settings_,
std::shared_ptr<SystemLogQueue<LogElement>> queue_ = nullptr);

View File

@ -66,7 +66,7 @@ class ThreadGroup
public:
ThreadGroup();
using FatalErrorCallback = std::function<void()>;
ThreadGroup(ContextPtr query_context_, FatalErrorCallback fatal_error_callback_ = {});
explicit ThreadGroup(ContextPtr query_context_, FatalErrorCallback fatal_error_callback_ = {});
/// The first thread created this thread group
const UInt64 master_thread_id;

View File

@ -476,7 +476,7 @@ private:
incrementErrorMetrics(code);
}
static void incrementErrorMetrics(const Error code_);
static void incrementErrorMetrics(Error code_);
public:
explicit Exception(const Error code_); /// NOLINT

View File

@ -152,7 +152,7 @@ private:
struct ResponsesWithFutures
{
ResponsesWithFutures(FutureResponses future_responses_) : future_responses(std::move(future_responses_))
ResponsesWithFutures(FutureResponses future_responses_) : future_responses(std::move(future_responses_)) /// NOLINT(google-explicit-constructor)
{
cached_responses.resize(future_responses.size());
}

View File

@ -25,7 +25,7 @@ struct ZooKeeperArgs
ZooKeeperArgs(const Poco::Util::AbstractConfiguration & config, const String & config_name);
/// hosts_string -- comma separated [secure://]host:port list
ZooKeeperArgs(const String & hosts_string);
ZooKeeperArgs(const String & hosts_string); /// NOLINT(google-explicit-constructor)
ZooKeeperArgs() = default;
bool operator == (const ZooKeeperArgs &) const = default;

View File

@ -2,23 +2,28 @@
#include <Common/TargetSpecific.h>
#include <Common/findExtreme.h>
#include <limits>
#include <type_traits>
namespace DB
{
template <is_any_native_number T>
template <has_find_extreme_implementation T>
struct MinComparator
{
static ALWAYS_INLINE inline const T & cmp(const T & a, const T & b) { return std::min(a, b); }
};
template <is_any_native_number T>
template <has_find_extreme_implementation T>
struct MaxComparator
{
static ALWAYS_INLINE inline const T & cmp(const T & a, const T & b) { return std::max(a, b); }
};
MULTITARGET_FUNCTION_AVX2_SSE42(
MULTITARGET_FUNCTION_HEADER(template <is_any_native_number T, typename ComparatorClass, bool add_all_elements, bool add_if_cond_zero> static std::optional<T> NO_INLINE),
MULTITARGET_FUNCTION_HEADER(
template <has_find_extreme_implementation T, typename ComparatorClass, bool add_all_elements, bool add_if_cond_zero>
static std::optional<T> NO_INLINE),
findExtremeImpl,
MULTITARGET_FUNCTION_BODY((const T * __restrict ptr, const UInt8 * __restrict condition_map [[maybe_unused]], size_t row_begin, size_t row_end) /// NOLINT
{
@ -65,24 +70,57 @@ MULTITARGET_FUNCTION_AVX2_SSE42(
for (size_t unroll_it = 0; unroll_it < unroll_block; unroll_it++)
ret = ComparatorClass::cmp(ret, partial_min[unroll_it]);
}
}
for (; i < count; i++)
for (; i < count; i++)
{
if (add_all_elements || !condition_map[i] == add_if_cond_zero)
ret = ComparatorClass::cmp(ret, ptr[i]);
}
return ret;
}
else
{
if (add_all_elements || !condition_map[i] == add_if_cond_zero)
ret = ComparatorClass::cmp(ret, ptr[i]);
/// Only native integers
for (; i < count; i++)
{
constexpr bool is_min = std::same_as<ComparatorClass, MinComparator<T>>;
if constexpr (add_all_elements)
{
ret = ComparatorClass::cmp(ret, ptr[i]);
}
else if constexpr (is_min)
{
/// keep_number will be 0 or 1
bool keep_number = !condition_map[i] == add_if_cond_zero;
/// If keep_number = ptr[i] * 1 + 0 * max = ptr[i]
/// If not keep_number = ptr[i] * 0 + 1 * max = max
T final = ptr[i] * T{keep_number} + T{!keep_number} * std::numeric_limits<T>::max();
ret = ComparatorClass::cmp(ret, final);
}
else
{
static_assert(std::same_as<ComparatorClass, MaxComparator<T>>);
/// keep_number will be 0 or 1
bool keep_number = !condition_map[i] == add_if_cond_zero;
/// If keep_number = ptr[i] * 1 + 0 * lowest = ptr[i]
/// If not keep_number = ptr[i] * 0 + 1 * lowest = lowest
T final = ptr[i] * T{keep_number} + T{!keep_number} * std::numeric_limits<T>::lowest();
ret = ComparatorClass::cmp(ret, final);
}
}
return ret;
}
return ret;
}
))
/// Given a vector of T finds the extreme (MIN or MAX) value
template <is_any_native_number T, class ComparatorClass, bool add_all_elements, bool add_if_cond_zero>
template <has_find_extreme_implementation T, class ComparatorClass, bool add_all_elements, bool add_if_cond_zero>
static std::optional<T>
findExtreme(const T * __restrict ptr, const UInt8 * __restrict condition_map [[maybe_unused]], size_t start, size_t end)
{
#if USE_MULTITARGET_CODE
/// In some cases the compiler if able to apply the condition and still generate SIMD, so we still build both
/// conditional and unconditional functions with multiple architectures
/// We see no benefit from using AVX512BW or AVX512F (over AVX2), so we only declare SSE and AVX2
if (isArchSupported(TargetArch::AVX2))
return findExtremeImplAVX2<T, ComparatorClass, add_all_elements, add_if_cond_zero>(ptr, condition_map, start, end);
@ -93,50 +131,90 @@ findExtreme(const T * __restrict ptr, const UInt8 * __restrict condition_map [[m
return findExtremeImpl<T, ComparatorClass, add_all_elements, add_if_cond_zero>(ptr, condition_map, start, end);
}
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMin(const T * __restrict ptr, size_t start, size_t end)
{
return findExtreme<T, MinComparator<T>, true, false>(ptr, nullptr, start, end);
}
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMinNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end)
{
return findExtreme<T, MinComparator<T>, false, true>(ptr, condition_map, start, end);
}
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMinIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end)
{
return findExtreme<T, MinComparator<T>, false, false>(ptr, condition_map, start, end);
}
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMax(const T * __restrict ptr, size_t start, size_t end)
{
return findExtreme<T, MaxComparator<T>, true, false>(ptr, nullptr, start, end);
}
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMaxNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end)
{
return findExtreme<T, MaxComparator<T>, false, true>(ptr, condition_map, start, end);
}
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMaxIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end)
{
return findExtreme<T, MaxComparator<T>, false, false>(ptr, condition_map, start, end);
}
template <has_find_extreme_implementation T>
std::optional<size_t> findExtremeMinIndex(const T * __restrict ptr, size_t start, size_t end)
{
/// This is implemented based on findNumericExtreme and not the other way around (or independently) because getting
/// the MIN or MAX value of an array is possible with SIMD, but getting the index isn't.
/// So what we do is use SIMD to find the lowest value and then iterate again over the array to find its position
std::optional<T> opt = findExtremeMin(ptr, start, end);
if (!opt)
return std::nullopt;
/// Some minimal heuristics for the case the input is sorted
if (*opt == ptr[start])
return {start};
for (size_t i = end - 1; i > start; i--)
if (ptr[i] == *opt)
return {i};
return std::nullopt;
}
template <has_find_extreme_implementation T>
std::optional<size_t> findExtremeMaxIndex(const T * __restrict ptr, size_t start, size_t end)
{
std::optional<T> opt = findExtremeMax(ptr, start, end);
if (!opt)
return std::nullopt;
/// Some minimal heuristics for the case the input is sorted
if (*opt == ptr[start])
return {start};
for (size_t i = end - 1; i > start; i--)
if (ptr[i] == *opt)
return {i};
return std::nullopt;
}
#define INSTANTIATION(T) \
template std::optional<T> findExtremeMin(const T * __restrict ptr, size_t start, size_t end); \
template std::optional<T> findExtremeMinNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
template std::optional<T> findExtremeMinIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
template std::optional<T> findExtremeMinNotNull( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
template std::optional<T> findExtremeMinIf( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
template std::optional<T> findExtremeMax(const T * __restrict ptr, size_t start, size_t end); \
template std::optional<T> findExtremeMaxNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
template std::optional<T> findExtremeMaxIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end);
template std::optional<T> findExtremeMaxNotNull( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
template std::optional<T> findExtremeMaxIf( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
template std::optional<size_t> findExtremeMinIndex(const T * __restrict ptr, size_t start, size_t end); \
template std::optional<size_t> findExtremeMaxIndex(const T * __restrict ptr, size_t start, size_t end);
FOR_BASIC_NUMERIC_TYPES(INSTANTIATION)
#undef INSTANTIATION

View File

@ -11,35 +11,47 @@
namespace DB
{
template <typename T>
concept is_any_native_number = (is_any_of<T, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64>);
concept has_find_extreme_implementation = (is_any_of<T, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64>);
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMin(const T * __restrict ptr, size_t start, size_t end);
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMinNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end);
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMinIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end);
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMax(const T * __restrict ptr, size_t start, size_t end);
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMaxNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end);
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMaxIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end);
template <has_find_extreme_implementation T>
std::optional<size_t> findExtremeMinIndex(const T * __restrict ptr, size_t start, size_t end);
template <has_find_extreme_implementation T>
std::optional<size_t> findExtremeMaxIndex(const T * __restrict ptr, size_t start, size_t end);
#define EXTERN_INSTANTIATION(T) \
extern template std::optional<T> findExtremeMin(const T * __restrict ptr, size_t start, size_t end); \
extern template std::optional<T> findExtremeMinNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
extern template std::optional<T> findExtremeMinIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
extern template std::optional<T> findExtremeMinNotNull( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
extern template std::optional<T> findExtremeMinIf( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
extern template std::optional<T> findExtremeMax(const T * __restrict ptr, size_t start, size_t end); \
extern template std::optional<T> findExtremeMaxNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
extern template std::optional<T> findExtremeMaxIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end);
extern template std::optional<T> findExtremeMaxNotNull( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
extern template std::optional<T> findExtremeMaxIf( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
extern template std::optional<size_t> findExtremeMinIndex(const T * __restrict ptr, size_t start, size_t end); \
extern template std::optional<size_t> findExtremeMaxIndex(const T * __restrict ptr, size_t start, size_t end);
FOR_BASIC_NUMERIC_TYPES(EXTERN_INSTANTIATION)
FOR_BASIC_NUMERIC_TYPES(EXTERN_INSTANTIATION)
#undef EXTERN_INSTANTIATION
}

View File

@ -131,7 +131,7 @@ inline bool parseIPv4whole(const char * src, unsigned char * dst)
* @return - true if parsed successfully, false otherwise.
*/
template <typename T, typename EOFfunction>
requires (std::is_same<typename std::remove_cv<T>::type, char>::value)
requires (std::is_same_v<typename std::remove_cv_t<T>, char>)
inline bool parseIPv6(T * &src, EOFfunction eof, unsigned char * dst, int32_t first_block = -1)
{
const auto clear_dst = [dst]()
@ -305,7 +305,7 @@ inline bool parseIPv6whole(const char * src, unsigned char * dst)
* @return - true if parsed successfully, false otherwise.
*/
template <typename T, typename EOFfunction>
requires (std::is_same<typename std::remove_cv<T>::type, char>::value)
requires (std::is_same_v<typename std::remove_cv_t<T>, char>)
inline bool parseIPv6orIPv4(T * &src, EOFfunction eof, unsigned char * dst)
{
const auto clear_dst = [dst]()

View File

@ -66,7 +66,7 @@ public:
/// RET_ERROR stands for hardware codec fail, needs fallback to software codec.
static constexpr Int32 RET_ERROR = -1;
HardwareCodecDeflateQpl(SoftwareCodecDeflateQpl & sw_codec_);
explicit HardwareCodecDeflateQpl(SoftwareCodecDeflateQpl & sw_codec_);
~HardwareCodecDeflateQpl();
Int32 doCompressData(const char * source, UInt32 source_size, char * dest, UInt32 dest_size) const;

View File

@ -210,7 +210,7 @@ namespace MySQLReplication
public:
EventHeader header;
EventBase(EventHeader && header_) : header(std::move(header_)) {}
explicit EventBase(EventHeader && header_) : header(std::move(header_)) {}
virtual ~EventBase() = default;
virtual void dump(WriteBuffer & out) const = 0;
@ -224,7 +224,7 @@ namespace MySQLReplication
class FormatDescriptionEvent : public EventBase
{
public:
FormatDescriptionEvent(EventHeader && header_)
explicit FormatDescriptionEvent(EventHeader && header_)
: EventBase(std::move(header_)), binlog_version(0), create_timestamp(0), event_header_length(0)
{
}
@ -249,7 +249,7 @@ namespace MySQLReplication
UInt64 position;
String next_binlog;
RotateEvent(EventHeader && header_) : EventBase(std::move(header_)), position(0) {}
explicit RotateEvent(EventHeader && header_) : EventBase(std::move(header_)), position(0) {}
void dump(WriteBuffer & out) const override;
protected:
@ -280,7 +280,7 @@ namespace MySQLReplication
QueryType typ = QUERY_EVENT_DDL;
bool transaction_complete = true;
QueryEvent(EventHeader && header_)
explicit QueryEvent(EventHeader && header_)
: EventBase(std::move(header_)), thread_id(0), exec_time(0), schema_len(0), error_code(0), status_len(0)
{
}
@ -295,7 +295,7 @@ namespace MySQLReplication
class XIDEvent : public EventBase
{
public:
XIDEvent(EventHeader && header_) : EventBase(std::move(header_)), xid(0) {}
explicit XIDEvent(EventHeader && header_) : EventBase(std::move(header_)), xid(0) {}
protected:
UInt64 xid;
@ -417,7 +417,7 @@ namespace MySQLReplication
UInt64 table_id;
UInt16 flags;
RowsEventHeader(EventType type_) : type(type_), table_id(0), flags(0) {}
explicit RowsEventHeader(EventType type_) : type(type_), table_id(0), flags(0) {}
void parse(ReadBuffer & payload);
};
@ -482,7 +482,7 @@ namespace MySQLReplication
UInt8 commit_flag;
GTID gtid;
GTIDEvent(EventHeader && header_) : EventBase(std::move(header_)), commit_flag(0) {}
explicit GTIDEvent(EventHeader && header_) : EventBase(std::move(header_)), commit_flag(0) {}
void dump(WriteBuffer & out) const override;
protected:
@ -492,7 +492,7 @@ namespace MySQLReplication
class DryRunEvent : public EventBase
{
public:
DryRunEvent(EventHeader && header_) : EventBase(std::move(header_)) {}
explicit DryRunEvent(EventHeader && header_) : EventBase(std::move(header_)) {}
void dump(WriteBuffer & out) const override;
protected:

View File

@ -93,7 +93,7 @@ protected:
void writePayloadImpl(WriteBuffer & buffer) const override;
public:
OKPacket(uint32_t capabilities_);
explicit OKPacket(uint32_t capabilities_);
OKPacket(uint8_t header_, uint32_t capabilities_, uint64_t affected_rows_,
uint32_t status_flags_, int16_t warnings_, String session_state_changes_ = "", String info_ = "");
@ -180,7 +180,7 @@ protected:
void readPayloadImpl(ReadBuffer & payload) override;
public:
ResponsePacket(UInt32 server_capability_flags_);
explicit ResponsePacket(UInt32 server_capability_flags_);
ResponsePacket(UInt32 server_capability_flags_, bool is_handshake_);
};

View File

@ -34,7 +34,7 @@ protected:
void writePayloadImpl(WriteBuffer & buffer) const override;
public:
RegisterSlave(UInt32 server_id_);
explicit RegisterSlave(UInt32 server_id_);
};
/// https://dev.mysql.com/doc/internals/en/com-binlog-dump-gtid.html

View File

@ -81,7 +81,7 @@ namespace DB
M(UInt64, mmap_cache_size, DEFAULT_MMAP_CACHE_MAX_SIZE, "A cache for mmapped files.", 0) \
\
M(Bool, disable_internal_dns_cache, false, "Disable internal DNS caching at all.", 0) \
M(UInt64, dns_cache_max_size, 1024, "Internal DNS cache max size in bytes.", 0) \
M(UInt64, dns_cache_max_entries, 10000, "Internal DNS cache max entries.", 0) \
M(Int32, dns_cache_update_period, 15, "Internal DNS cache update period in seconds.", 0) \
M(UInt32, dns_max_consecutive_failures, 10, "Max DNS resolve failures of a hostname before dropping the hostname from ClickHouse DNS cache.", 0) \
\

View File

@ -19,7 +19,7 @@ namespace ErrorCodes
class ClickHouseVersion
{
public:
ClickHouseVersion(const String & version)
ClickHouseVersion(const String & version) /// NOLINT(google-explicit-constructor)
{
Strings split;
boost::split(split, version, [](char c){ return c == '.'; });
@ -37,7 +37,7 @@ public:
}
}
ClickHouseVersion(const char * version) : ClickHouseVersion(String(version)) {}
ClickHouseVersion(const char * version) : ClickHouseVersion(String(version)) {} /// NOLINT(google-explicit-constructor)
String toString() const
{

View File

@ -15,7 +15,7 @@
class GraphiteWriter
{
public:
GraphiteWriter(const std::string & config_name, const std::string & sub_path = "");
explicit GraphiteWriter(const std::string & config_name, const std::string & sub_path = "");
template <typename T> using KeyValuePair = std::pair<std::string, T>;
template <typename T> using KeyValueVector = std::vector<KeyValuePair<T>>;

View File

@ -102,7 +102,7 @@ public:
struct SubstreamData
{
SubstreamData() = default;
SubstreamData(SerializationPtr serialization_)
explicit SubstreamData(SerializationPtr serialization_)
: serialization(std::move(serialization_))
{
}

View File

@ -94,7 +94,7 @@ using BinlogFactoryPtr = std::shared_ptr<IBinlogFactory>;
class BinlogFromFileFactory : public IBinlogFactory
{
public:
BinlogFromFileFactory(const String & filename_);
explicit BinlogFromFileFactory(const String & filename_);
BinlogPtr createBinlog(const String & executed_gtid_set) override;
private:

View File

@ -14,7 +14,7 @@ namespace DB::MySQLReplication
class BinlogClient
{
public:
BinlogClient(const BinlogFactoryPtr & factory,
explicit BinlogClient(const BinlogFactoryPtr & factory,
const String & name = {},
UInt64 max_bytes_in_buffer_ = DBMS_DEFAULT_BUFFER_SIZE,
UInt64 max_flush_ms_ = 1000);

View File

@ -18,7 +18,7 @@ class BinlogFromDispatcher;
class BinlogEventsDispatcher final : boost::noncopyable
{
public:
BinlogEventsDispatcher(const String & logger_name_ = "BinlogDispatcher", size_t max_bytes_in_buffer_ = 1_MiB, UInt64 max_flush_ms_ = 1000);
explicit BinlogEventsDispatcher(const String & logger_name_ = "BinlogDispatcher", size_t max_bytes_in_buffer_ = 1_MiB, UInt64 max_flush_ms_ = 1000);
~BinlogEventsDispatcher();
/// Moves all IBinlog objects to \a to if it has the same position

View File

@ -14,7 +14,7 @@ namespace ErrorCodes
class NullDictionarySource final : public IDictionarySource
{
public:
NullDictionarySource(Block & sample_block_);
explicit NullDictionarySource(Block & sample_block_);
NullDictionarySource(const NullDictionarySource & other);

View File

@ -76,7 +76,7 @@ private:
const LoggerPtr log;
public:
IOUringReader(uint32_t entries_);
explicit IOUringReader(uint32_t entries_);
inline bool isSupported() { return is_supported; }
std::future<Result> submit(Request request) override;

View File

@ -20,13 +20,6 @@ namespace ErrorCodes
extern const int BAD_ARGUMENTS;
}
struct AzureBlobStorageEndpoint
{
const String storage_account_url;
const String container_name;
const std::optional<bool> container_already_exists;
};
void validateStorageAccountUrl(const String & storage_account_url)
{
@ -58,28 +51,89 @@ void validateContainerName(const String & container_name)
AzureBlobStorageEndpoint processAzureBlobStorageEndpoint(const Poco::Util::AbstractConfiguration & config, const String & config_prefix)
{
std::string storage_url;
if (config.has(config_prefix + ".storage_account_url"))
String storage_url;
String account_name;
String container_name;
String prefix;
if (config.has(config_prefix + ".endpoint"))
{
String endpoint = config.getString(config_prefix + ".endpoint");
/// For some authentication methods account name is not present in the endpoint
/// 'endpoint_contains_account_name' bool is used to understand how to split the endpoint (default : true)
bool endpoint_contains_account_name = config.getBool(config_prefix + ".endpoint_contains_account_name", true);
size_t pos = endpoint.find("//");
if (pos == std::string::npos)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Expected '//' in endpoint");
if (endpoint_contains_account_name)
{
size_t acc_pos_begin = endpoint.find('/', pos+2);
if (acc_pos_begin == std::string::npos)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Expected account_name in endpoint");
storage_url = endpoint.substr(0,acc_pos_begin);
size_t acc_pos_end = endpoint.find('/',acc_pos_begin+1);
if (acc_pos_end == std::string::npos)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Expected container_name in endpoint");
account_name = endpoint.substr(acc_pos_begin+1,(acc_pos_end-acc_pos_begin)-1);
size_t cont_pos_end = endpoint.find('/', acc_pos_end+1);
if (cont_pos_end != std::string::npos)
{
container_name = endpoint.substr(acc_pos_end+1,(cont_pos_end-acc_pos_end)-1);
prefix = endpoint.substr(cont_pos_end+1);
}
else
{
container_name = endpoint.substr(acc_pos_end+1);
}
}
else
{
size_t cont_pos_begin = endpoint.find('/', pos+2);
if (cont_pos_begin == std::string::npos)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Expected container_name in endpoint");
storage_url = endpoint.substr(0,cont_pos_begin);
size_t cont_pos_end = endpoint.find('/',cont_pos_begin+1);
if (cont_pos_end != std::string::npos)
{
container_name = endpoint.substr(cont_pos_begin+1,(cont_pos_end-cont_pos_begin)-1);
prefix = endpoint.substr(cont_pos_end+1);
}
else
{
container_name = endpoint.substr(cont_pos_begin+1);
}
}
}
else if (config.has(config_prefix + ".connection_string"))
{
storage_url = config.getString(config_prefix + ".connection_string");
container_name = config.getString(config_prefix + ".container_name");
}
else if (config.has(config_prefix + ".storage_account_url"))
{
storage_url = config.getString(config_prefix + ".storage_account_url");
validateStorageAccountUrl(storage_url);
container_name = config.getString(config_prefix + ".container_name");
}
else
{
if (config.has(config_prefix + ".connection_string"))
storage_url = config.getString(config_prefix + ".connection_string");
else if (config.has(config_prefix + ".endpoint"))
storage_url = config.getString(config_prefix + ".endpoint");
else
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Expected either `connection_string` or `endpoint` in config");
}
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Expected either `storage_account_url` or `connection_string` or `endpoint` in config");
String container_name = config.getString(config_prefix + ".container_name", "default-container");
validateContainerName(container_name);
if (!container_name.empty())
validateContainerName(container_name);
std::optional<bool> container_already_exists {};
if (config.has(config_prefix + ".container_already_exists"))
container_already_exists = {config.getBool(config_prefix + ".container_already_exists")};
return {storage_url, container_name, container_already_exists};
return {storage_url, account_name, container_name, prefix, container_already_exists};
}
@ -133,15 +187,13 @@ std::unique_ptr<BlobContainerClient> getAzureBlobContainerClient(
{
auto endpoint = processAzureBlobStorageEndpoint(config, config_prefix);
auto container_name = endpoint.container_name;
auto final_url = container_name.empty()
? endpoint.storage_account_url
: (std::filesystem::path(endpoint.storage_account_url) / container_name).string();
auto final_url = endpoint.getEndpoint();
if (endpoint.container_already_exists.value_or(false))
return getAzureBlobStorageClientWithAuth<BlobContainerClient>(final_url, container_name, config, config_prefix);
auto blob_service_client = getAzureBlobStorageClientWithAuth<BlobServiceClient>(
endpoint.storage_account_url, container_name, config, config_prefix);
endpoint.getEndpointWithoutContainer(), container_name, config, config_prefix);
try
{

View File

@ -10,9 +10,46 @@
namespace DB
{
struct AzureBlobStorageEndpoint
{
const String storage_account_url;
const String account_name;
const String container_name;
const String prefix;
const std::optional<bool> container_already_exists;
String getEndpoint()
{
String url = storage_account_url;
if (!account_name.empty())
url += "/" + account_name;
if (!container_name.empty())
url += "/" + container_name;
if (!prefix.empty())
url += "/" + prefix;
return url;
}
String getEndpointWithoutContainer()
{
String url = storage_account_url;
if (!account_name.empty())
url += "/" + account_name;
return url;
}
};
std::unique_ptr<Azure::Storage::Blobs::BlobContainerClient> getAzureBlobContainerClient(
const Poco::Util::AbstractConfiguration & config, const String & config_prefix);
AzureBlobStorageEndpoint processAzureBlobStorageEndpoint(const Poco::Util::AbstractConfiguration & config, const String & config_prefix);
std::unique_ptr<AzureObjectStorageSettings> getAzureBlobStorageSettings(const Poco::Util::AbstractConfiguration & config, const String & config_prefix, ContextPtr /*context*/);
}

View File

@ -93,11 +93,11 @@ AzureObjectStorage::AzureObjectStorage(
const String & name_,
AzureClientPtr && client_,
SettingsPtr && settings_,
const String & container_)
const String & object_namespace_)
: name(name_)
, client(std::move(client_))
, settings(std::move(settings_))
, container(container_)
, object_namespace(object_namespace_)
, log(getLogger("AzureObjectStorage"))
{
}
@ -379,7 +379,7 @@ std::unique_ptr<IObjectStorage> AzureObjectStorage::cloneObjectStorage(const std
name,
getAzureBlobContainerClient(config, config_prefix),
getAzureBlobStorageSettings(config, config_prefix, context),
container
object_namespace
);
}

View File

@ -67,7 +67,7 @@ public:
const String & name_,
AzureClientPtr && client_,
SettingsPtr && settings_,
const String & container_);
const String & object_namespace_);
void listObjects(const std::string & path, RelativePathsWithMetadata & children, int max_keys) const override;
@ -130,7 +130,7 @@ public:
const std::string & config_prefix,
ContextPtr context) override;
String getObjectsNamespace() const override { return container ; }
String getObjectsNamespace() const override { return object_namespace ; }
std::unique_ptr<IObjectStorage> cloneObjectStorage(
const std::string & new_namespace,
@ -154,7 +154,7 @@ private:
/// client used to access the files in the Blob Storage cloud
MultiVersion<Azure::Storage::Blobs::BlobContainerClient> client;
MultiVersion<AzureObjectStorageSettings> settings;
const String container;
const String object_namespace; /// container + prefix
LoggerPtr log;
};

View File

@ -78,7 +78,7 @@ private:
std::vector<MetadataOperationPtr> operations;
public:
MetadataStorageFromPlainObjectStorageTransaction(const MetadataStorageFromPlainObjectStorage & metadata_storage_)
explicit MetadataStorageFromPlainObjectStorageTransaction(const MetadataStorageFromPlainObjectStorage & metadata_storage_)
: metadata_storage(metadata_storage_)
{}

View File

@ -213,12 +213,12 @@ void registerAzureObjectStorage(ObjectStorageFactory & factory)
const ContextPtr & context,
bool /* skip_access_check */) -> ObjectStoragePtr
{
String container_name = config.getString(config_prefix + ".container_name", "default-container");
AzureBlobStorageEndpoint endpoint = processAzureBlobStorageEndpoint(config, config_prefix);
return std::make_unique<AzureObjectStorage>(
name,
getAzureBlobContainerClient(config, config_prefix),
getAzureBlobStorageSettings(config, config_prefix, context),
container_name);
endpoint.prefix.empty() ? endpoint.container_name : endpoint.container_name + "/" + endpoint.prefix);
});
}

View File

@ -23,8 +23,8 @@ struct StoredObject
const String & remote_path_ = "",
const String & local_path_ = "",
uint64_t bytes_size_ = 0)
: remote_path(std::move(remote_path_))
, local_path(std::move(local_path_))
: remote_path(remote_path_)
, local_path(local_path_)
, bytes_size(bytes_size_)
{}
};

View File

@ -63,7 +63,7 @@ class MarksInCompressedFile
public:
using PlainArray = PODArray<MarkInCompressedFile>;
MarksInCompressedFile(const PlainArray & marks);
explicit MarksInCompressedFile(const PlainArray & marks);
MarkInCompressedFile get(size_t idx) const;

View File

@ -10,16 +10,17 @@ namespace ErrorCodes
void throwDateIsNotSupported(const char * name)
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type Date of argument for function {}", name);
}
void throwDateTimeIsNotSupported(const char * name)
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type DateTime of argument for function {}", name);
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal argument of type Date for function {}", name);
}
void throwDate32IsNotSupported(const char * name)
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type Date32 of argument for function {}", name);
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal argument of type Date32 for function {}", name);
}
void throwDateTimeIsNotSupported(const char * name)
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal argument of type DateTime for function {}", name);
}
}

View File

@ -6,6 +6,7 @@
#include <Common/DateLUTImpl.h>
#include <Common/DateLUT.h>
#include <Common/IntervalKind.h>
#include "base/Decimal.h"
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnVector.h>
@ -54,8 +55,8 @@ constexpr time_t MAX_DATE_TIMESTAMP = 5662310399; // 2149-06-06 23:59:59 U
constexpr time_t MAX_DATETIME_DAY_NUM = 49710; // 2106-02-07
[[noreturn]] void throwDateIsNotSupported(const char * name);
[[noreturn]] void throwDateTimeIsNotSupported(const char * name);
[[noreturn]] void throwDate32IsNotSupported(const char * name);
[[noreturn]] void throwDateTimeIsNotSupported(const char * name);
/// This factor transformation will say that the function is monotone everywhere.
struct ZeroTransform
@ -481,7 +482,7 @@ struct ToStartOfInterval<IntervalKind::Kind::Nanosecond>
}
static UInt32 execute(Int32, Int64, const DateLUTImpl &, Int64)
{
throwDateIsNotSupported(TO_START_OF_INTERVAL_NAME);
throwDate32IsNotSupported(TO_START_OF_INTERVAL_NAME);
}
static UInt32 execute(UInt32, Int64, const DateLUTImpl &, Int64)
{
@ -516,7 +517,7 @@ struct ToStartOfInterval<IntervalKind::Kind::Microsecond>
}
static UInt32 execute(Int32, Int64, const DateLUTImpl &, Int64)
{
throwDateIsNotSupported(TO_START_OF_INTERVAL_NAME);
throwDate32IsNotSupported(TO_START_OF_INTERVAL_NAME);
}
static UInt32 execute(UInt32, Int64, const DateLUTImpl &, Int64)
{
@ -559,7 +560,7 @@ struct ToStartOfInterval<IntervalKind::Kind::Millisecond>
}
static UInt32 execute(Int32, Int64, const DateLUTImpl &, Int64)
{
throwDateIsNotSupported(TO_START_OF_INTERVAL_NAME);
throwDate32IsNotSupported(TO_START_OF_INTERVAL_NAME);
}
static UInt32 execute(UInt32, Int64, const DateLUTImpl &, Int64)
{
@ -602,7 +603,7 @@ struct ToStartOfInterval<IntervalKind::Kind::Second>
}
static UInt32 execute(Int32, Int64, const DateLUTImpl &, Int64)
{
throwDateIsNotSupported(TO_START_OF_INTERVAL_NAME);
throwDate32IsNotSupported(TO_START_OF_INTERVAL_NAME);
}
static UInt32 execute(UInt32 t, Int64 seconds, const DateLUTImpl & time_zone, Int64)
{
@ -623,7 +624,7 @@ struct ToStartOfInterval<IntervalKind::Kind::Minute>
}
static UInt32 execute(Int32, Int64, const DateLUTImpl &, Int64)
{
throwDateIsNotSupported(TO_START_OF_INTERVAL_NAME);
throwDate32IsNotSupported(TO_START_OF_INTERVAL_NAME);
}
static UInt32 execute(UInt32 t, Int64 minutes, const DateLUTImpl & time_zone, Int64)
{
@ -644,7 +645,7 @@ struct ToStartOfInterval<IntervalKind::Kind::Hour>
}
static UInt32 execute(Int32, Int64, const DateLUTImpl &, Int64)
{
throwDateIsNotSupported(TO_START_OF_INTERVAL_NAME);
throwDate32IsNotSupported(TO_START_OF_INTERVAL_NAME);
}
static UInt32 execute(UInt32 t, Int64 hours, const DateLUTImpl & time_zone, Int64)
{
@ -777,7 +778,7 @@ struct ToTimeImpl
}
static UInt32 execute(Int32, const DateLUTImpl &)
{
throwDateIsNotSupported(name);
throwDate32IsNotSupported(name);
}
static UInt32 execute(UInt16, const DateLUTImpl &)
{
@ -802,7 +803,7 @@ struct ToStartOfMinuteImpl
}
static UInt32 execute(Int32, const DateLUTImpl &)
{
throwDateIsNotSupported(name);
throwDate32IsNotSupported(name);
}
static UInt32 execute(UInt16, const DateLUTImpl &)
{
@ -849,7 +850,7 @@ struct ToStartOfSecondImpl
}
static UInt32 execute(Int32, const DateLUTImpl &)
{
throwDateIsNotSupported(name);
throwDate32IsNotSupported(name);
}
static UInt32 execute(UInt16, const DateLUTImpl &)
{
@ -897,7 +898,7 @@ struct ToStartOfMillisecondImpl
}
static UInt32 execute(Int32, const DateLUTImpl &)
{
throwDateIsNotSupported(name);
throwDate32IsNotSupported(name);
}
static UInt32 execute(UInt16, const DateLUTImpl &)
{
@ -941,7 +942,7 @@ struct ToStartOfMicrosecondImpl
}
static UInt32 execute(Int32, const DateLUTImpl &)
{
throwDateIsNotSupported(name);
throwDate32IsNotSupported(name);
}
static UInt32 execute(UInt16, const DateLUTImpl &)
{
@ -979,7 +980,7 @@ struct ToStartOfNanosecondImpl
}
static UInt32 execute(Int32, const DateLUTImpl &)
{
throwDateIsNotSupported(name);
throwDate32IsNotSupported(name);
}
static UInt32 execute(UInt16, const DateLUTImpl &)
{
@ -1004,7 +1005,7 @@ struct ToStartOfFiveMinutesImpl
}
static UInt32 execute(Int32, const DateLUTImpl &)
{
throwDateIsNotSupported(name);
throwDate32IsNotSupported(name);
}
static UInt32 execute(UInt16, const DateLUTImpl &)
{
@ -1036,7 +1037,7 @@ struct ToStartOfTenMinutesImpl
}
static UInt32 execute(Int32, const DateLUTImpl &)
{
throwDateIsNotSupported(name);
throwDate32IsNotSupported(name);
}
static UInt32 execute(UInt16, const DateLUTImpl &)
{
@ -1068,7 +1069,7 @@ struct ToStartOfFifteenMinutesImpl
}
static UInt32 execute(Int32, const DateLUTImpl &)
{
throwDateIsNotSupported(name);
throwDate32IsNotSupported(name);
}
static UInt32 execute(UInt16, const DateLUTImpl &)
{
@ -1103,7 +1104,7 @@ struct TimeSlotImpl
static UInt32 execute(Int32, const DateLUTImpl &)
{
throwDateIsNotSupported(name);
throwDate32IsNotSupported(name);
}
static UInt32 execute(UInt16, const DateLUTImpl &)
@ -1142,7 +1143,7 @@ struct ToStartOfHourImpl
static UInt32 execute(Int32, const DateLUTImpl &)
{
throwDateIsNotSupported(name);
throwDate32IsNotSupported(name);
}
static UInt32 execute(UInt16, const DateLUTImpl &)
@ -1429,7 +1430,7 @@ struct ToHourImpl
}
static UInt8 execute(Int32, const DateLUTImpl &)
{
throwDateIsNotSupported(name);
throwDate32IsNotSupported(name);
}
static UInt8 execute(UInt16, const DateLUTImpl &)
{
@ -1456,7 +1457,7 @@ struct TimezoneOffsetImpl
static time_t execute(Int32, const DateLUTImpl &)
{
throwDateIsNotSupported(name);
throwDate32IsNotSupported(name);
}
static time_t execute(UInt16, const DateLUTImpl &)
@ -1482,7 +1483,7 @@ struct ToMinuteImpl
}
static UInt8 execute(Int32, const DateLUTImpl &)
{
throwDateIsNotSupported(name);
throwDate32IsNotSupported(name);
}
static UInt8 execute(UInt16, const DateLUTImpl &)
{
@ -1507,7 +1508,7 @@ struct ToSecondImpl
}
static UInt8 execute(Int32, const DateLUTImpl &)
{
throwDateIsNotSupported(name);
throwDate32IsNotSupported(name);
}
static UInt8 execute(UInt16, const DateLUTImpl &)
{
@ -1518,6 +1519,32 @@ struct ToSecondImpl
using FactorTransform = ToStartOfMinuteImpl;
};
struct ToMillisecondImpl
{
static constexpr auto name = "toMillisecond";
static UInt16 execute(const DateTime64 & datetime64, Int64 scale_multiplier, const DateLUTImpl & time_zone)
{
return time_zone.toMillisecond<DateTime64>(datetime64, scale_multiplier);
}
static UInt16 execute(UInt32, const DateLUTImpl &)
{
return 0;
}
static UInt16 execute(Int32, const DateLUTImpl &)
{
throwDate32IsNotSupported(name);
}
static UInt16 execute(UInt16, const DateLUTImpl &)
{
throwDateIsNotSupported(name);
}
static constexpr bool hasPreimage() { return false; }
using FactorTransform = ZeroTransform;
};
struct ToISOYearImpl
{
static constexpr auto name = "toISOYear";

View File

@ -141,7 +141,7 @@ private:
const std::shared_ptr<typename DictGetter::Src> owned_dict;
public:
FunctionTransformWithDictionary(const std::shared_ptr<typename DictGetter::Src> & owned_dict_)
explicit FunctionTransformWithDictionary(const std::shared_ptr<typename DictGetter::Src> & owned_dict_)
: owned_dict(owned_dict_)
{
if (!owned_dict)
@ -232,7 +232,7 @@ private:
const std::shared_ptr<typename DictGetter::Src> owned_dict;
public:
FunctionIsInWithDictionary(const std::shared_ptr<typename DictGetter::Src> & owned_dict_)
explicit FunctionIsInWithDictionary(const std::shared_ptr<typename DictGetter::Src> & owned_dict_)
: owned_dict(owned_dict_)
{
if (!owned_dict)
@ -365,7 +365,7 @@ private:
const std::shared_ptr<typename DictGetter::Src> owned_dict;
public:
FunctionHierarchyWithDictionary(const std::shared_ptr<typename DictGetter::Src> & owned_dict_)
explicit FunctionHierarchyWithDictionary(const std::shared_ptr<typename DictGetter::Src> & owned_dict_)
: owned_dict(owned_dict_)
{
if (!owned_dict)
@ -563,7 +563,7 @@ private:
const MultiVersion<RegionsNames>::Version owned_dict;
public:
FunctionRegionToName(const MultiVersion<RegionsNames>::Version & owned_dict_)
explicit FunctionRegionToName(const MultiVersion<RegionsNames>::Version & owned_dict_)
: owned_dict(owned_dict_)
{
if (!owned_dict)

View File

@ -403,7 +403,7 @@ struct NoEscapingStateHandler : public StateHandlerImpl<false>
};
template <typename ... Args>
NoEscapingStateHandler(Args && ... args)
explicit NoEscapingStateHandler(Args && ... args)
: StateHandlerImpl<false>(std::forward<Args>(args)...) {}
};
@ -465,7 +465,7 @@ struct InlineEscapingStateHandler : public StateHandlerImpl<true>
};
template <typename ... Args>
InlineEscapingStateHandler(Args && ... args)
explicit InlineEscapingStateHandler(Args && ... args)
: StateHandlerImpl<true>(std::forward<Args>(args)...) {}
};

View File

@ -189,6 +189,7 @@ REGISTER_FUNCTION(Substring)
factory.registerFunction<FunctionSubstring<false>>({}, FunctionFactory::CaseInsensitive);
factory.registerAlias("substr", "substring", FunctionFactory::CaseInsensitive); // MySQL alias
factory.registerAlias("mid", "substring", FunctionFactory::CaseInsensitive); /// MySQL alias
factory.registerAlias("byteSlice", "substring", FunctionFactory::CaseInsensitive); /// resembles PostgreSQL's get_byte function, similar to ClickHouse's bitSlice
factory.registerFunction<FunctionSubstring<true>>({}, FunctionFactory::CaseSensitive);
}

View File

@ -0,0 +1,18 @@
#include <Functions/FunctionFactory.h>
#include <Functions/DateTimeTransforms.h>
#include <Functions/FunctionDateOrDateTimeToSomething.h>
namespace DB
{
using FunctionToMillisecond = FunctionDateOrDateTimeToSomething<DataTypeUInt16, ToMillisecondImpl>;
REGISTER_FUNCTION(ToMillisecond)
{
factory.registerFunction<FunctionToMillisecond>();
/// MySQL compatibility alias.
factory.registerAlias("MILLISECOND", "toMillisecond", FunctionFactory::CaseInsensitive);
}
}

View File

@ -22,6 +22,8 @@ public:
/// of the function `writeFile()` should be destroyed before next call of `writeFile()`.
virtual std::unique_ptr<WriteBufferFromFileBase> writeFile(const String & filename) = 0;
virtual std::unique_ptr<WriteBufferFromFileBase> writeFile(const String & filename, size_t size) = 0;
/// Returns true if there is an active instance of WriteBuffer returned by writeFile().
/// This function should be used mostly for debugging purposes.
virtual bool isWritingFile() const = 0;

View File

@ -1,11 +1,9 @@
#include <IO/Archives/ArchiveUtils.h>
#include <IO/Archives/LibArchiveReader.h>
#include <IO/ReadBufferFromFileBase.h>
#include <Common/quoteString.h>
#include <Common/scope_guard_safe.h>
#include <IO/Archives/ArchiveUtils.h>
#include <mutex>
namespace DB
{
@ -14,35 +12,58 @@ namespace DB
namespace ErrorCodes
{
extern const int CANNOT_UNPACK_ARCHIVE;
extern const int LOGICAL_ERROR;
extern const int CANNOT_READ_ALL_DATA;
extern const int UNSUPPORTED_METHOD;
extern const int CANNOT_UNPACK_ARCHIVE;
extern const int LOGICAL_ERROR;
extern const int CANNOT_READ_ALL_DATA;
extern const int UNSUPPORTED_METHOD;
}
class LibArchiveReader::StreamInfo
{
public:
explicit StreamInfo(std::unique_ptr<SeekableReadBuffer> read_buffer_) : read_buffer(std::move(read_buffer_)) { }
static ssize_t read(struct archive *, void * client_data, const void ** buff)
{
auto * read_stream = reinterpret_cast<StreamInfo *>(client_data);
*buff = reinterpret_cast<void *>(read_stream->buf);
return read_stream->read_buffer->read(read_stream->buf, DBMS_DEFAULT_BUFFER_SIZE);
}
std::unique_ptr<SeekableReadBuffer> read_buffer;
char buf[DBMS_DEFAULT_BUFFER_SIZE];
};
class LibArchiveReader::Handle
{
public:
explicit Handle(std::string path_to_archive_, bool lock_on_reading_)
: path_to_archive(path_to_archive_), lock_on_reading(lock_on_reading_)
: path_to_archive(std::move(path_to_archive_)), lock_on_reading(lock_on_reading_)
{
current_archive = open(path_to_archive);
current_archive = openWithPath(path_to_archive);
}
explicit Handle(std::string path_to_archive_, bool lock_on_reading_, const ReadArchiveFunction & archive_read_function_)
: path_to_archive(std::move(path_to_archive_)), archive_read_function(archive_read_function_), lock_on_reading(lock_on_reading_)
{
read_stream = std::make_unique<StreamInfo>(archive_read_function());
current_archive = openWithReader(read_stream.get());
}
Handle(const Handle &) = delete;
Handle(Handle && other) noexcept
: current_archive(other.current_archive)
: read_stream(std::move(other.read_stream))
, current_archive(other.current_archive)
, current_entry(other.current_entry)
, archive_read_function(std::move(other.archive_read_function))
, lock_on_reading(other.lock_on_reading)
{
other.current_archive = nullptr;
other.current_entry = nullptr;
}
~Handle()
{
close(current_archive);
}
~Handle() { close(current_archive); }
bool locateFile(const std::string & filename)
{
@ -64,10 +85,14 @@ public:
break;
if (filter(archive_entry_pathname(current_entry)))
{
valid = true;
return true;
}
}
checkError(err);
valid = false;
return false;
}
@ -81,17 +106,19 @@ public:
} while (err == ARCHIVE_RETRY);
checkError(err);
return err == ARCHIVE_OK;
valid = err == ARCHIVE_OK;
return valid;
}
std::vector<std::string> getAllFiles(NameFilter filter)
{
auto * archive = open(path_to_archive);
SCOPE_EXIT(
close(archive);
);
std::unique_ptr<LibArchiveReader::StreamInfo> rs
= archive_read_function ? std::make_unique<StreamInfo>(archive_read_function()) : nullptr;
auto * archive = rs ? openWithReader(rs.get()) : openWithPath(path_to_archive);
struct archive_entry * entry = nullptr;
SCOPE_EXIT(close(archive););
Entry entry = nullptr;
std::vector<std::string> files;
int error = readNextHeader(archive, &entry);
@ -112,6 +139,8 @@ public:
const String & getFileName() const
{
chassert(current_entry);
if (!valid)
throw Exception(ErrorCodes::CANNOT_UNPACK_ARCHIVE, "No current file");
if (!file_name)
file_name.emplace(archive_entry_pathname(current_entry));
@ -121,6 +150,8 @@ public:
const FileInfo & getFileInfo() const
{
chassert(current_entry);
if (!valid)
throw Exception(ErrorCodes::CANNOT_UNPACK_ARCHIVE, "No current file");
if (!file_info)
{
file_info.emplace();
@ -132,13 +163,21 @@ public:
return *file_info;
}
struct archive * current_archive;
struct archive_entry * current_entry = nullptr;
la_ssize_t readData(void * buf, size_t len) { return archive_read_data(current_archive, buf, len); }
const char * getArchiveError() { return archive_error_string(current_archive); }
private:
using Archive = struct archive *;
using Entry = struct archive_entry *;
void checkError(int error) const
{
if (error == ARCHIVE_FATAL)
throw Exception(ErrorCodes::CANNOT_UNPACK_ARCHIVE, "Failed to read archive while fetching all files: {}", archive_error_string(current_archive));
throw Exception(
ErrorCodes::CANNOT_UNPACK_ARCHIVE,
"Failed to read archive while fetching all files: {}",
archive_error_string(current_archive));
}
void resetFileInfo()
@ -147,7 +186,7 @@ private:
file_info.reset();
}
static struct archive * open(const String & path_to_archive)
Archive openWithReader(StreamInfo * read_stream_)
{
auto * archive = archive_read_new();
try
@ -158,13 +197,18 @@ private:
archive_read_support_filter_xz(archive);
archive_read_support_filter_lz4(archive);
archive_read_support_filter_zstd(archive);
archive_read_support_filter_lzma(archive);
// Support tar, 7zip and zip
archive_read_support_format_tar(archive);
archive_read_support_format_7zip(archive);
archive_read_support_format_zip(archive);
if (archive_read_open_filename(archive, path_to_archive.c_str(), 10240) != ARCHIVE_OK)
throw Exception(ErrorCodes::CANNOT_UNPACK_ARCHIVE, "Couldn't open archive {}: {}", quoteString(path_to_archive), archive_error_string(archive));
if (archive_read_open(archive, read_stream_, nullptr, StreamInfo::read, nullptr) != ARCHIVE_OK)
throw Exception(
ErrorCodes::CANNOT_UNPACK_ARCHIVE,
"Couldn't open archive {}: {}",
quoteString(path_to_archive),
archive_error_string(archive));
}
catch (...)
{
@ -175,7 +219,39 @@ private:
return archive;
}
static void close(struct archive * archive)
Archive openWithPath(const String & path_to_archive_)
{
auto * archive = archive_read_new();
try
{
// Support for bzip2, gzip, lzip, xz, zstd and lz4
archive_read_support_filter_bzip2(archive);
archive_read_support_filter_gzip(archive);
archive_read_support_filter_xz(archive);
archive_read_support_filter_lz4(archive);
archive_read_support_filter_zstd(archive);
archive_read_support_filter_lzma(archive);
// Support tar, 7zip and zip
archive_read_support_format_tar(archive);
archive_read_support_format_7zip(archive);
archive_read_support_format_zip(archive);
if (archive_read_open_filename(archive, path_to_archive_.c_str(), 10240) != ARCHIVE_OK)
throw Exception(
ErrorCodes::CANNOT_UNPACK_ARCHIVE,
"Couldn't open archive {}: {}",
quoteString(path_to_archive),
archive_error_string(archive));
}
catch (...)
{
close(archive);
throw;
}
return archive;
}
static void close(Archive archive)
{
if (archive)
{
@ -193,7 +269,12 @@ private:
return archive_read_next_header(archive, entry);
}
const String path_to_archive;
String path_to_archive;
std::unique_ptr<StreamInfo> read_stream;
Archive current_archive;
Entry current_entry = nullptr;
bool valid = true;
IArchiveReader::ReadArchiveFunction archive_read_function;
/// for some archive types when we are reading headers static variables are used
/// which are not thread-safe
@ -207,7 +288,7 @@ private:
class LibArchiveReader::FileEnumeratorImpl : public FileEnumerator
{
public:
explicit FileEnumeratorImpl(Handle handle_) : handle(std::move(handle_)) {}
explicit FileEnumeratorImpl(Handle handle_) : handle(std::move(handle_)) { }
const String & getFileName() const override { return handle.getFileName(); }
const FileInfo & getFileInfo() const override { return handle.getFileInfo(); }
@ -215,6 +296,7 @@ public:
/// Releases owned handle to pass it to a read buffer.
Handle releaseHandle() && { return std::move(handle); }
private:
Handle handle;
};
@ -226,36 +308,33 @@ public:
: ReadBufferFromFileBase(DBMS_DEFAULT_BUFFER_SIZE, nullptr, 0)
, handle(std::move(handle_))
, path_to_archive(std::move(path_to_archive_))
{}
{
}
off_t seek(off_t /* off */, int /* whence */) override
{
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Seek is not supported when reading from archive");
}
bool checkIfActuallySeekable() override { return false; }
off_t getPosition() override
{
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "getPosition not supported when reading from archive");
}
off_t getPosition() override { throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "getPosition not supported when reading from archive"); }
String getFileName() const override { return handle.getFileName(); }
size_t getFileSize() override { return handle.getFileInfo().uncompressed_size; }
Handle releaseHandle() &&
{
return std::move(handle);
}
Handle releaseHandle() && { return std::move(handle); }
private:
bool nextImpl() override
{
auto bytes_read = archive_read_data(handle.current_archive, internal_buffer.begin(), static_cast<int>(internal_buffer.size()));
auto bytes_read = handle.readData(internal_buffer.begin(), internal_buffer.size());
if (bytes_read < 0)
throw Exception(ErrorCodes::CANNOT_READ_ALL_DATA, "Failed to read file {} from {}: {}", handle.getFileName(), path_to_archive, archive_error_string(handle.current_archive));
throw Exception(
ErrorCodes::CANNOT_READ_ALL_DATA,
"Failed to read file {} from {}: {}",
handle.getFileName(),
path_to_archive,
handle.getArchiveError());
if (!bytes_read)
return false;
@ -274,7 +353,17 @@ private:
LibArchiveReader::LibArchiveReader(std::string archive_name_, bool lock_on_reading_, std::string path_to_archive_)
: archive_name(std::move(archive_name_)), lock_on_reading(lock_on_reading_), path_to_archive(std::move(path_to_archive_))
{}
{
}
LibArchiveReader::LibArchiveReader(
std::string archive_name_, bool lock_on_reading_, std::string path_to_archive_, const ReadArchiveFunction & archive_read_function_)
: archive_name(std::move(archive_name_))
, lock_on_reading(lock_on_reading_)
, path_to_archive(std::move(path_to_archive_))
, archive_read_function(archive_read_function_)
{
}
LibArchiveReader::~LibArchiveReader() = default;
@ -285,21 +374,25 @@ const std::string & LibArchiveReader::getPath() const
bool LibArchiveReader::fileExists(const String & filename)
{
Handle handle(path_to_archive, lock_on_reading);
Handle handle = acquireHandle();
return handle.locateFile(filename);
}
LibArchiveReader::FileInfo LibArchiveReader::getFileInfo(const String & filename)
{
Handle handle(path_to_archive, lock_on_reading);
Handle handle = acquireHandle();
if (!handle.locateFile(filename))
throw Exception(ErrorCodes::CANNOT_UNPACK_ARCHIVE, "Couldn't unpack archive {}: file not found", path_to_archive);
throw Exception(
ErrorCodes::CANNOT_UNPACK_ARCHIVE,
"Couldn't unpack archive {}: File {} was not found in archive",
path_to_archive,
quoteString(filename));
return handle.getFileInfo();
}
std::unique_ptr<LibArchiveReader::FileEnumerator> LibArchiveReader::firstFile()
{
Handle handle(path_to_archive, lock_on_reading);
Handle handle = acquireHandle();
if (!handle.nextFile())
return nullptr;
@ -308,17 +401,28 @@ std::unique_ptr<LibArchiveReader::FileEnumerator> LibArchiveReader::firstFile()
std::unique_ptr<ReadBufferFromFileBase> LibArchiveReader::readFile(const String & filename, bool throw_on_not_found)
{
return readFile([&](const std::string & file) { return file == filename; }, throw_on_not_found);
Handle handle = acquireHandle();
if (!handle.locateFile(filename))
{
if (throw_on_not_found)
throw Exception(
ErrorCodes::CANNOT_UNPACK_ARCHIVE,
"Couldn't unpack archive {}: File {} was not found in archive",
path_to_archive,
quoteString(filename));
return nullptr;
}
return std::make_unique<ReadBufferFromLibArchive>(std::move(handle), path_to_archive);
}
std::unique_ptr<ReadBufferFromFileBase> LibArchiveReader::readFile(NameFilter filter, bool throw_on_not_found)
{
Handle handle(path_to_archive, lock_on_reading);
Handle handle = acquireHandle();
if (!handle.locateFile(filter))
{
if (throw_on_not_found)
throw Exception(
ErrorCodes::CANNOT_UNPACK_ARCHIVE, "Couldn't unpack archive {}: no file found satisfying the filter", path_to_archive);
ErrorCodes::CANNOT_UNPACK_ARCHIVE, "Couldn't unpack archive {}: No file satisfying filter in archive", path_to_archive);
return nullptr;
}
return std::make_unique<ReadBufferFromLibArchive>(std::move(handle), path_to_archive);
@ -337,7 +441,8 @@ std::unique_ptr<LibArchiveReader::FileEnumerator> LibArchiveReader::nextFile(std
{
if (!dynamic_cast<ReadBufferFromLibArchive *>(read_buffer.get()))
throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong ReadBuffer passed to nextFile()");
auto read_buffer_from_libarchive = std::unique_ptr<ReadBufferFromLibArchive>(static_cast<ReadBufferFromLibArchive *>(read_buffer.release()));
auto read_buffer_from_libarchive
= std::unique_ptr<ReadBufferFromLibArchive>(static_cast<ReadBufferFromLibArchive *>(read_buffer.release()));
auto handle = std::move(*read_buffer_from_libarchive).releaseHandle();
if (!handle.nextFile())
return nullptr;
@ -348,7 +453,8 @@ std::unique_ptr<LibArchiveReader::FileEnumerator> LibArchiveReader::currentFile(
{
if (!dynamic_cast<ReadBufferFromLibArchive *>(read_buffer.get()))
throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong ReadBuffer passed to nextFile()");
auto read_buffer_from_libarchive = std::unique_ptr<ReadBufferFromLibArchive>(static_cast<ReadBufferFromLibArchive *>(read_buffer.release()));
auto read_buffer_from_libarchive
= std::unique_ptr<ReadBufferFromLibArchive>(static_cast<ReadBufferFromLibArchive *>(read_buffer.release()));
auto handle = std::move(*read_buffer_from_libarchive).releaseHandle();
return std::make_unique<FileEnumeratorImpl>(std::move(handle));
}
@ -360,13 +466,22 @@ std::vector<std::string> LibArchiveReader::getAllFiles()
std::vector<std::string> LibArchiveReader::getAllFiles(NameFilter filter)
{
Handle handle(path_to_archive, lock_on_reading);
Handle handle = acquireHandle();
return handle.getAllFiles(filter);
}
void LibArchiveReader::setPassword(const String & /*password_*/)
void LibArchiveReader::setPassword(const String & password_)
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "Can not set password to {} archive", archive_name);
if (password_.empty())
return;
throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot set password to {} archive", archive_name);
}
LibArchiveReader::Handle LibArchiveReader::acquireHandle()
{
std::lock_guard lock{mutex};
return archive_read_function ? Handle{path_to_archive, lock_on_reading, archive_read_function}
: Handle{path_to_archive, lock_on_reading};
}
#endif

View File

@ -1,8 +1,9 @@
#pragma once
#include "config.h"
#include <mutex>
#include <IO/Archives/IArchiveReader.h>
#include <IO/Archives/LibArchiveReader.h>
#include "config.h"
namespace DB
@ -52,26 +53,44 @@ protected:
/// Constructs an archive's reader that will read from a file in the local filesystem.
LibArchiveReader(std::string archive_name_, bool lock_on_reading_, std::string path_to_archive_);
LibArchiveReader(
std::string archive_name_, bool lock_on_reading_, std::string path_to_archive_, const ReadArchiveFunction & archive_read_function_);
private:
class ReadBufferFromLibArchive;
class Handle;
class FileEnumeratorImpl;
class StreamInfo;
Handle acquireHandle();
const std::string archive_name;
const bool lock_on_reading;
const String path_to_archive;
const ReadArchiveFunction archive_read_function;
mutable std::mutex mutex;
};
class TarArchiveReader : public LibArchiveReader
{
public:
explicit TarArchiveReader(std::string path_to_archive) : LibArchiveReader("tar", /*lock_on_reading_=*/ true, std::move(path_to_archive)) { }
explicit TarArchiveReader(std::string path_to_archive) : LibArchiveReader("tar", /*lock_on_reading_=*/true, std::move(path_to_archive))
{
}
explicit TarArchiveReader(std::string path_to_archive, const ReadArchiveFunction & archive_read_function)
: LibArchiveReader("tar", /*lock_on_reading_=*/true, std::move(path_to_archive), archive_read_function)
{
}
};
class SevenZipArchiveReader : public LibArchiveReader
{
public:
explicit SevenZipArchiveReader(std::string path_to_archive) : LibArchiveReader("7z", /*lock_on_reading_=*/ false, std::move(path_to_archive)) { }
explicit SevenZipArchiveReader(std::string path_to_archive)
: LibArchiveReader("7z", /*lock_on_reading_=*/false, std::move(path_to_archive))
{
}
};
#endif

View File

@ -0,0 +1,248 @@
#include <IO/Archives/LibArchiveWriter.h>
#include <filesystem>
#include <IO/WriteBufferFromFileBase.h>
#include <Common/quoteString.h>
#include <Common/scope_guard_safe.h>
#include <mutex>
#if USE_LIBARCHIVE
// this implemation follows the ZipArchiveWriter implemation as closely as possible.
namespace DB
{
namespace ErrorCodes
{
extern const int CANNOT_PACK_ARCHIVE;
extern const int NOT_IMPLEMENTED;
}
namespace
{
void checkResultCodeImpl(int code, const String & filename)
{
if (code == ARCHIVE_OK)
return;
throw Exception(
ErrorCodes::CANNOT_PACK_ARCHIVE, "Couldn't pack archive: LibArchive Code = {}, filename={}", code, quoteString(filename));
}
}
// this is a thin wrapper for libarchive to be able to write the archive to a WriteBuffer
class LibArchiveWriter::StreamInfo
{
public:
explicit StreamInfo(std::unique_ptr<WriteBuffer> archive_write_buffer_) : archive_write_buffer(std::move(archive_write_buffer_)) { }
static ssize_t memory_write(struct archive *, void * client_data, const void * buff, size_t length)
{
auto * stream_info = reinterpret_cast<StreamInfo *>(client_data);
stream_info->archive_write_buffer->write(reinterpret_cast<const char *>(buff), length);
return length;
}
std::unique_ptr<WriteBuffer> archive_write_buffer;
};
class LibArchiveWriter::WriteBufferFromLibArchive : public WriteBufferFromFileBase
{
public:
WriteBufferFromLibArchive(std::shared_ptr<LibArchiveWriter> archive_writer_, const String & filename_, const size_t & size_)
: WriteBufferFromFileBase(DBMS_DEFAULT_BUFFER_SIZE, nullptr, 0), archive_writer(archive_writer_), filename(filename_), size(size_)
{
startWritingFile();
archive = archive_writer_->getArchive();
entry = nullptr;
}
~WriteBufferFromLibArchive() override
{
try
{
closeFile(/* throw_if_error= */ false);
endWritingFile();
}
catch (...)
{
tryLogCurrentException("WriteBufferFromTarArchive");
}
}
void finalizeImpl() override
{
next();
closeFile(/* throw_if_error=*/true);
endWritingFile();
}
void sync() override { next(); }
std::string getFileName() const override { return filename; }
private:
void nextImpl() override
{
if (!offset())
return;
if (entry == nullptr)
writeEntry();
ssize_t to_write = offset();
ssize_t written = archive_write_data(archive, working_buffer.begin(), offset());
if (written != to_write)
{
throw Exception(
ErrorCodes::CANNOT_PACK_ARCHIVE,
"Couldn't pack tar archive: Failed to write all bytes, {} of {}, filename={}",
written,
to_write,
quoteString(filename));
}
}
void writeEntry()
{
expected_size = getSize();
entry = archive_entry_new();
archive_entry_set_pathname(entry, filename.c_str());
archive_entry_set_size(entry, expected_size);
archive_entry_set_filetype(entry, static_cast<__LA_MODE_T>(0100000));
archive_entry_set_perm(entry, 0644);
checkResult(archive_write_header(archive, entry));
}
size_t getSize() const
{
if (size)
return size;
else
return offset();
}
void closeFile(bool throw_if_error)
{
if (entry)
{
archive_entry_free(entry);
entry = nullptr;
}
if (throw_if_error and bytes != expected_size)
{
throw Exception(
ErrorCodes::CANNOT_PACK_ARCHIVE,
"Couldn't pack tar archive: Wrote {} of expected {} , filename={}",
bytes,
expected_size,
quoteString(filename));
}
}
void endWritingFile()
{
if (auto archive_writer_ptr = archive_writer.lock())
archive_writer_ptr->endWritingFile();
}
void startWritingFile()
{
if (auto archive_writer_ptr = archive_writer.lock())
archive_writer_ptr->startWritingFile();
}
void checkResult(int code) { checkResultCodeImpl(code, filename); }
std::weak_ptr<LibArchiveWriter> archive_writer;
const String filename;
Entry entry;
Archive archive;
size_t size;
size_t expected_size;
};
LibArchiveWriter::LibArchiveWriter(const String & path_to_archive_, std::unique_ptr<WriteBuffer> archive_write_buffer_)
: path_to_archive(path_to_archive_)
{
if (archive_write_buffer_)
stream_info = std::make_unique<StreamInfo>(std::move(archive_write_buffer_));
}
void LibArchiveWriter::createArchive()
{
std::lock_guard lock{mutex};
archive = archive_write_new();
setFormatAndSettings();
if (stream_info)
{
//This allows use to write directly to a writebuffer rather than an intermediate buffer in libarchive.
//This has to be set otherwise zstd breaks due to extra bytes being written at the end of the archive.
archive_write_set_bytes_per_block(archive, 0);
archive_write_open2(archive, stream_info.get(), nullptr, &StreamInfo::memory_write, nullptr, nullptr);
}
else
archive_write_open_filename(archive, path_to_archive.c_str());
}
LibArchiveWriter::~LibArchiveWriter()
{
chassert((finalized || std::uncaught_exceptions() || std::current_exception()) && "LibArchiveWriter is not finalized in destructor.");
if (archive)
archive_write_free(archive);
}
std::unique_ptr<WriteBufferFromFileBase> LibArchiveWriter::writeFile(const String & filename, size_t size)
{
return std::make_unique<WriteBufferFromLibArchive>(std::static_pointer_cast<LibArchiveWriter>(shared_from_this()), filename, size);
}
std::unique_ptr<WriteBufferFromFileBase> LibArchiveWriter::writeFile(const String & filename)
{
return std::make_unique<WriteBufferFromLibArchive>(std::static_pointer_cast<LibArchiveWriter>(shared_from_this()), filename, 0);
}
bool LibArchiveWriter::isWritingFile() const
{
std::lock_guard lock{mutex};
return is_writing_file;
}
void LibArchiveWriter::endWritingFile()
{
std::lock_guard lock{mutex};
is_writing_file = false;
}
void LibArchiveWriter::startWritingFile()
{
std::lock_guard lock{mutex};
if (std::exchange(is_writing_file, true))
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Cannot write two files to a tar archive in parallel");
}
void LibArchiveWriter::finalize()
{
std::lock_guard lock{mutex};
if (finalized)
return;
if (archive)
archive_write_close(archive);
if (stream_info)
{
stream_info->archive_write_buffer->finalize();
stream_info.reset();
}
finalized = true;
}
void LibArchiveWriter::setPassword(const String & password_)
{
if (password_.empty())
return;
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Setting a password is not currently supported for libarchive");
}
LibArchiveWriter::Archive LibArchiveWriter::getArchive()
{
std::lock_guard lock{mutex};
return archive;
}
}
#endif

View File

@ -0,0 +1,77 @@
#pragma once
#include "config.h"
#if USE_LIBARCHIVE
# include <IO/Archives/ArchiveUtils.h>
# include <IO/Archives/IArchiveWriter.h>
# include <IO/WriteBufferFromFileBase.h>
# include <base/defines.h>
namespace DB
{
class WriteBufferFromFileBase;
/// Interface for writing an archive.
class LibArchiveWriter : public IArchiveWriter
{
public:
/// Constructs an archive that will be written as a file in the local filesystem.
explicit LibArchiveWriter(const String & path_to_archive_, std::unique_ptr<WriteBuffer> archive_write_buffer_);
/// Call finalize() before destructing IArchiveWriter.
~LibArchiveWriter() override;
/// Starts writing a file to the archive. The function returns a write buffer,
/// any data written to that buffer will be compressed and then put to the archive.
/// You can keep only one such buffer at a time, a buffer returned by previous call
/// of the function `writeFile()` should be destroyed before next call of `writeFile()`.
std::unique_ptr<WriteBufferFromFileBase> writeFile(const String & filename) override;
/// LibArchive needs to know the size of the file being written. If the file size is not
/// passed in the the archive writer tries to infer the size by looking at the available
/// data in the buffer, if next is called before all data is written to the buffer
/// an exception is thrown.
std::unique_ptr<WriteBufferFromFileBase> writeFile(const String & filename, size_t size) override;
/// Returns true if there is an active instance of WriteBuffer returned by writeFile().
/// This function should be used mostly for debugging purposes.
bool isWritingFile() const override;
/// Finalizes writing of the archive. This function must be always called at the end of writing.
/// (Unless an error appeared and the archive is in fact no longer needed.)
void finalize() override;
/// Sets compression method and level.
/// Changing them will affect next file in the archive.
//void setCompression(const String & compression_method_, int compression_level_) override;
/// Sets password. If the password is not empty it will enable encryption in the archive.
void setPassword(const String & password) override;
protected:
using Archive = struct archive *;
using Entry = struct archive_entry *;
/// derived classes must call createArchive. CreateArchive calls setFormatAndSettings.
void createArchive();
virtual void setFormatAndSettings() = 0;
Archive archive = nullptr;
String path_to_archive;
private:
class WriteBufferFromLibArchive;
class StreamInfo;
Archive getArchive();
void startWritingFile();
void endWritingFile();
std::unique_ptr<StreamInfo> stream_info TSA_GUARDED_BY(mutex) = nullptr;
bool is_writing_file TSA_GUARDED_BY(mutex) = false;
bool finalized TSA_GUARDED_BY(mutex) = false;
mutable std::mutex mutex;
};
}
#endif

View File

@ -0,0 +1,42 @@
#include <IO/Archives/TarArchiveWriter.h>
#if USE_LIBARCHIVE
namespace DB
{
namespace ErrorCodes
{
extern const int NOT_IMPLEMENTED;
extern const int CANNOT_PACK_ARCHIVE;
}
void TarArchiveWriter::setCompression(const String & compression_method_, int compression_level_)
{
// throw an error unless setCompression is passed the default value
if (compression_method_.empty() && compression_level_ == -1)
return;
throw Exception(
ErrorCodes::NOT_IMPLEMENTED, "Using compression_method and compression_level options are not supported for tar archives");
}
void TarArchiveWriter::setFormatAndSettings()
{
archive_write_set_format_pax_restricted(archive);
inferCompressionFromPath();
}
void TarArchiveWriter::inferCompressionFromPath()
{
if (path_to_archive.ends_with(".tar.gz") || path_to_archive.ends_with(".tgz"))
archive_write_add_filter_gzip(archive);
else if (path_to_archive.ends_with(".tar.bz2"))
archive_write_add_filter_bzip2(archive);
else if (path_to_archive.ends_with(".tar.lzma"))
archive_write_add_filter_lzma(archive);
else if (path_to_archive.ends_with(".tar.zst") || path_to_archive.ends_with(".tzst"))
archive_write_add_filter_zstd(archive);
else if (path_to_archive.ends_with(".tar.xz"))
archive_write_add_filter_xz(archive);
else if (!path_to_archive.ends_with(".tar"))
throw Exception(ErrorCodes::CANNOT_PACK_ARCHIVE, "Unknown compression format");
}
}
#endif

View File

@ -0,0 +1,26 @@
#pragma once
#include "config.h"
#if USE_LIBARCHIVE
# include <IO/Archives/LibArchiveWriter.h>
namespace DB
{
using namespace std::literals;
class TarArchiveWriter : public LibArchiveWriter
{
public:
explicit TarArchiveWriter(const String & path_to_archive_, std::unique_ptr<WriteBuffer> archive_write_buffer_)
: LibArchiveWriter(path_to_archive_, std::move(archive_write_buffer_))
{
createArchive();
}
void setCompression(const String & compression_method_, int compression_level_) override;
void setFormatAndSettings() override;
void inferCompressionFromPath();
};
}
#endif

View File

@ -274,6 +274,11 @@ std::unique_ptr<WriteBufferFromFileBase> ZipArchiveWriter::writeFile(const Strin
return std::make_unique<WriteBufferFromZipArchive>(std::static_pointer_cast<ZipArchiveWriter>(shared_from_this()), filename);
}
std::unique_ptr<WriteBufferFromFileBase> ZipArchiveWriter::writeFile(const String & filename, [[maybe_unused]] size_t size)
{
return ZipArchiveWriter::writeFile(filename);
}
bool ZipArchiveWriter::isWritingFile() const
{
std::lock_guard lock{mutex};

View File

@ -32,6 +32,9 @@ public:
/// of the function `writeFile()` should be destroyed before next call of `writeFile()`.
std::unique_ptr<WriteBufferFromFileBase> writeFile(const String & filename) override;
std::unique_ptr<WriteBufferFromFileBase> writeFile(const String & filename, size_t size) override;
/// Returns true if there is an active instance of WriteBuffer returned by writeFile().
/// This function should be used mostly for debugging purposes.
bool isWritingFile() const override;

Some files were not shown because too many files have changed in this diff Show More