This commit is contained in:
Danila Puzov 2024-05-22 18:59:39 +03:00
parent e8d66bf4d7
commit 332f449a0c
7 changed files with 252 additions and 157 deletions

View File

@ -5,6 +5,7 @@
#include <Core/ServerUUID.h>
#include <Common/Logger.h>
#include <Common/logger_useful.h>
#include "base/types.h"
namespace DB
@ -34,43 +35,153 @@ namespace
- The last 12 bits are a counter to disambiguate multiple snowflakeIDs generated within the same millisecond by differen processes
*/
/// bit counts
constexpr auto timestamp_bits_count = 41;
constexpr auto machine_id_bits_count = 10;
constexpr auto machine_seq_num_bits_count = 12;
constexpr int64_t timestamp_mask = ((1LL << timestamp_bits_count) - 1) << (machine_id_bits_count + machine_seq_num_bits_count);
constexpr int64_t machine_id_mask = ((1LL << machine_id_bits_count) - 1) << machine_seq_num_bits_count;
constexpr int64_t machine_seq_num_mask = (1LL << machine_seq_num_bits_count) - 1;
constexpr int64_t max_machine_seq_num = machine_seq_num_mask;
/// bits masks for Snowflake ID components
// constexpr uint64_t timestamp_mask = ((1ULL << timestamp_bits_count) - 1) << (machine_id_bits_count + machine_seq_num_bits_count); // unused
constexpr uint64_t machine_id_mask = ((1ULL << machine_id_bits_count) - 1) << machine_seq_num_bits_count;
constexpr uint64_t machine_seq_num_mask = (1ULL << machine_seq_num_bits_count) - 1;
Int64 getMachineID()
/// max values
constexpr uint64_t max_machine_seq_num = machine_seq_num_mask;
uint64_t getMachineID()
{
UUID server_uuid = ServerUUID::get();
/// hash into 64 bits
UInt64 hi = UUIDHelpers::getHighBytes(server_uuid);
UInt64 lo = UUIDHelpers::getLowBytes(server_uuid);
return ((hi * 11) ^ (lo * 17)) & machine_id_mask;
uint64_t hi = UUIDHelpers::getHighBytes(server_uuid);
uint64_t lo = UUIDHelpers::getLowBytes(server_uuid);
/// return only 10 bits
return (((hi * 11) ^ (lo * 17)) & machine_id_mask) >> machine_seq_num_bits_count;
}
Int64 getTimestamp()
uint64_t getTimestamp()
{
auto now = std::chrono::system_clock::now();
auto ticks_since_epoch = std::chrono::duration_cast<std::chrono::milliseconds>(now.time_since_epoch()).count();
return ticks_since_epoch & ((1LL << timestamp_bits_count) - 1);
return static_cast<uint64_t>(ticks_since_epoch) & ((1ULL << timestamp_bits_count) - 1);
}
struct SnowflakeComponents {
uint64_t timestamp;
uint64_t machind_id;
uint64_t machine_seq_num;
};
SnowflakeComponents toComponents(uint64_t snowflake) {
return {
.timestamp = (snowflake >> (machine_id_bits_count + machine_seq_num_bits_count)),
.machind_id = ((snowflake & machine_id_mask) >> machine_seq_num_bits_count),
.machine_seq_num = (snowflake & machine_seq_num_mask)
};
}
class FunctionSnowflakeID : public IFunction
uint64_t toSnowflakeID(SnowflakeComponents components) {
return (components.timestamp << (machine_id_bits_count + machine_seq_num_bits_count) |
components.machind_id << (machine_seq_num_bits_count) |
components.machine_seq_num);
}
struct RangeOfSnowflakeIDs {
/// [begin, end)
SnowflakeComponents begin, end;
};
/* Get range of `input_rows_count` Snowflake IDs from `max(available, now)`
1. Calculate Snowflake ID by current timestamp (`now`)
2. `begin = max(available, now)`
3. Calculate `end = begin + input_rows_count` handling `machine_seq_num` overflow
*/
RangeOfSnowflakeIDs getRangeOfAvailableIDs(const SnowflakeComponents& available, size_t input_rows_count)
{
private:
mutable std::atomic<Int64> lowest_available_snowflake_id = 0; /// atomic to avoid a mutex
/// 1. `now`
SnowflakeComponents begin = {
.timestamp = getTimestamp(),
.machind_id = getMachineID(),
.machine_seq_num = 0
};
public:
/// 2. `begin`
if (begin.timestamp <= available.timestamp)
{
begin.timestamp = available.timestamp;
begin.machine_seq_num = available.machine_seq_num;
}
/// 3. `end = begin + input_rows_count`
SnowflakeComponents end;
const uint64_t seq_nums_in_current_timestamp_left = (max_machine_seq_num - begin.machine_seq_num + 1);
if (input_rows_count >= seq_nums_in_current_timestamp_left)
/// if sequence numbers in current timestamp is not enough for rows => update timestamp
end.timestamp = begin.timestamp + 1 + (input_rows_count - seq_nums_in_current_timestamp_left) / (max_machine_seq_num + 1);
else
end.timestamp = begin.timestamp;
end.machind_id = begin.machind_id;
end.machine_seq_num = (begin.machine_seq_num + input_rows_count) & machine_seq_num_mask;
return {begin, end};
}
struct GlobalCounterPolicy
{
static constexpr auto name = "generateSnowflakeID";
static FunctionPtr create(ContextPtr /*context*/) { return std::make_shared<FunctionSnowflakeID>(); }
static constexpr auto doc_description = R"(Generates a Snowflake ID. The generated Snowflake ID contains the current Unix timestamp in milliseconds 41 (+ 1 top zero bit) bits, followed by machine id (10 bits), a counter (12 bits) to distinguish IDs within a millisecond. For any given timestamp (unix_ts_ms), the counter starts at 0 and is incremented by 1 for each new Snowflake ID until the timestamp changes. In case the counter overflows, the timestamp field is incremented by 1 and the counter is reset to 0. Function generateSnowflakeID guarantees that the counter field within a timestamp increments monotonically across all function invocations in concurrently running threads and queries.)";
String getName() const override { return name; }
/// Guarantee counter monotonicity within one timestamp across all threads generating Snowflake IDs simultaneously.
struct Data
{
static inline std::atomic<uint64_t> lowest_available_snowflake_id = 0;
SnowflakeComponents reserveRange(size_t input_rows_count)
{
uint64_t available_snowflake_id = lowest_available_snowflake_id.load();
RangeOfSnowflakeIDs range;
do
{
range = getRangeOfAvailableIDs(toComponents(available_snowflake_id), input_rows_count);
}
while (!lowest_available_snowflake_id.compare_exchange_weak(available_snowflake_id, toSnowflakeID(range.end)));
/// if `compare_exhange` failed => another thread updated `lowest_available_snowflake_id` and we should try again
/// completed => range of IDs [begin, end) is reserved, can return the beginning of the range
return range.begin;
}
};
};
struct ThreadLocalCounterPolicy
{
static constexpr auto name = "generateSnowflakeIDThreadMonotonic";
static constexpr auto doc_description = R"(Generates a Snowflake ID. The generated Snowflake ID contains the current Unix timestamp in milliseconds 41 (+ 1 top zero bit) bits, followed by machine id (10 bits), a counter (12 bits) to distinguish IDs within a millisecond. For any given timestamp (unix_ts_ms), the counter starts at 0 and is incremented by 1 for each new Snowflake ID until the timestamp changes. In case the counter overflows, the timestamp field is incremented by 1 and the counter is reset to 0. This function behaves like generateSnowflakeID but gives no guarantee on counter monotony across different simultaneous requests. Monotonicity within one timestamp is guaranteed only within the same thread calling this function to generate Snowflake IDs.)";
/// Guarantee counter monotonicity within one timestamp within the same thread. Faster than GlobalCounterPolicy if a query uses multiple threads.
struct Data
{
static inline thread_local uint64_t lowest_available_snowflake_id = 0;
SnowflakeComponents reserveRange(size_t input_rows_count)
{
RangeOfSnowflakeIDs range = getRangeOfAvailableIDs(toComponents(lowest_available_snowflake_id), input_rows_count);
lowest_available_snowflake_id = toSnowflakeID(range.end);
return range.begin;
}
};
};
}
template <typename FillPolicy>
class FunctionGenerateSnowflakeID : public IFunction, public FillPolicy
{
public:
static FunctionPtr create(ContextPtr /*context*/) { return std::make_shared<FunctionGenerateSnowflakeID>(); }
String getName() const override { return FillPolicy::name; }
size_t getNumberOfArguments() const override { return 0; }
bool isDeterministic() const override { return false; }
bool isDeterministicInScopeOfQuery() const override { return false; }
@ -80,71 +191,36 @@ public:
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if (!arguments.empty()) {
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Number of arguments for function {} doesn't match: passed {}, should be 0.",
getName(), arguments.size());
}
return std::make_shared<DataTypeInt64>();
FunctionArgumentDescriptors mandatory_args;
FunctionArgumentDescriptors optional_args{
{"expr", nullptr, nullptr, "Arbitrary Expression"}
};
validateFunctionArgumentTypes(*this, arguments, mandatory_args, optional_args);
return std::make_shared<DataTypeUInt64>();
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & /*arguments*/, const DataTypePtr &, size_t input_rows_count) const override
{
auto col_res = ColumnVector<Int64>::create();
typename ColumnVector<Int64>::Container & vec_to = col_res->getData();
auto col_res = ColumnVector<UInt64>::create();
typename ColumnVector<UInt64>::Container & vec_to = col_res->getData();
vec_to.resize(input_rows_count);
if (input_rows_count == 0) {
return col_res;
}
const Int64 machine_id = getMachineID();
Int64 current_timestamp = getTimestamp();
Int64 current_machine_seq_num;
Int64 available_snowflake_id, next_available_snowflake_id;
const Int64 input_rows_count_signed = static_cast<Int64>(input_rows_count);
do
if (input_rows_count != 0)
{
available_snowflake_id = lowest_available_snowflake_id.load();
const Int64 available_timestamp = (available_snowflake_id & timestamp_mask) >> (machine_id_bits_count + machine_seq_num_bits_count);
const Int64 available_machine_seq_num = available_snowflake_id & machine_seq_num_mask;
typename FillPolicy::Data data;
/// get the begin of available snowflake ids range
SnowflakeComponents snowflake_id = data.reserveRange(input_rows_count);
if (current_timestamp > available_timestamp)
for (UInt64 & to_row : vec_to)
{
/// handle overflow
current_machine_seq_num = 0;
}
else
{
current_timestamp = available_timestamp;
current_machine_seq_num = available_machine_seq_num;
}
/// calculate new lowest_available_snowflake_id
const Int64 seq_nums_in_current_timestamp_left = (max_machine_seq_num - current_machine_seq_num + 1);
Int64 new_timestamp;
if (input_rows_count_signed >= seq_nums_in_current_timestamp_left)
new_timestamp = current_timestamp + 1 + (input_rows_count_signed - seq_nums_in_current_timestamp_left) / max_machine_seq_num;
else
new_timestamp = current_timestamp;
const Int64 new_machine_seq_num = (current_machine_seq_num + input_rows_count_signed) & machine_seq_num_mask;
next_available_snowflake_id = (new_timestamp << (machine_id_bits_count + machine_seq_num_bits_count)) | machine_id | new_machine_seq_num;
}
while (!lowest_available_snowflake_id.compare_exchange_strong(available_snowflake_id, next_available_snowflake_id));
/// failed CAS => another thread updated `lowest_available_snowflake_id`
/// successful CAS => we have our range of exclusive values
for (Int64 & to_row : vec_to)
{
to_row = (current_timestamp << (machine_id_bits_count + machine_seq_num_bits_count)) | machine_id | current_machine_seq_num;
if (current_machine_seq_num++ == max_machine_seq_num)
{
current_machine_seq_num = 0;
++current_timestamp;
to_row = toSnowflakeID(snowflake_id);
if (snowflake_id.machine_seq_num++ == max_machine_seq_num)
{
snowflake_id.machine_seq_num = 0;
++snowflake_id.timestamp;
}
}
}
@ -153,43 +229,27 @@ public:
};
template<typename FillPolicy>
void registerSnowflakeIDGenerator(auto& factory)
{
static constexpr auto doc_syntax_format = "{}([expression])";
static constexpr auto example_format = "SELECT {}()";
static constexpr auto multiple_example_format = "SELECT {f}(1), {f}(2)";
FunctionDocumentation::Description doc_description = FillPolicy::doc_description;
FunctionDocumentation::Syntax doc_syntax = fmt::format(doc_syntax_format, FillPolicy::name);
FunctionDocumentation::Arguments doc_arguments = {{"expression", "The expression is used to bypass common subexpression elimination if the function is called multiple times in a query but otherwise ignored. Optional."}};
FunctionDocumentation::ReturnedValue doc_returned_value = "A value of type UInt64";
FunctionDocumentation::Examples doc_examples = {{"uuid", fmt::format(example_format, FillPolicy::name), ""}, {"multiple", fmt::format(multiple_example_format, fmt::arg("f", FillPolicy::name)), ""}};
FunctionDocumentation::Categories doc_categories = {"Snowflake ID"};
factory.template registerFunction<FunctionGenerateSnowflakeID<FillPolicy>>({doc_description, doc_syntax, doc_arguments, doc_returned_value, doc_examples, doc_categories}, FunctionFactory::CaseInsensitive);
}
REGISTER_FUNCTION(GenerateSnowflakeID)
{
factory.registerFunction<FunctionSnowflakeID>(FunctionDocumentation
{
.description=R"(
Generates a SnowflakeID -- unique identificators contains:
- The first 41 (+ 1 top zero bit) bits is the timestamp (millisecond since Unix epoch 1 Jan 1970)
- The middle 10 bits are the machine ID
- The last 12 bits are a counter to disambiguate multiple snowflakeIDs generated within the same millisecond by differen processes
In case the number of ids processed overflows, the timestamp field is incremented by 1 and the counter is reset to 0.
This function guarantees strict monotony on 1 machine and differences in values obtained on different machines.
)",
.syntax = "generateSnowflakeID()",
.arguments{},
.returned_value = "Column of Int64",
.examples{
{"single call", "SELECT generateSnowflakeID();", R"(
generateSnowflakeID()
7195510166884597760
)"},
{"column call", "SELECT generateSnowflakeID() FROM numbers(10);", R"(
generateSnowflakeID()
7195516038159417344
7195516038159417345
7195516038159417346
7195516038159417347
7195516038159417348
7195516038159417349
7195516038159417350
7195516038159417351
7195516038159417352
7195516038159417353
)"},
},
.categories{"Unique identifiers", "Snowflake ID"}
});
registerSnowflakeIDGenerator<GlobalCounterPolicy>(factory);
registerSnowflakeIDGenerator<ThreadLocalCounterPolicy>(factory);
}
}

View File

@ -1,9 +1,12 @@
#include "Common/Exception.h"
#include <Common/ZooKeeper/ZooKeeper.h>
#include <Columns/ColumnVector.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Interpreters/Context.h>
namespace DB
{
@ -14,6 +17,9 @@ namespace ErrorCodes
extern const int KEEPER_EXCEPTION;
}
constexpr auto function_node_name = "/serial_ids/";
constexpr size_t MAX_SERIES_NUMBER = 1000; // ?
class FunctionSerial : public IFunction
{
private:
@ -21,7 +27,7 @@ private:
ContextPtr context;
public:
static constexpr auto name = "serial";
static constexpr auto name = "generateSerialID";
explicit FunctionSerial(ContextPtr context_) : context(context_)
{
@ -48,16 +54,12 @@ public:
bool hasInformationAboutMonotonicity() const override { return true; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if (arguments.size() != 1)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Number of arguments for function {} doesn't match: passed {}, should be 1.",
getName(), arguments.size());
if (!isStringOrFixedString(arguments[0]))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Type of argument for function {} doesn't match: passed {}, should be string",
getName(), arguments[0]->getName());
FunctionArgumentDescriptors mandatory_args{
{"series identifier", static_cast<FunctionArgumentDescriptor::TypeValidator>(&isStringOrFixedString), nullptr, "String or FixedString"}
};
validateFunctionArgumentTypes(*this, arguments, mandatory_args);
return std::make_shared<DataTypeInt64>();
}
@ -71,12 +73,19 @@ public:
if (zk->expired())
zk = context->getZooKeeper();
// slow?
if (zk->exists(function_node_name) && zk->getChildren(function_node_name).size() == MAX_SERIES_NUMBER) {
throw Exception(ErrorCodes::KEEPER_EXCEPTION,
"At most {} serial nodes can be created",
MAX_SERIES_NUMBER);
}
auto col_res = ColumnVector<Int64>::create();
typename ColumnVector<Int64>::Container & vec_to = col_res->getData();
vec_to.resize(input_rows_count);
const auto & serial_path = "/serials/" + arguments[0].column->getDataAt(0).toString();
const auto & serial_path = function_node_name + arguments[0].column->getDataAt(0).toString();
/// CAS in ZooKeeper
/// `get` value and version, `trySet` new with version check
@ -130,28 +139,28 @@ Generates and returns sequential numbers starting from the previous counter valu
This function takes a constant string argument - a series identifier.
The server should be configured with a ZooKeeper.
)",
.syntax = "serial(identifier)",
.syntax = "generateSerialID(identifier)",
.arguments{
{"series identifier", "Series identifier (String)"}
{"series identifier", "Series identifier (String or FixedString)"}
},
.returned_value = "Sequential numbers of type Int64 starting from the previous counter value",
.examples{
{"first call", "SELECT serial('id1')", R"(
serial('id1')
1
)"},
{"second call", "SELECT serial('id1')", R"(
serial('id1')
2
)"},
{"column call", "SELECT *, serial('id1') FROM test_table", R"(
CounterIDUserIDverserial('id1')
1 3 3 3
1 1 1 4
1 2 2 5
1 5 5 6
1 4 4 7
{"first call", "SELECT generateSerialID('id1')", R"(
generateSerialID('id1')
1
)"},
{"second call", "SELECT generateSerialID('id1')", R"(
generateSerialID('id1')
2
)"},
{"column call", "SELECT *, generateSerialID('id1') FROM test_table", R"(
CounterIDUserIDvergenerateSerialID('id1')
1 3 3 3
1 1 1 4
1 2 2 5
1 5 5 6
1 4 4 7
)"}},
.categories{"Unique identifiers"}
});

View File

@ -1,12 +1,12 @@
-- Tags: zookeeper
SELECT serial('x');
SELECT serial('x');
SELECT serial('y');
SELECT serial('x') FROM numbers(5);
SELECT generateSerialID('x');
SELECT generateSerialID('x');
SELECT generateSerialID('y');
SELECT generateSerialID('x') FROM numbers(5);
SELECT serial(); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT serial('x', 'y'); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT serial(1); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT generateSerialID(); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT generateSerialID('x', 'y'); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT generateSerialID(1); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT serial('z'), serial('z') FROM numbers(5);
SELECT generateSerialID('z'), generateSerialID('z') FROM numbers(5);

View File

@ -0,0 +1,11 @@
-- generateSnowflakeID --
1
1
0
0
1
100
-- generateSnowflakeIDThreadMonotonic --
1
1
100

View File

@ -0,0 +1,29 @@
SELECT '-- generateSnowflakeID --';
SELECT bitShiftLeft(toUInt64(generateSnowflakeID()), 52) = 0; -- check machine sequence number is zero
SELECT bitAnd(bitShiftRight(toUInt64(generateSnowflakeID()), 63), 1) = 0; -- check first bit is zero
SELECT generateSnowflakeID(1) = generateSnowflakeID(2);
SELECT generateSnowflakeID() = generateSnowflakeID(1);
SELECT generateSnowflakeID(1) = generateSnowflakeID(1);
SELECT generateSnowflakeID(1, 2); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT count(*)
FROM
(
SELECT DISTINCT generateSnowflakeID()
FROM numbers(100)
);
SELECT '-- generateSnowflakeIDThreadMonotonic --';
SELECT bitShiftLeft(toUInt64(generateSnowflakeIDThreadMonotonic()), 52) = 0; -- check machine sequence number is zero
SELECT bitAnd(bitShiftRight(toUInt64(generateSnowflakeIDThreadMonotonic()), 63), 1) = 0; -- check first bit is zero
SELECT generateSnowflakeIDThreadMonotonic(1, 2); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT count(*)
FROM
(
SELECT DISTINCT generateSnowflakeIDThreadMonotonic()
FROM numbers(100)
);

View File

@ -1,11 +0,0 @@
SELECT bitShiftLeft(toUInt64(generateSnowflakeID()), 52) = 0;
SELECT bitAnd(bitShiftRight(toUInt64(generateSnowflakeID()), 63), 1) = 0;
SELECT generateSnowflakeID(1); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT count(*)
FROM
(
SELECT DISTINCT generateSnowflakeID()
FROM numbers(10)
)