Merge branch 'master' into fix-idiotic-code

This commit is contained in:
Alexey Milovidov 2024-07-20 02:29:59 +02:00
commit 75b042e9bd
173 changed files with 3318 additions and 5220 deletions

View File

@ -36,10 +36,6 @@ jobs:
cd "$GITHUB_WORKSPACE/tests/ci" cd "$GITHUB_WORKSPACE/tests/ci"
echo "Testing the main ci directory" echo "Testing the main ci directory"
python3 -m unittest discover -s . -p 'test_*.py' python3 -m unittest discover -s . -p 'test_*.py'
for dir in *_lambda/; do
echo "Testing $dir"
python3 -m unittest discover -s "$dir" -p 'test_*.py'
done
- name: PrepareRunConfig - name: PrepareRunConfig
id: runconfig id: runconfig
run: | run: |

View File

@ -33,10 +33,6 @@ jobs:
# cd "$GITHUB_WORKSPACE/tests/ci" # cd "$GITHUB_WORKSPACE/tests/ci"
# echo "Testing the main ci directory" # echo "Testing the main ci directory"
# python3 -m unittest discover -s . -p 'test_*.py' # python3 -m unittest discover -s . -p 'test_*.py'
# for dir in *_lambda/; do
# echo "Testing $dir"
# python3 -m unittest discover -s "$dir" -p 'test_*.py'
# done
- name: PrepareRunConfig - name: PrepareRunConfig
id: runconfig id: runconfig
run: | run: |

View File

@ -30,10 +30,6 @@ jobs:
cd "$GITHUB_WORKSPACE/tests/ci" cd "$GITHUB_WORKSPACE/tests/ci"
echo "Testing the main ci directory" echo "Testing the main ci directory"
python3 -m unittest discover -s . -p 'test_*.py' python3 -m unittest discover -s . -p 'test_*.py'
for dir in *_lambda/; do
echo "Testing $dir"
python3 -m unittest discover -s "$dir" -p 'test_*.py'
done
- name: PrepareRunConfig - name: PrepareRunConfig
id: runconfig id: runconfig
run: | run: |

View File

@ -48,10 +48,6 @@ jobs:
cd "$GITHUB_WORKSPACE/tests/ci" cd "$GITHUB_WORKSPACE/tests/ci"
echo "Testing the main ci directory" echo "Testing the main ci directory"
python3 -m unittest discover -s . -p 'test_*.py' python3 -m unittest discover -s . -p 'test_*.py'
for dir in *_lambda/; do
echo "Testing $dir"
python3 -m unittest discover -s "$dir" -p 'test_*.py'
done
- name: PrepareRunConfig - name: PrepareRunConfig
id: runconfig id: runconfig
run: | run: |

View File

@ -33,10 +33,6 @@ jobs:
cd "$GITHUB_WORKSPACE/tests/ci" cd "$GITHUB_WORKSPACE/tests/ci"
echo "Testing the main ci directory" echo "Testing the main ci directory"
python3 -m unittest discover -s . -p 'test_*.py' python3 -m unittest discover -s . -p 'test_*.py'
for dir in *_lambda/; do
echo "Testing $dir"
python3 -m unittest discover -s "$dir" -p 'test_*.py'
done
- name: PrepareRunConfig - name: PrepareRunConfig
id: runconfig id: runconfig
run: | run: |

View File

@ -1298,7 +1298,6 @@ elseif(ARCH_PPC64LE)
${OPENSSL_SOURCE_DIR}/crypto/camellia/camellia.c ${OPENSSL_SOURCE_DIR}/crypto/camellia/camellia.c
${OPENSSL_SOURCE_DIR}/crypto/camellia/cmll_cbc.c ${OPENSSL_SOURCE_DIR}/crypto/camellia/cmll_cbc.c
${OPENSSL_SOURCE_DIR}/crypto/chacha/chacha_enc.c ${OPENSSL_SOURCE_DIR}/crypto/chacha/chacha_enc.c
${OPENSSL_SOURCE_DIR}/crypto/mem_clr.c
${OPENSSL_SOURCE_DIR}/crypto/rc4/rc4_enc.c ${OPENSSL_SOURCE_DIR}/crypto/rc4/rc4_enc.c
${OPENSSL_SOURCE_DIR}/crypto/rc4/rc4_skey.c ${OPENSSL_SOURCE_DIR}/crypto/rc4/rc4_skey.c
${OPENSSL_SOURCE_DIR}/crypto/sha/keccak1600.c ${OPENSSL_SOURCE_DIR}/crypto/sha/keccak1600.c

View File

@ -4,6 +4,9 @@
source /setup_export_logs.sh source /setup_export_logs.sh
set -e -x set -e -x
MAX_RUN_TIME=${MAX_RUN_TIME:-3600}
MAX_RUN_TIME=$((MAX_RUN_TIME == 0 ? 3600 : MAX_RUN_TIME))
# Choose random timezone for this test run # Choose random timezone for this test run
TZ="$(rg -v '#' /usr/share/zoneinfo/zone.tab | awk '{print $3}' | shuf | head -n1)" TZ="$(rg -v '#' /usr/share/zoneinfo/zone.tab | awk '{print $3}' | shuf | head -n1)"
echo "Choosen random timezone $TZ" echo "Choosen random timezone $TZ"
@ -242,7 +245,22 @@ function run_tests()
} }
export -f run_tests export -f run_tests
timeout "$MAX_RUN_TIME" bash -c run_tests ||:
function timeout_with_logging() {
local exit_code=0
timeout -s TERM --preserve-status "${@}" || exit_code="${?}"
if [[ "${exit_code}" -eq "124" ]]
then
echo "The command 'timeout ${*}' has been killed by timeout"
fi
return $exit_code
}
TIMEOUT=$((MAX_RUN_TIME - 700))
timeout_with_logging "$TIMEOUT" bash -c run_tests ||:
echo "Files in current directory" echo "Files in current directory"
ls -la ./ ls -la ./

View File

@ -12,12 +12,6 @@ MAX_RUN_TIME=$((MAX_RUN_TIME == 0 ? 7200 : MAX_RUN_TIME))
USE_DATABASE_REPLICATED=${USE_DATABASE_REPLICATED:=0} USE_DATABASE_REPLICATED=${USE_DATABASE_REPLICATED:=0}
USE_SHARED_CATALOG=${USE_SHARED_CATALOG:=0} USE_SHARED_CATALOG=${USE_SHARED_CATALOG:=0}
RUN_SEQUENTIAL_TESTS_IN_PARALLEL=0
if [[ "$USE_DATABASE_REPLICATED" -eq 1 ]] || [[ "$USE_SHARED_CATALOG" -eq 1 ]]; then
RUN_SEQUENTIAL_TESTS_IN_PARALLEL=0
fi
# Choose random timezone for this test run. # Choose random timezone for this test run.
# #
# NOTE: that clickhouse-test will randomize session_timezone by itself as well # NOTE: that clickhouse-test will randomize session_timezone by itself as well
@ -101,53 +95,6 @@ if [ "$NUM_TRIES" -gt "1" ]; then
mkdir -p /var/run/clickhouse-server mkdir -p /var/run/clickhouse-server
fi fi
# Run a CH instance to execute sequential tests on it in parallel with all other tests.
if [[ "$RUN_SEQUENTIAL_TESTS_IN_PARALLEL" -eq 1 ]]; then
mkdir -p /var/run/clickhouse-server3 /etc/clickhouse-server3 /var/lib/clickhouse3
cp -r -L /etc/clickhouse-server/* /etc/clickhouse-server3/
sudo chown clickhouse:clickhouse /var/run/clickhouse-server3 /var/lib/clickhouse3 /etc/clickhouse-server3/
sudo chown -R clickhouse:clickhouse /etc/clickhouse-server3/*
function replace(){
sudo find /etc/clickhouse-server3/ -type f -name '*.xml' -exec sed -i "$1" {} \;
}
replace "s|<port>9000</port>|<port>19000</port>|g"
replace "s|<port>9440</port>|<port>19440</port>|g"
replace "s|<port>9988</port>|<port>19988</port>|g"
replace "s|<port>9234</port>|<port>19234</port>|g"
replace "s|<port>9181</port>|<port>19181</port>|g"
replace "s|<https_port>8443</https_port>|<https_port>18443</https_port>|g"
replace "s|<tcp_port>9000</tcp_port>|<tcp_port>19000</tcp_port>|g"
replace "s|<tcp_port>9181</tcp_port>|<tcp_port>19181</tcp_port>|g"
replace "s|<tcp_port_secure>9440</tcp_port_secure>|<tcp_port_secure>19440</tcp_port_secure>|g"
replace "s|<tcp_with_proxy_port>9010</tcp_with_proxy_port>|<tcp_with_proxy_port>19010</tcp_with_proxy_port>|g"
replace "s|<mysql_port>9004</mysql_port>|<mysql_port>19004</mysql_port>|g"
replace "s|<postgresql_port>9005</postgresql_port>|<postgresql_port>19005</postgresql_port>|g"
replace "s|<interserver_http_port>9009</interserver_http_port>|<interserver_http_port>19009</interserver_http_port>|g"
replace "s|8123|18123|g"
replace "s|/var/lib/clickhouse/|/var/lib/clickhouse3/|g"
replace "s|/etc/clickhouse-server/|/etc/clickhouse-server3/|g"
# distributed cache
replace "s|<tcp_port>10001</tcp_port>|<tcp_port>10003</tcp_port>|g"
replace "s|<tcp_port>10002</tcp_port>|<tcp_port>10004</tcp_port>|g"
sudo -E -u clickhouse /usr/bin/clickhouse server --daemon --config /etc/clickhouse-server3/config.xml \
--pid-file /var/run/clickhouse-server3/clickhouse-server.pid \
-- --path /var/lib/clickhouse3/ --logger.stderr /var/log/clickhouse-server/stderr3.log \
--logger.log /var/log/clickhouse-server/clickhouse-server3.log --logger.errorlog /var/log/clickhouse-server/clickhouse-server3.err.log \
--tcp_port 19000 --tcp_port_secure 19440 --http_port 18123 --https_port 18443 --interserver_http_port 19009 --tcp_with_proxy_port 19010 \
--prometheus.port 19988 --keeper_server.raft_configuration.server.port 19234 --keeper_server.tcp_port 19181 \
--mysql_port 19004 --postgresql_port 19005
for _ in {1..100}
do
clickhouse-client --port 19000 --query "SELECT 1" && break
sleep 1
done
fi
# simplest way to forward env variables to server # simplest way to forward env variables to server
sudo -E -u clickhouse /usr/bin/clickhouse-server --config /etc/clickhouse-server/config.xml --daemon --pid-file /var/run/clickhouse-server/clickhouse-server.pid sudo -E -u clickhouse /usr/bin/clickhouse-server --config /etc/clickhouse-server/config.xml --daemon --pid-file /var/run/clickhouse-server/clickhouse-server.pid
@ -183,9 +130,6 @@ if [[ "$USE_DATABASE_REPLICATED" -eq 1 ]]; then
--keeper_server.tcp_port 29181 --keeper_server.server_id 3 \ --keeper_server.tcp_port 29181 --keeper_server.server_id 3 \
--prometheus.port 29988 \ --prometheus.port 29988 \
--macros.shard s2 # It doesn't work :( --macros.shard s2 # It doesn't work :(
MAX_RUN_TIME=$((MAX_RUN_TIME < 9000 ? MAX_RUN_TIME : 9000)) # min(MAX_RUN_TIME, 2.5 hours)
MAX_RUN_TIME=$((MAX_RUN_TIME != 0 ? MAX_RUN_TIME : 9000)) # set to 2.5 hours if 0 (unlimited)
fi fi
if [[ "$USE_SHARED_CATALOG" -eq 1 ]]; then if [[ "$USE_SHARED_CATALOG" -eq 1 ]]; then
@ -210,9 +154,6 @@ if [[ "$USE_SHARED_CATALOG" -eq 1 ]]; then
--keeper_server.tcp_port 19181 --keeper_server.server_id 2 \ --keeper_server.tcp_port 19181 --keeper_server.server_id 2 \
--prometheus.port 19988 \ --prometheus.port 19988 \
--macros.replica r2 # It doesn't work :( --macros.replica r2 # It doesn't work :(
MAX_RUN_TIME=$((MAX_RUN_TIME < 9000 ? MAX_RUN_TIME : 9000)) # min(MAX_RUN_TIME, 2.5 hours)
MAX_RUN_TIME=$((MAX_RUN_TIME != 0 ? MAX_RUN_TIME : 9000)) # set to 2.5 hours if 0 (unlimited)
fi fi
# Wait for the server to start, but not for too long. # Wait for the server to start, but not for too long.
@ -223,7 +164,6 @@ do
done done
setup_logs_replication setup_logs_replication
attach_gdb_to_clickhouse || true # FIXME: to not break old builds, clean on 2023-09-01 attach_gdb_to_clickhouse || true # FIXME: to not break old builds, clean on 2023-09-01
function fn_exists() { function fn_exists() {
@ -284,11 +224,7 @@ function run_tests()
else else
# All other configurations are OK. # All other configurations are OK.
ADDITIONAL_OPTIONS+=('--jobs') ADDITIONAL_OPTIONS+=('--jobs')
ADDITIONAL_OPTIONS+=('5') ADDITIONAL_OPTIONS+=('7')
fi
if [[ "$RUN_SEQUENTIAL_TESTS_IN_PARALLEL" -eq 1 ]]; then
ADDITIONAL_OPTIONS+=('--run-sequential-tests-in-parallel')
fi fi
if [[ -n "$RUN_BY_HASH_NUM" ]] && [[ -n "$RUN_BY_HASH_TOTAL" ]]; then if [[ -n "$RUN_BY_HASH_NUM" ]] && [[ -n "$RUN_BY_HASH_TOTAL" ]]; then
@ -373,9 +309,6 @@ done
# Because it's the simplest way to read it when server has crashed. # Because it's the simplest way to read it when server has crashed.
sudo clickhouse stop ||: sudo clickhouse stop ||:
if [[ "$RUN_SEQUENTIAL_TESTS_IN_PARALLEL" -eq 1 ]]; then
sudo clickhouse stop --pid-path /var/run/clickhouse-server3 ||:
fi
if [[ "$USE_DATABASE_REPLICATED" -eq 1 ]]; then if [[ "$USE_DATABASE_REPLICATED" -eq 1 ]]; then
sudo clickhouse stop --pid-path /var/run/clickhouse-server1 ||: sudo clickhouse stop --pid-path /var/run/clickhouse-server1 ||:
@ -393,12 +326,6 @@ rg -Fa "<Fatal>" /var/log/clickhouse-server/clickhouse-server.log ||:
rg -A50 -Fa "============" /var/log/clickhouse-server/stderr.log ||: rg -A50 -Fa "============" /var/log/clickhouse-server/stderr.log ||:
zstd --threads=0 < /var/log/clickhouse-server/clickhouse-server.log > /test_output/clickhouse-server.log.zst & zstd --threads=0 < /var/log/clickhouse-server/clickhouse-server.log > /test_output/clickhouse-server.log.zst &
if [[ "$RUN_SEQUENTIAL_TESTS_IN_PARALLEL" -eq 1 ]]; then
rg -Fa "<Fatal>" /var/log/clickhouse-server3/clickhouse-server.log ||:
rg -A50 -Fa "============" /var/log/clickhouse-server3/stderr.log ||:
zstd --threads=0 < /var/log/clickhouse-server3/clickhouse-server.log > /test_output/clickhouse-server3.log.zst &
fi
data_path_config="--path=/var/lib/clickhouse/" data_path_config="--path=/var/lib/clickhouse/"
if [[ -n "$USE_S3_STORAGE_FOR_MERGE_TREE" ]] && [[ "$USE_S3_STORAGE_FOR_MERGE_TREE" -eq 1 ]]; then if [[ -n "$USE_S3_STORAGE_FOR_MERGE_TREE" ]] && [[ "$USE_S3_STORAGE_FOR_MERGE_TREE" -eq 1 ]]; then
# We need s3 storage configuration (but it's more likely that clickhouse-local will fail for some reason) # We need s3 storage configuration (but it's more likely that clickhouse-local will fail for some reason)
@ -419,10 +346,6 @@ if [ $failed_to_save_logs -ne 0 ]; then
do do
clickhouse-local "$data_path_config" --only-system-tables --stacktrace -q "select * from system.$table format TSVWithNamesAndTypes" | zstd --threads=0 > /test_output/$table.tsv.zst ||: clickhouse-local "$data_path_config" --only-system-tables --stacktrace -q "select * from system.$table format TSVWithNamesAndTypes" | zstd --threads=0 > /test_output/$table.tsv.zst ||:
if [[ "$RUN_SEQUENTIAL_TESTS_IN_PARALLEL" -eq 1 ]]; then
clickhouse-local --path /var/lib/clickhouse3/ --only-system-tables --stacktrace -q "select * from system.$table format TSVWithNamesAndTypes" | zstd --threads=0 > /test_output/$table.3.tsv.zst ||:
fi
if [[ "$USE_DATABASE_REPLICATED" -eq 1 ]]; then if [[ "$USE_DATABASE_REPLICATED" -eq 1 ]]; then
clickhouse-local --path /var/lib/clickhouse1/ --only-system-tables --stacktrace -q "select * from system.$table format TSVWithNamesAndTypes" | zstd --threads=0 > /test_output/$table.1.tsv.zst ||: clickhouse-local --path /var/lib/clickhouse1/ --only-system-tables --stacktrace -q "select * from system.$table format TSVWithNamesAndTypes" | zstd --threads=0 > /test_output/$table.1.tsv.zst ||:
clickhouse-local --path /var/lib/clickhouse2/ --only-system-tables --stacktrace -q "select * from system.$table format TSVWithNamesAndTypes" | zstd --threads=0 > /test_output/$table.2.tsv.zst ||: clickhouse-local --path /var/lib/clickhouse2/ --only-system-tables --stacktrace -q "select * from system.$table format TSVWithNamesAndTypes" | zstd --threads=0 > /test_output/$table.2.tsv.zst ||:
@ -464,12 +387,6 @@ rm -rf /var/lib/clickhouse/data/system/*/
tar -chf /test_output/store.tar /var/lib/clickhouse/store ||: tar -chf /test_output/store.tar /var/lib/clickhouse/store ||:
tar -chf /test_output/metadata.tar /var/lib/clickhouse/metadata/*.sql ||: tar -chf /test_output/metadata.tar /var/lib/clickhouse/metadata/*.sql ||:
if [[ "$RUN_SEQUENTIAL_TESTS_IN_PARALLEL" -eq 1 ]]; then
rm -rf /var/lib/clickhouse3/data/system/*/
tar -chf /test_output/store.tar /var/lib/clickhouse3/store ||:
tar -chf /test_output/metadata.tar /var/lib/clickhouse3/metadata/*.sql ||:
fi
if [[ "$USE_DATABASE_REPLICATED" -eq 1 ]]; then if [[ "$USE_DATABASE_REPLICATED" -eq 1 ]]; then
rg -Fa "<Fatal>" /var/log/clickhouse-server/clickhouse-server1.log ||: rg -Fa "<Fatal>" /var/log/clickhouse-server/clickhouse-server1.log ||:

View File

@ -124,7 +124,7 @@ which is equal to
#### Default values for from_env and from_zk attributes #### Default values for from_env and from_zk attributes
It's possible to set the default value and substitute it only if the environment variable or zookeeper node is set using `replace="1"`. It's possible to set the default value and substitute it only if the environment variable or zookeeper node is set using `replace="1"` (must be declared before from_env).
With previous example, but `MAX_QUERY_SIZE` is unset: With previous example, but `MAX_QUERY_SIZE` is unset:
@ -132,7 +132,7 @@ With previous example, but `MAX_QUERY_SIZE` is unset:
<clickhouse> <clickhouse>
<profiles> <profiles>
<default> <default>
<max_query_size from_env="MAX_QUERY_SIZE" replace="1">150000</max_query_size> <max_query_size replace="1" from_env="MAX_QUERY_SIZE">150000</max_query_size>
</default> </default>
</profiles> </profiles>
</clickhouse> </clickhouse>

View File

@ -9,7 +9,6 @@ Columns:
- `name` ([String](../../sql-reference/data-types/string.md)) The name of the function. - `name` ([String](../../sql-reference/data-types/string.md)) The name of the function.
- `is_aggregate` ([UInt8](../../sql-reference/data-types/int-uint.md)) — Whether the function is an aggregate function. - `is_aggregate` ([UInt8](../../sql-reference/data-types/int-uint.md)) — Whether the function is an aggregate function.
- `is_deterministic` ([Nullable](../../sql-reference/data-types/nullable.md)([UInt8](../../sql-reference/data-types/int-uint.md))) - Whether the function is deterministic.
- `case_insensitive`, ([UInt8](../../sql-reference/data-types/int-uint.md)) - Whether the function name can be used case-insensitively. - `case_insensitive`, ([UInt8](../../sql-reference/data-types/int-uint.md)) - Whether the function name can be used case-insensitively.
- `alias_to`, ([String](../../sql-reference/data-types/string.md)) - The original function name, if the function name is an alias. - `alias_to`, ([String](../../sql-reference/data-types/string.md)) - The original function name, if the function name is an alias.
- `create_query`, ([String](../../sql-reference/data-types/enum.md)) - Unused. - `create_query`, ([String](../../sql-reference/data-types/enum.md)) - Unused.

View File

@ -1,6 +1,6 @@
--- ---
slug: /en/sql-reference/table-functions/azureBlobStorageCluster slug: /en/sql-reference/table-functions/azureBlobStorageCluster
sidebar_position: 55 sidebar_position: 15
sidebar_label: azureBlobStorageCluster sidebar_label: azureBlobStorageCluster
title: "azureBlobStorageCluster Table Function" title: "azureBlobStorageCluster Table Function"
--- ---

View File

@ -45,16 +45,17 @@ int mainEntryClickHouseKeeperConverter(int argc, char ** argv)
keeper_context->setDigestEnabled(true); keeper_context->setDigestEnabled(true);
keeper_context->setSnapshotDisk(std::make_shared<DiskLocal>("Keeper-snapshots", options["output-dir"].as<std::string>())); keeper_context->setSnapshotDisk(std::make_shared<DiskLocal>("Keeper-snapshots", options["output-dir"].as<std::string>()));
DB::KeeperStorage storage(/* tick_time_ms */ 500, /* superdigest */ "", keeper_context, /* initialize_system_nodes */ false); /// TODO(hanfei): support rocksdb here
DB::KeeperMemoryStorage storage(/* tick_time_ms */ 500, /* superdigest */ "", keeper_context, /* initialize_system_nodes */ false);
DB::deserializeKeeperStorageFromSnapshotsDir(storage, options["zookeeper-snapshots-dir"].as<std::string>(), logger); DB::deserializeKeeperStorageFromSnapshotsDir(storage, options["zookeeper-snapshots-dir"].as<std::string>(), logger);
storage.initializeSystemNodes(); storage.initializeSystemNodes();
DB::deserializeLogsAndApplyToStorage(storage, options["zookeeper-logs-dir"].as<std::string>(), logger); DB::deserializeLogsAndApplyToStorage(storage, options["zookeeper-logs-dir"].as<std::string>(), logger);
DB::SnapshotMetadataPtr snapshot_meta = std::make_shared<DB::SnapshotMetadata>(storage.getZXID(), 1, std::make_shared<nuraft::cluster_config>()); DB::SnapshotMetadataPtr snapshot_meta = std::make_shared<DB::SnapshotMetadata>(storage.getZXID(), 1, std::make_shared<nuraft::cluster_config>());
DB::KeeperStorageSnapshot snapshot(&storage, snapshot_meta); DB::KeeperStorageSnapshot<DB::KeeperMemoryStorage> snapshot(&storage, snapshot_meta);
DB::KeeperSnapshotManager manager(1, keeper_context); DB::KeeperSnapshotManager<DB::KeeperMemoryStorage> manager(1, keeper_context);
auto snp = manager.serializeSnapshotToBuffer(snapshot); auto snp = manager.serializeSnapshotToBuffer(snapshot);
auto file_info = manager.serializeSnapshotBufferToDisk(*snp, storage.getZXID()); auto file_info = manager.serializeSnapshotBufferToDisk(*snp, storage.getZXID());
std::cout << "Snapshot serialized to path:" << fs::path(file_info->disk->getPath()) / file_info->path << std::endl; std::cout << "Snapshot serialized to path:" << fs::path(file_info->disk->getPath()) / file_info->path << std::endl;

View File

@ -2919,6 +2919,17 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
resolveExpressionNode(in_second_argument, scope, false /*allow_lambda_expression*/, true /*allow_table_expression*/); resolveExpressionNode(in_second_argument, scope, false /*allow_lambda_expression*/, true /*allow_table_expression*/);
} }
/// Edge case when the first argument of IN is scalar subquery.
auto & in_first_argument = function_in_arguments_nodes[0];
auto first_argument_type = in_first_argument->getNodeType();
if (first_argument_type == QueryTreeNodeType::QUERY || first_argument_type == QueryTreeNodeType::UNION)
{
IdentifierResolveScope subquery_scope(in_first_argument, &scope /*parent_scope*/);
subquery_scope.subquery_depth = scope.subquery_depth + 1;
evaluateScalarSubqueryIfNeeded(in_first_argument, subquery_scope);
}
} }
/// Initialize function argument columns /// Initialize function argument columns

View File

@ -84,12 +84,18 @@ public:
return result; return result;
} }
void append(Self && other) // append items for other inscnace only if there is no such item in current instance
void appendIfUniq(Self && other)
{ {
auto middle_idx = records.size(); auto middle_idx = records.size();
std::move(other.records.begin(), other.records.end(), std::back_inserter(records)); std::move(other.records.begin(), other.records.end(), std::back_inserter(records));
// merge is stable
std::inplace_merge(records.begin(), records.begin() + middle_idx, records.end()); std::inplace_merge(records.begin(), records.begin() + middle_idx, records.end());
chassert(isUniqTypes()); // remove duplicates
records.erase(std::unique(records.begin(), records.end()), records.end());
assert(std::is_sorted(records.begin(), records.end()));
assert(isUniqTypes());
} }
template <class T> template <class T>
@ -142,7 +148,6 @@ private:
bool isUniqTypes() const bool isUniqTypes() const
{ {
auto uniq_it = std::adjacent_find(records.begin(), records.end()); auto uniq_it = std::adjacent_find(records.begin(), records.end());
return uniq_it == records.end(); return uniq_it == records.end();
} }
@ -161,8 +166,6 @@ private:
records.emplace(it, type_idx, item); records.emplace(it, type_idx, item);
chassert(isUniqTypes());
} }
Records::const_iterator getImpl(std::type_index type_idx) const Records::const_iterator getImpl(std::type_index type_idx) const

View File

@ -442,8 +442,6 @@ The server successfully detected this situation and will download merged part fr
M(ReadBufferFromS3InitMicroseconds, "Time spent initializing connection to S3.") \ M(ReadBufferFromS3InitMicroseconds, "Time spent initializing connection to S3.") \
M(ReadBufferFromS3Bytes, "Bytes read from S3.") \ M(ReadBufferFromS3Bytes, "Bytes read from S3.") \
M(ReadBufferFromS3RequestsErrors, "Number of exceptions while reading from S3.") \ M(ReadBufferFromS3RequestsErrors, "Number of exceptions while reading from S3.") \
M(ReadBufferFromS3ResetSessions, "Number of HTTP sessions that were reset in ReadBufferFromS3.") \
M(ReadBufferFromS3PreservedSessions, "Number of HTTP sessions that were preserved in ReadBufferFromS3.") \
\ \
M(WriteBufferFromS3Microseconds, "Time spent on writing to S3.") \ M(WriteBufferFromS3Microseconds, "Time spent on writing to S3.") \
M(WriteBufferFromS3Bytes, "Bytes written to S3.") \ M(WriteBufferFromS3Bytes, "Bytes written to S3.") \

View File

@ -13,14 +13,14 @@
#include <Common/ZooKeeper/Types.h> #include <Common/ZooKeeper/Types.h>
#include <Common/ZooKeeper/ZooKeeperCommon.h> #include <Common/ZooKeeper/ZooKeeperCommon.h>
#include <Common/randomSeed.h> #include <Common/randomSeed.h>
#include <base/find_symbols.h>
#include <base/sort.h> #include <base/sort.h>
#include <base/map.h>
#include <base/getFQDNOrHostName.h> #include <base/getFQDNOrHostName.h>
#include <Core/ServerUUID.h> #include <Core/ServerUUID.h>
#include <Core/BackgroundSchedulePool.h> #include <Core/BackgroundSchedulePool.h>
#include "Common/ZooKeeper/IKeeper.h" #include <Common/ZooKeeper/IKeeper.h>
#include <Common/DNSResolver.h>
#include <Common/StringUtils.h> #include <Common/StringUtils.h>
#include <Common/quoteString.h>
#include <Common/Exception.h> #include <Common/Exception.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
@ -114,7 +114,11 @@ void ZooKeeper::init(ZooKeeperArgs args_, std::unique_ptr<Coordination::IKeeper>
/// availability_zones is empty on server startup or after config reloading /// availability_zones is empty on server startup or after config reloading
/// We will keep the az info when starting new sessions /// We will keep the az info when starting new sessions
availability_zones = args.availability_zones; availability_zones = args.availability_zones;
LOG_TEST(log, "Availability zones from config: [{}], client: {}", fmt::join(availability_zones, ", "), args.client_availability_zone);
LOG_TEST(log, "Availability zones from config: [{}], client: {}",
fmt::join(collections::map(availability_zones, [](auto s){ return DB::quoteString(s); }), ", "),
DB::quoteString(args.client_availability_zone));
if (args.availability_zone_autodetect) if (args.availability_zone_autodetect)
updateAvailabilityZones(); updateAvailabilityZones();
} }

View File

@ -55,6 +55,7 @@ struct Settings;
M(UInt64, min_request_size_for_cache, 50 * 1024, "Minimal size of the request to cache the deserialization result. Caching can have negative effect on latency for smaller requests, set to 0 to disable", 0) \ M(UInt64, min_request_size_for_cache, 50 * 1024, "Minimal size of the request to cache the deserialization result. Caching can have negative effect on latency for smaller requests, set to 0 to disable", 0) \
M(UInt64, raft_limits_reconnect_limit, 50, "If connection to a peer is silent longer than this limit * (multiplied by heartbeat interval), we re-establish the connection.", 0) \ M(UInt64, raft_limits_reconnect_limit, 50, "If connection to a peer is silent longer than this limit * (multiplied by heartbeat interval), we re-establish the connection.", 0) \
M(Bool, async_replication, false, "Enable async replication. All write and read guarantees are preserved while better performance is achieved. Settings is disabled by default to not break backwards compatibility.", 0) \ M(Bool, async_replication, false, "Enable async replication. All write and read guarantees are preserved while better performance is achieved. Settings is disabled by default to not break backwards compatibility.", 0) \
M(Bool, experimental_use_rocksdb, false, "Use rocksdb as backend storage", 0) \
M(UInt64, latest_logs_cache_size_threshold, 1 * 1024 * 1024 * 1024, "Maximum total size of in-memory cache of latest log entries.", 0) \ M(UInt64, latest_logs_cache_size_threshold, 1 * 1024 * 1024 * 1024, "Maximum total size of in-memory cache of latest log entries.", 0) \
M(UInt64, commit_logs_cache_size_threshold, 500 * 1024 * 1024, "Maximum total size of in-memory cache of log entries needed next for commit.", 0) \ M(UInt64, commit_logs_cache_size_threshold, 500 * 1024 * 1024, "Maximum total size of in-memory cache of log entries needed next for commit.", 0) \
M(UInt64, disk_move_retries_wait_ms, 1000, "How long to wait between retries after a failure which happened while a file was being moved between disks.", 0) \ M(UInt64, disk_move_retries_wait_ms, 1000, "How long to wait between retries after a failure which happened while a file was being moved between disks.", 0) \

View File

@ -183,8 +183,6 @@
M(ReadBufferFromS3InitMicroseconds) \ M(ReadBufferFromS3InitMicroseconds) \
M(ReadBufferFromS3Bytes) \ M(ReadBufferFromS3Bytes) \
M(ReadBufferFromS3RequestsErrors) \ M(ReadBufferFromS3RequestsErrors) \
M(ReadBufferFromS3ResetSessions) \
M(ReadBufferFromS3PreservedSessions) \
\ \
M(WriteBufferFromS3Microseconds) \ M(WriteBufferFromS3Microseconds) \
M(WriteBufferFromS3Bytes) \ M(WriteBufferFromS3Bytes) \

View File

@ -5,18 +5,27 @@
#include <Coordination/CoordinationSettings.h> #include <Coordination/CoordinationSettings.h>
#include <Coordination/Defines.h> #include <Coordination/Defines.h>
#include <Disks/DiskLocal.h>
#include <Interpreters/Context.h>
#include <IO/S3/Credentials.h>
#include <IO/WriteHelpers.h>
#include <Poco/Util/AbstractConfiguration.h>
#include <Poco/Util/JSONConfiguration.h>
#include <Coordination/KeeperConstants.h> #include <Coordination/KeeperConstants.h>
#include <Server/CloudPlacementInfo.h> #include <Server/CloudPlacementInfo.h>
#include <Coordination/KeeperFeatureFlags.h> #include <Coordination/KeeperFeatureFlags.h>
#include <Disks/DiskLocal.h>
#include <Disks/DiskSelector.h> #include <Disks/DiskSelector.h>
#include <IO/S3/Credentials.h>
#include <Interpreters/Context.h>
#include <Poco/Util/AbstractConfiguration.h>
#include <Common/logger_useful.h> #include <Common/logger_useful.h>
#include <boost/algorithm/string.hpp> #include <boost/algorithm/string.hpp>
#include "config.h"
#if USE_ROCKSDB
#include <rocksdb/table.h>
#include <rocksdb/convenience.h>
#include <rocksdb/utilities/db_ttl.h>
#endif
namespace DB namespace DB
{ {
@ -24,6 +33,8 @@ namespace ErrorCodes
{ {
extern const int BAD_ARGUMENTS; extern const int BAD_ARGUMENTS;
extern const int LOGICAL_ERROR;
extern const int ROCKSDB_ERROR;
} }
@ -41,6 +52,95 @@ KeeperContext::KeeperContext(bool standalone_keeper_, CoordinationSettingsPtr co
system_nodes_with_data[keeper_api_version_path] = toString(static_cast<uint8_t>(KeeperApiVersion::WITH_MULTI_READ)); system_nodes_with_data[keeper_api_version_path] = toString(static_cast<uint8_t>(KeeperApiVersion::WITH_MULTI_READ));
} }
#if USE_ROCKSDB
using RocksDBOptions = std::unordered_map<std::string, std::string>;
static RocksDBOptions getOptionsFromConfig(const Poco::Util::AbstractConfiguration & config, const std::string & path)
{
RocksDBOptions options;
Poco::Util::AbstractConfiguration::Keys keys;
config.keys(path, keys);
for (const auto & key : keys)
{
const String key_path = path + "." + key;
options[key] = config.getString(key_path);
}
return options;
}
static rocksdb::Options getRocksDBOptionsFromConfig(const Poco::Util::AbstractConfiguration & config)
{
rocksdb::Status status;
rocksdb::Options base;
base.create_if_missing = true;
base.compression = rocksdb::CompressionType::kZSTD;
base.statistics = rocksdb::CreateDBStatistics();
/// It is too verbose by default, and in fact we don't care about rocksdb logs at all.
base.info_log_level = rocksdb::ERROR_LEVEL;
rocksdb::Options merged = base;
rocksdb::BlockBasedTableOptions table_options;
if (config.has("keeper_server.rocksdb.options"))
{
auto config_options = getOptionsFromConfig(config, "keeper_server.rocksdb.options");
status = rocksdb::GetDBOptionsFromMap(merged, config_options, &merged);
if (!status.ok())
{
throw Exception(ErrorCodes::ROCKSDB_ERROR, "Fail to merge rocksdb options from 'rocksdb.options' : {}",
status.ToString());
}
}
if (config.has("rocksdb.column_family_options"))
{
auto column_family_options = getOptionsFromConfig(config, "rocksdb.column_family_options");
status = rocksdb::GetColumnFamilyOptionsFromMap(merged, column_family_options, &merged);
if (!status.ok())
{
throw Exception(ErrorCodes::ROCKSDB_ERROR, "Fail to merge rocksdb options from 'rocksdb.column_family_options' at: {}", status.ToString());
}
}
if (config.has("rocksdb.block_based_table_options"))
{
auto block_based_table_options = getOptionsFromConfig(config, "rocksdb.block_based_table_options");
status = rocksdb::GetBlockBasedTableOptionsFromMap(table_options, block_based_table_options, &table_options);
if (!status.ok())
{
throw Exception(ErrorCodes::ROCKSDB_ERROR, "Fail to merge rocksdb options from 'rocksdb.block_based_table_options' at: {}", status.ToString());
}
}
merged.table_factory.reset(rocksdb::NewBlockBasedTableFactory(table_options));
return merged;
}
#endif
KeeperContext::Storage KeeperContext::getRocksDBPathFromConfig(const Poco::Util::AbstractConfiguration & config) const
{
const auto create_local_disk = [](const auto & path)
{
if (fs::exists(path))
fs::remove_all(path);
fs::create_directories(path);
return std::make_shared<DiskLocal>("LocalRocksDBDisk", path);
};
if (config.has("keeper_server.rocksdb_path"))
return create_local_disk(config.getString("keeper_server.rocksdb_path"));
if (config.has("keeper_server.storage_path"))
return create_local_disk(std::filesystem::path{config.getString("keeper_server.storage_path")} / "rocksdb");
if (standalone_keeper)
return create_local_disk(std::filesystem::path{config.getString("path", KEEPER_DEFAULT_PATH)} / "rocksdb");
else
return create_local_disk(std::filesystem::path{config.getString("path", DBMS_DEFAULT_PATH)} / "coordination/rocksdb");
}
void KeeperContext::initialize(const Poco::Util::AbstractConfiguration & config, KeeperDispatcher * dispatcher_) void KeeperContext::initialize(const Poco::Util::AbstractConfiguration & config, KeeperDispatcher * dispatcher_)
{ {
dispatcher = dispatcher_; dispatcher = dispatcher_;
@ -59,6 +159,14 @@ void KeeperContext::initialize(const Poco::Util::AbstractConfiguration & config,
initializeFeatureFlags(config); initializeFeatureFlags(config);
initializeDisks(config); initializeDisks(config);
#if USE_ROCKSDB
if (config.getBool("keeper_server.coordination_settings.experimental_use_rocksdb", false))
{
rocksdb_options = std::make_shared<rocksdb::Options>(getRocksDBOptionsFromConfig(config));
digest_enabled = false; /// TODO: support digest
}
#endif
} }
namespace namespace
@ -94,6 +202,8 @@ void KeeperContext::initializeDisks(const Poco::Util::AbstractConfiguration & co
{ {
disk_selector->initialize(config, "storage_configuration.disks", Context::getGlobalContextInstance(), diskValidator); disk_selector->initialize(config, "storage_configuration.disks", Context::getGlobalContextInstance(), diskValidator);
rocksdb_storage = getRocksDBPathFromConfig(config);
log_storage = getLogsPathFromConfig(config); log_storage = getLogsPathFromConfig(config);
if (config.has("keeper_server.latest_log_storage_disk")) if (config.has("keeper_server.latest_log_storage_disk"))
@ -262,6 +372,37 @@ void KeeperContext::dumpConfiguration(WriteBufferFromOwnString & buf) const
} }
} }
void KeeperContext::setRocksDBDisk(DiskPtr disk)
{
rocksdb_storage = std::move(disk);
}
DiskPtr KeeperContext::getTemporaryRocksDBDisk() const
{
DiskPtr rocksdb_disk = getDisk(rocksdb_storage);
if (!rocksdb_disk)
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "rocksdb storage is not initialized");
}
auto uuid_str = formatUUID(UUIDHelpers::generateV4());
String path_to_create = "rocks_" + std::string(uuid_str.data(), uuid_str.size());
rocksdb_disk->createDirectory(path_to_create);
return std::make_shared<DiskLocal>("LocalTmpRocksDBDisk", fullPath(rocksdb_disk, path_to_create));
}
void KeeperContext::setRocksDBOptions(std::shared_ptr<rocksdb::Options> rocksdb_options_)
{
if (rocksdb_options_ != nullptr)
rocksdb_options = rocksdb_options_;
else
{
#if USE_ROCKSDB
rocksdb_options = std::make_shared<rocksdb::Options>(getRocksDBOptionsFromConfig(Poco::Util::JSONConfiguration()));
#endif
}
}
KeeperContext::Storage KeeperContext::getLogsPathFromConfig(const Poco::Util::AbstractConfiguration & config) const KeeperContext::Storage KeeperContext::getLogsPathFromConfig(const Poco::Util::AbstractConfiguration & config) const
{ {
const auto create_local_disk = [](const auto & path) const auto create_local_disk = [](const auto & path)

View File

@ -6,6 +6,11 @@
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
namespace rocksdb
{
struct Options;
}
namespace DB namespace DB
{ {
@ -62,6 +67,12 @@ public:
constexpr KeeperDispatcher * getDispatcher() const { return dispatcher; } constexpr KeeperDispatcher * getDispatcher() const { return dispatcher; }
void setRocksDBDisk(DiskPtr disk);
DiskPtr getTemporaryRocksDBDisk() const;
void setRocksDBOptions(std::shared_ptr<rocksdb::Options> rocksdb_options_ = nullptr);
std::shared_ptr<rocksdb::Options> getRocksDBOptions() const { return rocksdb_options; }
UInt64 getKeeperMemorySoftLimit() const { return memory_soft_limit; } UInt64 getKeeperMemorySoftLimit() const { return memory_soft_limit; }
void updateKeeperMemorySoftLimit(const Poco::Util::AbstractConfiguration & config); void updateKeeperMemorySoftLimit(const Poco::Util::AbstractConfiguration & config);
@ -90,6 +101,7 @@ private:
void initializeFeatureFlags(const Poco::Util::AbstractConfiguration & config); void initializeFeatureFlags(const Poco::Util::AbstractConfiguration & config);
void initializeDisks(const Poco::Util::AbstractConfiguration & config); void initializeDisks(const Poco::Util::AbstractConfiguration & config);
Storage getRocksDBPathFromConfig(const Poco::Util::AbstractConfiguration & config) const;
Storage getLogsPathFromConfig(const Poco::Util::AbstractConfiguration & config) const; Storage getLogsPathFromConfig(const Poco::Util::AbstractConfiguration & config) const;
Storage getSnapshotsPathFromConfig(const Poco::Util::AbstractConfiguration & config) const; Storage getSnapshotsPathFromConfig(const Poco::Util::AbstractConfiguration & config) const;
Storage getStatePathFromConfig(const Poco::Util::AbstractConfiguration & config) const; Storage getStatePathFromConfig(const Poco::Util::AbstractConfiguration & config) const;
@ -111,12 +123,15 @@ private:
std::shared_ptr<DiskSelector> disk_selector; std::shared_ptr<DiskSelector> disk_selector;
Storage rocksdb_storage;
Storage log_storage; Storage log_storage;
Storage latest_log_storage; Storage latest_log_storage;
Storage snapshot_storage; Storage snapshot_storage;
Storage latest_snapshot_storage; Storage latest_snapshot_storage;
Storage state_file_storage; Storage state_file_storage;
std::shared_ptr<rocksdb::Options> rocksdb_options;
std::vector<std::string> old_log_disk_names; std::vector<std::string> old_log_disk_names;
std::vector<std::string> old_snapshot_disk_names; std::vector<std::string> old_snapshot_disk_names;

View File

@ -117,13 +117,13 @@ void KeeperDispatcher::requestThread()
RaftAppendResult prev_result = nullptr; RaftAppendResult prev_result = nullptr;
/// Requests from previous iteration. We store them to be able /// Requests from previous iteration. We store them to be able
/// to send errors to the client. /// to send errors to the client.
KeeperStorage::RequestsForSessions prev_batch; KeeperStorageBase::RequestsForSessions prev_batch;
const auto & shutdown_called = keeper_context->isShutdownCalled(); const auto & shutdown_called = keeper_context->isShutdownCalled();
while (!shutdown_called) while (!shutdown_called)
{ {
KeeperStorage::RequestForSession request; KeeperStorageBase::RequestForSession request;
auto coordination_settings = configuration_and_settings->coordination_settings; auto coordination_settings = configuration_and_settings->coordination_settings;
uint64_t max_wait = coordination_settings->operation_timeout_ms.totalMilliseconds(); uint64_t max_wait = coordination_settings->operation_timeout_ms.totalMilliseconds();
@ -153,7 +153,7 @@ void KeeperDispatcher::requestThread()
continue; continue;
} }
KeeperStorage::RequestsForSessions current_batch; KeeperStorageBase::RequestsForSessions current_batch;
size_t current_batch_bytes_size = 0; size_t current_batch_bytes_size = 0;
bool has_read_request = false; bool has_read_request = false;
@ -311,7 +311,7 @@ void KeeperDispatcher::responseThread()
const auto & shutdown_called = keeper_context->isShutdownCalled(); const auto & shutdown_called = keeper_context->isShutdownCalled();
while (!shutdown_called) while (!shutdown_called)
{ {
KeeperStorage::ResponseForSession response_for_session; KeeperStorageBase::ResponseForSession response_for_session;
uint64_t max_wait = configuration_and_settings->coordination_settings->operation_timeout_ms.totalMilliseconds(); uint64_t max_wait = configuration_and_settings->coordination_settings->operation_timeout_ms.totalMilliseconds();
@ -402,7 +402,7 @@ bool KeeperDispatcher::putRequest(const Coordination::ZooKeeperRequestPtr & requ
return false; return false;
} }
KeeperStorage::RequestForSession request_info; KeeperStorageBase::RequestForSession request_info;
request_info.request = request; request_info.request = request;
using namespace std::chrono; using namespace std::chrono;
request_info.time = duration_cast<milliseconds>(system_clock::now().time_since_epoch()).count(); request_info.time = duration_cast<milliseconds>(system_clock::now().time_since_epoch()).count();
@ -448,7 +448,7 @@ void KeeperDispatcher::initialize(const Poco::Util::AbstractConfiguration & conf
snapshots_queue, snapshots_queue,
keeper_context, keeper_context,
snapshot_s3, snapshot_s3,
[this](uint64_t /*log_idx*/, const KeeperStorage::RequestForSession & request_for_session) [this](uint64_t /*log_idx*/, const KeeperStorageBase::RequestForSession & request_for_session)
{ {
{ {
/// check if we have queue of read requests depending on this request to be committed /// check if we have queue of read requests depending on this request to be committed
@ -540,7 +540,7 @@ void KeeperDispatcher::shutdown()
update_configuration_thread.join(); update_configuration_thread.join();
} }
KeeperStorage::RequestForSession request_for_session; KeeperStorageBase::RequestForSession request_for_session;
/// Set session expired for all pending requests /// Set session expired for all pending requests
while (requests_queue && requests_queue->tryPop(request_for_session)) while (requests_queue && requests_queue->tryPop(request_for_session))
@ -551,7 +551,7 @@ void KeeperDispatcher::shutdown()
setResponse(request_for_session.session_id, response); setResponse(request_for_session.session_id, response);
} }
KeeperStorage::RequestsForSessions close_requests; KeeperStorageBase::RequestsForSessions close_requests;
{ {
/// Clear all registered sessions /// Clear all registered sessions
std::lock_guard lock(session_to_response_callback_mutex); std::lock_guard lock(session_to_response_callback_mutex);
@ -565,7 +565,7 @@ void KeeperDispatcher::shutdown()
auto request = Coordination::ZooKeeperRequestFactory::instance().get(Coordination::OpNum::Close); auto request = Coordination::ZooKeeperRequestFactory::instance().get(Coordination::OpNum::Close);
request->xid = Coordination::CLOSE_XID; request->xid = Coordination::CLOSE_XID;
using namespace std::chrono; using namespace std::chrono;
KeeperStorage::RequestForSession request_info KeeperStorageBase::RequestForSession request_info
{ {
.session_id = session, .session_id = session,
.time = duration_cast<milliseconds>(system_clock::now().time_since_epoch()).count(), .time = duration_cast<milliseconds>(system_clock::now().time_since_epoch()).count(),
@ -663,7 +663,7 @@ void KeeperDispatcher::sessionCleanerTask()
auto request = Coordination::ZooKeeperRequestFactory::instance().get(Coordination::OpNum::Close); auto request = Coordination::ZooKeeperRequestFactory::instance().get(Coordination::OpNum::Close);
request->xid = Coordination::CLOSE_XID; request->xid = Coordination::CLOSE_XID;
using namespace std::chrono; using namespace std::chrono;
KeeperStorage::RequestForSession request_info KeeperStorageBase::RequestForSession request_info
{ {
.session_id = dead_session, .session_id = dead_session,
.time = duration_cast<milliseconds>(system_clock::now().time_since_epoch()).count(), .time = duration_cast<milliseconds>(system_clock::now().time_since_epoch()).count(),
@ -711,16 +711,16 @@ void KeeperDispatcher::finishSession(int64_t session_id)
} }
} }
void KeeperDispatcher::addErrorResponses(const KeeperStorage::RequestsForSessions & requests_for_sessions, Coordination::Error error) void KeeperDispatcher::addErrorResponses(const KeeperStorageBase::RequestsForSessions & requests_for_sessions, Coordination::Error error)
{ {
for (const auto & request_for_session : requests_for_sessions) for (const auto & request_for_session : requests_for_sessions)
{ {
KeeperStorage::ResponsesForSessions responses; KeeperStorageBase::ResponsesForSessions responses;
auto response = request_for_session.request->makeResponse(); auto response = request_for_session.request->makeResponse();
response->xid = request_for_session.request->xid; response->xid = request_for_session.request->xid;
response->zxid = 0; response->zxid = 0;
response->error = error; response->error = error;
if (!responses_queue.push(DB::KeeperStorage::ResponseForSession{request_for_session.session_id, response})) if (!responses_queue.push(DB::KeeperStorageBase::ResponseForSession{request_for_session.session_id, response}))
throw Exception(ErrorCodes::SYSTEM_ERROR, throw Exception(ErrorCodes::SYSTEM_ERROR,
"Could not push error response xid {} zxid {} error message {} to responses queue", "Could not push error response xid {} zxid {} error message {} to responses queue",
response->xid, response->xid,
@ -730,7 +730,7 @@ void KeeperDispatcher::addErrorResponses(const KeeperStorage::RequestsForSession
} }
nuraft::ptr<nuraft::buffer> KeeperDispatcher::forceWaitAndProcessResult( nuraft::ptr<nuraft::buffer> KeeperDispatcher::forceWaitAndProcessResult(
RaftAppendResult & result, KeeperStorage::RequestsForSessions & requests_for_sessions, bool clear_requests_on_success) RaftAppendResult & result, KeeperStorageBase::RequestsForSessions & requests_for_sessions, bool clear_requests_on_success)
{ {
if (!result->has_result()) if (!result->has_result())
result->get(); result->get();
@ -755,7 +755,7 @@ int64_t KeeperDispatcher::getSessionID(int64_t session_timeout_ms)
{ {
/// New session id allocation is a special request, because we cannot process it in normal /// New session id allocation is a special request, because we cannot process it in normal
/// way: get request -> put to raft -> set response for registered callback. /// way: get request -> put to raft -> set response for registered callback.
KeeperStorage::RequestForSession request_info; KeeperStorageBase::RequestForSession request_info;
std::shared_ptr<Coordination::ZooKeeperSessionIDRequest> request = std::make_shared<Coordination::ZooKeeperSessionIDRequest>(); std::shared_ptr<Coordination::ZooKeeperSessionIDRequest> request = std::make_shared<Coordination::ZooKeeperSessionIDRequest>();
/// Internal session id. It's a temporary number which is unique for each client on this server /// Internal session id. It's a temporary number which is unique for each client on this server
/// but can be same on different servers. /// but can be same on different servers.

View File

@ -26,7 +26,7 @@ using ZooKeeperResponseCallback = std::function<void(const Coordination::ZooKeep
class KeeperDispatcher class KeeperDispatcher
{ {
private: private:
using RequestsQueue = ConcurrentBoundedQueue<KeeperStorage::RequestForSession>; using RequestsQueue = ConcurrentBoundedQueue<KeeperStorageBase::RequestForSession>;
using SessionToResponseCallback = std::unordered_map<int64_t, ZooKeeperResponseCallback>; using SessionToResponseCallback = std::unordered_map<int64_t, ZooKeeperResponseCallback>;
using ClusterUpdateQueue = ConcurrentBoundedQueue<ClusterUpdateAction>; using ClusterUpdateQueue = ConcurrentBoundedQueue<ClusterUpdateAction>;
@ -95,18 +95,18 @@ private:
/// Add error responses for requests to responses queue. /// Add error responses for requests to responses queue.
/// Clears requests. /// Clears requests.
void addErrorResponses(const KeeperStorage::RequestsForSessions & requests_for_sessions, Coordination::Error error); void addErrorResponses(const KeeperStorageBase::RequestsForSessions & requests_for_sessions, Coordination::Error error);
/// Forcefully wait for result and sets errors if something when wrong. /// Forcefully wait for result and sets errors if something when wrong.
/// Clears both arguments /// Clears both arguments
nuraft::ptr<nuraft::buffer> forceWaitAndProcessResult( nuraft::ptr<nuraft::buffer> forceWaitAndProcessResult(
RaftAppendResult & result, KeeperStorage::RequestsForSessions & requests_for_sessions, bool clear_requests_on_success); RaftAppendResult & result, KeeperStorageBase::RequestsForSessions & requests_for_sessions, bool clear_requests_on_success);
public: public:
std::mutex read_request_queue_mutex; std::mutex read_request_queue_mutex;
/// queue of read requests that can be processed after a request with specific session ID and XID is committed /// queue of read requests that can be processed after a request with specific session ID and XID is committed
std::unordered_map<int64_t, std::unordered_map<Coordination::XID, KeeperStorage::RequestsForSessions>> read_request_queue; std::unordered_map<int64_t, std::unordered_map<Coordination::XID, KeeperStorageBase::RequestsForSessions>> read_request_queue;
/// Just allocate some objects, real initialization is done by `intialize method` /// Just allocate some objects, real initialization is done by `intialize method`
KeeperDispatcher(); KeeperDispatcher();
@ -192,7 +192,7 @@ public:
Keeper4LWInfo getKeeper4LWInfo() const; Keeper4LWInfo getKeeper4LWInfo() const;
const KeeperStateMachine & getStateMachine() const const IKeeperStateMachine & getStateMachine() const
{ {
return *server->getKeeperStateMachine(); return *server->getKeeperStateMachine();
} }

View File

@ -123,7 +123,7 @@ KeeperServer::KeeperServer(
SnapshotsQueue & snapshots_queue_, SnapshotsQueue & snapshots_queue_,
KeeperContextPtr keeper_context_, KeeperContextPtr keeper_context_,
KeeperSnapshotManagerS3 & snapshot_manager_s3, KeeperSnapshotManagerS3 & snapshot_manager_s3,
KeeperStateMachine::CommitCallback commit_callback) IKeeperStateMachine::CommitCallback commit_callback)
: server_id(configuration_and_settings_->server_id) : server_id(configuration_and_settings_->server_id)
, log(getLogger("KeeperServer")) , log(getLogger("KeeperServer"))
, is_recovering(config.getBool("keeper_server.force_recovery", false)) , is_recovering(config.getBool("keeper_server.force_recovery", false))
@ -134,13 +134,28 @@ KeeperServer::KeeperServer(
if (keeper_context->getCoordinationSettings()->quorum_reads) if (keeper_context->getCoordinationSettings()->quorum_reads)
LOG_WARNING(log, "Quorum reads enabled, Keeper will work slower."); LOG_WARNING(log, "Quorum reads enabled, Keeper will work slower.");
state_machine = nuraft::cs_new<KeeperStateMachine>( #if USE_ROCKSDB
responses_queue_, const auto & coordination_settings = keeper_context->getCoordinationSettings();
snapshots_queue_, if (coordination_settings->experimental_use_rocksdb)
keeper_context, {
config.getBool("keeper_server.upload_snapshot_on_exit", false) ? &snapshot_manager_s3 : nullptr, state_machine = nuraft::cs_new<KeeperStateMachine<KeeperRocksStorage>>(
commit_callback, responses_queue_,
checkAndGetSuperdigest(configuration_and_settings_->super_digest)); snapshots_queue_,
keeper_context,
config.getBool("keeper_server.upload_snapshot_on_exit", false) ? &snapshot_manager_s3 : nullptr,
commit_callback,
checkAndGetSuperdigest(configuration_and_settings_->super_digest));
LOG_WARNING(log, "Use RocksDB as Keeper backend storage.");
}
else
#endif
state_machine = nuraft::cs_new<KeeperStateMachine<KeeperMemoryStorage>>(
responses_queue_,
snapshots_queue_,
keeper_context,
config.getBool("keeper_server.upload_snapshot_on_exit", false) ? &snapshot_manager_s3 : nullptr,
commit_callback,
checkAndGetSuperdigest(configuration_and_settings_->super_digest));
state_manager = nuraft::cs_new<KeeperStateManager>( state_manager = nuraft::cs_new<KeeperStateManager>(
server_id, server_id,
@ -522,7 +537,7 @@ namespace
{ {
// Serialize the request for the log entry // Serialize the request for the log entry
nuraft::ptr<nuraft::buffer> getZooKeeperLogEntry(const KeeperStorage::RequestForSession & request_for_session) nuraft::ptr<nuraft::buffer> getZooKeeperLogEntry(const KeeperStorageBase::RequestForSession & request_for_session)
{ {
DB::WriteBufferFromNuraftBuffer write_buf; DB::WriteBufferFromNuraftBuffer write_buf;
DB::writeIntBinary(request_for_session.session_id, write_buf); DB::writeIntBinary(request_for_session.session_id, write_buf);
@ -530,7 +545,7 @@ nuraft::ptr<nuraft::buffer> getZooKeeperLogEntry(const KeeperStorage::RequestFor
DB::writeIntBinary(request_for_session.time, write_buf); DB::writeIntBinary(request_for_session.time, write_buf);
/// we fill with dummy values to eliminate unnecessary copy later on when we will write correct values /// we fill with dummy values to eliminate unnecessary copy later on when we will write correct values
DB::writeIntBinary(static_cast<int64_t>(0), write_buf); /// zxid DB::writeIntBinary(static_cast<int64_t>(0), write_buf); /// zxid
DB::writeIntBinary(KeeperStorage::DigestVersion::NO_DIGEST, write_buf); /// digest version or NO_DIGEST flag DB::writeIntBinary(KeeperStorageBase::DigestVersion::NO_DIGEST, write_buf); /// digest version or NO_DIGEST flag
DB::writeIntBinary(static_cast<uint64_t>(0), write_buf); /// digest value DB::writeIntBinary(static_cast<uint64_t>(0), write_buf); /// digest value
/// if new fields are added, update KeeperStateMachine::ZooKeeperLogSerializationVersion along with parseRequest function and PreAppendLog callback handler /// if new fields are added, update KeeperStateMachine::ZooKeeperLogSerializationVersion along with parseRequest function and PreAppendLog callback handler
return write_buf.getBuffer(); return write_buf.getBuffer();
@ -538,7 +553,7 @@ nuraft::ptr<nuraft::buffer> getZooKeeperLogEntry(const KeeperStorage::RequestFor
} }
void KeeperServer::putLocalReadRequest(const KeeperStorage::RequestForSession & request_for_session) void KeeperServer::putLocalReadRequest(const KeeperStorageBase::RequestForSession & request_for_session)
{ {
if (!request_for_session.request->isReadRequest()) if (!request_for_session.request->isReadRequest())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot process non-read request locally"); throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot process non-read request locally");
@ -546,7 +561,7 @@ void KeeperServer::putLocalReadRequest(const KeeperStorage::RequestForSession &
state_machine->processReadRequest(request_for_session); state_machine->processReadRequest(request_for_session);
} }
RaftAppendResult KeeperServer::putRequestBatch(const KeeperStorage::RequestsForSessions & requests_for_sessions) RaftAppendResult KeeperServer::putRequestBatch(const KeeperStorageBase::RequestsForSessions & requests_for_sessions)
{ {
std::vector<nuraft::ptr<nuraft::buffer>> entries; std::vector<nuraft::ptr<nuraft::buffer>> entries;
entries.reserve(requests_for_sessions.size()); entries.reserve(requests_for_sessions.size());
@ -789,7 +804,7 @@ nuraft::cb_func::ReturnCode KeeperServer::callbackFunc(nuraft::cb_func::Type typ
auto entry_buf = entry->get_buf_ptr(); auto entry_buf = entry->get_buf_ptr();
KeeperStateMachine::ZooKeeperLogSerializationVersion serialization_version; IKeeperStateMachine::ZooKeeperLogSerializationVersion serialization_version;
auto request_for_session = state_machine->parseRequest(*entry_buf, /*final=*/false, &serialization_version); auto request_for_session = state_machine->parseRequest(*entry_buf, /*final=*/false, &serialization_version);
request_for_session->zxid = next_zxid; request_for_session->zxid = next_zxid;
if (!state_machine->preprocess(*request_for_session)) if (!state_machine->preprocess(*request_for_session))
@ -799,10 +814,10 @@ nuraft::cb_func::ReturnCode KeeperServer::callbackFunc(nuraft::cb_func::Type typ
/// older versions of Keeper can send logs that are missing some fields /// older versions of Keeper can send logs that are missing some fields
size_t bytes_missing = 0; size_t bytes_missing = 0;
if (serialization_version < KeeperStateMachine::ZooKeeperLogSerializationVersion::WITH_TIME) if (serialization_version < IKeeperStateMachine::ZooKeeperLogSerializationVersion::WITH_TIME)
bytes_missing += sizeof(request_for_session->time); bytes_missing += sizeof(request_for_session->time);
if (serialization_version < KeeperStateMachine::ZooKeeperLogSerializationVersion::WITH_ZXID_DIGEST) if (serialization_version < IKeeperStateMachine::ZooKeeperLogSerializationVersion::WITH_ZXID_DIGEST)
bytes_missing += sizeof(request_for_session->zxid) + sizeof(request_for_session->digest->version) + sizeof(request_for_session->digest->value); bytes_missing += sizeof(request_for_session->zxid) + sizeof(request_for_session->digest->version) + sizeof(request_for_session->digest->value);
if (bytes_missing != 0) if (bytes_missing != 0)
@ -816,19 +831,19 @@ nuraft::cb_func::ReturnCode KeeperServer::callbackFunc(nuraft::cb_func::Type typ
size_t write_buffer_header_size size_t write_buffer_header_size
= sizeof(request_for_session->zxid) + sizeof(request_for_session->digest->version) + sizeof(request_for_session->digest->value); = sizeof(request_for_session->zxid) + sizeof(request_for_session->digest->version) + sizeof(request_for_session->digest->value);
if (serialization_version < KeeperStateMachine::ZooKeeperLogSerializationVersion::WITH_TIME) if (serialization_version < IKeeperStateMachine::ZooKeeperLogSerializationVersion::WITH_TIME)
write_buffer_header_size += sizeof(request_for_session->time); write_buffer_header_size += sizeof(request_for_session->time);
auto * buffer_start = reinterpret_cast<BufferBase::Position>(entry_buf->data_begin() + entry_buf->size() - write_buffer_header_size); auto * buffer_start = reinterpret_cast<BufferBase::Position>(entry_buf->data_begin() + entry_buf->size() - write_buffer_header_size);
WriteBufferFromPointer write_buf(buffer_start, write_buffer_header_size); WriteBufferFromPointer write_buf(buffer_start, write_buffer_header_size);
if (serialization_version < KeeperStateMachine::ZooKeeperLogSerializationVersion::WITH_TIME) if (serialization_version < IKeeperStateMachine::ZooKeeperLogSerializationVersion::WITH_TIME)
writeIntBinary(request_for_session->time, write_buf); writeIntBinary(request_for_session->time, write_buf);
writeIntBinary(request_for_session->zxid, write_buf); writeIntBinary(request_for_session->zxid, write_buf);
writeIntBinary(request_for_session->digest->version, write_buf); writeIntBinary(request_for_session->digest->version, write_buf);
if (request_for_session->digest->version != KeeperStorage::NO_DIGEST) if (request_for_session->digest->version != KeeperStorageBase::NO_DIGEST)
writeIntBinary(request_for_session->digest->value, write_buf); writeIntBinary(request_for_session->digest->value, write_buf);
write_buf.finalize(); write_buf.finalize();

View File

@ -24,7 +24,7 @@ class KeeperServer
private: private:
const int server_id; const int server_id;
nuraft::ptr<KeeperStateMachine> state_machine; nuraft::ptr<IKeeperStateMachine> state_machine;
nuraft::ptr<KeeperStateManager> state_manager; nuraft::ptr<KeeperStateManager> state_manager;
@ -79,26 +79,26 @@ public:
SnapshotsQueue & snapshots_queue_, SnapshotsQueue & snapshots_queue_,
KeeperContextPtr keeper_context_, KeeperContextPtr keeper_context_,
KeeperSnapshotManagerS3 & snapshot_manager_s3, KeeperSnapshotManagerS3 & snapshot_manager_s3,
KeeperStateMachine::CommitCallback commit_callback); IKeeperStateMachine::CommitCallback commit_callback);
/// Load state machine from the latest snapshot and load log storage. Start NuRaft with required settings. /// Load state machine from the latest snapshot and load log storage. Start NuRaft with required settings.
void startup(const Poco::Util::AbstractConfiguration & config, bool enable_ipv6 = true); void startup(const Poco::Util::AbstractConfiguration & config, bool enable_ipv6 = true);
/// Put local read request and execute in state machine directly and response into /// Put local read request and execute in state machine directly and response into
/// responses queue /// responses queue
void putLocalReadRequest(const KeeperStorage::RequestForSession & request); void putLocalReadRequest(const KeeperStorageBase::RequestForSession & request);
bool isRecovering() const { return is_recovering; } bool isRecovering() const { return is_recovering; }
bool reconfigEnabled() const { return enable_reconfiguration; } bool reconfigEnabled() const { return enable_reconfiguration; }
/// Put batch of requests into Raft and get result of put. Responses will be set separately into /// Put batch of requests into Raft and get result of put. Responses will be set separately into
/// responses_queue. /// responses_queue.
RaftAppendResult putRequestBatch(const KeeperStorage::RequestsForSessions & requests); RaftAppendResult putRequestBatch(const KeeperStorageBase::RequestsForSessions & requests);
/// Return set of the non-active sessions /// Return set of the non-active sessions
std::vector<int64_t> getDeadSessions(); std::vector<int64_t> getDeadSessions();
nuraft::ptr<KeeperStateMachine> getKeeperStateMachine() const { return state_machine; } nuraft::ptr<IKeeperStateMachine> getKeeperStateMachine() const { return state_machine; }
void forceRecovery(); void forceRecovery();

View File

@ -66,7 +66,8 @@ namespace
return base; return base;
} }
void writeNode(const KeeperStorage::Node & node, SnapshotVersion version, WriteBuffer & out) template<typename Node>
void writeNode(const Node & node, SnapshotVersion version, WriteBuffer & out)
{ {
writeBinary(node.getData(), out); writeBinary(node.getData(), out);
@ -86,7 +87,7 @@ namespace
writeBinary(node.aversion, out); writeBinary(node.aversion, out);
writeBinary(node.ephemeralOwner(), out); writeBinary(node.ephemeralOwner(), out);
if (version < SnapshotVersion::V6) if (version < SnapshotVersion::V6)
writeBinary(static_cast<int32_t>(node.data_size), out); writeBinary(static_cast<int32_t>(node.getData().size()), out);
writeBinary(node.numChildren(), out); writeBinary(node.numChildren(), out);
writeBinary(node.pzxid, out); writeBinary(node.pzxid, out);
@ -96,7 +97,8 @@ namespace
writeBinary(node.sizeInBytes(), out); writeBinary(node.sizeInBytes(), out);
} }
void readNode(KeeperStorage::Node & node, ReadBuffer & in, SnapshotVersion version, ACLMap & acl_map) template<typename Node>
void readNode(Node & node, ReadBuffer & in, SnapshotVersion version, ACLMap & acl_map)
{ {
readVarUInt(node.data_size, in); readVarUInt(node.data_size, in);
if (node.data_size != 0) if (node.data_size != 0)
@ -195,7 +197,8 @@ namespace
} }
} }
void KeeperStorageSnapshot::serialize(const KeeperStorageSnapshot & snapshot, WriteBuffer & out, KeeperContextPtr keeper_context) template<typename Storage>
void KeeperStorageSnapshot<Storage>::serialize(const KeeperStorageSnapshot<Storage> & snapshot, WriteBuffer & out, KeeperContextPtr keeper_context)
{ {
writeBinary(static_cast<uint8_t>(snapshot.version), out); writeBinary(static_cast<uint8_t>(snapshot.version), out);
serializeSnapshotMetadata(snapshot.snapshot_meta, out); serializeSnapshotMetadata(snapshot.snapshot_meta, out);
@ -205,11 +208,11 @@ void KeeperStorageSnapshot::serialize(const KeeperStorageSnapshot & snapshot, Wr
writeBinary(snapshot.zxid, out); writeBinary(snapshot.zxid, out);
if (keeper_context->digestEnabled()) if (keeper_context->digestEnabled())
{ {
writeBinary(static_cast<uint8_t>(KeeperStorage::CURRENT_DIGEST_VERSION), out); writeBinary(static_cast<uint8_t>(Storage::CURRENT_DIGEST_VERSION), out);
writeBinary(snapshot.nodes_digest, out); writeBinary(snapshot.nodes_digest, out);
} }
else else
writeBinary(static_cast<uint8_t>(KeeperStorage::NO_DIGEST), out); writeBinary(static_cast<uint8_t>(Storage::NO_DIGEST), out);
} }
writeBinary(snapshot.session_id, out); writeBinary(snapshot.session_id, out);
@ -255,7 +258,6 @@ void KeeperStorageSnapshot::serialize(const KeeperStorageSnapshot & snapshot, Wr
/// slightly bigger than required. /// slightly bigger than required.
if (node.mzxid > snapshot.zxid) if (node.mzxid > snapshot.zxid)
break; break;
writeBinary(path, out); writeBinary(path, out);
writeNode(node, snapshot.version, out); writeNode(node, snapshot.version, out);
@ -282,7 +284,7 @@ void KeeperStorageSnapshot::serialize(const KeeperStorageSnapshot & snapshot, Wr
writeBinary(session_id, out); writeBinary(session_id, out);
writeBinary(timeout, out); writeBinary(timeout, out);
KeeperStorage::AuthIDs ids; KeeperStorageBase::AuthIDs ids;
if (snapshot.session_and_auth.contains(session_id)) if (snapshot.session_and_auth.contains(session_id))
ids = snapshot.session_and_auth.at(session_id); ids = snapshot.session_and_auth.at(session_id);
@ -303,7 +305,8 @@ void KeeperStorageSnapshot::serialize(const KeeperStorageSnapshot & snapshot, Wr
} }
} }
void KeeperStorageSnapshot::deserialize(SnapshotDeserializationResult & deserialization_result, ReadBuffer & in, KeeperContextPtr keeper_context) template<typename Storage>
void KeeperStorageSnapshot<Storage>::deserialize(SnapshotDeserializationResult<Storage> & deserialization_result, ReadBuffer & in, KeeperContextPtr keeper_context)
{ {
uint8_t version; uint8_t version;
readBinary(version, in); readBinary(version, in);
@ -312,7 +315,7 @@ void KeeperStorageSnapshot::deserialize(SnapshotDeserializationResult & deserial
throw Exception(ErrorCodes::UNKNOWN_FORMAT_VERSION, "Unsupported snapshot version {}", version); throw Exception(ErrorCodes::UNKNOWN_FORMAT_VERSION, "Unsupported snapshot version {}", version);
deserialization_result.snapshot_meta = deserializeSnapshotMetadata(in); deserialization_result.snapshot_meta = deserializeSnapshotMetadata(in);
KeeperStorage & storage = *deserialization_result.storage; Storage & storage = *deserialization_result.storage;
bool recalculate_digest = keeper_context->digestEnabled(); bool recalculate_digest = keeper_context->digestEnabled();
if (version >= SnapshotVersion::V5) if (version >= SnapshotVersion::V5)
@ -320,11 +323,11 @@ void KeeperStorageSnapshot::deserialize(SnapshotDeserializationResult & deserial
readBinary(storage.zxid, in); readBinary(storage.zxid, in);
uint8_t digest_version; uint8_t digest_version;
readBinary(digest_version, in); readBinary(digest_version, in);
if (digest_version != KeeperStorage::DigestVersion::NO_DIGEST) if (digest_version != Storage::DigestVersion::NO_DIGEST)
{ {
uint64_t nodes_digest; uint64_t nodes_digest;
readBinary(nodes_digest, in); readBinary(nodes_digest, in);
if (digest_version == KeeperStorage::CURRENT_DIGEST_VERSION) if (digest_version == Storage::CURRENT_DIGEST_VERSION)
{ {
storage.nodes_digest = nodes_digest; storage.nodes_digest = nodes_digest;
recalculate_digest = false; recalculate_digest = false;
@ -374,8 +377,8 @@ void KeeperStorageSnapshot::deserialize(SnapshotDeserializationResult & deserial
size_t snapshot_container_size; size_t snapshot_container_size;
readBinary(snapshot_container_size, in); readBinary(snapshot_container_size, in);
if constexpr (!use_rocksdb)
storage.container.reserve(snapshot_container_size); storage.container.reserve(snapshot_container_size);
if (recalculate_digest) if (recalculate_digest)
storage.nodes_digest = 0; storage.nodes_digest = 0;
@ -389,7 +392,7 @@ void KeeperStorageSnapshot::deserialize(SnapshotDeserializationResult & deserial
in.readStrict(path_data.get(), path_size); in.readStrict(path_data.get(), path_size);
std::string_view path{path_data.get(), path_size}; std::string_view path{path_data.get(), path_size};
KeeperStorage::Node node{}; typename Storage::Node node{};
readNode(node, in, current_version, storage.acl_map); readNode(node, in, current_version, storage.acl_map);
using enum Coordination::PathMatchResult; using enum Coordination::PathMatchResult;
@ -421,7 +424,7 @@ void KeeperStorageSnapshot::deserialize(SnapshotDeserializationResult & deserial
if (keeper_context->ignoreSystemPathOnStartup() || keeper_context->getServerState() != KeeperContext::Phase::INIT) if (keeper_context->ignoreSystemPathOnStartup() || keeper_context->getServerState() != KeeperContext::Phase::INIT)
{ {
LOG_ERROR(getLogger("KeeperSnapshotManager"), "{}. Ignoring it", get_error_msg()); LOG_ERROR(getLogger("KeeperSnapshotManager"), "{}. Ignoring it", get_error_msg());
node = KeeperStorage::Node{}; node = typename Storage::Node{};
} }
else else
throw Exception( throw Exception(
@ -433,8 +436,9 @@ void KeeperStorageSnapshot::deserialize(SnapshotDeserializationResult & deserial
} }
auto ephemeral_owner = node.ephemeralOwner(); auto ephemeral_owner = node.ephemeralOwner();
if (!node.isEphemeral() && node.numChildren() > 0) if constexpr (!use_rocksdb)
node.getChildren().reserve(node.numChildren()); if (!node.isEphemeral() && node.numChildren() > 0)
node.getChildren().reserve(node.numChildren());
if (ephemeral_owner != 0) if (ephemeral_owner != 0)
storage.ephemerals[node.ephemeralOwner()].insert(std::string{path}); storage.ephemerals[node.ephemeralOwner()].insert(std::string{path});
@ -447,36 +451,38 @@ void KeeperStorageSnapshot::deserialize(SnapshotDeserializationResult & deserial
LOG_TRACE(getLogger("KeeperSnapshotManager"), "Building structure for children nodes"); LOG_TRACE(getLogger("KeeperSnapshotManager"), "Building structure for children nodes");
for (const auto & itr : storage.container) if constexpr (!use_rocksdb)
{ {
if (itr.key != "/") for (const auto & itr : storage.container)
{ {
auto parent_path = parentNodePath(itr.key); if (itr.key != "/")
storage.container.updateValue(
parent_path, [path = itr.key](KeeperStorage::Node & value) { value.addChild(getBaseNodeName(path)); });
}
}
for (const auto & itr : storage.container)
{
if (itr.key != "/")
{
if (itr.value.numChildren() != static_cast<int32_t>(itr.value.getChildren().size()))
{ {
auto parent_path = parentNodePath(itr.key);
storage.container.updateValue(
parent_path, [path = itr.key](typename Storage::Node & value) { value.addChild(getBaseNodeName(path)); });
}
}
for (const auto & itr : storage.container)
{
if (itr.key != "/")
{
if (itr.value.numChildren() != static_cast<int32_t>(itr.value.getChildren().size()))
{
#ifdef NDEBUG #ifdef NDEBUG
/// TODO (alesapin) remove this, it should be always CORRUPTED_DATA. /// TODO (alesapin) remove this, it should be always CORRUPTED_DATA.
LOG_ERROR(getLogger("KeeperSnapshotManager"), "Children counter in stat.numChildren {}" LOG_ERROR(getLogger("KeeperSnapshotManager"), "Children counter in stat.numChildren {}"
" is different from actual children size {} for node {}", itr.value.numChildren(), itr.value.getChildren().size(), itr.key); " is different from actual children size {} for node {}", itr.value.numChildren(), itr.value.getChildren().size(), itr.key);
#else #else
throw Exception(ErrorCodes::LOGICAL_ERROR, "Children counter in stat.numChildren {}" throw Exception(ErrorCodes::LOGICAL_ERROR, "Children counter in stat.numChildren {}"
" is different from actual children size {} for node {}", " is different from actual children size {} for node {}",
itr.value.numChildren(), itr.value.getChildren().size(), itr.key); itr.value.numChildren(), itr.value.getChildren().size(), itr.key);
#endif #endif
}
} }
} }
} }
size_t active_sessions_size; size_t active_sessions_size;
readBinary(active_sessions_size, in); readBinary(active_sessions_size, in);
@ -493,14 +499,14 @@ void KeeperStorageSnapshot::deserialize(SnapshotDeserializationResult & deserial
size_t session_auths_size; size_t session_auths_size;
readBinary(session_auths_size, in); readBinary(session_auths_size, in);
KeeperStorage::AuthIDs ids; typename Storage::AuthIDs ids;
size_t session_auth_counter = 0; size_t session_auth_counter = 0;
while (session_auth_counter < session_auths_size) while (session_auth_counter < session_auths_size)
{ {
String scheme, id; String scheme, id;
readBinary(scheme, in); readBinary(scheme, in);
readBinary(id, in); readBinary(id, in);
ids.emplace_back(KeeperStorage::AuthID{scheme, id}); ids.emplace_back(typename Storage::AuthID{scheme, id});
session_auth_counter++; session_auth_counter++;
} }
@ -523,7 +529,8 @@ void KeeperStorageSnapshot::deserialize(SnapshotDeserializationResult & deserial
} }
} }
KeeperStorageSnapshot::KeeperStorageSnapshot(KeeperStorage * storage_, uint64_t up_to_log_idx_, const ClusterConfigPtr & cluster_config_) template<typename Storage>
KeeperStorageSnapshot<Storage>::KeeperStorageSnapshot(Storage * storage_, uint64_t up_to_log_idx_, const ClusterConfigPtr & cluster_config_)
: storage(storage_) : storage(storage_)
, snapshot_meta(std::make_shared<SnapshotMetadata>(up_to_log_idx_, 0, std::make_shared<nuraft::cluster_config>())) , snapshot_meta(std::make_shared<SnapshotMetadata>(up_to_log_idx_, 0, std::make_shared<nuraft::cluster_config>()))
, session_id(storage->session_id_counter) , session_id(storage->session_id_counter)
@ -540,8 +547,9 @@ KeeperStorageSnapshot::KeeperStorageSnapshot(KeeperStorage * storage_, uint64_t
session_and_auth = storage->session_and_auth; session_and_auth = storage->session_and_auth;
} }
KeeperStorageSnapshot::KeeperStorageSnapshot( template<typename Storage>
KeeperStorage * storage_, const SnapshotMetadataPtr & snapshot_meta_, const ClusterConfigPtr & cluster_config_) KeeperStorageSnapshot<Storage>::KeeperStorageSnapshot(
Storage * storage_, const SnapshotMetadataPtr & snapshot_meta_, const ClusterConfigPtr & cluster_config_)
: storage(storage_) : storage(storage_)
, snapshot_meta(snapshot_meta_) , snapshot_meta(snapshot_meta_)
, session_id(storage->session_id_counter) , session_id(storage->session_id_counter)
@ -558,12 +566,14 @@ KeeperStorageSnapshot::KeeperStorageSnapshot(
session_and_auth = storage->session_and_auth; session_and_auth = storage->session_and_auth;
} }
KeeperStorageSnapshot::~KeeperStorageSnapshot() template<typename Storage>
KeeperStorageSnapshot<Storage>::~KeeperStorageSnapshot()
{ {
storage->disableSnapshotMode(); storage->disableSnapshotMode();
} }
KeeperSnapshotManager::KeeperSnapshotManager( template<typename Storage>
KeeperSnapshotManager<Storage>::KeeperSnapshotManager(
size_t snapshots_to_keep_, size_t snapshots_to_keep_,
const KeeperContextPtr & keeper_context_, const KeeperContextPtr & keeper_context_,
bool compress_snapshots_zstd_, bool compress_snapshots_zstd_,
@ -651,7 +661,8 @@ KeeperSnapshotManager::KeeperSnapshotManager(
moveSnapshotsIfNeeded(); moveSnapshotsIfNeeded();
} }
SnapshotFileInfoPtr KeeperSnapshotManager::serializeSnapshotBufferToDisk(nuraft::buffer & buffer, uint64_t up_to_log_idx) template<typename Storage>
SnapshotFileInfoPtr KeeperSnapshotManager<Storage>::serializeSnapshotBufferToDisk(nuraft::buffer & buffer, uint64_t up_to_log_idx)
{ {
ReadBufferFromNuraftBuffer reader(buffer); ReadBufferFromNuraftBuffer reader(buffer);
@ -680,7 +691,8 @@ SnapshotFileInfoPtr KeeperSnapshotManager::serializeSnapshotBufferToDisk(nuraft:
return snapshot_file_info; return snapshot_file_info;
} }
nuraft::ptr<nuraft::buffer> KeeperSnapshotManager::deserializeLatestSnapshotBufferFromDisk() template<typename Storage>
nuraft::ptr<nuraft::buffer> KeeperSnapshotManager<Storage>::deserializeLatestSnapshotBufferFromDisk()
{ {
while (!existing_snapshots.empty()) while (!existing_snapshots.empty())
{ {
@ -701,7 +713,8 @@ nuraft::ptr<nuraft::buffer> KeeperSnapshotManager::deserializeLatestSnapshotBuff
return nullptr; return nullptr;
} }
nuraft::ptr<nuraft::buffer> KeeperSnapshotManager::deserializeSnapshotBufferFromDisk(uint64_t up_to_log_idx) const template<typename Storage>
nuraft::ptr<nuraft::buffer> KeeperSnapshotManager<Storage>::deserializeSnapshotBufferFromDisk(uint64_t up_to_log_idx) const
{ {
const auto & [snapshot_path, snapshot_disk, size] = *existing_snapshots.at(up_to_log_idx); const auto & [snapshot_path, snapshot_disk, size] = *existing_snapshots.at(up_to_log_idx);
WriteBufferFromNuraftBuffer writer; WriteBufferFromNuraftBuffer writer;
@ -710,7 +723,8 @@ nuraft::ptr<nuraft::buffer> KeeperSnapshotManager::deserializeSnapshotBufferFrom
return writer.getBuffer(); return writer.getBuffer();
} }
nuraft::ptr<nuraft::buffer> KeeperSnapshotManager::serializeSnapshotToBuffer(const KeeperStorageSnapshot & snapshot) const template<typename Storage>
nuraft::ptr<nuraft::buffer> KeeperSnapshotManager<Storage>::serializeSnapshotToBuffer(const KeeperStorageSnapshot<Storage> & snapshot) const
{ {
std::unique_ptr<WriteBufferFromNuraftBuffer> writer = std::make_unique<WriteBufferFromNuraftBuffer>(); std::unique_ptr<WriteBufferFromNuraftBuffer> writer = std::make_unique<WriteBufferFromNuraftBuffer>();
auto * buffer_raw_ptr = writer.get(); auto * buffer_raw_ptr = writer.get();
@ -720,13 +734,13 @@ nuraft::ptr<nuraft::buffer> KeeperSnapshotManager::serializeSnapshotToBuffer(con
else else
compressed_writer = std::make_unique<CompressedWriteBuffer>(*writer); compressed_writer = std::make_unique<CompressedWriteBuffer>(*writer);
KeeperStorageSnapshot::serialize(snapshot, *compressed_writer, keeper_context); KeeperStorageSnapshot<Storage>::serialize(snapshot, *compressed_writer, keeper_context);
compressed_writer->finalize(); compressed_writer->finalize();
return buffer_raw_ptr->getBuffer(); return buffer_raw_ptr->getBuffer();
} }
template<typename Storage>
bool KeeperSnapshotManager::isZstdCompressed(nuraft::ptr<nuraft::buffer> buffer) bool KeeperSnapshotManager<Storage>::isZstdCompressed(nuraft::ptr<nuraft::buffer> buffer)
{ {
static constexpr unsigned char ZSTD_COMPRESSED_MAGIC[4] = {0x28, 0xB5, 0x2F, 0xFD}; static constexpr unsigned char ZSTD_COMPRESSED_MAGIC[4] = {0x28, 0xB5, 0x2F, 0xFD};
@ -737,7 +751,8 @@ bool KeeperSnapshotManager::isZstdCompressed(nuraft::ptr<nuraft::buffer> buffer)
return memcmp(magic_from_buffer, ZSTD_COMPRESSED_MAGIC, 4) == 0; return memcmp(magic_from_buffer, ZSTD_COMPRESSED_MAGIC, 4) == 0;
} }
SnapshotDeserializationResult KeeperSnapshotManager::deserializeSnapshotFromBuffer(nuraft::ptr<nuraft::buffer> buffer) const template<typename Storage>
SnapshotDeserializationResult<Storage> KeeperSnapshotManager<Storage>::deserializeSnapshotFromBuffer(nuraft::ptr<nuraft::buffer> buffer) const
{ {
bool is_zstd_compressed = isZstdCompressed(buffer); bool is_zstd_compressed = isZstdCompressed(buffer);
@ -749,14 +764,15 @@ SnapshotDeserializationResult KeeperSnapshotManager::deserializeSnapshotFromBuff
else else
compressed_reader = std::make_unique<CompressedReadBuffer>(*reader); compressed_reader = std::make_unique<CompressedReadBuffer>(*reader);
SnapshotDeserializationResult result; SnapshotDeserializationResult<Storage> result;
result.storage = std::make_unique<KeeperStorage>(storage_tick_time, superdigest, keeper_context, /* initialize_system_nodes */ false); result.storage = std::make_unique<Storage>(storage_tick_time, superdigest, keeper_context, /* initialize_system_nodes */ false);
KeeperStorageSnapshot::deserialize(result, *compressed_reader, keeper_context); KeeperStorageSnapshot<Storage>::deserialize(result, *compressed_reader, keeper_context);
result.storage->initializeSystemNodes(); result.storage->initializeSystemNodes();
return result; return result;
} }
SnapshotDeserializationResult KeeperSnapshotManager::restoreFromLatestSnapshot() template<typename Storage>
SnapshotDeserializationResult<Storage> KeeperSnapshotManager<Storage>::restoreFromLatestSnapshot()
{ {
if (existing_snapshots.empty()) if (existing_snapshots.empty())
return {}; return {};
@ -767,23 +783,27 @@ SnapshotDeserializationResult KeeperSnapshotManager::restoreFromLatestSnapshot()
return deserializeSnapshotFromBuffer(buffer); return deserializeSnapshotFromBuffer(buffer);
} }
DiskPtr KeeperSnapshotManager::getDisk() const template<typename Storage>
DiskPtr KeeperSnapshotManager<Storage>::getDisk() const
{ {
return keeper_context->getSnapshotDisk(); return keeper_context->getSnapshotDisk();
} }
DiskPtr KeeperSnapshotManager::getLatestSnapshotDisk() const template<typename Storage>
DiskPtr KeeperSnapshotManager<Storage>::getLatestSnapshotDisk() const
{ {
return keeper_context->getLatestSnapshotDisk(); return keeper_context->getLatestSnapshotDisk();
} }
void KeeperSnapshotManager::removeOutdatedSnapshotsIfNeeded() template<typename Storage>
void KeeperSnapshotManager<Storage>::removeOutdatedSnapshotsIfNeeded()
{ {
while (existing_snapshots.size() > snapshots_to_keep) while (existing_snapshots.size() > snapshots_to_keep)
removeSnapshot(existing_snapshots.begin()->first); removeSnapshot(existing_snapshots.begin()->first);
} }
void KeeperSnapshotManager::moveSnapshotsIfNeeded() template<typename Storage>
void KeeperSnapshotManager<Storage>::moveSnapshotsIfNeeded()
{ {
/// move snapshots to correct disks /// move snapshots to correct disks
@ -813,7 +833,8 @@ void KeeperSnapshotManager::moveSnapshotsIfNeeded()
} }
void KeeperSnapshotManager::removeSnapshot(uint64_t log_idx) template<typename Storage>
void KeeperSnapshotManager<Storage>::removeSnapshot(uint64_t log_idx)
{ {
auto itr = existing_snapshots.find(log_idx); auto itr = existing_snapshots.find(log_idx);
if (itr == existing_snapshots.end()) if (itr == existing_snapshots.end())
@ -823,7 +844,8 @@ void KeeperSnapshotManager::removeSnapshot(uint64_t log_idx)
existing_snapshots.erase(itr); existing_snapshots.erase(itr);
} }
SnapshotFileInfoPtr KeeperSnapshotManager::serializeSnapshotToDisk(const KeeperStorageSnapshot & snapshot) template<typename Storage>
SnapshotFileInfoPtr KeeperSnapshotManager<Storage>::serializeSnapshotToDisk(const KeeperStorageSnapshot<Storage> & snapshot)
{ {
auto up_to_log_idx = snapshot.snapshot_meta->get_last_log_idx(); auto up_to_log_idx = snapshot.snapshot_meta->get_last_log_idx();
auto snapshot_file_name = getSnapshotFileName(up_to_log_idx, compress_snapshots_zstd); auto snapshot_file_name = getSnapshotFileName(up_to_log_idx, compress_snapshots_zstd);
@ -842,7 +864,7 @@ SnapshotFileInfoPtr KeeperSnapshotManager::serializeSnapshotToDisk(const KeeperS
else else
compressed_writer = std::make_unique<CompressedWriteBuffer>(*writer); compressed_writer = std::make_unique<CompressedWriteBuffer>(*writer);
KeeperStorageSnapshot::serialize(snapshot, *compressed_writer, keeper_context); KeeperStorageSnapshot<Storage>::serialize(snapshot, *compressed_writer, keeper_context);
compressed_writer->finalize(); compressed_writer->finalize();
compressed_writer->sync(); compressed_writer->sync();
@ -864,14 +886,16 @@ SnapshotFileInfoPtr KeeperSnapshotManager::serializeSnapshotToDisk(const KeeperS
return snapshot_file_info; return snapshot_file_info;
} }
size_t KeeperSnapshotManager::getLatestSnapshotIndex() const template<typename Storage>
size_t KeeperSnapshotManager<Storage>::getLatestSnapshotIndex() const
{ {
if (!existing_snapshots.empty()) if (!existing_snapshots.empty())
return existing_snapshots.rbegin()->first; return existing_snapshots.rbegin()->first;
return 0; return 0;
} }
SnapshotFileInfoPtr KeeperSnapshotManager::getLatestSnapshotInfo() const template<typename Storage>
SnapshotFileInfoPtr KeeperSnapshotManager<Storage>::getLatestSnapshotInfo() const
{ {
if (!existing_snapshots.empty()) if (!existing_snapshots.empty())
{ {
@ -890,4 +914,10 @@ SnapshotFileInfoPtr KeeperSnapshotManager::getLatestSnapshotInfo() const
return nullptr; return nullptr;
} }
template struct KeeperStorageSnapshot<KeeperMemoryStorage>;
template class KeeperSnapshotManager<KeeperMemoryStorage>;
#if USE_ROCKSDB
template struct KeeperStorageSnapshot<KeeperRocksStorage>;
template class KeeperSnapshotManager<KeeperRocksStorage>;
#endif
} }

View File

@ -34,10 +34,11 @@ enum SnapshotVersion : uint8_t
static constexpr auto CURRENT_SNAPSHOT_VERSION = SnapshotVersion::V6; static constexpr auto CURRENT_SNAPSHOT_VERSION = SnapshotVersion::V6;
/// What is stored in binary snapshot /// What is stored in binary snapshot
template<typename Storage>
struct SnapshotDeserializationResult struct SnapshotDeserializationResult
{ {
/// Storage /// Storage
KeeperStoragePtr storage; std::unique_ptr<Storage> storage;
/// Snapshot metadata (up_to_log_idx and so on) /// Snapshot metadata (up_to_log_idx and so on)
SnapshotMetadataPtr snapshot_meta; SnapshotMetadataPtr snapshot_meta;
/// Cluster config /// Cluster config
@ -52,21 +53,31 @@ struct SnapshotDeserializationResult
/// ///
/// This representation of snapshot have to be serialized into NuRaft /// This representation of snapshot have to be serialized into NuRaft
/// buffer and send over network or saved to file. /// buffer and send over network or saved to file.
template<typename Storage>
struct KeeperStorageSnapshot struct KeeperStorageSnapshot
{ {
#if USE_ROCKSDB
static constexpr bool use_rocksdb = std::is_same_v<Storage, KeeperRocksStorage>;
#else
static constexpr bool use_rocksdb = false;
#endif
public: public:
KeeperStorageSnapshot(KeeperStorage * storage_, uint64_t up_to_log_idx_, const ClusterConfigPtr & cluster_config_ = nullptr); KeeperStorageSnapshot(Storage * storage_, uint64_t up_to_log_idx_, const ClusterConfigPtr & cluster_config_ = nullptr);
KeeperStorageSnapshot( KeeperStorageSnapshot(
KeeperStorage * storage_, const SnapshotMetadataPtr & snapshot_meta_, const ClusterConfigPtr & cluster_config_ = nullptr); Storage * storage_, const SnapshotMetadataPtr & snapshot_meta_, const ClusterConfigPtr & cluster_config_ = nullptr);
KeeperStorageSnapshot(const KeeperStorageSnapshot<Storage>&) = delete;
KeeperStorageSnapshot(KeeperStorageSnapshot<Storage>&&) = default;
~KeeperStorageSnapshot(); ~KeeperStorageSnapshot();
static void serialize(const KeeperStorageSnapshot & snapshot, WriteBuffer & out, KeeperContextPtr keeper_context); static void serialize(const KeeperStorageSnapshot<Storage> & snapshot, WriteBuffer & out, KeeperContextPtr keeper_context);
static void deserialize(SnapshotDeserializationResult & deserialization_result, ReadBuffer & in, KeeperContextPtr keeper_context); static void deserialize(SnapshotDeserializationResult<Storage> & deserialization_result, ReadBuffer & in, KeeperContextPtr keeper_context);
KeeperStorage * storage; Storage * storage;
SnapshotVersion version = CURRENT_SNAPSHOT_VERSION; SnapshotVersion version = CURRENT_SNAPSHOT_VERSION;
/// Snapshot metadata /// Snapshot metadata
@ -77,11 +88,11 @@ public:
/// so we have for loop for (i = 0; i < snapshot_container_size; ++i) { doSmth(begin + i); } /// so we have for loop for (i = 0; i < snapshot_container_size; ++i) { doSmth(begin + i); }
size_t snapshot_container_size; size_t snapshot_container_size;
/// Iterator to the start of the storage /// Iterator to the start of the storage
KeeperStorage::Container::const_iterator begin; Storage::Container::const_iterator begin;
/// Active sessions and their timeouts /// Active sessions and their timeouts
SessionAndTimeout session_and_timeout; SessionAndTimeout session_and_timeout;
/// Sessions credentials /// Sessions credentials
KeeperStorage::SessionAndAuth session_and_auth; Storage::SessionAndAuth session_and_auth;
/// ACLs cache for better performance. Without we cannot deserialize storage. /// ACLs cache for better performance. Without we cannot deserialize storage.
std::unordered_map<uint64_t, Coordination::ACLs> acl_map; std::unordered_map<uint64_t, Coordination::ACLs> acl_map;
/// Cluster config from snapshot, can be empty /// Cluster config from snapshot, can be empty
@ -105,14 +116,16 @@ struct SnapshotFileInfo
}; };
using SnapshotFileInfoPtr = std::shared_ptr<SnapshotFileInfo>; using SnapshotFileInfoPtr = std::shared_ptr<SnapshotFileInfo>;
#if USE_ROCKSDB
using KeeperStorageSnapshotPtr = std::shared_ptr<KeeperStorageSnapshot>; using KeeperStorageSnapshotPtr = std::variant<std::shared_ptr<KeeperStorageSnapshot<KeeperMemoryStorage>>, std::shared_ptr<KeeperStorageSnapshot<KeeperRocksStorage>>>;
using CreateSnapshotCallback = std::function<std::shared_ptr<SnapshotFileInfo>(KeeperStorageSnapshotPtr &&, bool)>; #else
using KeeperStorageSnapshotPtr = std::variant<std::shared_ptr<KeeperStorageSnapshot<KeeperMemoryStorage>>>;
using SnapshotMetaAndStorage = std::pair<SnapshotMetadataPtr, KeeperStoragePtr>; #endif
using CreateSnapshotCallback = std::function<SnapshotFileInfoPtr(KeeperStorageSnapshotPtr &&, bool)>;
/// Class responsible for snapshots serialization and deserialization. Each snapshot /// Class responsible for snapshots serialization and deserialization. Each snapshot
/// has it's path on disk and log index. /// has it's path on disk and log index.
template<typename Storage>
class KeeperSnapshotManager class KeeperSnapshotManager
{ {
public: public:
@ -124,18 +137,18 @@ public:
size_t storage_tick_time_ = 500); size_t storage_tick_time_ = 500);
/// Restore storage from latest available snapshot /// Restore storage from latest available snapshot
SnapshotDeserializationResult restoreFromLatestSnapshot(); SnapshotDeserializationResult<Storage> restoreFromLatestSnapshot();
/// Compress snapshot and serialize it to buffer /// Compress snapshot and serialize it to buffer
nuraft::ptr<nuraft::buffer> serializeSnapshotToBuffer(const KeeperStorageSnapshot & snapshot) const; nuraft::ptr<nuraft::buffer> serializeSnapshotToBuffer(const KeeperStorageSnapshot<Storage> & snapshot) const;
/// Serialize already compressed snapshot to disk (return path) /// Serialize already compressed snapshot to disk (return path)
SnapshotFileInfoPtr serializeSnapshotBufferToDisk(nuraft::buffer & buffer, uint64_t up_to_log_idx); SnapshotFileInfoPtr serializeSnapshotBufferToDisk(nuraft::buffer & buffer, uint64_t up_to_log_idx);
/// Serialize snapshot directly to disk /// Serialize snapshot directly to disk
SnapshotFileInfoPtr serializeSnapshotToDisk(const KeeperStorageSnapshot & snapshot); SnapshotFileInfoPtr serializeSnapshotToDisk(const KeeperStorageSnapshot<Storage> & snapshot);
SnapshotDeserializationResult deserializeSnapshotFromBuffer(nuraft::ptr<nuraft::buffer> buffer) const; SnapshotDeserializationResult<Storage> deserializeSnapshotFromBuffer(nuraft::ptr<nuraft::buffer> buffer) const;
/// Deserialize snapshot with log index up_to_log_idx from disk into compressed nuraft buffer. /// Deserialize snapshot with log index up_to_log_idx from disk into compressed nuraft buffer.
nuraft::ptr<nuraft::buffer> deserializeSnapshotBufferFromDisk(uint64_t up_to_log_idx) const; nuraft::ptr<nuraft::buffer> deserializeSnapshotBufferFromDisk(uint64_t up_to_log_idx) const;

View File

@ -44,7 +44,7 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR; extern const int LOGICAL_ERROR;
} }
KeeperStateMachine::KeeperStateMachine( IKeeperStateMachine::IKeeperStateMachine(
ResponsesQueue & responses_queue_, ResponsesQueue & responses_queue_,
SnapshotsQueue & snapshots_queue_, SnapshotsQueue & snapshots_queue_,
const KeeperContextPtr & keeper_context_, const KeeperContextPtr & keeper_context_,
@ -52,12 +52,6 @@ KeeperStateMachine::KeeperStateMachine(
CommitCallback commit_callback_, CommitCallback commit_callback_,
const std::string & superdigest_) const std::string & superdigest_)
: commit_callback(commit_callback_) : commit_callback(commit_callback_)
, snapshot_manager(
keeper_context_->getCoordinationSettings()->snapshots_to_keep,
keeper_context_,
keeper_context_->getCoordinationSettings()->compress_snapshots_with_zstd_format,
superdigest_,
keeper_context_->getCoordinationSettings()->dead_session_check_period_ms.totalMilliseconds())
, responses_queue(responses_queue_) , responses_queue(responses_queue_)
, snapshots_queue(snapshots_queue_) , snapshots_queue(snapshots_queue_)
, min_request_size_to_cache(keeper_context_->getCoordinationSettings()->min_request_size_for_cache) , min_request_size_to_cache(keeper_context_->getCoordinationSettings()->min_request_size_for_cache)
@ -68,6 +62,32 @@ KeeperStateMachine::KeeperStateMachine(
{ {
} }
template<typename Storage>
KeeperStateMachine<Storage>::KeeperStateMachine(
ResponsesQueue & responses_queue_,
SnapshotsQueue & snapshots_queue_,
// const CoordinationSettingsPtr & coordination_settings_,
const KeeperContextPtr & keeper_context_,
KeeperSnapshotManagerS3 * snapshot_manager_s3_,
IKeeperStateMachine::CommitCallback commit_callback_,
const std::string & superdigest_)
: IKeeperStateMachine(
responses_queue_,
snapshots_queue_,
/// coordination_settings_,
keeper_context_,
snapshot_manager_s3_,
commit_callback_,
superdigest_),
snapshot_manager(
keeper_context_->getCoordinationSettings()->snapshots_to_keep,
keeper_context_,
keeper_context_->getCoordinationSettings()->compress_snapshots_with_zstd_format,
superdigest_,
keeper_context_->getCoordinationSettings()->dead_session_check_period_ms.totalMilliseconds())
{
}
namespace namespace
{ {
@ -78,7 +98,8 @@ bool isLocalDisk(const IDisk & disk)
} }
void KeeperStateMachine::init() template<typename Storage>
void KeeperStateMachine<Storage>::init()
{ {
/// Do everything without mutexes, no other threads exist. /// Do everything without mutexes, no other threads exist.
LOG_DEBUG(log, "Totally have {} snapshots", snapshot_manager.totalSnapshots()); LOG_DEBUG(log, "Totally have {} snapshots", snapshot_manager.totalSnapshots());
@ -123,7 +144,7 @@ void KeeperStateMachine::init()
LOG_DEBUG(log, "No existing snapshots, last committed log index {}", last_committed_idx); LOG_DEBUG(log, "No existing snapshots, last committed log index {}", last_committed_idx);
if (!storage) if (!storage)
storage = std::make_unique<KeeperStorage>( storage = std::make_unique<Storage>(
keeper_context->getCoordinationSettings()->dead_session_check_period_ms.totalMilliseconds(), superdigest, keeper_context); keeper_context->getCoordinationSettings()->dead_session_check_period_ms.totalMilliseconds(), superdigest, keeper_context);
} }
@ -131,13 +152,13 @@ namespace
{ {
void assertDigest( void assertDigest(
const KeeperStorage::Digest & expected, const KeeperStorageBase::Digest & expected,
const KeeperStorage::Digest & actual, const KeeperStorageBase::Digest & actual,
const Coordination::ZooKeeperRequest & request, const Coordination::ZooKeeperRequest & request,
uint64_t log_idx, uint64_t log_idx,
bool committing) bool committing)
{ {
if (!KeeperStorage::checkDigest(expected, actual)) if (!KeeperStorageBase::checkDigest(expected, actual))
{ {
LOG_FATAL( LOG_FATAL(
getLogger("KeeperStateMachine"), getLogger("KeeperStateMachine"),
@ -170,7 +191,8 @@ struct TSA_SCOPED_LOCKABLE LockGuardWithStats final
} }
nuraft::ptr<nuraft::buffer> KeeperStateMachine::pre_commit(uint64_t log_idx, nuraft::buffer & data) template<typename Storage>
nuraft::ptr<nuraft::buffer> KeeperStateMachine<Storage>::pre_commit(uint64_t log_idx, nuraft::buffer & data)
{ {
auto result = nuraft::buffer::alloc(sizeof(log_idx)); auto result = nuraft::buffer::alloc(sizeof(log_idx));
nuraft::buffer_serializer ss(result); nuraft::buffer_serializer ss(result);
@ -191,10 +213,10 @@ nuraft::ptr<nuraft::buffer> KeeperStateMachine::pre_commit(uint64_t log_idx, nur
return result; return result;
} }
std::shared_ptr<KeeperStorage::RequestForSession> KeeperStateMachine::parseRequest(nuraft::buffer & data, bool final, ZooKeeperLogSerializationVersion * serialization_version) std::shared_ptr<KeeperStorageBase::RequestForSession> IKeeperStateMachine::parseRequest(nuraft::buffer & data, bool final, ZooKeeperLogSerializationVersion * serialization_version)
{ {
ReadBufferFromNuraftBuffer buffer(data); ReadBufferFromNuraftBuffer buffer(data);
auto request_for_session = std::make_shared<KeeperStorage::RequestForSession>(); auto request_for_session = std::make_shared<KeeperStorageBase::RequestForSession>();
readIntBinary(request_for_session->session_id, buffer); readIntBinary(request_for_session->session_id, buffer);
int32_t length; int32_t length;
@ -267,7 +289,7 @@ std::shared_ptr<KeeperStorage::RequestForSession> KeeperStateMachine::parseReque
request_for_session->digest.emplace(); request_for_session->digest.emplace();
readIntBinary(request_for_session->digest->version, buffer); readIntBinary(request_for_session->digest->version, buffer);
if (request_for_session->digest->version != KeeperStorage::DigestVersion::NO_DIGEST || !buffer.eof()) if (request_for_session->digest->version != KeeperStorageBase::DigestVersion::NO_DIGEST || !buffer.eof())
readIntBinary(request_for_session->digest->value, buffer); readIntBinary(request_for_session->digest->value, buffer);
} }
@ -283,7 +305,8 @@ std::shared_ptr<KeeperStorage::RequestForSession> KeeperStateMachine::parseReque
return request_for_session; return request_for_session;
} }
bool KeeperStateMachine::preprocess(const KeeperStorage::RequestForSession & request_for_session) template<typename Storage>
bool KeeperStateMachine<Storage>::preprocess(const KeeperStorageBase::RequestForSession & request_for_session)
{ {
const auto op_num = request_for_session.request->getOpNum(); const auto op_num = request_for_session.request->getOpNum();
if (op_num == Coordination::OpNum::SessionID || op_num == Coordination::OpNum::Reconfig) if (op_num == Coordination::OpNum::SessionID || op_num == Coordination::OpNum::Reconfig)
@ -317,10 +340,11 @@ bool KeeperStateMachine::preprocess(const KeeperStorage::RequestForSession & req
return true; return true;
} }
void KeeperStateMachine::reconfigure(const KeeperStorage::RequestForSession& request_for_session) template<typename Storage>
void KeeperStateMachine<Storage>::reconfigure(const KeeperStorageBase::RequestForSession& request_for_session)
{ {
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
KeeperStorage::ResponseForSession response = processReconfiguration(request_for_session); KeeperStorageBase::ResponseForSession response = processReconfiguration(request_for_session);
if (!responses_queue.push(response)) if (!responses_queue.push(response))
{ {
ProfileEvents::increment(ProfileEvents::KeeperCommitsFailed); ProfileEvents::increment(ProfileEvents::KeeperCommitsFailed);
@ -330,8 +354,9 @@ void KeeperStateMachine::reconfigure(const KeeperStorage::RequestForSession& req
} }
} }
KeeperStorage::ResponseForSession KeeperStateMachine::processReconfiguration( template<typename Storage>
const KeeperStorage::RequestForSession & request_for_session) KeeperStorageBase::ResponseForSession KeeperStateMachine<Storage>::processReconfiguration(
const KeeperStorageBase::RequestForSession & request_for_session)
{ {
ProfileEvents::increment(ProfileEvents::KeeperReconfigRequest); ProfileEvents::increment(ProfileEvents::KeeperReconfigRequest);
@ -340,7 +365,7 @@ KeeperStorage::ResponseForSession KeeperStateMachine::processReconfiguration(
const int64_t zxid = request_for_session.zxid; const int64_t zxid = request_for_session.zxid;
using enum Coordination::Error; using enum Coordination::Error;
auto bad_request = [&](Coordination::Error code = ZBADARGUMENTS) -> KeeperStorage::ResponseForSession auto bad_request = [&](Coordination::Error code = ZBADARGUMENTS) -> KeeperStorageBase::ResponseForSession
{ {
auto res = std::make_shared<Coordination::ZooKeeperReconfigResponse>(); auto res = std::make_shared<Coordination::ZooKeeperReconfigResponse>();
res->xid = request.xid; res->xid = request.xid;
@ -397,7 +422,8 @@ KeeperStorage::ResponseForSession KeeperStateMachine::processReconfiguration(
return { session_id, std::move(response) }; return { session_id, std::move(response) };
} }
nuraft::ptr<nuraft::buffer> KeeperStateMachine::commit(const uint64_t log_idx, nuraft::buffer & data) template<typename Storage>
nuraft::ptr<nuraft::buffer> KeeperStateMachine<Storage>::commit(const uint64_t log_idx, nuraft::buffer & data)
{ {
auto request_for_session = parseRequest(data, true); auto request_for_session = parseRequest(data, true);
if (!request_for_session->zxid) if (!request_for_session->zxid)
@ -408,7 +434,7 @@ nuraft::ptr<nuraft::buffer> KeeperStateMachine::commit(const uint64_t log_idx, n
if (!keeper_context->localLogsPreprocessed() && !preprocess(*request_for_session)) if (!keeper_context->localLogsPreprocessed() && !preprocess(*request_for_session))
return nullptr; return nullptr;
auto try_push = [&](const KeeperStorage::ResponseForSession & response) auto try_push = [&](const KeeperStorageBase::ResponseForSession & response)
{ {
if (!responses_queue.push(response)) if (!responses_queue.push(response))
{ {
@ -430,7 +456,7 @@ nuraft::ptr<nuraft::buffer> KeeperStateMachine::commit(const uint64_t log_idx, n
std::shared_ptr<Coordination::ZooKeeperSessionIDResponse> response = std::make_shared<Coordination::ZooKeeperSessionIDResponse>(); std::shared_ptr<Coordination::ZooKeeperSessionIDResponse> response = std::make_shared<Coordination::ZooKeeperSessionIDResponse>();
response->internal_id = session_id_request.internal_id; response->internal_id = session_id_request.internal_id;
response->server_id = session_id_request.server_id; response->server_id = session_id_request.server_id;
KeeperStorage::ResponseForSession response_for_session; KeeperStorageBase::ResponseForSession response_for_session;
response_for_session.session_id = -1; response_for_session.session_id = -1;
response_for_session.response = response; response_for_session.response = response;
response_for_session.request = request_for_session->request; response_for_session.request = request_for_session->request;
@ -451,7 +477,7 @@ nuraft::ptr<nuraft::buffer> KeeperStateMachine::commit(const uint64_t log_idx, n
} }
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
KeeperStorage::ResponsesForSessions responses_for_sessions KeeperStorageBase::ResponsesForSessions responses_for_sessions
= storage->processRequest(request_for_session->request, request_for_session->session_id, request_for_session->zxid); = storage->processRequest(request_for_session->request, request_for_session->session_id, request_for_session->zxid);
for (auto & response_for_session : responses_for_sessions) for (auto & response_for_session : responses_for_sessions)
@ -482,7 +508,8 @@ nuraft::ptr<nuraft::buffer> KeeperStateMachine::commit(const uint64_t log_idx, n
return nullptr; return nullptr;
} }
bool KeeperStateMachine::apply_snapshot(nuraft::snapshot & s) template<typename Storage>
bool KeeperStateMachine<Storage>::apply_snapshot(nuraft::snapshot & s)
{ {
LOG_DEBUG(log, "Applying snapshot {}", s.get_last_log_idx()); LOG_DEBUG(log, "Applying snapshot {}", s.get_last_log_idx());
nuraft::ptr<nuraft::buffer> latest_snapshot_ptr; nuraft::ptr<nuraft::buffer> latest_snapshot_ptr;
@ -509,7 +536,7 @@ bool KeeperStateMachine::apply_snapshot(nuraft::snapshot & s)
{ /// deserialize and apply snapshot to storage { /// deserialize and apply snapshot to storage
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
SnapshotDeserializationResult snapshot_deserialization_result; SnapshotDeserializationResult<Storage> snapshot_deserialization_result;
if (latest_snapshot_ptr) if (latest_snapshot_ptr)
snapshot_deserialization_result = snapshot_manager.deserializeSnapshotFromBuffer(latest_snapshot_ptr); snapshot_deserialization_result = snapshot_manager.deserializeSnapshotFromBuffer(latest_snapshot_ptr);
else else
@ -530,7 +557,7 @@ bool KeeperStateMachine::apply_snapshot(nuraft::snapshot & s)
} }
void KeeperStateMachine::commit_config(const uint64_t log_idx, nuraft::ptr<nuraft::cluster_config> & new_conf) void IKeeperStateMachine::commit_config(const uint64_t log_idx, nuraft::ptr<nuraft::cluster_config> & new_conf)
{ {
std::lock_guard lock(cluster_config_lock); std::lock_guard lock(cluster_config_lock);
auto tmp = new_conf->serialize(); auto tmp = new_conf->serialize();
@ -538,7 +565,7 @@ void KeeperStateMachine::commit_config(const uint64_t log_idx, nuraft::ptr<nuraf
keeper_context->setLastCommitIndex(log_idx); keeper_context->setLastCommitIndex(log_idx);
} }
void KeeperStateMachine::rollback(uint64_t log_idx, nuraft::buffer & data) void IKeeperStateMachine::rollback(uint64_t log_idx, nuraft::buffer & data)
{ {
/// Don't rollback anything until the first commit because nothing was preprocessed /// Don't rollback anything until the first commit because nothing was preprocessed
if (!keeper_context->localLogsPreprocessed()) if (!keeper_context->localLogsPreprocessed())
@ -554,7 +581,8 @@ void KeeperStateMachine::rollback(uint64_t log_idx, nuraft::buffer & data)
rollbackRequest(*request_for_session, false); rollbackRequest(*request_for_session, false);
} }
void KeeperStateMachine::rollbackRequest(const KeeperStorage::RequestForSession & request_for_session, bool allow_missing) template<typename Storage>
void KeeperStateMachine<Storage>::rollbackRequest(const KeeperStorageBase::RequestForSession & request_for_session, bool allow_missing)
{ {
if (request_for_session.request->getOpNum() == Coordination::OpNum::SessionID) if (request_for_session.request->getOpNum() == Coordination::OpNum::SessionID)
return; return;
@ -563,7 +591,8 @@ void KeeperStateMachine::rollbackRequest(const KeeperStorage::RequestForSession
storage->rollbackRequest(request_for_session.zxid, allow_missing); storage->rollbackRequest(request_for_session.zxid, allow_missing);
} }
void KeeperStateMachine::rollbackRequestNoLock(const KeeperStorage::RequestForSession & request_for_session, bool allow_missing) template<typename Storage>
void KeeperStateMachine<Storage>::rollbackRequestNoLock(const KeeperStorageBase::RequestForSession & request_for_session, bool allow_missing)
{ {
if (request_for_session.request->getOpNum() == Coordination::OpNum::SessionID) if (request_for_session.request->getOpNum() == Coordination::OpNum::SessionID)
return; return;
@ -571,14 +600,15 @@ void KeeperStateMachine::rollbackRequestNoLock(const KeeperStorage::RequestForSe
storage->rollbackRequest(request_for_session.zxid, allow_missing); storage->rollbackRequest(request_for_session.zxid, allow_missing);
} }
nuraft::ptr<nuraft::snapshot> KeeperStateMachine::last_snapshot() nuraft::ptr<nuraft::snapshot> IKeeperStateMachine::last_snapshot()
{ {
/// Just return the latest snapshot. /// Just return the latest snapshot.
std::lock_guard lock(snapshots_lock); std::lock_guard lock(snapshots_lock);
return latest_snapshot_meta; return latest_snapshot_meta;
} }
void KeeperStateMachine::create_snapshot(nuraft::snapshot & s, nuraft::async_result<bool>::handler_type & when_done) template<typename Storage>
void KeeperStateMachine<Storage>::create_snapshot(nuraft::snapshot & s, nuraft::async_result<bool>::handler_type & when_done)
{ {
LOG_DEBUG(log, "Creating snapshot {}", s.get_last_log_idx()); LOG_DEBUG(log, "Creating snapshot {}", s.get_last_log_idx());
@ -587,14 +617,15 @@ void KeeperStateMachine::create_snapshot(nuraft::snapshot & s, nuraft::async_res
CreateSnapshotTask snapshot_task; CreateSnapshotTask snapshot_task;
{ /// lock storage for a short period time to turn on "snapshot mode". After that we can read consistent storage state without locking. { /// lock storage for a short period time to turn on "snapshot mode". After that we can read consistent storage state without locking.
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
snapshot_task.snapshot = std::make_shared<KeeperStorageSnapshot>(storage.get(), snapshot_meta_copy, getClusterConfig()); snapshot_task.snapshot = std::make_shared<KeeperStorageSnapshot<Storage>>(storage.get(), snapshot_meta_copy, getClusterConfig());
} }
/// create snapshot task for background execution (in snapshot thread) /// create snapshot task for background execution (in snapshot thread)
snapshot_task.create_snapshot = [this, when_done](KeeperStorageSnapshotPtr && snapshot, bool execute_only_cleanup) snapshot_task.create_snapshot = [this, when_done](KeeperStorageSnapshotPtr && snapshot_, bool execute_only_cleanup)
{ {
nuraft::ptr<std::exception> exception(nullptr); nuraft::ptr<std::exception> exception(nullptr);
bool ret = false; bool ret = false;
auto && snapshot = std::get<std::shared_ptr<KeeperStorageSnapshot<Storage>>>(std::move(snapshot_));
if (!execute_only_cleanup) if (!execute_only_cleanup)
{ {
try try
@ -683,7 +714,8 @@ void KeeperStateMachine::create_snapshot(nuraft::snapshot & s, nuraft::async_res
LOG_WARNING(log, "Cannot push snapshot task into queue"); LOG_WARNING(log, "Cannot push snapshot task into queue");
} }
void KeeperStateMachine::save_logical_snp_obj( template<typename Storage>
void KeeperStateMachine<Storage>::save_logical_snp_obj(
nuraft::snapshot & s, uint64_t & obj_id, nuraft::buffer & data, bool /*is_first_obj*/, bool /*is_last_obj*/) nuraft::snapshot & s, uint64_t & obj_id, nuraft::buffer & data, bool /*is_first_obj*/, bool /*is_last_obj*/)
{ {
LOG_DEBUG(log, "Saving snapshot {} obj_id {}", s.get_last_log_idx(), obj_id); LOG_DEBUG(log, "Saving snapshot {} obj_id {}", s.get_last_log_idx(), obj_id);
@ -748,7 +780,7 @@ static int bufferFromFile(LoggerPtr log, const std::string & path, nuraft::ptr<n
return 0; return 0;
} }
int KeeperStateMachine::read_logical_snp_obj( int IKeeperStateMachine::read_logical_snp_obj(
nuraft::snapshot & s, void *& /*user_snp_ctx*/, uint64_t obj_id, nuraft::ptr<nuraft::buffer> & data_out, bool & is_last_obj) nuraft::snapshot & s, void *& /*user_snp_ctx*/, uint64_t obj_id, nuraft::ptr<nuraft::buffer> & data_out, bool & is_last_obj)
{ {
LOG_DEBUG(log, "Reading snapshot {} obj_id {}", s.get_last_log_idx(), obj_id); LOG_DEBUG(log, "Reading snapshot {} obj_id {}", s.get_last_log_idx(), obj_id);
@ -788,7 +820,8 @@ int KeeperStateMachine::read_logical_snp_obj(
return 1; return 1;
} }
void KeeperStateMachine::processReadRequest(const KeeperStorage::RequestForSession & request_for_session) template<typename Storage>
void KeeperStateMachine<Storage>::processReadRequest(const KeeperStorageBase::RequestForSession & request_for_session)
{ {
/// Pure local request, just process it with storage /// Pure local request, just process it with storage
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
@ -804,103 +837,120 @@ void KeeperStateMachine::processReadRequest(const KeeperStorage::RequestForSessi
} }
} }
void KeeperStateMachine::shutdownStorage() template<typename Storage>
void KeeperStateMachine<Storage>::shutdownStorage()
{ {
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
storage->finalize(); storage->finalize();
} }
std::vector<int64_t> KeeperStateMachine::getDeadSessions() template<typename Storage>
std::vector<int64_t> KeeperStateMachine<Storage>::getDeadSessions()
{ {
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
return storage->getDeadSessions(); return storage->getDeadSessions();
} }
int64_t KeeperStateMachine::getNextZxid() const template<typename Storage>
int64_t KeeperStateMachine<Storage>::getNextZxid() const
{ {
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
return storage->getNextZXID(); return storage->getNextZXID();
} }
KeeperStorage::Digest KeeperStateMachine::getNodesDigest() const template<typename Storage>
KeeperStorageBase::Digest KeeperStateMachine<Storage>::getNodesDigest() const
{ {
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
return storage->getNodesDigest(false); return storage->getNodesDigest(false);
} }
uint64_t KeeperStateMachine::getLastProcessedZxid() const template<typename Storage>
uint64_t KeeperStateMachine<Storage>::getLastProcessedZxid() const
{ {
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
return storage->getZXID(); return storage->getZXID();
} }
uint64_t KeeperStateMachine::getNodesCount() const template<typename Storage>
uint64_t KeeperStateMachine<Storage>::getNodesCount() const
{ {
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
return storage->getNodesCount(); return storage->getNodesCount();
} }
uint64_t KeeperStateMachine::getTotalWatchesCount() const template<typename Storage>
uint64_t KeeperStateMachine<Storage>::getTotalWatchesCount() const
{ {
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
return storage->getTotalWatchesCount(); return storage->getTotalWatchesCount();
} }
uint64_t KeeperStateMachine::getWatchedPathsCount() const template<typename Storage>
uint64_t KeeperStateMachine<Storage>::getWatchedPathsCount() const
{ {
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
return storage->getWatchedPathsCount(); return storage->getWatchedPathsCount();
} }
uint64_t KeeperStateMachine::getSessionsWithWatchesCount() const template<typename Storage>
uint64_t KeeperStateMachine<Storage>::getSessionsWithWatchesCount() const
{ {
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
return storage->getSessionsWithWatchesCount(); return storage->getSessionsWithWatchesCount();
} }
uint64_t KeeperStateMachine::getTotalEphemeralNodesCount() const template<typename Storage>
uint64_t KeeperStateMachine<Storage>::getTotalEphemeralNodesCount() const
{ {
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
return storage->getTotalEphemeralNodesCount(); return storage->getTotalEphemeralNodesCount();
} }
uint64_t KeeperStateMachine::getSessionWithEphemeralNodesCount() const template<typename Storage>
uint64_t KeeperStateMachine<Storage>::getSessionWithEphemeralNodesCount() const
{ {
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
return storage->getSessionWithEphemeralNodesCount(); return storage->getSessionWithEphemeralNodesCount();
} }
void KeeperStateMachine::dumpWatches(WriteBufferFromOwnString & buf) const template<typename Storage>
void KeeperStateMachine<Storage>::dumpWatches(WriteBufferFromOwnString & buf) const
{ {
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
storage->dumpWatches(buf); storage->dumpWatches(buf);
} }
void KeeperStateMachine::dumpWatchesByPath(WriteBufferFromOwnString & buf) const template<typename Storage>
void KeeperStateMachine<Storage>::dumpWatchesByPath(WriteBufferFromOwnString & buf) const
{ {
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
storage->dumpWatchesByPath(buf); storage->dumpWatchesByPath(buf);
} }
void KeeperStateMachine::dumpSessionsAndEphemerals(WriteBufferFromOwnString & buf) const template<typename Storage>
void KeeperStateMachine<Storage>::dumpSessionsAndEphemerals(WriteBufferFromOwnString & buf) const
{ {
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
storage->dumpSessionsAndEphemerals(buf); storage->dumpSessionsAndEphemerals(buf);
} }
uint64_t KeeperStateMachine::getApproximateDataSize() const template<typename Storage>
uint64_t KeeperStateMachine<Storage>::getApproximateDataSize() const
{ {
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
return storage->getApproximateDataSize(); return storage->getApproximateDataSize();
} }
uint64_t KeeperStateMachine::getKeyArenaSize() const template<typename Storage>
uint64_t KeeperStateMachine<Storage>::getKeyArenaSize() const
{ {
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
return storage->getArenaDataSize(); return storage->getArenaDataSize();
} }
uint64_t KeeperStateMachine::getLatestSnapshotSize() const template<typename Storage>
uint64_t KeeperStateMachine<Storage>::getLatestSnapshotSize() const
{ {
auto snapshot_info = [&] auto snapshot_info = [&]
{ {
@ -923,7 +973,7 @@ uint64_t KeeperStateMachine::getLatestSnapshotSize() const
return size; return size;
} }
ClusterConfigPtr KeeperStateMachine::getClusterConfig() const ClusterConfigPtr IKeeperStateMachine::getClusterConfig() const
{ {
std::lock_guard lock(cluster_config_lock); std::lock_guard lock(cluster_config_lock);
if (cluster_config) if (cluster_config)
@ -935,11 +985,18 @@ ClusterConfigPtr KeeperStateMachine::getClusterConfig() const
return nullptr; return nullptr;
} }
void KeeperStateMachine::recalculateStorageStats() template<typename Storage>
void KeeperStateMachine<Storage>::recalculateStorageStats()
{ {
LockGuardWithStats lock(storage_and_responses_lock); LockGuardWithStats lock(storage_and_responses_lock);
LOG_INFO(log, "Recalculating storage stats"); LOG_INFO(log, "Recalculating storage stats");
storage->recalculateStats(); storage->recalculateStats();
LOG_INFO(log, "Done recalculating storage stats"); LOG_INFO(log, "Done recalculating storage stats");
} }
template class KeeperStateMachine<KeeperMemoryStorage>;
#if USE_ROCKSDB
template class KeeperStateMachine<KeeperRocksStorage>;
#endif
} }

View File

@ -11,26 +11,24 @@
namespace DB namespace DB
{ {
using ResponsesQueue = ConcurrentBoundedQueue<KeeperStorage::ResponseForSession>; using ResponsesQueue = ConcurrentBoundedQueue<KeeperStorageBase::ResponseForSession>;
using SnapshotsQueue = ConcurrentBoundedQueue<CreateSnapshotTask>; using SnapshotsQueue = ConcurrentBoundedQueue<CreateSnapshotTask>;
/// ClickHouse Keeper state machine. Wrapper for KeeperStorage. class IKeeperStateMachine : public nuraft::state_machine
/// Responsible for entries commit, snapshots creation and so on.
class KeeperStateMachine : public nuraft::state_machine
{ {
public: public:
using CommitCallback = std::function<void(uint64_t, const KeeperStorage::RequestForSession &)>; using CommitCallback = std::function<void(uint64_t, const KeeperStorageBase::RequestForSession &)>;
KeeperStateMachine( IKeeperStateMachine(
ResponsesQueue & responses_queue_, ResponsesQueue & responses_queue_,
SnapshotsQueue & snapshots_queue_, SnapshotsQueue & snapshots_queue_,
const KeeperContextPtr & keeper_context_, const KeeperContextPtr & keeper_context_,
KeeperSnapshotManagerS3 * snapshot_manager_s3_, KeeperSnapshotManagerS3 * snapshot_manager_s3_,
CommitCallback commit_callback_ = {}, CommitCallback commit_callback_,
const std::string & superdigest_ = ""); const std::string & superdigest_);
/// Read state from the latest snapshot /// Read state from the latest snapshot
void init(); virtual void init() = 0;
enum ZooKeeperLogSerializationVersion enum ZooKeeperLogSerializationVersion
{ {
@ -47,89 +45,66 @@ public:
/// ///
/// final - whether it's the final time we will fetch the request so we can safely remove it from cache /// final - whether it's the final time we will fetch the request so we can safely remove it from cache
/// serialization_version - information about which fields were parsed from the buffer so we can modify the buffer accordingly /// serialization_version - information about which fields were parsed from the buffer so we can modify the buffer accordingly
std::shared_ptr<KeeperStorage::RequestForSession> parseRequest(nuraft::buffer & data, bool final, ZooKeeperLogSerializationVersion * serialization_version = nullptr); std::shared_ptr<KeeperStorageBase::RequestForSession> parseRequest(nuraft::buffer & data, bool final, ZooKeeperLogSerializationVersion * serialization_version = nullptr);
bool preprocess(const KeeperStorage::RequestForSession & request_for_session); virtual bool preprocess(const KeeperStorageBase::RequestForSession & request_for_session) = 0;
nuraft::ptr<nuraft::buffer> pre_commit(uint64_t log_idx, nuraft::buffer & data) override;
nuraft::ptr<nuraft::buffer> commit(const uint64_t log_idx, nuraft::buffer & data) override; /// NOLINT
/// Save new cluster config to our snapshot (copy of the config stored in StateManager)
void commit_config(const uint64_t log_idx, nuraft::ptr<nuraft::cluster_config> & new_conf) override; /// NOLINT void commit_config(const uint64_t log_idx, nuraft::ptr<nuraft::cluster_config> & new_conf) override; /// NOLINT
void rollback(uint64_t log_idx, nuraft::buffer & data) override; void rollback(uint64_t log_idx, nuraft::buffer & data) override;
// allow_missing - whether the transaction we want to rollback can be missing from storage // allow_missing - whether the transaction we want to rollback can be missing from storage
// (can happen in case of exception during preprocessing) // (can happen in case of exception during preprocessing)
void rollbackRequest(const KeeperStorage::RequestForSession & request_for_session, bool allow_missing); virtual void rollbackRequest(const KeeperStorageBase::RequestForSession & request_for_session, bool allow_missing) = 0;
void rollbackRequestNoLock(
const KeeperStorage::RequestForSession & request_for_session,
bool allow_missing) TSA_NO_THREAD_SAFETY_ANALYSIS;
uint64_t last_commit_index() override { return keeper_context->lastCommittedIndex(); } uint64_t last_commit_index() override { return keeper_context->lastCommittedIndex(); }
/// Apply preliminarily saved (save_logical_snp_obj) snapshot to our state.
bool apply_snapshot(nuraft::snapshot & s) override;
nuraft::ptr<nuraft::snapshot> last_snapshot() override; nuraft::ptr<nuraft::snapshot> last_snapshot() override;
/// Create new snapshot from current state. /// Create new snapshot from current state.
void create_snapshot(nuraft::snapshot & s, nuraft::async_result<bool>::handler_type & when_done) override; void create_snapshot(nuraft::snapshot & s, nuraft::async_result<bool>::handler_type & when_done) override = 0;
/// Save snapshot which was send by leader to us. After that we will apply it in apply_snapshot. /// Save snapshot which was send by leader to us. After that we will apply it in apply_snapshot.
void save_logical_snp_obj(nuraft::snapshot & s, uint64_t & obj_id, nuraft::buffer & data, bool is_first_obj, bool is_last_obj) override; void save_logical_snp_obj(nuraft::snapshot & s, uint64_t & obj_id, nuraft::buffer & data, bool is_first_obj, bool is_last_obj) override = 0;
/// Better name is `serialize snapshot` -- save existing snapshot (created by create_snapshot) into
/// in-memory buffer data_out.
int read_logical_snp_obj( int read_logical_snp_obj(
nuraft::snapshot & s, void *& user_snp_ctx, uint64_t obj_id, nuraft::ptr<nuraft::buffer> & data_out, bool & is_last_obj) override; nuraft::snapshot & s, void *& user_snp_ctx, uint64_t obj_id, nuraft::ptr<nuraft::buffer> & data_out, bool & is_last_obj) override;
// This should be used only for tests or keeper-data-dumper because it violates virtual void shutdownStorage() = 0;
// TSA -- we can't acquire the lock outside of this class or return a storage under lock
// in a reasonable way.
KeeperStorage & getStorageUnsafe() TSA_NO_THREAD_SAFETY_ANALYSIS
{
return *storage;
}
void shutdownStorage();
ClusterConfigPtr getClusterConfig() const; ClusterConfigPtr getClusterConfig() const;
/// Process local read request virtual void processReadRequest(const KeeperStorageBase::RequestForSession & request_for_session) = 0;
void processReadRequest(const KeeperStorage::RequestForSession & request_for_session);
std::vector<int64_t> getDeadSessions(); virtual std::vector<int64_t> getDeadSessions() = 0;
int64_t getNextZxid() const; virtual int64_t getNextZxid() const = 0;
KeeperStorage::Digest getNodesDigest() const; virtual KeeperStorageBase::Digest getNodesDigest() const = 0;
/// Introspection functions for 4lw commands /// Introspection functions for 4lw commands
uint64_t getLastProcessedZxid() const; virtual uint64_t getLastProcessedZxid() const = 0;
uint64_t getNodesCount() const; virtual uint64_t getNodesCount() const = 0;
uint64_t getTotalWatchesCount() const; virtual uint64_t getTotalWatchesCount() const = 0;
uint64_t getWatchedPathsCount() const; virtual uint64_t getWatchedPathsCount() const = 0;
uint64_t getSessionsWithWatchesCount() const; virtual uint64_t getSessionsWithWatchesCount() const = 0;
void dumpWatches(WriteBufferFromOwnString & buf) const; virtual void dumpWatches(WriteBufferFromOwnString & buf) const = 0;
void dumpWatchesByPath(WriteBufferFromOwnString & buf) const; virtual void dumpWatchesByPath(WriteBufferFromOwnString & buf) const = 0;
void dumpSessionsAndEphemerals(WriteBufferFromOwnString & buf) const; virtual void dumpSessionsAndEphemerals(WriteBufferFromOwnString & buf) const = 0;
uint64_t getSessionWithEphemeralNodesCount() const; virtual uint64_t getSessionWithEphemeralNodesCount() const = 0;
uint64_t getTotalEphemeralNodesCount() const; virtual uint64_t getTotalEphemeralNodesCount() const = 0;
uint64_t getApproximateDataSize() const; virtual uint64_t getApproximateDataSize() const = 0;
uint64_t getKeyArenaSize() const; virtual uint64_t getKeyArenaSize() const = 0;
uint64_t getLatestSnapshotSize() const; virtual uint64_t getLatestSnapshotSize() const = 0;
void recalculateStorageStats(); virtual void recalculateStorageStats() = 0;
void reconfigure(const KeeperStorage::RequestForSession& request_for_session); virtual void reconfigure(const KeeperStorageBase::RequestForSession& request_for_session) = 0;
private: protected:
CommitCallback commit_callback; CommitCallback commit_callback;
/// In our state machine we always have a single snapshot which is stored /// In our state machine we always have a single snapshot which is stored
/// in memory in compressed (serialized) format. /// in memory in compressed (serialized) format.
@ -137,12 +112,9 @@ private:
std::shared_ptr<SnapshotFileInfo> latest_snapshot_info; std::shared_ptr<SnapshotFileInfo> latest_snapshot_info;
nuraft::ptr<nuraft::buffer> latest_snapshot_buf = nullptr; nuraft::ptr<nuraft::buffer> latest_snapshot_buf = nullptr;
/// Main state machine logic CoordinationSettingsPtr coordination_settings;
KeeperStoragePtr storage TSA_PT_GUARDED_BY(storage_and_responses_lock);
/// Save/Load and Serialize/Deserialize logic for snapshots. /// Save/Load and Serialize/Deserialize logic for snapshots.
KeeperSnapshotManager snapshot_manager;
/// Put processed responses into this queue /// Put processed responses into this queue
ResponsesQueue & responses_queue; ResponsesQueue & responses_queue;
@ -159,7 +131,7 @@ private:
/// for request. /// for request.
mutable std::mutex storage_and_responses_lock; mutable std::mutex storage_and_responses_lock;
std::unordered_map<int64_t, std::unordered_map<Coordination::XID, std::shared_ptr<KeeperStorage::RequestForSession>>> parsed_request_cache; std::unordered_map<int64_t, std::unordered_map<Coordination::XID, std::shared_ptr<KeeperStorageBase::RequestForSession>>> parsed_request_cache;
uint64_t min_request_size_to_cache{0}; uint64_t min_request_size_to_cache{0};
/// we only need to protect the access to the map itself /// we only need to protect the access to the map itself
/// requests can be modified from anywhere without lock because a single request /// requests can be modified from anywhere without lock because a single request
@ -181,7 +153,104 @@ private:
KeeperSnapshotManagerS3 * snapshot_manager_s3; KeeperSnapshotManagerS3 * snapshot_manager_s3;
KeeperStorage::ResponseForSession processReconfiguration(const KeeperStorage::RequestForSession & request_for_session) virtual KeeperStorageBase::ResponseForSession processReconfiguration(
TSA_REQUIRES(storage_and_responses_lock); const KeeperStorageBase::RequestForSession& request_for_session)
TSA_REQUIRES(storage_and_responses_lock) = 0;
};
/// ClickHouse Keeper state machine. Wrapper for KeeperStorage.
/// Responsible for entries commit, snapshots creation and so on.
template<typename Storage>
class KeeperStateMachine : public IKeeperStateMachine
{
public:
/// using CommitCallback = std::function<void(uint64_t, const KeeperStorage::RequestForSession &)>;
KeeperStateMachine(
ResponsesQueue & responses_queue_,
SnapshotsQueue & snapshots_queue_,
/// const CoordinationSettingsPtr & coordination_settings_,
const KeeperContextPtr & keeper_context_,
KeeperSnapshotManagerS3 * snapshot_manager_s3_,
CommitCallback commit_callback_ = {},
const std::string & superdigest_ = "");
/// Read state from the latest snapshot
void init() override;
bool preprocess(const KeeperStorageBase::RequestForSession & request_for_session) override;
nuraft::ptr<nuraft::buffer> pre_commit(uint64_t log_idx, nuraft::buffer & data) override;
nuraft::ptr<nuraft::buffer> commit(const uint64_t log_idx, nuraft::buffer & data) override; /// NOLINT
// allow_missing - whether the transaction we want to rollback can be missing from storage
// (can happen in case of exception during preprocessing)
void rollbackRequest(const KeeperStorageBase::RequestForSession & request_for_session, bool allow_missing) override;
void rollbackRequestNoLock(
const KeeperStorageBase::RequestForSession & request_for_session,
bool allow_missing) TSA_NO_THREAD_SAFETY_ANALYSIS;
/// Apply preliminarily saved (save_logical_snp_obj) snapshot to our state.
bool apply_snapshot(nuraft::snapshot & s) override;
/// Create new snapshot from current state.
void create_snapshot(nuraft::snapshot & s, nuraft::async_result<bool>::handler_type & when_done) override;
/// Save snapshot which was send by leader to us. After that we will apply it in apply_snapshot.
void save_logical_snp_obj(nuraft::snapshot & s, uint64_t & obj_id, nuraft::buffer & data, bool is_first_obj, bool is_last_obj) override;
// This should be used only for tests or keeper-data-dumper because it violates
// TSA -- we can't acquire the lock outside of this class or return a storage under lock
// in a reasonable way.
Storage & getStorageUnsafe() TSA_NO_THREAD_SAFETY_ANALYSIS
{
return *storage;
}
void shutdownStorage() override;
/// Process local read request
void processReadRequest(const KeeperStorageBase::RequestForSession & request_for_session) override;
std::vector<int64_t> getDeadSessions() override;
int64_t getNextZxid() const override;
KeeperStorageBase::Digest getNodesDigest() const override;
/// Introspection functions for 4lw commands
uint64_t getLastProcessedZxid() const override;
uint64_t getNodesCount() const override;
uint64_t getTotalWatchesCount() const override;
uint64_t getWatchedPathsCount() const override;
uint64_t getSessionsWithWatchesCount() const override;
void dumpWatches(WriteBufferFromOwnString & buf) const override;
void dumpWatchesByPath(WriteBufferFromOwnString & buf) const override;
void dumpSessionsAndEphemerals(WriteBufferFromOwnString & buf) const override;
uint64_t getSessionWithEphemeralNodesCount() const override;
uint64_t getTotalEphemeralNodesCount() const override;
uint64_t getApproximateDataSize() const override;
uint64_t getKeyArenaSize() const override;
uint64_t getLatestSnapshotSize() const override;
void recalculateStorageStats() override;
void reconfigure(const KeeperStorageBase::RequestForSession& request_for_session) override;
private:
/// Main state machine logic
std::unique_ptr<Storage> storage; //TSA_PT_GUARDED_BY(storage_and_responses_lock);
/// Save/Load and Serialize/Deserialize logic for snapshots.
KeeperSnapshotManager<Storage> snapshot_manager;
KeeperStorageBase::ResponseForSession processReconfiguration(const KeeperStorageBase::RequestForSession & request_for_session)
TSA_REQUIRES(storage_and_responses_lock) override;
}; };
} }

File diff suppressed because it is too large Load Diff

View File

@ -8,188 +8,384 @@
#include <absl/container/flat_hash_set.h> #include <absl/container/flat_hash_set.h>
#include "config.h"
#if USE_ROCKSDB
#include <Coordination/RocksDBContainer.h>
#endif
namespace DB namespace DB
{ {
class KeeperContext; class KeeperContext;
using KeeperContextPtr = std::shared_ptr<KeeperContext>; using KeeperContextPtr = std::shared_ptr<KeeperContext>;
struct KeeperStorageRequestProcessor;
using KeeperStorageRequestProcessorPtr = std::shared_ptr<KeeperStorageRequestProcessor>;
using ResponseCallback = std::function<void(const Coordination::ZooKeeperResponsePtr &)>; using ResponseCallback = std::function<void(const Coordination::ZooKeeperResponsePtr &)>;
using ChildrenSet = absl::flat_hash_set<StringRef, StringRefHash>; using ChildrenSet = absl::flat_hash_set<StringRef, StringRefHash>;
using SessionAndTimeout = std::unordered_map<int64_t, int64_t>; using SessionAndTimeout = std::unordered_map<int64_t, int64_t>;
struct KeeperStorageSnapshot; /// KeeperRocksNodeInfo is used in RocksDB keeper.
/// It is serialized directly as POD to RocksDB.
/// Keeper state machine almost equal to the ZooKeeper's state machine. struct KeeperRocksNodeInfo
/// Implements all logic of operations, data changes, sessions allocation.
/// In-memory and not thread safe.
class KeeperStorage
{ {
public: int64_t czxid{0};
/// Node should have as minimal size as possible to reduce memory footprint int64_t mzxid{0};
/// of stored nodes int64_t pzxid{0};
/// New fields should be added to the struct only if it's really necessary uint64_t acl_id = 0; /// 0 -- no ACL by default
struct Node
int64_t mtime{0};
int32_t version{0};
int32_t cversion{0};
int32_t aversion{0};
int32_t seq_num = 0;
mutable UInt64 digest = 0; /// we cached digest for this node.
/// as ctime can't be negative because it stores the timestamp when the
/// node was created, we can use the MSB for a bool
struct
{ {
int64_t czxid{0}; bool is_ephemeral : 1;
int64_t mzxid{0}; int64_t ctime : 63;
int64_t pzxid{0}; } is_ephemeral_and_ctime{false, 0};
uint64_t acl_id = 0; /// 0 -- no ACL by default
int64_t mtime{0}; /// ephemeral notes cannot have children so a node can set either
/// ephemeral_owner OR seq_num + num_children
std::unique_ptr<char[]> data{nullptr}; union
uint32_t data_size{0}; {
int64_t ephemeral_owner;
int32_t version{0};
int32_t cversion{0};
int32_t aversion{0};
mutable uint64_t cached_digest = 0;
Node() = default;
Node & operator=(const Node & other);
Node(const Node & other);
Node & operator=(Node && other) noexcept;
Node(Node && other) noexcept;
bool empty() const;
bool isEphemeral() const
{
return is_ephemeral_and_ctime.is_ephemeral;
}
int64_t ephemeralOwner() const
{
if (isEphemeral())
return ephemeral_or_children_data.ephemeral_owner;
return 0;
}
void setEphemeralOwner(int64_t ephemeral_owner)
{
is_ephemeral_and_ctime.is_ephemeral = ephemeral_owner != 0;
ephemeral_or_children_data.ephemeral_owner = ephemeral_owner;
}
int32_t numChildren() const
{
if (isEphemeral())
return 0;
return ephemeral_or_children_data.children_info.num_children;
}
void setNumChildren(int32_t num_children)
{
ephemeral_or_children_data.children_info.num_children = num_children;
}
void increaseNumChildren()
{
chassert(!isEphemeral());
++ephemeral_or_children_data.children_info.num_children;
}
void decreaseNumChildren()
{
chassert(!isEphemeral());
--ephemeral_or_children_data.children_info.num_children;
}
int32_t seqNum() const
{
if (isEphemeral())
return 0;
return ephemeral_or_children_data.children_info.seq_num;
}
void setSeqNum(int32_t seq_num)
{
ephemeral_or_children_data.children_info.seq_num = seq_num;
}
void increaseSeqNum()
{
chassert(!isEphemeral());
++ephemeral_or_children_data.children_info.seq_num;
}
int64_t ctime() const
{
return is_ephemeral_and_ctime.ctime;
}
void setCtime(uint64_t ctime)
{
is_ephemeral_and_ctime.ctime = ctime;
}
void copyStats(const Coordination::Stat & stat);
void setResponseStat(Coordination::Stat & response_stat) const;
/// Object memory size
uint64_t sizeInBytes() const;
void setData(const String & new_data);
std::string_view getData() const noexcept { return {data.get(), data_size}; }
void addChild(StringRef child_path);
void removeChild(StringRef child_path);
const auto & getChildren() const noexcept { return children; }
auto & getChildren() { return children; }
// Invalidate the calculated digest so it's recalculated again on the next
// getDigest call
void invalidateDigestCache() const;
// get the calculated digest of the node
UInt64 getDigest(std::string_view path) const;
// copy only necessary information for preprocessing and digest calculation
// (e.g. we don't need to copy list of children)
void shallowCopy(const Node & other);
private:
/// as ctime can't be negative because it stores the timestamp when the
/// node was created, we can use the MSB for a bool
struct struct
{ {
bool is_ephemeral : 1; int32_t seq_num;
int64_t ctime : 63; int32_t num_children;
} is_ephemeral_and_ctime{false, 0}; } children_info;
} ephemeral_or_children_data{0};
/// ephemeral notes cannot have children so a node can set either bool isEphemeral() const
/// ephemeral_owner OR seq_num + num_children {
union return is_ephemeral_and_ctime.is_ephemeral;
{ }
int64_t ephemeral_owner;
struct
{
int32_t seq_num;
int32_t num_children;
} children_info;
} ephemeral_or_children_data{0};
ChildrenSet children{}; int64_t ephemeralOwner() const
}; {
if (isEphemeral())
return ephemeral_or_children_data.ephemeral_owner;
#if !defined(ADDRESS_SANITIZER) && !defined(MEMORY_SANITIZER) return 0;
static_assert( }
sizeof(ListNode<Node>) <= 144,
"std::list node containing ListNode<Node> is > 160 bytes (sizeof(ListNode<Node>) + 16 bytes for pointers) which will increase " void setEphemeralOwner(int64_t ephemeral_owner)
"memory consumption"); {
is_ephemeral_and_ctime.is_ephemeral = ephemeral_owner != 0;
ephemeral_or_children_data.ephemeral_owner = ephemeral_owner;
}
int32_t numChildren() const
{
if (isEphemeral())
return 0;
return ephemeral_or_children_data.children_info.num_children;
}
void setNumChildren(int32_t num_children)
{
ephemeral_or_children_data.children_info.num_children = num_children;
}
/// dummy interface for test
void addChild(StringRef) {}
auto getChildren() const
{
return std::vector<int>(numChildren());
}
void increaseNumChildren()
{
chassert(!isEphemeral());
++ephemeral_or_children_data.children_info.num_children;
}
void decreaseNumChildren()
{
chassert(!isEphemeral());
--ephemeral_or_children_data.children_info.num_children;
}
int32_t seqNum() const
{
if (isEphemeral())
return 0;
return ephemeral_or_children_data.children_info.seq_num;
}
void setSeqNum(int32_t seq_num_)
{
ephemeral_or_children_data.children_info.seq_num = seq_num_;
}
void increaseSeqNum()
{
chassert(!isEphemeral());
++ephemeral_or_children_data.children_info.seq_num;
}
int64_t ctime() const
{
return is_ephemeral_and_ctime.ctime;
}
void setCtime(uint64_t ctime)
{
is_ephemeral_and_ctime.ctime = ctime;
}
void copyStats(const Coordination::Stat & stat);
};
/// KeeperRocksNode is the memory structure used by RocksDB
struct KeeperRocksNode : public KeeperRocksNodeInfo
{
#if USE_ROCKSDB
friend struct RocksDBContainer<KeeperRocksNode>;
#endif #endif
using Meta = KeeperRocksNodeInfo;
uint64_t size_bytes = 0; // only for compatible, should be deprecated
uint64_t sizeInBytes() const { return data_size + sizeof(KeeperRocksNodeInfo); }
void setData(String new_data)
{
data_size = static_cast<uint32_t>(new_data.size());
if (data_size != 0)
{
data = std::unique_ptr<char[]>(new char[new_data.size()]);
memcpy(data.get(), new_data.data(), data_size);
}
}
void shallowCopy(const KeeperRocksNode & other)
{
czxid = other.czxid;
mzxid = other.mzxid;
pzxid = other.pzxid;
acl_id = other.acl_id; /// 0 -- no ACL by default
mtime = other.mtime;
is_ephemeral_and_ctime = other.is_ephemeral_and_ctime;
ephemeral_or_children_data = other.ephemeral_or_children_data;
data_size = other.data_size;
if (data_size != 0)
{
data = std::unique_ptr<char[]>(new char[data_size]);
memcpy(data.get(), other.data.get(), data_size);
}
version = other.version;
cversion = other.cversion;
aversion = other.aversion;
/// cached_digest = other.cached_digest;
}
void invalidateDigestCache() const;
UInt64 getDigest(std::string_view path) const;
String getEncodedString();
void decodeFromString(const String & buffer_str);
void recalculateSize() {}
std::string_view getData() const noexcept { return {data.get(), data_size}; }
void setResponseStat(Coordination::Stat & response_stat) const
{
response_stat.czxid = czxid;
response_stat.mzxid = mzxid;
response_stat.ctime = ctime();
response_stat.mtime = mtime;
response_stat.version = version;
response_stat.cversion = cversion;
response_stat.aversion = aversion;
response_stat.ephemeralOwner = ephemeralOwner();
response_stat.dataLength = static_cast<int32_t>(data_size);
response_stat.numChildren = numChildren();
response_stat.pzxid = pzxid;
}
void reset()
{
serialized = false;
}
bool empty() const
{
return data_size == 0 && mzxid == 0;
}
std::unique_ptr<char[]> data{nullptr};
uint32_t data_size{0};
private:
bool serialized = false;
};
/// KeeperMemNode should have as minimal size as possible to reduce memory footprint
/// of stored nodes
/// New fields should be added to the struct only if it's really necessary
struct KeeperMemNode
{
int64_t czxid{0};
int64_t mzxid{0};
int64_t pzxid{0};
uint64_t acl_id = 0; /// 0 -- no ACL by default
int64_t mtime{0};
std::unique_ptr<char[]> data{nullptr};
uint32_t data_size{0};
int32_t version{0};
int32_t cversion{0};
int32_t aversion{0};
mutable uint64_t cached_digest = 0;
KeeperMemNode() = default;
KeeperMemNode & operator=(const KeeperMemNode & other);
KeeperMemNode(const KeeperMemNode & other);
KeeperMemNode & operator=(KeeperMemNode && other) noexcept;
KeeperMemNode(KeeperMemNode && other) noexcept;
bool empty() const;
bool isEphemeral() const
{
return is_ephemeral_and_ctime.is_ephemeral;
}
int64_t ephemeralOwner() const
{
if (isEphemeral())
return ephemeral_or_children_data.ephemeral_owner;
return 0;
}
void setEphemeralOwner(int64_t ephemeral_owner)
{
is_ephemeral_and_ctime.is_ephemeral = ephemeral_owner != 0;
ephemeral_or_children_data.ephemeral_owner = ephemeral_owner;
}
int32_t numChildren() const
{
if (isEphemeral())
return 0;
return ephemeral_or_children_data.children_info.num_children;
}
void setNumChildren(int32_t num_children)
{
ephemeral_or_children_data.children_info.num_children = num_children;
}
void increaseNumChildren()
{
chassert(!isEphemeral());
++ephemeral_or_children_data.children_info.num_children;
}
void decreaseNumChildren()
{
chassert(!isEphemeral());
--ephemeral_or_children_data.children_info.num_children;
}
int32_t seqNum() const
{
if (isEphemeral())
return 0;
return ephemeral_or_children_data.children_info.seq_num;
}
void setSeqNum(int32_t seq_num)
{
ephemeral_or_children_data.children_info.seq_num = seq_num;
}
void increaseSeqNum()
{
chassert(!isEphemeral());
++ephemeral_or_children_data.children_info.seq_num;
}
int64_t ctime() const
{
return is_ephemeral_and_ctime.ctime;
}
void setCtime(uint64_t ctime)
{
is_ephemeral_and_ctime.ctime = ctime;
}
void copyStats(const Coordination::Stat & stat);
void setResponseStat(Coordination::Stat & response_stat) const;
/// Object memory size
uint64_t sizeInBytes() const;
void setData(const String & new_data);
std::string_view getData() const noexcept { return {data.get(), data_size}; }
void addChild(StringRef child_path);
void removeChild(StringRef child_path);
const auto & getChildren() const noexcept { return children; }
auto & getChildren() { return children; }
// Invalidate the calculated digest so it's recalculated again on the next
// getDigest call
void invalidateDigestCache() const;
// get the calculated digest of the node
UInt64 getDigest(std::string_view path) const;
// copy only necessary information for preprocessing and digest calculation
// (e.g. we don't need to copy list of children)
void shallowCopy(const KeeperMemNode & other);
private:
/// as ctime can't be negative because it stores the timestamp when the
/// node was created, we can use the MSB for a bool
struct
{
bool is_ephemeral : 1;
int64_t ctime : 63;
} is_ephemeral_and_ctime{false, 0};
/// ephemeral notes cannot have children so a node can set either
/// ephemeral_owner OR seq_num + num_children
union
{
int64_t ephemeral_owner;
struct
{
int32_t seq_num;
int32_t num_children;
} children_info;
} ephemeral_or_children_data{0};
ChildrenSet children{};
};
class KeeperStorageBase
{
public:
enum DigestVersion : uint8_t enum DigestVersion : uint8_t
{ {
@ -200,7 +396,11 @@ public:
V4 = 4 // 0 is not a valid digest value V4 = 4 // 0 is not a valid digest value
}; };
static constexpr auto CURRENT_DIGEST_VERSION = DigestVersion::V4; struct Digest
{
DigestVersion version{DigestVersion::NO_DIGEST};
uint64_t value{0};
};
struct ResponseForSession struct ResponseForSession
{ {
@ -210,16 +410,6 @@ public:
}; };
using ResponsesForSessions = std::vector<ResponseForSession>; using ResponsesForSessions = std::vector<ResponseForSession>;
struct Digest
{
DigestVersion version{DigestVersion::NO_DIGEST};
uint64_t value{0};
};
static bool checkDigest(const Digest & first, const Digest & second);
static String generateDigest(const String & userdata);
struct RequestForSession struct RequestForSession
{ {
int64_t session_id; int64_t session_id;
@ -229,6 +419,7 @@ public:
std::optional<Digest> digest; std::optional<Digest> digest;
int64_t log_idx{0}; int64_t log_idx{0};
}; };
using RequestsForSessions = std::vector<RequestForSession>;
struct AuthID struct AuthID
{ {
@ -238,9 +429,6 @@ public:
bool operator==(const AuthID & other) const { return scheme == other.scheme && id == other.id; } bool operator==(const AuthID & other) const { return scheme == other.scheme && id == other.id; }
}; };
using RequestsForSessions = std::vector<RequestForSession>;
using Container = SnapshotableHashTable<Node>;
using Ephemerals = std::unordered_map<int64_t, std::unordered_set<std::string>>; using Ephemerals = std::unordered_map<int64_t, std::unordered_set<std::string>>;
using SessionAndWatcher = std::unordered_map<int64_t, std::unordered_set<std::string>>; using SessionAndWatcher = std::unordered_map<int64_t, std::unordered_set<std::string>>;
using SessionIDs = std::unordered_set<int64_t>; using SessionIDs = std::unordered_set<int64_t>;
@ -250,6 +438,38 @@ public:
using SessionAndAuth = std::unordered_map<int64_t, AuthIDs>; using SessionAndAuth = std::unordered_map<int64_t, AuthIDs>;
using Watches = std::unordered_map<String /* path, relative of root_path */, SessionIDs>; using Watches = std::unordered_map<String /* path, relative of root_path */, SessionIDs>;
static bool checkDigest(const Digest & first, const Digest & second);
};
/// Keeper state machine almost equal to the ZooKeeper's state machine.
/// Implements all logic of operations, data changes, sessions allocation.
/// In-memory and not thread safe.
template<typename Container_>
class KeeperStorage : public KeeperStorageBase
{
public:
using Container = Container_;
using Node = Container::Node;
#if !defined(ADDRESS_SANITIZER) && !defined(MEMORY_SANITIZER)
static_assert(
sizeof(ListNode<Node>) <= 144,
"std::list node containing ListNode<Node> is > 160 bytes (sizeof(ListNode<Node>) + 16 bytes for pointers) which will increase "
"memory consumption");
#endif
#if USE_ROCKSDB
static constexpr bool use_rocksdb = std::is_same_v<Container_, RocksDBContainer<KeeperRocksNode>>;
#else
static constexpr bool use_rocksdb = false;
#endif
static constexpr auto CURRENT_DIGEST_VERSION = DigestVersion::V4;
static String generateDigest(const String & userdata);
int64_t session_id_counter{1}; int64_t session_id_counter{1};
SessionAndAuth session_and_auth; SessionAndAuth session_and_auth;
@ -393,7 +613,7 @@ public:
std::unordered_map<std::string, std::list<const Delta *>, Hash, Equal> deltas_for_path; std::unordered_map<std::string, std::list<const Delta *>, Hash, Equal> deltas_for_path;
std::list<Delta> deltas; std::list<Delta> deltas;
KeeperStorage & storage; KeeperStorage<Container> & storage;
}; };
UncommittedState uncommitted_state{*this}; UncommittedState uncommitted_state{*this};
@ -530,10 +750,16 @@ public:
/// Set of methods for creating snapshots /// Set of methods for creating snapshots
/// Turn on snapshot mode, so data inside Container is not deleted, but replaced with new version. /// Turn on snapshot mode, so data inside Container is not deleted, but replaced with new version.
void enableSnapshotMode(size_t up_to_version) { container.enableSnapshotMode(up_to_version); } void enableSnapshotMode(size_t up_to_version)
{
container.enableSnapshotMode(up_to_version);
}
/// Turn off snapshot mode. /// Turn off snapshot mode.
void disableSnapshotMode() { container.disableSnapshotMode(); } void disableSnapshotMode()
{
container.disableSnapshotMode();
}
Container::const_iterator getSnapshotIteratorBegin() const { return container.begin(); } Container::const_iterator getSnapshotIteratorBegin() const { return container.begin(); }
@ -572,6 +798,9 @@ private:
void addDigest(const Node & node, std::string_view path); void addDigest(const Node & node, std::string_view path);
}; };
using KeeperStoragePtr = std::unique_ptr<KeeperStorage>; using KeeperMemoryStorage = KeeperStorage<SnapshotableHashTable<KeeperMemNode>>;
#if USE_ROCKSDB
using KeeperRocksStorage = KeeperStorage<RocksDBContainer<KeeperRocksNode>>;
#endif
} }

View File

@ -0,0 +1,460 @@
#pragma once
#include <base/StringRef.h>
#include <Coordination/CoordinationSettings.h>
#include <Coordination/KeeperContext.h>
#include <Common/SipHash.h>
#include <Disks/DiskLocal.h>
#include <IO/WriteBufferFromString.h>
#include <IO/ReadBufferFromString.h>
#include <rocksdb/convenience.h>
#include <rocksdb/options.h>
#include <rocksdb/status.h>
#include <rocksdb/table.h>
#include <rocksdb/snapshot.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ROCKSDB_ERROR;
extern const int LOGICAL_ERROR;
}
/// The key-value format of rocks db will be
/// - key: Int8 (depth of the path) + String (path)
/// - value: SizeOf(keeperRocksNodeInfo) (meta of the node) + String (data)
template <class Node_>
struct RocksDBContainer
{
using Node = Node_;
private:
/// MockNode is only use in test to mock `getChildren()` and `getData()`
struct MockNode
{
std::vector<int> children;
std::string data;
MockNode(size_t children_num, std::string_view data_)
: children(std::vector<int>(children_num)),
data(data_)
{
}
std::vector<int> getChildren() { return children; }
std::string getData() { return data; }
};
UInt16 getKeyDepth(const std::string & key)
{
UInt16 depth = 0;
for (size_t i = 0; i < key.size(); i++)
{
if (key[i] == '/' && i + 1 != key.size())
depth ++;
}
return depth;
}
std::string getEncodedKey(const std::string & key, bool child_prefix = false)
{
WriteBufferFromOwnString key_buffer;
UInt16 depth = getKeyDepth(key) + (child_prefix ? 1 : 0);
writeIntBinary(depth, key_buffer);
writeString(key, key_buffer);
return key_buffer.str();
}
static std::string_view getDecodedKey(const std::string_view & key)
{
return std::string_view(key.begin() + 2, key.end());
}
struct KVPair
{
StringRef key;
Node value;
};
using ValueUpdater = std::function<void(Node & node)>;
public:
/// This is an iterator wrapping rocksdb iterator and the kv result.
struct const_iterator
{
std::shared_ptr<rocksdb::Iterator> iter;
std::shared_ptr<const KVPair> pair;
const_iterator() = default;
explicit const_iterator(std::shared_ptr<KVPair> pair_) : pair(std::move(pair_)) {}
explicit const_iterator(rocksdb::Iterator * iter_) : iter(iter_)
{
updatePairFromIter();
}
const KVPair & operator * () const
{
return *pair;
}
const KVPair * operator->() const
{
return pair.get();
}
bool operator != (const const_iterator & other) const
{
return !(*this == other);
}
bool operator == (const const_iterator & other) const
{
if (pair == nullptr && other == nullptr)
return true;
if (pair == nullptr || other == nullptr)
return false;
return pair->key.toView() == other->key.toView() && iter == other.iter;
}
bool operator == (std::nullptr_t) const
{
return iter == nullptr;
}
bool operator != (std::nullptr_t) const
{
return iter != nullptr;
}
explicit operator bool() const
{
return iter != nullptr;
}
const_iterator & operator ++()
{
iter->Next();
updatePairFromIter();
return *this;
}
private:
void updatePairFromIter()
{
if (iter && iter->Valid())
{
auto new_pair = std::make_shared<KVPair>();
new_pair->key = StringRef(getDecodedKey(iter->key().ToStringView()));
ReadBufferFromOwnString buffer(iter->value().ToStringView());
typename Node::Meta & meta = new_pair->value;
readPODBinary(meta, buffer);
readVarUInt(new_pair->value.data_size, buffer);
if (new_pair->value.data_size)
{
new_pair->value.data = std::unique_ptr<char[]>(new char[new_pair->value.data_size]);
buffer.readStrict(new_pair->value.data.get(), new_pair->value.data_size);
}
pair = new_pair;
}
else
{
pair = nullptr;
iter = nullptr;
}
}
};
bool initialized = false;
const const_iterator end_ptr;
void initialize(const KeeperContextPtr & context)
{
DiskPtr disk = context->getTemporaryRocksDBDisk();
if (disk == nullptr)
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot get rocksdb disk");
}
auto options = context->getRocksDBOptions();
if (options == nullptr)
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot get rocksdb options");
}
rocksdb_dir = disk->getPath();
rocksdb::DB * db;
auto status = rocksdb::DB::Open(*options, rocksdb_dir, &db);
if (!status.ok())
{
throw Exception(ErrorCodes::ROCKSDB_ERROR, "Failed to open rocksdb path at: {}: {}",
rocksdb_dir, status.ToString());
}
rocksdb_ptr = std::unique_ptr<rocksdb::DB>(db);
write_options.disableWAL = true;
initialized = true;
}
~RocksDBContainer()
{
if (initialized)
{
rocksdb_ptr->Close();
rocksdb_ptr = nullptr;
std::filesystem::remove_all(rocksdb_dir);
}
}
std::vector<std::pair<std::string, Node>> getChildren(const std::string & key_)
{
rocksdb::ReadOptions read_options;
read_options.total_order_seek = true;
std::string key = key_;
if (!key.ends_with('/'))
key += '/';
size_t len = key.size() + 2;
auto iter = std::unique_ptr<rocksdb::Iterator>(rocksdb_ptr->NewIterator(read_options));
std::string encoded_string = getEncodedKey(key, true);
rocksdb::Slice prefix(encoded_string);
std::vector<std::pair<std::string, Node>> result;
for (iter->Seek(prefix); iter->Valid() && iter->key().starts_with(prefix); iter->Next())
{
Node node;
ReadBufferFromOwnString buffer(iter->value().ToStringView());
typename Node::Meta & meta = node;
/// We do not read data here
readPODBinary(meta, buffer);
std::string real_key(iter->key().data() + len, iter->key().size() - len);
// std::cout << "real key: " << real_key << std::endl;
result.emplace_back(std::move(real_key), std::move(node));
}
return result;
}
bool contains(const std::string & path)
{
const std::string & encoded_key = getEncodedKey(path);
std::string buffer_str;
rocksdb::Status status = rocksdb_ptr->Get(rocksdb::ReadOptions(), encoded_key, &buffer_str);
if (status.IsNotFound())
return false;
if (!status.ok())
throw Exception(ErrorCodes::ROCKSDB_ERROR, "Got rocksdb error during executing contains. The error message is {}.", status.ToString());
return true;
}
const_iterator find(StringRef key_)
{
/// rocksdb::PinnableSlice slice;
const std::string & encoded_key = getEncodedKey(key_.toString());
std::string buffer_str;
rocksdb::Status status = rocksdb_ptr->Get(rocksdb::ReadOptions(), encoded_key, &buffer_str);
if (status.IsNotFound())
return end();
if (!status.ok())
throw Exception(ErrorCodes::ROCKSDB_ERROR, "Got rocksdb error during executing find. The error message is {}.", status.ToString());
ReadBufferFromOwnString buffer(buffer_str);
auto kv = std::make_shared<KVPair>();
kv->key = key_;
typename Node::Meta & meta = kv->value;
readPODBinary(meta, buffer);
/// TODO: Sometimes we don't need to load data.
readVarUInt(kv->value.data_size, buffer);
if (kv->value.data_size)
{
kv->value.data = std::unique_ptr<char[]>(new char[kv->value.data_size]);
buffer.readStrict(kv->value.data.get(), kv->value.data_size);
}
return const_iterator(kv);
}
MockNode getValue(StringRef key)
{
auto it = find(key);
chassert(it != end());
return MockNode(it->value.numChildren(), it->value.getData());
}
const_iterator updateValue(StringRef key_, ValueUpdater updater)
{
/// rocksdb::PinnableSlice slice;
const std::string & key = key_.toString();
const std::string & encoded_key = getEncodedKey(key);
std::string buffer_str;
rocksdb::Status status = rocksdb_ptr->Get(rocksdb::ReadOptions(), encoded_key, &buffer_str);
if (!status.ok())
throw Exception(ErrorCodes::ROCKSDB_ERROR, "Got rocksdb error during find. The error message is {}.", status.ToString());
auto kv = std::make_shared<KVPair>();
kv->key = key_;
kv->value.decodeFromString(buffer_str);
/// storage->removeDigest(node, key);
updater(kv->value);
insertOrReplace(key, kv->value);
return const_iterator(kv);
}
bool insert(const std::string & key, Node & value)
{
std::string value_str;
const std::string & encoded_key = getEncodedKey(key);
rocksdb::Status status = rocksdb_ptr->Get(rocksdb::ReadOptions(), encoded_key, &value_str);
if (status.ok())
{
return false;
}
else if (status.IsNotFound())
{
status = rocksdb_ptr->Put(write_options, encoded_key, value.getEncodedString());
if (status.ok())
{
counter++;
return true;
}
}
throw Exception(ErrorCodes::ROCKSDB_ERROR, "Got rocksdb error during insert. The error message is {}.", status.ToString());
}
void insertOrReplace(const std::string & key, Node & value)
{
const std::string & encoded_key = getEncodedKey(key);
/// storage->addDigest(value, key);
std::string value_str;
rocksdb::Status status = rocksdb_ptr->Get(rocksdb::ReadOptions(), encoded_key, &value_str);
bool increase_counter = false;
if (status.IsNotFound())
increase_counter = true;
else if (!status.ok())
throw Exception(ErrorCodes::ROCKSDB_ERROR, "Got rocksdb error during get. The error message is {}.", status.ToString());
status = rocksdb_ptr->Put(write_options, encoded_key, value.getEncodedString());
if (status.ok())
counter += increase_counter;
else
throw Exception(ErrorCodes::ROCKSDB_ERROR, "Got rocksdb error during insert. The error message is {}.", status.ToString());
}
using KeyPtr = std::unique_ptr<char[]>;
/// To be compatible with SnapshotableHashTable, will remove later;
KeyPtr allocateKey(size_t size)
{
return KeyPtr{new char[size]};
}
void insertOrReplace(KeyPtr key_data, size_t key_size, Node value)
{
std::string key(key_data.get(), key_size);
insertOrReplace(key, value);
}
bool erase(const std::string & key)
{
/// storage->removeDigest(value, key);
const std::string & encoded_key = getEncodedKey(key);
auto status = rocksdb_ptr->Delete(write_options, encoded_key);
if (status.IsNotFound())
return false;
if (status.ok())
{
counter--;
return true;
}
throw Exception(ErrorCodes::ROCKSDB_ERROR, "Got rocksdb error during erase. The error message is {}.", status.ToString());
}
void recalculateDataSize() {}
void reverse(size_t size_) {(void)size_;}
uint64_t getApproximateDataSize() const
{
/// use statistics from rocksdb
return counter * sizeof(Node);
}
void enableSnapshotMode(size_t version)
{
chassert(!snapshot_mode);
snapshot_mode = true;
snapshot_up_to_version = version;
snapshot_size = counter;
++current_version;
snapshot = rocksdb_ptr->GetSnapshot();
}
void disableSnapshotMode()
{
chassert(snapshot_mode);
snapshot_mode = false;
rocksdb_ptr->ReleaseSnapshot(snapshot);
}
void clearOutdatedNodes() {}
std::pair<size_t, size_t> snapshotSizeWithVersion() const
{
if (!snapshot_mode)
return std::make_pair(counter, current_version);
else
return std::make_pair(snapshot_size, current_version);
}
const_iterator begin() const
{
rocksdb::ReadOptions read_options;
read_options.total_order_seek = true;
if (snapshot_mode)
read_options.snapshot = snapshot;
auto * iter = rocksdb_ptr->NewIterator(read_options);
iter->SeekToFirst();
return const_iterator(iter);
}
const_iterator end() const
{
return end_ptr;
}
size_t size() const
{
return counter;
}
uint64_t getArenaDataSize() const
{
return 0;
}
uint64_t keyArenaSize() const
{
return 0;
}
private:
String rocksdb_dir;
std::unique_ptr<rocksdb::DB> rocksdb_ptr;
rocksdb::WriteOptions write_options;
const rocksdb::Snapshot * snapshot;
bool snapshot_mode{false};
size_t current_version{0};
size_t snapshot_up_to_version{0};
size_t snapshot_size{0};
size_t counter{0};
};
}

View File

@ -212,9 +212,9 @@ private:
updateDataSize(INSERT_OR_REPLACE, key.size, new_value_size, old_value_size, !snapshot_mode); updateDataSize(INSERT_OR_REPLACE, key.size, new_value_size, old_value_size, !snapshot_mode);
} }
public: public:
using Node = V;
using iterator = typename List::iterator; using iterator = typename List::iterator;
using const_iterator = typename List::const_iterator; using const_iterator = typename List::const_iterator;
using ValueUpdater = std::function<void(V & value)>; using ValueUpdater = std::function<void(V & value)>;
@ -364,6 +364,7 @@ public:
{ {
auto map_it = map.find(key); auto map_it = map.find(key);
if (map_it != map.end()) if (map_it != map.end())
/// return std::make_shared<KVPair>(KVPair{map_it->getMapped()->key, map_it->getMapped()->value});
return map_it->getMapped(); return map_it->getMapped();
return list.end(); return list.end();
} }

View File

@ -43,7 +43,8 @@ void deserializeSnapshotMagic(ReadBuffer & in)
throw Exception(ErrorCodes::CORRUPTED_DATA, "Incorrect magic header in file, expected {}, got {}", SNP_HEADER, magic_header); throw Exception(ErrorCodes::CORRUPTED_DATA, "Incorrect magic header in file, expected {}, got {}", SNP_HEADER, magic_header);
} }
int64_t deserializeSessionAndTimeout(KeeperStorage & storage, ReadBuffer & in) template<typename Storage>
int64_t deserializeSessionAndTimeout(Storage & storage, ReadBuffer & in)
{ {
int32_t count; int32_t count;
Coordination::read(count, in); Coordination::read(count, in);
@ -62,7 +63,8 @@ int64_t deserializeSessionAndTimeout(KeeperStorage & storage, ReadBuffer & in)
return max_session_id; return max_session_id;
} }
void deserializeACLMap(KeeperStorage & storage, ReadBuffer & in) template<typename Storage>
void deserializeACLMap(Storage & storage, ReadBuffer & in)
{ {
int32_t count; int32_t count;
Coordination::read(count, in); Coordination::read(count, in);
@ -90,7 +92,8 @@ void deserializeACLMap(KeeperStorage & storage, ReadBuffer & in)
} }
} }
int64_t deserializeStorageData(KeeperStorage & storage, ReadBuffer & in, LoggerPtr log) template<typename Storage>
int64_t deserializeStorageData(Storage & storage, ReadBuffer & in, LoggerPtr log)
{ {
int64_t max_zxid = 0; int64_t max_zxid = 0;
std::string path; std::string path;
@ -98,7 +101,7 @@ int64_t deserializeStorageData(KeeperStorage & storage, ReadBuffer & in, LoggerP
size_t count = 0; size_t count = 0;
while (path != "/") while (path != "/")
{ {
KeeperStorage::Node node{}; typename Storage::Node node{};
String data; String data;
Coordination::read(data, in); Coordination::read(data, in);
node.setData(data); node.setData(data);
@ -146,14 +149,15 @@ int64_t deserializeStorageData(KeeperStorage & storage, ReadBuffer & in, LoggerP
if (itr.key != "/") if (itr.key != "/")
{ {
auto parent_path = parentNodePath(itr.key); auto parent_path = parentNodePath(itr.key);
storage.container.updateValue(parent_path, [my_path = itr.key] (KeeperStorage::Node & value) { value.addChild(getBaseNodeName(my_path)); value.increaseNumChildren(); }); storage.container.updateValue(parent_path, [my_path = itr.key] (typename Storage::Node & value) { value.addChild(getBaseNodeName(my_path)); value.increaseNumChildren(); });
} }
} }
return max_zxid; return max_zxid;
} }
void deserializeKeeperStorageFromSnapshot(KeeperStorage & storage, const std::string & snapshot_path, LoggerPtr log) template<typename Storage>
void deserializeKeeperStorageFromSnapshot(Storage & storage, const std::string & snapshot_path, LoggerPtr log)
{ {
LOG_INFO(log, "Deserializing storage snapshot {}", snapshot_path); LOG_INFO(log, "Deserializing storage snapshot {}", snapshot_path);
int64_t zxid = getZxidFromName(snapshot_path); int64_t zxid = getZxidFromName(snapshot_path);
@ -192,9 +196,11 @@ void deserializeKeeperStorageFromSnapshot(KeeperStorage & storage, const std::st
LOG_INFO(log, "Finished, snapshot ZXID {}", storage.zxid); LOG_INFO(log, "Finished, snapshot ZXID {}", storage.zxid);
} }
void deserializeKeeperStorageFromSnapshotsDir(KeeperStorage & storage, const std::string & path, LoggerPtr log) namespace fs = std::filesystem;
template<typename Storage>
void deserializeKeeperStorageFromSnapshotsDir(Storage & storage, const std::string & path, LoggerPtr log)
{ {
namespace fs = std::filesystem;
std::map<int64_t, std::string> existing_snapshots; std::map<int64_t, std::string> existing_snapshots;
for (const auto & p : fs::directory_iterator(path)) for (const auto & p : fs::directory_iterator(path))
{ {
@ -480,7 +486,8 @@ bool hasErrorsInMultiRequest(Coordination::ZooKeeperRequestPtr request)
} }
bool deserializeTxn(KeeperStorage & storage, ReadBuffer & in, LoggerPtr /*log*/) template<typename Storage>
bool deserializeTxn(Storage & storage, ReadBuffer & in, LoggerPtr /*log*/)
{ {
int64_t checksum; int64_t checksum;
Coordination::read(checksum, in); Coordination::read(checksum, in);
@ -535,7 +542,8 @@ bool deserializeTxn(KeeperStorage & storage, ReadBuffer & in, LoggerPtr /*log*/)
return true; return true;
} }
void deserializeLogAndApplyToStorage(KeeperStorage & storage, const std::string & log_path, LoggerPtr log) template<typename Storage>
void deserializeLogAndApplyToStorage(Storage & storage, const std::string & log_path, LoggerPtr log)
{ {
ReadBufferFromFile reader(log_path); ReadBufferFromFile reader(log_path);
@ -559,9 +567,9 @@ void deserializeLogAndApplyToStorage(KeeperStorage & storage, const std::string
LOG_INFO(log, "Finished {} deserialization, totally read {} records", log_path, counter); LOG_INFO(log, "Finished {} deserialization, totally read {} records", log_path, counter);
} }
void deserializeLogsAndApplyToStorage(KeeperStorage & storage, const std::string & path, LoggerPtr log) template<typename Storage>
void deserializeLogsAndApplyToStorage(Storage & storage, const std::string & path, LoggerPtr log)
{ {
namespace fs = std::filesystem;
std::map<int64_t, std::string> existing_logs; std::map<int64_t, std::string> existing_logs;
for (const auto & p : fs::directory_iterator(path)) for (const auto & p : fs::directory_iterator(path))
{ {
@ -595,4 +603,9 @@ void deserializeLogsAndApplyToStorage(KeeperStorage & storage, const std::string
} }
} }
template void deserializeKeeperStorageFromSnapshot<KeeperMemoryStorage>(KeeperMemoryStorage & storage, const std::string & snapshot_path, LoggerPtr log);
template void deserializeKeeperStorageFromSnapshotsDir<KeeperMemoryStorage>(KeeperMemoryStorage & storage, const std::string & path, LoggerPtr log);
template void deserializeLogAndApplyToStorage<KeeperMemoryStorage>(KeeperMemoryStorage & storage, const std::string & log_path, LoggerPtr log);
template void deserializeLogsAndApplyToStorage<KeeperMemoryStorage>(KeeperMemoryStorage & storage, const std::string & path, LoggerPtr log);
} }

View File

@ -5,12 +5,16 @@
namespace DB namespace DB
{ {
void deserializeKeeperStorageFromSnapshot(KeeperStorage & storage, const std::string & snapshot_path, LoggerPtr log); template<typename Storage>
void deserializeKeeperStorageFromSnapshot(Storage & storage, const std::string & snapshot_path, LoggerPtr log);
void deserializeKeeperStorageFromSnapshotsDir(KeeperStorage & storage, const std::string & path, LoggerPtr log); template<typename Storage>
void deserializeKeeperStorageFromSnapshotsDir(Storage & storage, const std::string & path, LoggerPtr log);
void deserializeLogAndApplyToStorage(KeeperStorage & storage, const std::string & log_path, LoggerPtr log); template<typename Storage>
void deserializeLogAndApplyToStorage(Storage & storage, const std::string & log_path, LoggerPtr log);
void deserializeLogsAndApplyToStorage(KeeperStorage & storage, const std::string & path, LoggerPtr log); template<typename Storage>
void deserializeLogsAndApplyToStorage(Storage & storage, const std::string & path, LoggerPtr log);
} }

File diff suppressed because it is too large Load Diff

View File

@ -11,6 +11,7 @@ namespace DB
namespace ErrorCodes namespace ErrorCodes
{ {
extern const int NOT_IMPLEMENTED; extern const int NOT_IMPLEMENTED;
extern const int LOGICAL_ERROR;
} }
struct SerializationVariantElement::DeserializeBinaryBulkStateVariantElement : public ISerialization::DeserializeBinaryBulkState struct SerializationVariantElement::DeserializeBinaryBulkStateVariantElement : public ISerialization::DeserializeBinaryBulkState
@ -188,13 +189,6 @@ void SerializationVariantElement::deserializeBinaryBulkWithMultipleStreams(
assert_cast<ColumnLowCardinality &>(*variant_element_state->variant->assumeMutable()).nestedRemoveNullable(); assert_cast<ColumnLowCardinality &>(*variant_element_state->variant->assumeMutable()).nestedRemoveNullable();
} }
/// If nothing to deserialize, just insert defaults.
if (variant_limit == 0)
{
mutable_column->insertManyDefaults(num_new_discriminators);
return;
}
addVariantToPath(settings.path); addVariantToPath(settings.path);
nested_serialization->deserializeBinaryBulkWithMultipleStreams(variant_element_state->variant, *variant_limit, settings, variant_element_state->variant_element_state, cache); nested_serialization->deserializeBinaryBulkWithMultipleStreams(variant_element_state->variant, *variant_limit, settings, variant_element_state->variant_element_state, cache);
removeVariantFromPath(settings.path); removeVariantFromPath(settings.path);
@ -209,6 +203,17 @@ void SerializationVariantElement::deserializeBinaryBulkWithMultipleStreams(
return; return;
} }
/// If there was nothing to deserialize or nothing was actually deserialized when variant_limit > 0, just insert defaults.
/// The second case means that we don't have a stream for such sub-column. It may happen during ALTER MODIFY column with Variant extension.
if (variant_limit == 0 || variant_element_state->variant->empty())
{
mutable_column->insertManyDefaults(num_new_discriminators);
return;
}
if (variant_element_state->variant->size() < *variant_limit)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Size of deserialized variant column less than the limit: {} < {}", variant_element_state->variant->size(), *variant_limit);
size_t variant_offset = variant_element_state->variant->size() - *variant_limit; size_t variant_offset = variant_element_state->variant->size() - *variant_limit;
/// If we have only our discriminator in range, insert the whole range to result column. /// If we have only our discriminator in range, insert the whole range to result column.

View File

@ -22,7 +22,9 @@ class FunctionIsNotNull : public IFunction
public: public:
static constexpr auto name = "isNotNull"; static constexpr auto name = "isNotNull";
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionIsNotNull>(); } static FunctionPtr create(ContextPtr context) { return std::make_shared<FunctionIsNotNull>(context->getSettingsRef().allow_experimental_analyzer); }
explicit FunctionIsNotNull(bool use_analyzer_) : use_analyzer(use_analyzer_) {}
std::string getName() const override std::string getName() const override
{ {
@ -31,6 +33,10 @@ public:
ColumnPtr getConstantResultForNonConstArguments(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type) const override ColumnPtr getConstantResultForNonConstArguments(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type) const override
{ {
/// (column IS NULL) triggers a bug in old analyzer when it is replaced to constant.
if (!use_analyzer)
return nullptr;
const ColumnWithTypeAndName & elem = arguments[0]; const ColumnWithTypeAndName & elem = arguments[0];
if (elem.type->onlyNull()) if (elem.type->onlyNull())
return result_type->createColumnConst(1, UInt8(0)); return result_type->createColumnConst(1, UInt8(0));
@ -123,6 +129,8 @@ private:
#endif #endif
vectorImpl(null_map, res); vectorImpl(null_map, res);
} }
bool use_analyzer;
}; };
} }

View File

@ -7,6 +7,8 @@
#include <Columns/ColumnLowCardinality.h> #include <Columns/ColumnLowCardinality.h>
#include <Columns/ColumnVariant.h> #include <Columns/ColumnVariant.h>
#include <Columns/ColumnDynamic.h> #include <Columns/ColumnDynamic.h>
#include <Core/Settings.h>
#include <Interpreters/Context.h>
namespace DB namespace DB
@ -21,11 +23,13 @@ class FunctionIsNull : public IFunction
public: public:
static constexpr auto name = "isNull"; static constexpr auto name = "isNull";
static FunctionPtr create(ContextPtr) static FunctionPtr create(ContextPtr context)
{ {
return std::make_shared<FunctionIsNull>(); return std::make_shared<FunctionIsNull>(context->getSettingsRef().allow_experimental_analyzer);
} }
explicit FunctionIsNull(bool use_analyzer_) : use_analyzer(use_analyzer_) {}
std::string getName() const override std::string getName() const override
{ {
return name; return name;
@ -33,6 +37,10 @@ public:
ColumnPtr getConstantResultForNonConstArguments(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type) const override ColumnPtr getConstantResultForNonConstArguments(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type) const override
{ {
/// (column IS NULL) triggers a bug in old analyzer when it is replaced to constant.
if (!use_analyzer)
return nullptr;
const ColumnWithTypeAndName & elem = arguments[0]; const ColumnWithTypeAndName & elem = arguments[0];
if (elem.type->onlyNull()) if (elem.type->onlyNull())
return result_type->createColumnConst(1, UInt8(1)); return result_type->createColumnConst(1, UInt8(1));
@ -95,6 +103,9 @@ public:
return DataTypeUInt8().createColumnConst(elem.column->size(), 0u); return DataTypeUInt8().createColumnConst(elem.column->size(), 0u);
} }
} }
private:
bool use_analyzer;
}; };
} }

View File

@ -3,6 +3,8 @@
#include <DataTypes/DataTypesNumber.h> #include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnsNumber.h> #include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypeNullable.h> #include <DataTypes/DataTypeNullable.h>
#include <Core/Settings.h>
#include <Interpreters/Context.h>
namespace DB namespace DB
{ {
@ -14,11 +16,13 @@ class FunctionIsNullable : public IFunction
{ {
public: public:
static constexpr auto name = "isNullable"; static constexpr auto name = "isNullable";
static FunctionPtr create(ContextPtr) static FunctionPtr create(ContextPtr context)
{ {
return std::make_shared<FunctionIsNullable>(); return std::make_shared<FunctionIsNullable>(context->getSettingsRef().allow_experimental_analyzer);
} }
explicit FunctionIsNullable(bool use_analyzer_) : use_analyzer(use_analyzer_) {}
String getName() const override String getName() const override
{ {
return name; return name;
@ -26,6 +30,10 @@ public:
ColumnPtr getConstantResultForNonConstArguments(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type) const override ColumnPtr getConstantResultForNonConstArguments(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type) const override
{ {
/// isNullable(column) triggers a bug in old analyzer when it is replaced to constant.
if (!use_analyzer)
return nullptr;
const ColumnWithTypeAndName & elem = arguments[0]; const ColumnWithTypeAndName & elem = arguments[0];
if (elem.type->onlyNull() || canContainNull(*elem.type)) if (elem.type->onlyNull() || canContainNull(*elem.type))
return result_type->createColumnConst(1, UInt8(1)); return result_type->createColumnConst(1, UInt8(1));
@ -60,6 +68,9 @@ public:
const auto & elem = arguments[0]; const auto & elem = arguments[0];
return ColumnUInt8::create(input_rows_count, isColumnNullable(*elem.column) || elem.type->isLowCardinalityNullable()); return ColumnUInt8::create(input_rows_count, isColumnNullable(*elem.column) || elem.type->isLowCardinalityNullable());
} }
private:
bool use_analyzer;
}; };
} }

View File

@ -25,8 +25,6 @@ namespace ProfileEvents
extern const Event ReadBufferFromS3InitMicroseconds; extern const Event ReadBufferFromS3InitMicroseconds;
extern const Event ReadBufferFromS3Bytes; extern const Event ReadBufferFromS3Bytes;
extern const Event ReadBufferFromS3RequestsErrors; extern const Event ReadBufferFromS3RequestsErrors;
extern const Event ReadBufferFromS3ResetSessions;
extern const Event ReadBufferFromS3PreservedSessions;
extern const Event ReadBufferSeekCancelConnection; extern const Event ReadBufferSeekCancelConnection;
extern const Event S3GetObject; extern const Event S3GetObject;
extern const Event DiskS3GetObject; extern const Event DiskS3GetObject;

View File

@ -134,7 +134,7 @@ Chunk Squashing::squash(std::vector<Chunk> && input_chunks, Chunk::ChunkInfoColl
Chunk result; Chunk result;
result.setColumns(std::move(mutable_columns), rows); result.setColumns(std::move(mutable_columns), rows);
result.setChunkInfos(infos); result.setChunkInfos(infos);
result.getChunkInfos().append(std::move(input_chunks.back().getChunkInfos())); result.getChunkInfos().appendIfUniq(std::move(input_chunks.back().getChunkInfos()));
chassert(result); chassert(result);
return result; return result;

View File

@ -1,5 +1,6 @@
#include <Parsers/isUnquotedIdentifier.h> #include <Parsers/isUnquotedIdentifier.h>
#include <Parsers/CommonParsers.h>
#include <Parsers/Lexer.h> #include <Parsers/Lexer.h>
namespace DB namespace DB
@ -7,6 +8,18 @@ namespace DB
bool isUnquotedIdentifier(const String & name) bool isUnquotedIdentifier(const String & name)
{ {
auto is_keyword = [&name](Keyword keyword)
{
auto s = toStringView(keyword);
if (name.size() != s.size())
return false;
return strncasecmp(s.data(), name.data(), s.size()) == 0;
};
/// Special keywords are parsed as literals instead of identifiers.
if (is_keyword(Keyword::NULL_KEYWORD) || is_keyword(Keyword::TRUE_KEYWORD) || is_keyword(Keyword::FALSE_KEYWORD))
return false;
Lexer lexer(name.data(), name.data() + name.size()); Lexer lexer(name.data(), name.data() + name.size());
auto maybe_ident = lexer.nextToken(); auto maybe_ident = lexer.nextToken();

View File

@ -5,6 +5,14 @@
namespace DB namespace DB
{ {
/// Checks if the input string @name is a valid unquoted identifier.
///
/// Example Usage:
/// abc -> true (valid unquoted identifier)
/// 123 -> false (identifiers cannot start with digits)
/// `123` -> false (quoted identifiers are not considered)
/// `abc` -> false (quoted identifiers are not considered)
/// null -> false (reserved literal keyword)
bool isUnquotedIdentifier(const String & name); bool isUnquotedIdentifier(const String & name);
} }

View File

@ -20,7 +20,7 @@ namespace ErrorCodes
void RestoreChunkInfosTransform::transform(Chunk & chunk) void RestoreChunkInfosTransform::transform(Chunk & chunk)
{ {
chunk.getChunkInfos().append(chunk_infos.clone()); chunk.getChunkInfos().appendIfUniq(chunk_infos.clone());
} }
namespace DeduplicationToken namespace DeduplicationToken

View File

@ -652,15 +652,12 @@ size_t IMergeTreeDataPart::getFileSizeOrZero(const String & file_name) const
return checksum->second.file_size; return checksum->second.file_size;
} }
String IMergeTreeDataPart::getColumnNameWithMinimumCompressedSize(bool with_subcolumns) const String IMergeTreeDataPart::getColumnNameWithMinimumCompressedSize(const NamesAndTypesList & available_columns) const
{ {
auto options = GetColumnsOptions(GetColumnsOptions::AllPhysical).withSubcolumns(with_subcolumns);
auto columns_list = columns_description.get(options);
std::optional<std::string> minimum_size_column; std::optional<std::string> minimum_size_column;
UInt64 minimum_size = std::numeric_limits<UInt64>::max(); UInt64 minimum_size = std::numeric_limits<UInt64>::max();
for (const auto & column : columns_list) for (const auto & column : available_columns)
{ {
if (!hasColumnFiles(column)) if (!hasColumnFiles(column))
continue; continue;

View File

@ -196,7 +196,9 @@ public:
/// Returns the name of a column with minimum compressed size (as returned by getColumnSize()). /// Returns the name of a column with minimum compressed size (as returned by getColumnSize()).
/// If no checksums are present returns the name of the first physically existing column. /// If no checksums are present returns the name of the first physically existing column.
String getColumnNameWithMinimumCompressedSize(bool with_subcolumns) const; /// We pass a list of available columns since the ones available in the current storage snapshot might be smaller
/// than the one the table has (e.g a DROP COLUMN happened) and we don't want to get a column not in the snapshot
String getColumnNameWithMinimumCompressedSize(const NamesAndTypesList & available_columns) const;
bool contains(const IMergeTreeDataPart & other) const { return info.contains(other.info); } bool contains(const IMergeTreeDataPart & other) const { return info.contains(other.info); }

View File

@ -47,7 +47,7 @@ public:
virtual std::optional<size_t> getColumnPosition(const String & column_name) const = 0; virtual std::optional<size_t> getColumnPosition(const String & column_name) const = 0;
virtual String getColumnNameWithMinimumCompressedSize(bool with_subcolumns) const = 0; virtual String getColumnNameWithMinimumCompressedSize(const NamesAndTypesList & available_columns) const = 0;
virtual const MergeTreeDataPartChecksums & getChecksums() const = 0; virtual const MergeTreeDataPartChecksums & getChecksums() const = 0;

View File

@ -36,7 +36,10 @@ public:
AlterConversionsPtr getAlterConversions() const override { return alter_conversions; } AlterConversionsPtr getAlterConversions() const override { return alter_conversions; }
String getColumnNameWithMinimumCompressedSize(bool with_subcolumns) const override { return data_part->getColumnNameWithMinimumCompressedSize(with_subcolumns); } String getColumnNameWithMinimumCompressedSize(const NamesAndTypesList & available_columns) const override
{
return data_part->getColumnNameWithMinimumCompressedSize(available_columns);
}
const MergeTreeDataPartChecksums & getChecksums() const override { return data_part->checksums; } const MergeTreeDataPartChecksums & getChecksums() const override { return data_part->checksums; }

View File

@ -127,7 +127,8 @@ NameSet injectRequiredColumns(
*/ */
if (!have_at_least_one_physical_column) if (!have_at_least_one_physical_column)
{ {
const auto minimum_size_column_name = data_part_info_for_reader.getColumnNameWithMinimumCompressedSize(with_subcolumns); auto available_columns = storage_snapshot->metadata->getColumns().get(options);
const auto minimum_size_column_name = data_part_info_for_reader.getColumnNameWithMinimumCompressedSize(available_columns);
columns.push_back(minimum_size_column_name); columns.push_back(minimum_size_column_name);
/// correctly report added column /// correctly report added column
injected_columns.insert(columns.back()); injected_columns.insert(columns.back());

View File

@ -266,10 +266,13 @@ void MergeTreeDataPartWide::doCheckConsistency(bool require_part_metadata) const
bool MergeTreeDataPartWide::hasColumnFiles(const NameAndTypePair & column) const bool MergeTreeDataPartWide::hasColumnFiles(const NameAndTypePair & column) const
{ {
auto serialization = tryGetSerialization(column.name);
if (!serialization)
return false;
auto marks_file_extension = index_granularity_info.mark_type.getFileExtension(); auto marks_file_extension = index_granularity_info.mark_type.getFileExtension();
bool res = true; bool res = true;
getSerialization(column.name)->enumerateStreams([&](const auto & substream_path) serialization->enumerateStreams([&](const auto & substream_path)
{ {
auto stream_name = getStreamNameForColumn(column, substream_path, checksums); auto stream_name = getStreamNameForColumn(column, substream_path, checksums);
if (!stream_name || !checksums.files.contains(*stream_name + marks_file_extension)) if (!stream_name || !checksums.files.contains(*stream_name + marks_file_extension))

View File

@ -233,7 +233,7 @@ static bool isConditionGood(const RPNBuilderTreeNode & condition, const NameSet
else if (type == Field::Types::Float64) else if (type == Field::Types::Float64)
{ {
const auto value = output_value.get<Float64>(); const auto value = output_value.get<Float64>();
return value < threshold || threshold < value; return value < -threshold || threshold < value;
} }
return false; return false;

View File

@ -16,16 +16,6 @@
namespace DB namespace DB
{ {
namespace ErrorCodes
{
extern const int DICTIONARIES_WAS_NOT_LOADED;
extern const int FUNCTION_NOT_ALLOWED;
extern const int NOT_IMPLEMENTED;
extern const int SUPPORT_IS_DISABLED;
extern const int ACCESS_DENIED;
extern const int DEPRECATED_FUNCTION;
};
enum class FunctionOrigin : int8_t enum class FunctionOrigin : int8_t
{ {
SYSTEM = 0, SYSTEM = 0,
@ -40,7 +30,6 @@ namespace
MutableColumns & res_columns, MutableColumns & res_columns,
const String & name, const String & name,
UInt64 is_aggregate, UInt64 is_aggregate,
std::optional<UInt64> is_deterministic,
const String & create_query, const String & create_query,
FunctionOrigin function_origin, FunctionOrigin function_origin,
const Factory & factory) const Factory & factory)
@ -48,58 +37,53 @@ namespace
res_columns[0]->insert(name); res_columns[0]->insert(name);
res_columns[1]->insert(is_aggregate); res_columns[1]->insert(is_aggregate);
if (!is_deterministic.has_value())
res_columns[2]->insertDefault();
else
res_columns[2]->insert(*is_deterministic);
if constexpr (std::is_same_v<Factory, UserDefinedSQLFunctionFactory> || std::is_same_v<Factory, UserDefinedExecutableFunctionFactory>) if constexpr (std::is_same_v<Factory, UserDefinedSQLFunctionFactory> || std::is_same_v<Factory, UserDefinedExecutableFunctionFactory>)
{ {
res_columns[3]->insert(false); res_columns[2]->insert(false);
res_columns[4]->insertDefault(); res_columns[3]->insertDefault();
} }
else else
{ {
res_columns[3]->insert(factory.isCaseInsensitive(name)); res_columns[2]->insert(factory.isCaseInsensitive(name));
if (factory.isAlias(name)) if (factory.isAlias(name))
res_columns[4]->insert(factory.aliasTo(name)); res_columns[3]->insert(factory.aliasTo(name));
else else
res_columns[4]->insertDefault(); res_columns[3]->insertDefault();
} }
res_columns[5]->insert(create_query); res_columns[4]->insert(create_query);
res_columns[6]->insert(static_cast<Int8>(function_origin)); res_columns[5]->insert(static_cast<Int8>(function_origin));
if constexpr (std::is_same_v<Factory, FunctionFactory>) if constexpr (std::is_same_v<Factory, FunctionFactory>)
{ {
if (factory.isAlias(name)) if (factory.isAlias(name))
{ {
res_columns[6]->insertDefault();
res_columns[7]->insertDefault(); res_columns[7]->insertDefault();
res_columns[8]->insertDefault(); res_columns[8]->insertDefault();
res_columns[9]->insertDefault(); res_columns[9]->insertDefault();
res_columns[10]->insertDefault(); res_columns[10]->insertDefault();
res_columns[11]->insertDefault(); res_columns[11]->insertDefault();
res_columns[12]->insertDefault();
} }
else else
{ {
auto documentation = factory.getDocumentation(name); auto documentation = factory.getDocumentation(name);
res_columns[7]->insert(documentation.description); res_columns[6]->insert(documentation.description);
res_columns[8]->insert(documentation.syntax); res_columns[7]->insert(documentation.syntax);
res_columns[9]->insert(documentation.argumentsAsString()); res_columns[8]->insert(documentation.argumentsAsString());
res_columns[10]->insert(documentation.returned_value); res_columns[9]->insert(documentation.returned_value);
res_columns[11]->insert(documentation.examplesAsString()); res_columns[10]->insert(documentation.examplesAsString());
res_columns[12]->insert(documentation.categoriesAsString()); res_columns[11]->insert(documentation.categoriesAsString());
} }
} }
else else
{ {
res_columns[6]->insertDefault();
res_columns[7]->insertDefault(); res_columns[7]->insertDefault();
res_columns[8]->insertDefault(); res_columns[8]->insertDefault();
res_columns[9]->insertDefault(); res_columns[9]->insertDefault();
res_columns[10]->insertDefault(); res_columns[10]->insertDefault();
res_columns[11]->insertDefault(); res_columns[11]->insertDefault();
res_columns[12]->insertDefault();
} }
} }
} }
@ -120,7 +104,6 @@ ColumnsDescription StorageSystemFunctions::getColumnsDescription()
{ {
{"name", std::make_shared<DataTypeString>(), "The name of the function."}, {"name", std::make_shared<DataTypeString>(), "The name of the function."},
{"is_aggregate", std::make_shared<DataTypeUInt8>(), "Whether the function is an aggregate function."}, {"is_aggregate", std::make_shared<DataTypeUInt8>(), "Whether the function is an aggregate function."},
{"is_deterministic", std::make_shared<DataTypeNullable>(std::make_shared<DataTypeUInt8>()), "Whether the function is deterministic."},
{"case_insensitive", std::make_shared<DataTypeUInt8>(), "Whether the function name can be used case-insensitively."}, {"case_insensitive", std::make_shared<DataTypeUInt8>(), "Whether the function name can be used case-insensitively."},
{"alias_to", std::make_shared<DataTypeString>(), "The original function name, if the function name is an alias."}, {"alias_to", std::make_shared<DataTypeString>(), "The original function name, if the function name is an alias."},
{"create_query", std::make_shared<DataTypeString>(), "Obsolete."}, {"create_query", std::make_shared<DataTypeString>(), "Obsolete."},
@ -140,36 +123,14 @@ void StorageSystemFunctions::fillData(MutableColumns & res_columns, ContextPtr c
const auto & function_names = functions_factory.getAllRegisteredNames(); const auto & function_names = functions_factory.getAllRegisteredNames();
for (const auto & function_name : function_names) for (const auto & function_name : function_names)
{ {
std::optional<UInt64> is_deterministic; fillRow(res_columns, function_name, 0, "", FunctionOrigin::SYSTEM, functions_factory);
try
{
DO_NOT_UPDATE_ERROR_STATISTICS();
is_deterministic = functions_factory.tryGet(function_name, context)->isDeterministic();
}
catch (const Exception & e)
{
/// Some functions throw because they need special configuration or setup before use.
if (e.code() == ErrorCodes::DICTIONARIES_WAS_NOT_LOADED
|| e.code() == ErrorCodes::FUNCTION_NOT_ALLOWED
|| e.code() == ErrorCodes::NOT_IMPLEMENTED
|| e.code() == ErrorCodes::SUPPORT_IS_DISABLED
|| e.code() == ErrorCodes::ACCESS_DENIED
|| e.code() == ErrorCodes::DEPRECATED_FUNCTION)
{
/// Ignore exception, show is_deterministic = NULL.
}
else
throw;
}
fillRow(res_columns, function_name, 0, is_deterministic, "", FunctionOrigin::SYSTEM, functions_factory);
} }
const auto & aggregate_functions_factory = AggregateFunctionFactory::instance(); const auto & aggregate_functions_factory = AggregateFunctionFactory::instance();
const auto & aggregate_function_names = aggregate_functions_factory.getAllRegisteredNames(); const auto & aggregate_function_names = aggregate_functions_factory.getAllRegisteredNames();
for (const auto & function_name : aggregate_function_names) for (const auto & function_name : aggregate_function_names)
{ {
fillRow(res_columns, function_name, 1, {1}, "", FunctionOrigin::SYSTEM, aggregate_functions_factory); fillRow(res_columns, function_name, 1, "", FunctionOrigin::SYSTEM, aggregate_functions_factory);
} }
const auto & user_defined_sql_functions_factory = UserDefinedSQLFunctionFactory::instance(); const auto & user_defined_sql_functions_factory = UserDefinedSQLFunctionFactory::instance();
@ -177,14 +138,14 @@ void StorageSystemFunctions::fillData(MutableColumns & res_columns, ContextPtr c
for (const auto & function_name : user_defined_sql_functions_names) for (const auto & function_name : user_defined_sql_functions_names)
{ {
auto create_query = queryToString(user_defined_sql_functions_factory.get(function_name)); auto create_query = queryToString(user_defined_sql_functions_factory.get(function_name));
fillRow(res_columns, function_name, 0, {0}, create_query, FunctionOrigin::SQL_USER_DEFINED, user_defined_sql_functions_factory); fillRow(res_columns, function_name, 0, create_query, FunctionOrigin::SQL_USER_DEFINED, user_defined_sql_functions_factory);
} }
const auto & user_defined_executable_functions_factory = UserDefinedExecutableFunctionFactory::instance(); const auto & user_defined_executable_functions_factory = UserDefinedExecutableFunctionFactory::instance();
const auto & user_defined_executable_functions_names = user_defined_executable_functions_factory.getRegisteredNames(context); /// NOLINT(readability-static-accessed-through-instance) const auto & user_defined_executable_functions_names = user_defined_executable_functions_factory.getRegisteredNames(context); /// NOLINT(readability-static-accessed-through-instance)
for (const auto & function_name : user_defined_executable_functions_names) for (const auto & function_name : user_defined_executable_functions_names)
{ {
fillRow(res_columns, function_name, 0, {0}, "", FunctionOrigin::EXECUTABLE_USER_DEFINED, user_defined_executable_functions_factory); fillRow(res_columns, function_name, 0, "", FunctionOrigin::EXECUTABLE_USER_DEFINED, user_defined_executable_functions_factory);
} }
} }

3
tests/ci/.gitignore vendored
View File

@ -1,4 +1 @@
*_lambda/lambda-venv
*_lambda/lambda-package
*_lambda/lambda-package.zip
gh_cache gh_cache

View File

@ -1,235 +0,0 @@
#!/usr/bin/env python3
"""The lambda to decrease/increase ASG desired capacity based on current queue"""
import logging
from dataclasses import dataclass
from pprint import pformat
from typing import Any, List, Literal, Optional, Tuple
import boto3 # type: ignore
from lambda_shared import (
RUNNER_TYPE_LABELS,
CHException,
ClickHouseHelper,
get_parameter_from_ssm,
)
### Update comment on the change ###
# 4 HOUR - is a balance to get the most precise values
# - Our longest possible running check is around 5h on the worst scenario
# - The long queue won't be wiped out and replaced, so the measurmenet is fine
# - If the data is spoiled by something, we are from the bills perspective
# Changed it to 3 HOUR: in average we have 1h tasks, but p90 is around 2h.
# With 4h we have too much wasted computing time in case of issues with DB
QUEUE_QUERY = f"""SELECT
last_status AS status,
toUInt32(count()) AS length,
labels
FROM
(
SELECT
arraySort(groupArray(status))[-1] AS last_status,
labels,
id,
html_url
FROM default.workflow_jobs
WHERE has(labels, 'self-hosted')
AND hasAny({RUNNER_TYPE_LABELS}, labels)
AND started_at > now() - INTERVAL 3 HOUR
GROUP BY ALL
HAVING last_status IN ('in_progress', 'queued')
)
GROUP BY ALL
ORDER BY labels, last_status"""
@dataclass
class Queue:
status: Literal["in_progress", "queued"]
lentgh: int
label: str
def get_scales(runner_type: str) -> Tuple[int, int]:
"returns the multipliers for scaling down and up ASG by types"
# Scaling down is quicker on the lack of running jobs than scaling up on
# queue
# The ASG should deflate almost instantly
scale_down = 1
# the style checkers have so many noise, so it scales up too quickly
# The 5 was too quick, there are complainings regarding too slow with
# 10. I am trying 7 now.
# 7 still looks a bit slow, so I try 6
# Let's have it the same as the other ASG
#
# All type of style-checkers should be added very quickly to not block the workflows
# UPDATE THE COMMENT ON CHANGES
scale_up = 3
if "style" in runner_type:
scale_up = 1
return scale_down, scale_up
CH_CLIENT = None # type: Optional[ClickHouseHelper]
def set_capacity(
runner_type: str, queues: List[Queue], client: Any, dry_run: bool = True
) -> None:
assert len(queues) in (1, 2)
assert all(q.label == runner_type for q in queues)
as_groups = client.describe_auto_scaling_groups(
Filters=[
{"Name": "tag-key", "Values": ["github:runner-type"]},
{"Name": "tag-value", "Values": [runner_type]},
]
)["AutoScalingGroups"]
assert len(as_groups) == 1
asg = as_groups[0]
running = 0
queued = 0
for q in queues:
if q.status == "in_progress":
running = q.lentgh
continue
if q.status == "queued":
queued = q.lentgh
continue
raise ValueError("Queue status is not in ['in_progress', 'queued']")
# scale_down, scale_up = get_scales(runner_type)
_, scale_up = get_scales(runner_type)
# With lyfecycle hooks some instances are actually free because some of
# them are in 'Terminating:Wait' state
effective_capacity = max(
asg["DesiredCapacity"],
len([ins for ins in asg["Instances"] if ins["HealthStatus"] == "Healthy"]),
)
# How much nodes are free (positive) or need to be added (negative)
capacity_reserve = effective_capacity - running - queued
stop = False
if capacity_reserve <= 0:
# This part is about scaling up
capacity_deficit = -capacity_reserve
# It looks that we are still OK, since no queued jobs exist
stop = stop or queued == 0
# Are we already at the capacity limits
stop = stop or asg["MaxSize"] <= asg["DesiredCapacity"]
# Let's calculate a new desired capacity
# (capacity_deficit + scale_up - 1) // scale_up : will increase min by 1
# if there is any capacity_deficit
new_capacity = (
asg["DesiredCapacity"] + (capacity_deficit + scale_up - 1) // scale_up
)
new_capacity = max(new_capacity, asg["MinSize"])
new_capacity = min(new_capacity, asg["MaxSize"])
# Finally, should the capacity be even changed
stop = stop or asg["DesiredCapacity"] == new_capacity
if stop:
logging.info(
"Do not increase ASG %s capacity, current capacity=%s, effective "
"capacity=%s, maximum capacity=%s, running jobs=%s, queue size=%s",
asg["AutoScalingGroupName"],
asg["DesiredCapacity"],
effective_capacity,
asg["MaxSize"],
running,
queued,
)
return
logging.info(
"The ASG %s capacity will be increased to %s, current capacity=%s, "
"effective capacity=%s, maximum capacity=%s, running jobs=%s, queue size=%s",
asg["AutoScalingGroupName"],
new_capacity,
asg["DesiredCapacity"],
effective_capacity,
asg["MaxSize"],
running,
queued,
)
if not dry_run:
client.set_desired_capacity(
AutoScalingGroupName=asg["AutoScalingGroupName"],
DesiredCapacity=new_capacity,
)
return
# FIXME: try decreasing capacity from runners that finished their jobs and have no job assigned
# IMPORTANT: Runner init script must be of version that supports ASG decrease
# # Now we will calculate if we need to scale down
# stop = stop or asg["DesiredCapacity"] == asg["MinSize"]
# new_capacity = asg["DesiredCapacity"] - (capacity_reserve // scale_down)
# new_capacity = max(new_capacity, asg["MinSize"])
# new_capacity = min(new_capacity, asg["MaxSize"])
# stop = stop or asg["DesiredCapacity"] == new_capacity
# if stop:
# logging.info(
# "Do not decrease ASG %s capacity, current capacity=%s, effective "
# "capacity=%s, minimum capacity=%s, running jobs=%s, queue size=%s",
# asg["AutoScalingGroupName"],
# asg["DesiredCapacity"],
# effective_capacity,
# asg["MinSize"],
# running,
# queued,
# )
# return
#
# logging.info(
# "The ASG %s capacity will be decreased to %s, current capacity=%s, effective "
# "capacity=%s, minimum capacity=%s, running jobs=%s, queue size=%s",
# asg["AutoScalingGroupName"],
# new_capacity,
# asg["DesiredCapacity"],
# effective_capacity,
# asg["MinSize"],
# running,
# queued,
# )
# if not dry_run:
# client.set_desired_capacity(
# AutoScalingGroupName=asg["AutoScalingGroupName"],
# DesiredCapacity=new_capacity,
# )
def main(dry_run: bool = True) -> None:
logging.getLogger().setLevel(logging.INFO)
asg_client = boto3.client("autoscaling")
try:
global CH_CLIENT
CH_CLIENT = CH_CLIENT or ClickHouseHelper(
get_parameter_from_ssm("clickhouse-test-stat-url"), "play"
)
queues = CH_CLIENT.select_json_each_row("default", QUEUE_QUERY)
except CHException as ex:
logging.exception(
"Got an exception on insert, tryuing to update the client "
"credentials and repeat",
exc_info=ex,
)
CH_CLIENT = ClickHouseHelper(
get_parameter_from_ssm("clickhouse-test-stat-url"), "play"
)
queues = CH_CLIENT.select_json_each_row("default", QUEUE_QUERY)
logging.info("Received queue data:\n%s", pformat(queues, width=120))
for runner_type in RUNNER_TYPE_LABELS:
runner_queues = [
Queue(queue["status"], queue["length"], runner_type)
for queue in queues
if runner_type in queue["labels"]
]
runner_queues = runner_queues or [Queue("in_progress", 0, runner_type)]
set_capacity(runner_type, runner_queues, asg_client, dry_run)
def handler(event: dict, context: Any) -> None:
_ = event
_ = context
return main(False)

View File

@ -1 +0,0 @@
../team_keys_lambda/build_and_deploy_archive.sh

View File

@ -1 +0,0 @@
../lambda_shared_package/lambda_shared

View File

@ -1 +0,0 @@
../lambda_shared_package

View File

@ -1,196 +0,0 @@
#!/usr/bin/env python
import unittest
from dataclasses import dataclass
from typing import Any, List
from app import Queue, set_capacity
@dataclass
class TestCase:
name: str
min_size: int
desired_capacity: int
max_size: int
queues: List[Queue]
expected_capacity: int
class TestSetCapacity(unittest.TestCase):
class FakeClient:
def __init__(self):
self._expected_data = {} # type: dict
self._expected_capacity = -1
@property
def expected_data(self) -> dict:
"""a one-time property"""
data, self._expected_data = self._expected_data, {}
return data
@expected_data.setter
def expected_data(self, value: dict) -> None:
self._expected_data = value
@property
def expected_capacity(self) -> int:
"""a one-time property"""
capacity, self._expected_capacity = self._expected_capacity, -1
return capacity
def describe_auto_scaling_groups(self, **kwargs: Any) -> dict:
_ = kwargs
return self.expected_data
def set_desired_capacity(self, **kwargs: Any) -> None:
self._expected_capacity = kwargs["DesiredCapacity"]
def data_helper(
self, name: str, min_size: int, desired_capacity: int, max_size: int
) -> None:
self.expected_data = {
"AutoScalingGroups": [
{
"AutoScalingGroupName": name,
"DesiredCapacity": desired_capacity,
"MinSize": min_size,
"MaxSize": max_size,
"Instances": [], # necessary for ins["HealthStatus"] check
}
]
}
def setUp(self):
self.client = self.FakeClient()
def test_normal_cases(self):
test_cases = (
# Do not change capacity
TestCase("noqueue", 1, 13, 20, [Queue("in_progress", 155, "noqueue")], -1),
TestCase("reserve", 1, 13, 20, [Queue("queued", 13, "reserve")], -1),
# Increase capacity
TestCase(
"increase-always",
1,
13,
20,
[Queue("queued", 14, "increase-always")],
14,
),
TestCase("increase-1", 1, 13, 20, [Queue("queued", 23, "increase-1")], 17),
TestCase(
"style-checker", 1, 13, 20, [Queue("queued", 19, "style-checker")], 19
),
TestCase("increase-2", 1, 13, 20, [Queue("queued", 18, "increase-2")], 15),
TestCase("increase-3", 1, 13, 20, [Queue("queued", 183, "increase-3")], 20),
TestCase(
"increase-w/o reserve",
1,
13,
20,
[
Queue("in_progress", 11, "increase-w/o reserve"),
Queue("queued", 12, "increase-w/o reserve"),
],
17,
),
TestCase("lower-min", 10, 5, 20, [Queue("queued", 5, "lower-min")], 10),
# Decrease capacity
# FIXME: Tests changed for lambda that can only scale up
# TestCase("w/reserve", 1, 13, 20, [Queue("queued", 5, "w/reserve")], 5),
TestCase("w/reserve", 1, 13, 20, [Queue("queued", 5, "w/reserve")], -1),
# TestCase(
# "style-checker", 1, 13, 20, [Queue("queued", 5, "style-checker")], 5
# ),
TestCase(
"style-checker", 1, 13, 20, [Queue("queued", 5, "style-checker")], -1
),
# TestCase("w/reserve", 1, 23, 20, [Queue("queued", 17, "w/reserve")], 17),
TestCase("w/reserve", 1, 23, 20, [Queue("queued", 17, "w/reserve")], -1),
# TestCase("decrease", 1, 13, 20, [Queue("in_progress", 3, "decrease")], 3),
TestCase("decrease", 1, 13, 20, [Queue("in_progress", 3, "decrease")], -1),
# TestCase(
# "style-checker",
# 1,
# 13,
# 20,
# [Queue("in_progress", 5, "style-checker")],
# 5,
# ),
TestCase(
"style-checker",
1,
13,
20,
[Queue("in_progress", 5, "style-checker")],
-1,
),
)
for t in test_cases:
self.client.data_helper(t.name, t.min_size, t.desired_capacity, t.max_size)
set_capacity(t.name, t.queues, self.client, False)
self.assertEqual(t.expected_capacity, self.client.expected_capacity, t.name)
def test_effective_capacity(self):
"""Normal cases test increasing w/o considering
effective_capacity much lower than DesiredCapacity"""
test_cases = (
TestCase(
"desired-overwritten",
1,
20, # DesiredCapacity, overwritten by effective_capacity
50,
[
Queue("in_progress", 30, "desired-overwritten"),
Queue("queued", 60, "desired-overwritten"),
],
40,
),
)
for t in test_cases:
self.client.data_helper(t.name, t.min_size, t.desired_capacity, t.max_size)
# we test that effective_capacity is 30 (a half of 60)
data_with_instances = self.client.expected_data
data_with_instances["AutoScalingGroups"][0]["Instances"] = [
{"HealthStatus": "Healthy" if i % 2 else "Unhealthy"} for i in range(60)
]
self.client.expected_data = data_with_instances
set_capacity(t.name, t.queues, self.client, False)
self.assertEqual(t.expected_capacity, self.client.expected_capacity, t.name)
def test_exceptions(self):
test_cases = (
(
TestCase(
"different names",
1,
1,
1,
[Queue("queued", 5, "another name")],
-1,
),
AssertionError,
),
(TestCase("wrong queue len", 1, 1, 1, [], -1), AssertionError),
(
TestCase(
"wrong queue", 1, 1, 1, [Queue("wrong", 1, "wrong queue")], -1 # type: ignore
),
ValueError,
),
)
for t, error in test_cases:
with self.assertRaises(error):
self.client.data_helper(
t.name, t.min_size, t.desired_capacity, t.max_size
)
set_capacity(t.name, t.queues, self.client, False)
with self.assertRaises(AssertionError):
self.client.expected_data = {"AutoScalingGroups": [1, 2]}
set_capacity(
"wrong number of ASGs",
[Queue("queued", 1, "wrong number of ASGs")],
self.client,
)

View File

@ -12,7 +12,6 @@ import docker_images_helper
from ci_config import CI from ci_config import CI
from env_helper import REPO_COPY, S3_BUILDS_BUCKET, TEMP_PATH from env_helper import REPO_COPY, S3_BUILDS_BUCKET, TEMP_PATH
from git_helper import Git from git_helper import Git
from lambda_shared_package.lambda_shared.pr import Labels
from pr_info import PRInfo from pr_info import PRInfo
from report import FAILURE, SUCCESS, JobReport, StatusType from report import FAILURE, SUCCESS, JobReport, StatusType
from stopwatch import Stopwatch from stopwatch import Stopwatch
@ -108,7 +107,9 @@ def build_clickhouse(
def is_release_pr(pr_info: PRInfo) -> bool: def is_release_pr(pr_info: PRInfo) -> bool:
return Labels.RELEASE in pr_info.labels or Labels.RELEASE_LTS in pr_info.labels return (
CI.Labels.RELEASE in pr_info.labels or CI.Labels.RELEASE_LTS in pr_info.labels
)
def get_release_or_pr(pr_info: PRInfo, version: ClickHouseVersion) -> Tuple[str, str]: def get_release_or_pr(pr_info: PRInfo, version: ClickHouseVersion) -> Tuple[str, str]:

View File

@ -1,376 +0,0 @@
#!/usr/bin/env python3
import json
import time
from base64 import b64decode
from collections import namedtuple
from queue import Queue
from threading import Thread
from typing import Any, Dict, List, Optional
import requests
from lambda_shared.pr import Labels
from lambda_shared.token import get_cached_access_token
NEED_RERUN_OR_CANCELL_WORKFLOWS = {
"BackportPR",
"DocsCheck",
"MasterCI",
"PullRequestCI",
}
MAX_RETRY = 5
DEBUG_INFO = {} # type: Dict[str, Any]
class Worker(Thread):
def __init__(
self, request_queue: Queue, token: str, ignore_exception: bool = False
):
Thread.__init__(self)
self.queue = request_queue
self.token = token
self.ignore_exception = ignore_exception
self.response = {} # type: Dict
def run(self):
m = self.queue.get()
try:
self.response = _exec_get_with_retry(m, self.token)
except Exception as e:
if not self.ignore_exception:
raise
print(f"Exception occured, still continue: {e}")
self.queue.task_done()
def _exec_get_with_retry(url: str, token: str) -> dict:
headers = {"Authorization": f"token {token}"}
e = Exception()
for i in range(MAX_RETRY):
try:
response = requests.get(url, headers=headers, timeout=30)
response.raise_for_status()
return response.json() # type: ignore
except Exception as ex:
print("Got exception executing request", ex)
e = ex
time.sleep(i + 1)
raise requests.HTTPError("Cannot execute GET request with retries") from e
WorkflowDescription = namedtuple(
"WorkflowDescription",
[
"url",
"run_id",
"name",
"head_sha",
"status",
"rerun_url",
"cancel_url",
"conclusion",
],
)
def get_workflows_description_for_pull_request(
pull_request_event: dict, token: str
) -> List[WorkflowDescription]:
head_repo = pull_request_event["head"]["repo"]["full_name"]
head_branch = pull_request_event["head"]["ref"]
print("PR", pull_request_event["number"], "has head ref", head_branch)
workflows_data = []
repo_url = pull_request_event["base"]["repo"]["url"]
request_url = f"{repo_url}/actions/runs?per_page=100"
# Get all workflows for the current branch
for i in range(1, 11):
workflows = _exec_get_with_retry(
f"{request_url}&event=pull_request&branch={head_branch}&page={i}", token
)
if not workflows["workflow_runs"]:
break
workflows_data += workflows["workflow_runs"]
if i == 10:
print("Too many workflows found")
if not workflows_data:
print("No workflows found by filter")
return []
print(f"Total workflows for the branch {head_branch} found: {len(workflows_data)}")
DEBUG_INFO["workflows"] = []
workflow_descriptions = []
for workflow in workflows_data:
# Some time workflow["head_repository"]["full_name"] is None
if workflow["head_repository"] is None:
continue
DEBUG_INFO["workflows"].append(
{
"full_name": workflow["head_repository"]["full_name"],
"name": workflow["name"],
"branch": workflow["head_branch"],
}
)
# unfortunately we cannot filter workflows from forks in request to API
# so doing it manually
if (
workflow["head_repository"]["full_name"] == head_repo
and workflow["name"] in NEED_RERUN_OR_CANCELL_WORKFLOWS
):
workflow_descriptions.append(
WorkflowDescription(
url=workflow["url"],
run_id=workflow["id"],
name=workflow["name"],
head_sha=workflow["head_sha"],
status=workflow["status"],
rerun_url=workflow["rerun_url"],
cancel_url=workflow["cancel_url"],
conclusion=workflow["conclusion"],
)
)
return workflow_descriptions
def get_workflow_description_fallback(
pull_request_event: dict, token: str
) -> List[WorkflowDescription]:
head_repo = pull_request_event["head"]["repo"]["full_name"]
head_branch = pull_request_event["head"]["ref"]
print("Get last 500 workflows from API to search related there")
# Fallback for a case of an already deleted branch and no workflows received
repo_url = pull_request_event["base"]["repo"]["url"]
request_url = f"{repo_url}/actions/runs?per_page=100"
q = Queue() # type: Queue
workers = []
workflows_data = []
i = 1
for i in range(1, 6):
q.put(f"{request_url}&page={i}")
worker = Worker(q, token, True)
worker.start()
workers.append(worker)
for worker in workers:
worker.join()
if not worker.response:
# We ignore get errors, so response can be empty
continue
# Prefilter workflows
workflows_data += [
wf
for wf in worker.response["workflow_runs"]
if wf["head_repository"] is not None
and wf["head_repository"]["full_name"] == head_repo
and wf["head_branch"] == head_branch
and wf["name"] in NEED_RERUN_OR_CANCELL_WORKFLOWS
]
print(f"Total workflows in last 500 actions matches: {len(workflows_data)}")
DEBUG_INFO["workflows"] = [
{
"full_name": wf["head_repository"]["full_name"],
"name": wf["name"],
"branch": wf["head_branch"],
}
for wf in workflows_data
]
workflow_descriptions = [
WorkflowDescription(
url=wf["url"],
run_id=wf["id"],
name=wf["name"],
head_sha=wf["head_sha"],
status=wf["status"],
rerun_url=wf["rerun_url"],
cancel_url=wf["cancel_url"],
conclusion=wf["conclusion"],
)
for wf in workflows_data
]
return workflow_descriptions
def get_workflow_description(workflow_url: str, token: str) -> WorkflowDescription:
workflow = _exec_get_with_retry(workflow_url, token)
return WorkflowDescription(
url=workflow["url"],
run_id=workflow["id"],
name=workflow["name"],
head_sha=workflow["head_sha"],
status=workflow["status"],
rerun_url=workflow["rerun_url"],
cancel_url=workflow["cancel_url"],
conclusion=workflow["conclusion"],
)
def _exec_post_with_retry(url: str, token: str, json: Optional[Any] = None) -> Any:
headers = {"Authorization": f"token {token}"}
e = Exception()
for i in range(MAX_RETRY):
try:
response = requests.post(url, headers=headers, json=json, timeout=30)
response.raise_for_status()
return response.json()
except Exception as ex:
print("Got exception executing request", ex)
e = ex
time.sleep(i + 1)
raise requests.HTTPError("Cannot execute POST request with retry") from e
def exec_workflow_url(urls_to_post, token):
for url in urls_to_post:
print("Post for workflow workflow using url", url)
_exec_post_with_retry(url, token)
print("Workflow post finished")
def main(event):
token = get_cached_access_token()
DEBUG_INFO["event"] = event
if event["isBase64Encoded"]:
event_data = json.loads(b64decode(event["body"]))
else:
event_data = json.loads(event["body"])
print("Got event for PR", event_data["number"])
action = event_data["action"]
print("Got action", event_data["action"])
pull_request = event_data["pull_request"]
label = ""
if action == "labeled":
label = event_data["label"]["name"]
print("Added label:", label)
print("PR has labels", {label["name"] for label in pull_request["labels"]})
if action == "opened" or (
action == "labeled" and pull_request["created_at"] == pull_request["updated_at"]
):
print("Freshly opened PR, nothing to do")
return
if action == "closed" or label == Labels.DO_NOT_TEST:
print("PR merged/closed or manually labeled 'do not test', will kill workflows")
workflow_descriptions = get_workflows_description_for_pull_request(
pull_request, token
)
workflow_descriptions = (
workflow_descriptions
or get_workflow_description_fallback(pull_request, token)
)
urls_to_cancel = []
for workflow_description in workflow_descriptions:
if (
workflow_description.status != "completed"
and workflow_description.conclusion != "cancelled"
):
urls_to_cancel.append(workflow_description.cancel_url)
print(f"Found {len(urls_to_cancel)} workflows to cancel")
exec_workflow_url(urls_to_cancel, token)
return
if label == Labels.CAN_BE_TESTED:
print("PR marked with can be tested label, rerun workflow")
workflow_descriptions = get_workflows_description_for_pull_request(
pull_request, token
)
workflow_descriptions = (
workflow_descriptions
or get_workflow_description_fallback(pull_request, token)
)
if not workflow_descriptions:
print("Not found any workflows")
return
workflow_descriptions.sort(key=lambda x: x.run_id) # type: ignore
most_recent_workflow = workflow_descriptions[-1]
print("Latest workflow", most_recent_workflow)
if (
most_recent_workflow.status != "completed"
and most_recent_workflow.conclusion != "cancelled"
):
print("Latest workflow is not completed, cancelling")
exec_workflow_url([most_recent_workflow.cancel_url], token)
print("Cancelled")
for _ in range(45):
# If the number of retries is changed: tune the lambda limits accordingly
latest_workflow_desc = get_workflow_description(
most_recent_workflow.url, token
)
print("Checking latest workflow", latest_workflow_desc)
if latest_workflow_desc.status in ("completed", "cancelled"):
print("Finally latest workflow done, going to rerun")
exec_workflow_url([most_recent_workflow.rerun_url], token)
print("Rerun finished, exiting")
break
print("Still have strange status")
time.sleep(3)
return
if action == "edited":
print("PR is edited - do nothing")
# error, _ = check_pr_description(
# pull_request["body"], pull_request["base"]["repo"]["full_name"]
# )
# if error:
# print(
# f"The PR's body is wrong, is going to comment it. The error is: {error}"
# )
# post_json = {
# "body": "This is an automatic comment. The PR descriptions does not "
# f"match the [template]({pull_request['base']['repo']['html_url']}/"
# "blob/master/.github/PULL_REQUEST_TEMPLATE.md?plain=1).\n\n"
# f"Please, edit it accordingly.\n\nThe error is: {error}"
# }
# _exec_post_with_retry(pull_request["comments_url"], token, json=post_json)
return
if action == "synchronize":
print("PR is synchronized, going to stop old actions")
workflow_descriptions = get_workflows_description_for_pull_request(
pull_request, token
)
workflow_descriptions = (
workflow_descriptions
or get_workflow_description_fallback(pull_request, token)
)
urls_to_cancel = []
for workflow_description in workflow_descriptions:
if (
workflow_description.status != "completed"
and workflow_description.conclusion != "cancelled"
and workflow_description.head_sha != pull_request["head"]["sha"]
):
urls_to_cancel.append(workflow_description.cancel_url)
print(f"Found {len(urls_to_cancel)} workflows to cancel")
exec_workflow_url(urls_to_cancel, token)
return
print("Nothing to do")
def handler(event, _):
try:
main(event)
return {
"statusCode": 200,
"headers": {"Content-Type": "application/json"},
"body": '{"status": "OK"}',
}
finally:
for name, value in DEBUG_INFO.items():
print(f"Value of {name}: ", value)

View File

@ -1 +0,0 @@
../team_keys_lambda/build_and_deploy_archive.sh

View File

@ -1 +0,0 @@
../lambda_shared_package/lambda_shared

View File

@ -1 +0,0 @@
../lambda_shared_package[token]

View File

@ -38,7 +38,7 @@ from env_helper import TEMP_PATH
from get_robot_token import get_best_robot_token from get_robot_token import get_best_robot_token
from git_helper import GIT_PREFIX, git_runner, is_shallow from git_helper import GIT_PREFIX, git_runner, is_shallow
from github_helper import GitHub, PullRequest, PullRequests, Repository from github_helper import GitHub, PullRequest, PullRequests, Repository
from lambda_shared_package.lambda_shared.pr import Labels from ci_config import Labels
from ssh import SSHKey from ssh import SSHKey

View File

@ -32,6 +32,9 @@ class CI:
from ci_definitions import MQ_JOBS as MQ_JOBS from ci_definitions import MQ_JOBS as MQ_JOBS
from ci_definitions import WorkflowStages as WorkflowStages from ci_definitions import WorkflowStages as WorkflowStages
from ci_definitions import Runners as Runners from ci_definitions import Runners as Runners
from ci_definitions import Labels as Labels
from ci_definitions import TRUSTED_CONTRIBUTORS as TRUSTED_CONTRIBUTORS
from ci_utils import CATEGORY_TO_LABEL as CATEGORY_TO_LABEL
# Jobs that run for doc related updates # Jobs that run for doc related updates
_DOCS_CHECK_JOBS = [JobNames.DOCS_CHECK, JobNames.STYLE_CHECK] _DOCS_CHECK_JOBS = [JobNames.DOCS_CHECK, JobNames.STYLE_CHECK]
@ -45,24 +48,14 @@ class CI:
JobNames.INTEGRATION_TEST_ARM, JobNames.INTEGRATION_TEST_ARM,
] ]
), ),
Tags.CI_SET_REQUIRED: LabelConfig(run_jobs=REQUIRED_CHECKS), Tags.CI_SET_REQUIRED: LabelConfig(
run_jobs=REQUIRED_CHECKS
+ [build for build in BuildNames if build != BuildNames.FUZZERS]
),
Tags.CI_SET_BUILDS: LabelConfig( Tags.CI_SET_BUILDS: LabelConfig(
run_jobs=[JobNames.STYLE_CHECK, JobNames.BUILD_CHECK] run_jobs=[JobNames.STYLE_CHECK, JobNames.BUILD_CHECK]
+ [build for build in BuildNames if build != BuildNames.FUZZERS] + [build for build in BuildNames if build != BuildNames.FUZZERS]
), ),
Tags.CI_SET_NON_REQUIRED: LabelConfig(
run_jobs=[job for job in JobNames if job not in REQUIRED_CHECKS]
),
Tags.CI_SET_OLD_ANALYZER: LabelConfig(
run_jobs=[
JobNames.STYLE_CHECK,
JobNames.FAST_TEST,
BuildNames.PACKAGE_RELEASE,
BuildNames.PACKAGE_ASAN,
JobNames.STATELESS_TEST_OLD_ANALYZER_S3_REPLICATED_RELEASE,
JobNames.INTEGRATION_TEST_ASAN_OLD_ANALYZER,
]
),
Tags.CI_SET_SYNC: LabelConfig( Tags.CI_SET_SYNC: LabelConfig(
run_jobs=[ run_jobs=[
BuildNames.PACKAGE_ASAN, BuildNames.PACKAGE_ASAN,

View File

@ -7,6 +7,53 @@ from ci_utils import WithIter
from integration_test_images import IMAGES from integration_test_images import IMAGES
class Labels:
PR_BUGFIX = "pr-bugfix"
PR_CRITICAL_BUGFIX = "pr-critical-bugfix"
CAN_BE_TESTED = "can be tested"
DO_NOT_TEST = "do not test"
MUST_BACKPORT = "pr-must-backport"
MUST_BACKPORT_CLOUD = "pr-must-backport-cloud"
JEPSEN_TEST = "jepsen-test"
SKIP_MERGEABLE_CHECK = "skip mergeable check"
PR_BACKPORT = "pr-backport"
PR_BACKPORTS_CREATED = "pr-backports-created"
PR_BACKPORTS_CREATED_CLOUD = "pr-backports-created-cloud"
PR_CHERRYPICK = "pr-cherrypick"
PR_CI = "pr-ci"
PR_FEATURE = "pr-feature"
PR_SYNCED_TO_CLOUD = "pr-synced-to-cloud"
PR_SYNC_UPSTREAM = "pr-sync-upstream"
RELEASE = "release"
RELEASE_LTS = "release-lts"
SUBMODULE_CHANGED = "submodule changed"
# automatic backport for critical bug fixes
AUTO_BACKPORT = {"pr-critical-bugfix"}
TRUSTED_CONTRIBUTORS = {
e.lower()
for e in [
"amosbird",
"azat", # SEMRush
"bharatnc", # Many contributions.
"cwurm", # ClickHouse, Inc
"den-crane", # Documentation contributor
"ildus", # adjust, ex-pgpro
"nvartolomei", # Seasoned contributor, CloudFlare
"taiyang-li",
"ucasFL", # Amos Bird's friend
"thomoco", # ClickHouse, Inc
"tonickkozlov", # Cloudflare
"tylerhannan", # ClickHouse, Inc
"tsolodov", # ClickHouse, Inc
"justindeguzman", # ClickHouse, Inc
"XuJia0210", # ClickHouse, Inc
]
}
class WorkflowStages(metaclass=WithIter): class WorkflowStages(metaclass=WithIter):
""" """
Stages of GitHUb actions workflow Stages of GitHUb actions workflow
@ -55,8 +102,6 @@ class Tags(metaclass=WithIter):
CI_SET_ARM = "ci_set_arm" CI_SET_ARM = "ci_set_arm"
CI_SET_REQUIRED = "ci_set_required" CI_SET_REQUIRED = "ci_set_required"
CI_SET_BUILDS = "ci_set_builds" CI_SET_BUILDS = "ci_set_builds"
CI_SET_NON_REQUIRED = "ci_set_non_required"
CI_SET_OLD_ANALYZER = "ci_set_old_analyzer"
libFuzzer = "libFuzzer" libFuzzer = "libFuzzer"

View File

@ -1,164 +0,0 @@
#!/usr/bin/env python3
"""
Lambda function to:
- calculate number of running runners
- cleaning dead runners from GitHub
- terminating stale lost runners in EC2
"""
import argparse
import sys
from typing import Dict
import boto3 # type: ignore
from lambda_shared import RUNNER_TYPE_LABELS, RunnerDescriptions, list_runners
from lambda_shared.token import (
get_access_token_by_key_app,
get_cached_access_token,
get_key_and_app_from_aws,
)
UNIVERSAL_LABEL = "universal"
def handler(event, context):
_ = event
_ = context
main(get_cached_access_token(), True)
def group_runners_by_tag(
listed_runners: RunnerDescriptions,
) -> Dict[str, RunnerDescriptions]:
result = {} # type: Dict[str, RunnerDescriptions]
def add_to_result(tag, runner):
if tag not in result:
result[tag] = []
result[tag].append(runner)
for runner in listed_runners:
if UNIVERSAL_LABEL in runner.tags:
# Do not proceed other labels if UNIVERSAL_LABEL is included
add_to_result(UNIVERSAL_LABEL, runner)
continue
for tag in runner.tags:
if tag in RUNNER_TYPE_LABELS:
add_to_result(tag, runner)
break
else:
add_to_result("unlabeled", runner)
return result
def push_metrics_to_cloudwatch(
listed_runners: RunnerDescriptions, group_name: str
) -> None:
client = boto3.client("cloudwatch")
namespace = "RunnersMetrics"
metrics_data = []
busy_runners = sum(
1 for runner in listed_runners if runner.busy and not runner.offline
)
dimensions = [{"Name": "group", "Value": group_name}]
metrics_data.append(
{
"MetricName": "BusyRunners",
"Value": busy_runners,
"Unit": "Count",
"Dimensions": dimensions,
}
)
total_active_runners = sum(1 for runner in listed_runners if not runner.offline)
metrics_data.append(
{
"MetricName": "ActiveRunners",
"Value": total_active_runners,
"Unit": "Count",
"Dimensions": dimensions,
}
)
total_runners = len(listed_runners)
metrics_data.append(
{
"MetricName": "TotalRunners",
"Value": total_runners,
"Unit": "Count",
"Dimensions": dimensions,
}
)
if total_active_runners == 0:
busy_ratio = 100.0
else:
busy_ratio = busy_runners / total_active_runners * 100
metrics_data.append(
{
"MetricName": "BusyRunnersRatio",
"Value": busy_ratio,
"Unit": "Percent",
"Dimensions": dimensions,
}
)
client.put_metric_data(Namespace=namespace, MetricData=metrics_data)
def main(
access_token: str,
push_to_cloudwatch: bool,
) -> None:
gh_runners = list_runners(access_token)
grouped_runners = group_runners_by_tag(gh_runners)
for group, group_runners in grouped_runners.items():
if push_to_cloudwatch:
print(f"Pushing metrics for group '{group}'")
push_metrics_to_cloudwatch(group_runners, group)
else:
print(group, f"({len(group_runners)})")
for runner in group_runners:
print("\t", runner)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Get list of runners and their states")
parser.add_argument(
"-p", "--private-key-path", help="Path to file with private key"
)
parser.add_argument("-k", "--private-key", help="Private key")
parser.add_argument(
"-a", "--app-id", type=int, help="GitHub application ID", required=True
)
parser.add_argument(
"--push-to-cloudwatch",
action="store_true",
help="Push metrics for active and busy runners to cloudwatch",
)
args = parser.parse_args()
if not args.private_key_path and not args.private_key:
print(
"Either --private-key-path or --private-key must be specified",
file=sys.stderr,
)
if args.private_key_path and args.private_key:
print(
"Either --private-key-path or --private-key must be specified",
file=sys.stderr,
)
if args.private_key:
private_key = args.private_key
elif args.private_key_path:
with open(args.private_key_path, "r", encoding="utf-8") as key_file:
private_key = key_file.read()
else:
print("Attempt to get key and id from AWS secret manager")
private_key, args.app_id = get_key_and_app_from_aws()
token = get_access_token_by_key_app(private_key, args.app_id)
main(token, args.push_to_cloudwatch)

View File

@ -1 +0,0 @@
../team_keys_lambda/build_and_deploy_archive.sh

View File

@ -1 +0,0 @@
../lambda_shared_package/lambda_shared

View File

@ -1,2 +0,0 @@
../lambda_shared_package
../lambda_shared_package[token]

View File

@ -160,11 +160,8 @@ class CiSettings:
else: else:
return False return False
if CI.is_build_job(job): # do not exclude builds
print(f"Build job [{job}] - always run") if self.exclude_keywords and not CI.is_build_job(job):
return True
if self.exclude_keywords:
for keyword in self.exclude_keywords: for keyword in self.exclude_keywords:
if keyword in normalize_string(job): if keyword in normalize_string(job):
print(f"Job [{job}] matches Exclude keyword [{keyword}] - deny") print(f"Job [{job}] matches Exclude keyword [{keyword}] - deny")
@ -172,7 +169,8 @@ class CiSettings:
to_deny = False to_deny = False
if self.include_keywords: if self.include_keywords:
if job == CI.JobNames.STYLE_CHECK: # do not exclude builds
if job == CI.JobNames.STYLE_CHECK or CI.is_build_job(job):
# never exclude Style Check by include keywords # never exclude Style Check by include keywords
return True return True
for keyword in self.include_keywords: for keyword in self.include_keywords:

View File

@ -3,7 +3,42 @@ import re
import subprocess import subprocess
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Any, Iterator, List, Union, Optional from typing import Any, Iterator, List, Union, Optional, Tuple
LABEL_CATEGORIES = {
"pr-backward-incompatible": ["Backward Incompatible Change"],
"pr-bugfix": [
"Bug Fix",
"Bug Fix (user-visible misbehavior in an official stable release)",
"Bug Fix (user-visible misbehaviour in official stable or prestable release)",
"Bug Fix (user-visible misbehavior in official stable or prestable release)",
],
"pr-critical-bugfix": ["Critical Bug Fix (crash, LOGICAL_ERROR, data loss, RBAC)"],
"pr-build": [
"Build/Testing/Packaging Improvement",
"Build Improvement",
"Build/Testing Improvement",
"Build",
"Packaging Improvement",
],
"pr-documentation": [
"Documentation (changelog entry is not required)",
"Documentation",
],
"pr-feature": ["New Feature"],
"pr-improvement": ["Improvement"],
"pr-not-for-changelog": [
"Not for changelog (changelog entry is not required)",
"Not for changelog",
],
"pr-performance": ["Performance Improvement"],
"pr-ci": ["CI Fix or Improvement (changelog entry is not required)"],
}
CATEGORY_TO_LABEL = {
c: lb for lb, categories in LABEL_CATEGORIES.items() for c in categories
}
class WithIter(type): class WithIter(type):
@ -109,3 +144,81 @@ class Utils:
@staticmethod @staticmethod
def clear_dmesg(): def clear_dmesg():
Shell.run("sudo dmesg --clear ||:") Shell.run("sudo dmesg --clear ||:")
@staticmethod
def check_pr_description(pr_body: str, repo_name: str) -> Tuple[str, str]:
"""The function checks the body to being properly formatted according to
.github/PULL_REQUEST_TEMPLATE.md, if the first returned string is not empty,
then there is an error."""
lines = list(map(lambda x: x.strip(), pr_body.split("\n") if pr_body else []))
lines = [re.sub(r"\s+", " ", line) for line in lines]
# Check if body contains "Reverts ClickHouse/ClickHouse#36337"
if [
True for line in lines if re.match(rf"\AReverts {repo_name}#[\d]+\Z", line)
]:
return "", LABEL_CATEGORIES["pr-not-for-changelog"][0]
category = ""
entry = ""
description_error = ""
i = 0
while i < len(lines):
if re.match(r"(?i)^[#>*_ ]*change\s*log\s*category", lines[i]):
i += 1
if i >= len(lines):
break
# Can have one empty line between header and the category
# itself. Filter it out.
if not lines[i]:
i += 1
if i >= len(lines):
break
category = re.sub(r"^[-*\s]*", "", lines[i])
i += 1
# Should not have more than one category. Require empty line
# after the first found category.
if i >= len(lines):
break
if lines[i]:
second_category = re.sub(r"^[-*\s]*", "", lines[i])
description_error = (
"More than one changelog category specified: "
f"'{category}', '{second_category}'"
)
return description_error, category
elif re.match(
r"(?i)^[#>*_ ]*(short\s*description|change\s*log\s*entry)", lines[i]
):
i += 1
# Can have one empty line between header and the entry itself.
# Filter it out.
if i < len(lines) and not lines[i]:
i += 1
# All following lines until empty one are the changelog entry.
entry_lines = []
while i < len(lines) and lines[i]:
entry_lines.append(lines[i])
i += 1
entry = " ".join(entry_lines)
# Don't accept changelog entries like '...'.
entry = re.sub(r"[#>*_.\- ]", "", entry)
# Don't accept changelog entries like 'Close #12345'.
entry = re.sub(r"^[\w\-\s]{0,10}#?\d{5,6}\.?$", "", entry)
else:
i += 1
if not category:
description_error = "Changelog category is empty"
# Filter out the PR categories that are not for changelog.
elif "(changelog entry is not required)" in category:
pass # to not check the rest of the conditions
elif category not in CATEGORY_TO_LABEL:
description_error, category = f"Category '{category}' is not valid", ""
elif not entry:
description_error = f"Changelog entry required for category '{category}'"
return description_error, category

View File

@ -1,336 +0,0 @@
#!/usr/bin/env python3
"""
Lambda function to:
- calculate number of running runners
- cleaning dead runners from GitHub
- terminating stale lost runners in EC2
"""
import argparse
import sys
from dataclasses import dataclass
from datetime import datetime
from typing import Dict, List
import boto3 # type: ignore
import requests
from botocore.exceptions import ClientError # type: ignore
from lambda_shared import (
RUNNER_TYPE_LABELS,
RunnerDescription,
RunnerDescriptions,
list_runners,
)
from lambda_shared.token import (
get_access_token_by_key_app,
get_cached_access_token,
get_key_and_app_from_aws,
)
UNIVERSAL_LABEL = "universal"
@dataclass
class LostInstance:
counter: int
seen: datetime
def set_offline(self) -> None:
now = datetime.now()
if now.timestamp() <= self.seen.timestamp() + 120:
# the instance is offline for more than 2 minutes, so we increase
# the counter
self.counter += 1
else:
self.counter = 1
self.seen = now
@property
def recently_offline(self) -> bool:
"""Returns True if the instance has been seen less than 5 minutes ago"""
return datetime.now().timestamp() <= self.seen.timestamp() + 300
@property
def stable_offline(self) -> bool:
return self.counter >= 3
LOST_INSTANCES = {} # type: Dict["str", LostInstance]
def get_dead_runners_in_ec2(runners: RunnerDescriptions) -> RunnerDescriptions:
"""Returns instances that are offline/dead in EC2, or not found in EC2"""
ids = {
runner.name: runner
for runner in runners
# Only `i-deadbead123` are valid names for an instance ID
if runner.name.startswith("i-") and runner.offline and not runner.busy
}
if not ids:
return []
# Delete all offline runners with wrong name
result_to_delete = [
runner
for runner in runners
if not ids.get(runner.name) and runner.offline and not runner.busy
]
client = boto3.client("ec2")
i = 0
inc = 100
print("Checking ids: ", " ".join(ids.keys()))
instances_statuses = []
while i < len(ids.keys()):
try:
instances_statuses.append(
client.describe_instance_status(
InstanceIds=list(ids.keys())[i : i + inc]
)
)
# It applied only if all ids exist in EC2
i += inc
except ClientError as e:
# The list of non-existent instances is in the message:
# The instance IDs 'i-069b1c256c06cf4e3, i-0f26430432b044035,
# i-0faa2ff44edbc147e, i-0eccf2514585045ec, i-0ee4ee53e0daa7d4a,
# i-07928f15acd473bad, i-0eaddda81298f9a85' do not exist
message = e.response["Error"]["Message"]
if message.startswith("The instance IDs '") and message.endswith(
"' do not exist"
):
non_existent = message[18:-14].split(", ")
for n in non_existent:
result_to_delete.append(ids.pop(n))
else:
raise
found_instances = set([])
print("Response", instances_statuses)
for instances_status in instances_statuses:
for instance_status in instances_status["InstanceStatuses"]:
if instance_status["InstanceState"]["Name"] in ("pending", "running"):
found_instances.add(instance_status["InstanceId"])
print("Found instances", found_instances)
for runner in result_to_delete:
print("Instance", runner.name, "is not alive, going to remove it")
for instance_id, runner in ids.items():
if instance_id not in found_instances:
print("Instance", instance_id, "is not found in EC2, going to remove it")
result_to_delete.append(runner)
return result_to_delete
def handler(event, context):
_ = event
_ = context
main(get_cached_access_token(), True)
def delete_runner(access_token: str, runner: RunnerDescription) -> bool:
headers = {
"Authorization": f"token {access_token}",
"Accept": "application/vnd.github.v3+json",
}
response = requests.delete(
f"https://api.github.com/orgs/ClickHouse/actions/runners/{runner.id}",
headers=headers,
timeout=30,
)
response.raise_for_status()
print(f"Response code deleting {runner.name} is {response.status_code}")
return bool(response.status_code == 204)
def get_lost_ec2_instances(runners: RunnerDescriptions) -> List[str]:
global LOST_INSTANCES
now = datetime.now()
client = boto3.client("ec2")
reservations = client.describe_instances(
Filters=[
{"Name": "tag-key", "Values": ["github:runner-type"]},
{"Name": "instance-state-name", "Values": ["pending", "running"]},
],
)["Reservations"]
# flatten the reservation into instances
instances = [
instance
for reservation in reservations
for instance in reservation["Instances"]
]
offline_runner_names = {
runner.name for runner in runners if runner.offline and not runner.busy
}
runner_names = {runner.name for runner in runners}
def offline_instance(iid: str) -> None:
if iid in LOST_INSTANCES:
LOST_INSTANCES[iid].set_offline()
return
LOST_INSTANCES[iid] = LostInstance(1, now)
for instance in instances:
# Do not consider instances started 20 minutes ago as problematic
if now.timestamp() - instance["LaunchTime"].timestamp() < 1200:
continue
runner_type = [
tag["Value"]
for tag in instance["Tags"]
if tag["Key"] == "github:runner-type"
][0]
# If there's no necessary labels in runner type it's fine
if not (UNIVERSAL_LABEL in runner_type or runner_type in RUNNER_TYPE_LABELS):
continue
if instance["InstanceId"] in offline_runner_names:
offline_instance(instance["InstanceId"])
continue
if (
instance["State"]["Name"] == "running"
and not instance["InstanceId"] in runner_names
):
offline_instance(instance["InstanceId"])
instance_ids = [instance["InstanceId"] for instance in instances]
# clean out long unseen instances
LOST_INSTANCES = {
instance_id: stats
for instance_id, stats in LOST_INSTANCES.items()
if stats.recently_offline and instance_id in instance_ids
}
print("The remained LOST_INSTANCES: ", LOST_INSTANCES)
return [
instance_id
for instance_id, stats in LOST_INSTANCES.items()
if stats.stable_offline
]
def continue_lifecycle_hooks(delete_offline_runners: bool) -> None:
"""The function to trigger CONTINUE for instances' lifectycle hooks"""
client = boto3.client("ec2")
reservations = client.describe_instances(
Filters=[
{"Name": "tag-key", "Values": ["github:runner-type"]},
{"Name": "instance-state-name", "Values": ["shutting-down", "terminated"]},
],
)["Reservations"]
# flatten the reservation into instances
terminated_instances = [
instance["InstanceId"]
for reservation in reservations
for instance in reservation["Instances"]
]
asg_client = boto3.client("autoscaling")
as_groups = asg_client.describe_auto_scaling_groups(
Filters=[{"Name": "tag-key", "Values": ["github:runner-type"]}]
)["AutoScalingGroups"]
for asg in as_groups:
lifecycle_hooks = [
lch
for lch in asg_client.describe_lifecycle_hooks(
AutoScalingGroupName=asg["AutoScalingGroupName"]
)["LifecycleHooks"]
if lch["LifecycleTransition"] == "autoscaling:EC2_INSTANCE_TERMINATING"
]
if not lifecycle_hooks:
continue
for instance in asg["Instances"]:
continue_instance = False
if instance["LifecycleState"] == "Terminating:Wait":
if instance["HealthStatus"] == "Unhealthy":
print(f"The instance {instance['InstanceId']} is Unhealthy")
continue_instance = True
elif (
instance["HealthStatus"] == "Healthy"
and instance["InstanceId"] in terminated_instances
):
print(
f"The instance {instance['InstanceId']} is already terminated"
)
continue_instance = True
if continue_instance:
if delete_offline_runners:
for lch in lifecycle_hooks:
print(f"Continue lifecycle hook {lch['LifecycleHookName']}")
asg_client.complete_lifecycle_action(
LifecycleHookName=lch["LifecycleHookName"],
AutoScalingGroupName=asg["AutoScalingGroupName"],
LifecycleActionResult="CONTINUE",
InstanceId=instance["InstanceId"],
)
def main(
access_token: str,
delete_offline_runners: bool,
) -> None:
gh_runners = list_runners(access_token)
dead_runners = get_dead_runners_in_ec2(gh_runners)
print("Runners in GH API to terminate: ", [runner.name for runner in dead_runners])
if delete_offline_runners and dead_runners:
print("Going to delete offline runners")
for runner in dead_runners:
print("Deleting runner", runner)
delete_runner(access_token, runner)
elif dead_runners:
print("Would delete dead runners: ", dead_runners)
lost_instances = get_lost_ec2_instances(gh_runners)
print("Instances to terminate: ", lost_instances)
if delete_offline_runners:
if lost_instances:
print("Going to terminate lost instances")
boto3.client("ec2").terminate_instances(InstanceIds=lost_instances)
continue_lifecycle_hooks(delete_offline_runners)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Get list of runners and their states")
parser.add_argument(
"-p", "--private-key-path", help="Path to file with private key"
)
parser.add_argument("-k", "--private-key", help="Private key")
parser.add_argument(
"-a", "--app-id", type=int, help="GitHub application ID", required=True
)
parser.add_argument(
"--delete-offline", action="store_true", help="Remove offline runners"
)
args = parser.parse_args()
if not args.private_key_path and not args.private_key:
print(
"Either --private-key-path or --private-key must be specified",
file=sys.stderr,
)
if args.private_key_path and args.private_key:
print(
"Either --private-key-path or --private-key must be specified",
file=sys.stderr,
)
if args.private_key:
private_key = args.private_key
elif args.private_key_path:
with open(args.private_key_path, "r", encoding="utf-8") as key_file:
private_key = key_file.read()
else:
print("Attempt to get key and id from AWS secret manager")
private_key, args.app_id = get_key_and_app_from_aws()
token = get_access_token_by_key_app(private_key, args.app_id)
main(token, args.delete_offline)

View File

@ -1 +0,0 @@
../team_keys_lambda/build_and_deploy_archive.sh

View File

@ -1 +0,0 @@
../lambda_shared_package/lambda_shared

View File

@ -1,2 +0,0 @@
../lambda_shared_package
../lambda_shared_package[token]

View File

@ -13,7 +13,6 @@ from git_helper import Git, GIT_PREFIX
from ssh import SSHAgent from ssh import SSHAgent
from env_helper import GITHUB_REPOSITORY, S3_BUILDS_BUCKET from env_helper import GITHUB_REPOSITORY, S3_BUILDS_BUCKET
from s3_helper import S3Helper from s3_helper import S3Helper
from autoscale_runners_lambda.lambda_shared.pr import Labels
from ci_utils import Shell from ci_utils import Shell
from version_helper import ( from version_helper import (
FILE_WITH_VERSION_PATH, FILE_WITH_VERSION_PATH,
@ -220,9 +219,9 @@ class ReleaseInfo:
) )
with checkout(self.release_branch): with checkout(self.release_branch):
with checkout_new(new_release_branch): with checkout_new(new_release_branch):
pr_labels = f"--label {Labels.RELEASE}" pr_labels = f"--label {CI.Labels.RELEASE}"
if stable_release_type == VersionType.LTS: if stable_release_type == VersionType.LTS:
pr_labels += f" --label {Labels.RELEASE_LTS}" pr_labels += f" --label {CI.Labels.RELEASE_LTS}"
cmd_push_branch = ( cmd_push_branch = (
f"{GIT_PREFIX} push --set-upstream origin {new_release_branch}" f"{GIT_PREFIX} push --set-upstream origin {new_release_branch}"
) )

View File

@ -112,8 +112,8 @@ def get_run_command(
] ]
if flaky_check: if flaky_check:
envs.append("-e NUM_TRIES=100") envs.append("-e NUM_TRIES=50")
envs.append("-e MAX_RUN_TIME=1800") envs.append("-e MAX_RUN_TIME=2800")
envs += [f"-e {e}" for e in additional_envs] envs += [f"-e {e}" for e in additional_envs]

View File

@ -1,2 +0,0 @@
build
*.egg-info

View File

@ -1,237 +0,0 @@
"""The shared code and types for all our CI lambdas
It exists as __init__.py and lambda_shared/__init__.py to work both in local and venv"""
import json
import logging
import time
from collections import namedtuple
from typing import Any, Dict, Iterable, List, Optional
import boto3 # type: ignore
import requests
RUNNER_TYPE_LABELS = [
"builder",
"func-tester",
"func-tester-aarch64",
"fuzzer-unit-tester",
"limited-tester",
"stress-tester",
"style-checker",
"style-checker-aarch64",
# private runners
"private-builder",
"private-clickpipes",
"private-func-tester",
"private-fuzzer-unit-tester",
"private-stress-tester",
"private-style-checker",
]
### VENDORING
def get_parameter_from_ssm(
name: str, decrypt: bool = True, client: Optional[Any] = None
) -> str:
if not client:
client = boto3.client("ssm", region_name="us-east-1")
return client.get_parameter(Name=name, WithDecryption=decrypt)[ # type: ignore
"Parameter"
]["Value"]
class CHException(Exception):
pass
class InsertException(CHException):
pass
class ClickHouseHelper:
def __init__(
self,
url: str,
user: Optional[str] = None,
password: Optional[str] = None,
):
self.url = url
self.auth = {}
if user:
self.auth["X-ClickHouse-User"] = user
if password:
self.auth["X-ClickHouse-Key"] = password
@staticmethod
def _insert_json_str_info_impl(
url: str, auth: Dict[str, str], db: str, table: str, json_str: str
) -> None:
params = {
"database": db,
"query": f"INSERT INTO {table} FORMAT JSONEachRow",
"date_time_input_format": "best_effort",
"send_logs_level": "warning",
}
for i in range(5):
try:
response = requests.post(
url, params=params, data=json_str, headers=auth
)
except Exception as e:
error = f"Received exception while sending data to {url} on {i} attempt: {e}"
logging.warning(error)
continue
logging.info("Response content '%s'", response.content)
if response.ok:
break
error = (
"Cannot insert data into clickhouse at try "
+ str(i)
+ ": HTTP code "
+ str(response.status_code)
+ ": '"
+ str(response.text)
+ "'"
)
if response.status_code >= 500:
# A retriable error
time.sleep(1)
continue
logging.info(
"Request headers '%s', body '%s'",
response.request.headers,
response.request.body,
)
raise InsertException(error)
else:
raise InsertException(error)
def _insert_json_str_info(self, db: str, table: str, json_str: str) -> None:
self._insert_json_str_info_impl(self.url, self.auth, db, table, json_str)
def insert_event_into(
self, db: str, table: str, event: object, safe: bool = True
) -> None:
event_str = json.dumps(event)
try:
self._insert_json_str_info(db, table, event_str)
except InsertException as e:
logging.error(
"Exception happened during inserting data into clickhouse: %s", e
)
if not safe:
raise
def insert_events_into(
self, db: str, table: str, events: Iterable[object], safe: bool = True
) -> None:
jsons = []
for event in events:
jsons.append(json.dumps(event))
try:
self._insert_json_str_info(db, table, ",".join(jsons))
except InsertException as e:
logging.error(
"Exception happened during inserting data into clickhouse: %s", e
)
if not safe:
raise
def _select_and_get_json_each_row(self, db: str, query: str) -> str:
params = {
"database": db,
"query": query,
"default_format": "JSONEachRow",
}
for i in range(5):
response = None
try:
response = requests.get(self.url, params=params, headers=self.auth)
response.raise_for_status()
return response.text # type: ignore
except Exception as ex:
logging.warning("Cannot fetch data with exception %s", str(ex))
if response:
logging.warning("Reponse text %s", response.text)
time.sleep(0.1 * i)
raise CHException("Cannot fetch data from clickhouse")
def select_json_each_row(self, db: str, query: str) -> List[dict]:
text = self._select_and_get_json_each_row(db, query)
result = []
for line in text.split("\n"):
if line:
result.append(json.loads(line))
return result
### Runners
RunnerDescription = namedtuple(
"RunnerDescription", ["id", "name", "tags", "offline", "busy"]
)
RunnerDescriptions = List[RunnerDescription]
def list_runners(access_token: str) -> RunnerDescriptions:
headers = {
"Authorization": f"token {access_token}",
"Accept": "application/vnd.github.v3+json",
}
per_page = 100
response = requests.get(
f"https://api.github.com/orgs/ClickHouse/actions/runners?per_page={per_page}",
headers=headers,
)
response.raise_for_status()
data = response.json()
total_runners = data["total_count"]
print("Expected total runners", total_runners)
runners = data["runners"]
# round to 0 for 0, 1 for 1..100, but to 2 for 101..200
total_pages = (total_runners - 1) // per_page + 1
print("Total pages", total_pages)
for i in range(2, total_pages + 1):
response = requests.get(
"https://api.github.com/orgs/ClickHouse/actions/runners"
f"?page={i}&per_page={per_page}",
headers=headers,
)
response.raise_for_status()
data = response.json()
runners += data["runners"]
print("Total runners", len(runners))
result = []
for runner in runners:
tags = [tag["name"] for tag in runner["labels"]]
desc = RunnerDescription(
id=runner["id"],
name=runner["name"],
tags=tags,
offline=runner["status"] == "offline",
busy=runner["busy"],
)
result.append(desc)
return result
def cached_value_is_valid(updated_at: float, ttl: float) -> bool:
"a common function to identify if cachable value is still valid"
if updated_at == 0:
return False
if time.time() - ttl < updated_at:
return True
return False

View File

@ -1,168 +0,0 @@
#!/usr/bin/env python
import re
from typing import Tuple
# Individual trusted contributors who are not in any trusted organization.
# Can be changed in runtime: we will append users that we learned to be in
# a trusted org, to save GitHub API calls.
TRUSTED_CONTRIBUTORS = {
e.lower()
for e in [
"amosbird",
"azat", # SEMRush
"bharatnc", # Many contributions.
"cwurm", # ClickHouse, Inc
"den-crane", # Documentation contributor
"ildus", # adjust, ex-pgpro
"nvartolomei", # Seasoned contributor, CloudFlare
"taiyang-li",
"ucasFL", # Amos Bird's friend
"thomoco", # ClickHouse, Inc
"tonickkozlov", # Cloudflare
"tylerhannan", # ClickHouse, Inc
"tsolodov", # ClickHouse, Inc
"justindeguzman", # ClickHouse, Inc
"XuJia0210", # ClickHouse, Inc
]
}
class Labels:
PR_BUGFIX = "pr-bugfix"
PR_CRITICAL_BUGFIX = "pr-critical-bugfix"
CAN_BE_TESTED = "can be tested"
DO_NOT_TEST = "do not test"
MUST_BACKPORT = "pr-must-backport"
MUST_BACKPORT_CLOUD = "pr-must-backport-cloud"
JEPSEN_TEST = "jepsen-test"
SKIP_MERGEABLE_CHECK = "skip mergeable check"
PR_BACKPORT = "pr-backport"
PR_BACKPORTS_CREATED = "pr-backports-created"
PR_BACKPORTS_CREATED_CLOUD = "pr-backports-created-cloud"
PR_CHERRYPICK = "pr-cherrypick"
PR_CI = "pr-ci"
PR_FEATURE = "pr-feature"
PR_SYNCED_TO_CLOUD = "pr-synced-to-cloud"
PR_SYNC_UPSTREAM = "pr-sync-upstream"
RELEASE = "release"
RELEASE_LTS = "release-lts"
SUBMODULE_CHANGED = "submodule changed"
# automatic backport for critical bug fixes
AUTO_BACKPORT = {"pr-critical-bugfix"}
# Descriptions are used in .github/PULL_REQUEST_TEMPLATE.md, keep comments there
# updated accordingly
# The following lists are append only, try to avoid editing them
# They still could be cleaned out after the decent time though.
LABEL_CATEGORIES = {
"pr-backward-incompatible": ["Backward Incompatible Change"],
"pr-bugfix": [
"Bug Fix",
"Bug Fix (user-visible misbehavior in an official stable release)",
"Bug Fix (user-visible misbehaviour in official stable or prestable release)",
"Bug Fix (user-visible misbehavior in official stable or prestable release)",
],
"pr-critical-bugfix": ["Critical Bug Fix (crash, LOGICAL_ERROR, data loss, RBAC)"],
"pr-build": [
"Build/Testing/Packaging Improvement",
"Build Improvement",
"Build/Testing Improvement",
"Build",
"Packaging Improvement",
],
"pr-documentation": [
"Documentation (changelog entry is not required)",
"Documentation",
],
"pr-feature": ["New Feature"],
"pr-improvement": ["Improvement"],
"pr-not-for-changelog": [
"Not for changelog (changelog entry is not required)",
"Not for changelog",
],
"pr-performance": ["Performance Improvement"],
"pr-ci": ["CI Fix or Improvement (changelog entry is not required)"],
}
CATEGORY_TO_LABEL = {
c: lb for lb, categories in LABEL_CATEGORIES.items() for c in categories
}
def check_pr_description(pr_body: str, repo_name: str) -> Tuple[str, str]:
"""The function checks the body to being properly formatted according to
.github/PULL_REQUEST_TEMPLATE.md, if the first returned string is not empty,
then there is an error."""
lines = list(map(lambda x: x.strip(), pr_body.split("\n") if pr_body else []))
lines = [re.sub(r"\s+", " ", line) for line in lines]
# Check if body contains "Reverts ClickHouse/ClickHouse#36337"
if [True for line in lines if re.match(rf"\AReverts {repo_name}#[\d]+\Z", line)]:
return "", LABEL_CATEGORIES["pr-not-for-changelog"][0]
category = ""
entry = ""
description_error = ""
i = 0
while i < len(lines):
if re.match(r"(?i)^[#>*_ ]*change\s*log\s*category", lines[i]):
i += 1
if i >= len(lines):
break
# Can have one empty line between header and the category
# itself. Filter it out.
if not lines[i]:
i += 1
if i >= len(lines):
break
category = re.sub(r"^[-*\s]*", "", lines[i])
i += 1
# Should not have more than one category. Require empty line
# after the first found category.
if i >= len(lines):
break
if lines[i]:
second_category = re.sub(r"^[-*\s]*", "", lines[i])
description_error = (
"More than one changelog category specified: "
f"'{category}', '{second_category}'"
)
return description_error, category
elif re.match(
r"(?i)^[#>*_ ]*(short\s*description|change\s*log\s*entry)", lines[i]
):
i += 1
# Can have one empty line between header and the entry itself.
# Filter it out.
if i < len(lines) and not lines[i]:
i += 1
# All following lines until empty one are the changelog entry.
entry_lines = []
while i < len(lines) and lines[i]:
entry_lines.append(lines[i])
i += 1
entry = " ".join(entry_lines)
# Don't accept changelog entries like '...'.
entry = re.sub(r"[#>*_.\- ]", "", entry)
# Don't accept changelog entries like 'Close #12345'.
entry = re.sub(r"^[\w\-\s]{0,10}#?\d{5,6}\.?$", "", entry)
else:
i += 1
if not category:
description_error = "Changelog category is empty"
# Filter out the PR categories that are not for changelog.
elif "(changelog entry is not required)" in category:
pass # to not check the rest of the conditions
elif category not in CATEGORY_TO_LABEL:
description_error, category = f"Category '{category}' is not valid", ""
elif not entry:
description_error = f"Changelog entry required for category '{category}'"
return description_error, category

View File

@ -1,95 +0,0 @@
"""Module to get the token for GitHub"""
from dataclasses import dataclass
import json
import time
from typing import Tuple
import boto3 # type: ignore
import jwt
import requests
from . import cached_value_is_valid
def get_key_and_app_from_aws() -> Tuple[str, int]:
secret_name = "clickhouse_github_secret_key"
session = boto3.session.Session()
client = session.client(
service_name="secretsmanager",
)
get_secret_value_response = client.get_secret_value(SecretId=secret_name)
data = json.loads(get_secret_value_response["SecretString"])
return data["clickhouse-app-key"], int(data["clickhouse-app-id"])
def get_installation_id(jwt_token: str) -> int:
headers = {
"Authorization": f"Bearer {jwt_token}",
"Accept": "application/vnd.github.v3+json",
}
response = requests.get("https://api.github.com/app/installations", headers=headers)
response.raise_for_status()
data = response.json()
for installation in data:
if installation["account"]["login"] == "ClickHouse":
installation_id = installation["id"]
return installation_id # type: ignore
def get_access_token_by_jwt(jwt_token: str, installation_id: int) -> str:
headers = {
"Authorization": f"Bearer {jwt_token}",
"Accept": "application/vnd.github.v3+json",
}
response = requests.post(
f"https://api.github.com/app/installations/{installation_id}/access_tokens",
headers=headers,
)
response.raise_for_status()
data = response.json()
return data["token"] # type: ignore
def get_token_from_aws() -> str:
private_key, app_id = get_key_and_app_from_aws()
return get_access_token_by_key_app(private_key, app_id)
def get_access_token_by_key_app(private_key: str, app_id: int) -> str:
payload = {
"iat": int(time.time()) - 60,
"exp": int(time.time()) + (10 * 60),
"iss": app_id,
}
# FIXME: apparently should be switched to this so that mypy is happy
# jwt_instance = JWT()
# encoded_jwt = jwt_instance.encode(payload, private_key, algorithm="RS256")
encoded_jwt = jwt.encode(payload, private_key, algorithm="RS256") # type: ignore
installation_id = get_installation_id(encoded_jwt)
return get_access_token_by_jwt(encoded_jwt, installation_id)
@dataclass
class CachedToken:
time: float
value: str
updating: bool = False
_cached_token = CachedToken(0, "")
def get_cached_access_token() -> str:
if time.time() - 550 < _cached_token.time or _cached_token.updating:
return _cached_token.value
# Indicate that the value is updating now, so the cached value can be
# used. The first setting and close-to-ttl are not counted as update
_cached_token.updating = cached_value_is_valid(_cached_token.time, 590)
private_key, app_id = get_key_and_app_from_aws()
_cached_token.time = time.time()
_cached_token.value = get_access_token_by_key_app(private_key, app_id)
_cached_token.updating = False
return _cached_token.value

View File

@ -1,24 +0,0 @@
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[project]
name = "lambda_shared"
version = "0.0.1"
dependencies = [
"requests",
"urllib3 < 2"
]
[project.optional-dependencies]
token = [
"PyJWT",
"cryptography",
]
dev = [
"boto3",
"lambda_shared[token]",
]
[tool.distutils.bdist_wheel]
universal = true

View File

@ -1,8 +0,0 @@
### This file exists for clear builds in docker ###
# without it the `build` directory wouldn't be #
# updated on the fly and will require manual clean #
[build]
build_base = /tmp/lambda_shared
[egg_info]
egg_base = /tmp/

View File

@ -15,7 +15,7 @@ from env_helper import (
GITHUB_SERVER_URL, GITHUB_SERVER_URL,
GITHUB_UPSTREAM_REPOSITORY, GITHUB_UPSTREAM_REPOSITORY,
) )
from lambda_shared_package.lambda_shared.pr import Labels from ci_config import Labels
from get_robot_token import get_best_robot_token from get_robot_token import get_best_robot_token
from github_helper import GitHub from github_helper import GitHub

View File

@ -25,7 +25,7 @@ from contextlib import contextmanager
from typing import Any, Final, Iterator, List, Optional, Tuple from typing import Any, Final, Iterator, List, Optional, Tuple
from git_helper import Git, commit, release_branch from git_helper import Git, commit, release_branch
from lambda_shared_package.lambda_shared.pr import Labels from ci_config import Labels
from report import SUCCESS from report import SUCCESS
from version_helper import ( from version_helper import (
FILE_WITH_VERSION_PATH, FILE_WITH_VERSION_PATH,

View File

@ -15,26 +15,22 @@ from commit_status_helper import (
) )
from env_helper import GITHUB_REPOSITORY, GITHUB_SERVER_URL from env_helper import GITHUB_REPOSITORY, GITHUB_SERVER_URL
from get_robot_token import get_best_robot_token from get_robot_token import get_best_robot_token
from lambda_shared_package.lambda_shared.pr import ( from ci_config import CI
CATEGORY_TO_LABEL, from ci_utils import Utils
TRUSTED_CONTRIBUTORS,
Labels,
check_pr_description,
)
from pr_info import PRInfo from pr_info import PRInfo
from report import FAILURE, PENDING, SUCCESS, StatusType from report import FAILURE, PENDING, SUCCESS, StatusType
from ci_config import CI
TRUSTED_ORG_IDS = { TRUSTED_ORG_IDS = {
54801242, # clickhouse 54801242, # clickhouse
} }
OK_SKIP_LABELS = {Labels.RELEASE, Labels.PR_BACKPORT, Labels.PR_CHERRYPICK} OK_SKIP_LABELS = {CI.Labels.RELEASE, CI.Labels.PR_BACKPORT, CI.Labels.PR_CHERRYPICK}
PR_CHECK = "PR Check" PR_CHECK = "PR Check"
def pr_is_by_trusted_user(pr_user_login, pr_user_orgs): def pr_is_by_trusted_user(pr_user_login, pr_user_orgs):
if pr_user_login.lower() in TRUSTED_CONTRIBUTORS: if pr_user_login.lower() in CI.TRUSTED_CONTRIBUTORS:
logging.info("User '%s' is trusted", pr_user_login) logging.info("User '%s' is trusted", pr_user_login)
return True return True
@ -63,13 +59,13 @@ def should_run_ci_for_pr(pr_info: PRInfo) -> Tuple[bool, str]:
if OK_SKIP_LABELS.intersection(pr_info.labels): if OK_SKIP_LABELS.intersection(pr_info.labels):
return True, "Don't try new checks for release/backports/cherry-picks" return True, "Don't try new checks for release/backports/cherry-picks"
if Labels.CAN_BE_TESTED not in pr_info.labels and not pr_is_by_trusted_user( if CI.Labels.CAN_BE_TESTED not in pr_info.labels and not pr_is_by_trusted_user(
pr_info.user_login, pr_info.user_orgs pr_info.user_login, pr_info.user_orgs
): ):
logging.info( logging.info(
"PRs by untrusted users need the '%s' label - " "PRs by untrusted users need the '%s' label - "
"please contact a member of the core team", "please contact a member of the core team",
Labels.CAN_BE_TESTED, CI.Labels.CAN_BE_TESTED,
) )
return False, "Needs 'can be tested' label" return False, "Needs 'can be tested' label"
@ -96,30 +92,32 @@ def main():
commit = get_commit(gh, pr_info.sha) commit = get_commit(gh, pr_info.sha)
status = SUCCESS # type: StatusType status = SUCCESS # type: StatusType
description_error, category = check_pr_description(pr_info.body, GITHUB_REPOSITORY) description_error, category = Utils.check_pr_description(
pr_info.body, GITHUB_REPOSITORY
)
pr_labels_to_add = [] pr_labels_to_add = []
pr_labels_to_remove = [] pr_labels_to_remove = []
if ( if (
category in CATEGORY_TO_LABEL category in CI.CATEGORY_TO_LABEL
and CATEGORY_TO_LABEL[category] not in pr_info.labels and CI.CATEGORY_TO_LABEL[category] not in pr_info.labels
): ):
pr_labels_to_add.append(CATEGORY_TO_LABEL[category]) pr_labels_to_add.append(CI.CATEGORY_TO_LABEL[category])
for label in pr_info.labels: for label in pr_info.labels:
if ( if (
label in CATEGORY_TO_LABEL.values() label in CI.CATEGORY_TO_LABEL.values()
and category in CATEGORY_TO_LABEL and category in CI.CATEGORY_TO_LABEL
and label != CATEGORY_TO_LABEL[category] and label != CI.CATEGORY_TO_LABEL[category]
): ):
pr_labels_to_remove.append(label) pr_labels_to_remove.append(label)
if pr_info.has_changes_in_submodules(): if pr_info.has_changes_in_submodules():
pr_labels_to_add.append(Labels.SUBMODULE_CHANGED) pr_labels_to_add.append(CI.Labels.SUBMODULE_CHANGED)
elif Labels.SUBMODULE_CHANGED in pr_info.labels: elif CI.Labels.SUBMODULE_CHANGED in pr_info.labels:
pr_labels_to_remove.append(Labels.SUBMODULE_CHANGED) pr_labels_to_remove.append(CI.Labels.SUBMODULE_CHANGED)
if any(label in Labels.AUTO_BACKPORT for label in pr_labels_to_add): if any(label in CI.Labels.AUTO_BACKPORT for label in pr_labels_to_add):
backport_labels = [Labels.MUST_BACKPORT, Labels.MUST_BACKPORT_CLOUD] backport_labels = [CI.Labels.MUST_BACKPORT, CI.Labels.MUST_BACKPORT_CLOUD]
pr_labels_to_add += [ pr_labels_to_add += [
label for label in backport_labels if label not in pr_info.labels label for label in backport_labels if label not in pr_info.labels
] ]
@ -164,15 +162,15 @@ def main():
# 2. Then we check if the documentation is not created to fail the Mergeable check # 2. Then we check if the documentation is not created to fail the Mergeable check
if ( if (
Labels.PR_FEATURE in pr_info.labels CI.Labels.PR_FEATURE in pr_info.labels
and not pr_info.has_changes_in_documentation() and not pr_info.has_changes_in_documentation()
): ):
print( print(
f"::error ::The '{Labels.PR_FEATURE}' in the labels, " f"::error ::The '{CI.Labels.PR_FEATURE}' in the labels, "
"but there's no changed documentation" "but there's no changed documentation"
) )
status = FAILURE status = FAILURE
description = f"expect adding docs for {Labels.PR_FEATURE}" description = f"expect adding docs for {CI.Labels.PR_FEATURE}"
# 3. But we allow the workflow to continue # 3. But we allow the workflow to continue
# 4. And post only a single commit status on a failure # 4. And post only a single commit status on a failure

View File

@ -1,93 +0,0 @@
#!/usr/bin/env python3
import argparse
import sys
import boto3 # type: ignore
import requests
from lambda_shared.token import get_access_token_by_key_app, get_cached_access_token
def get_runner_registration_token(access_token):
headers = {
"Authorization": f"token {access_token}",
"Accept": "application/vnd.github.v3+json",
}
response = requests.post(
"https://api.github.com/orgs/ClickHouse/actions/runners/registration-token",
headers=headers,
timeout=30,
)
response.raise_for_status()
data = response.json()
return data["token"]
def main(access_token, push_to_ssm, ssm_parameter_name):
runner_registration_token = get_runner_registration_token(access_token)
if push_to_ssm:
print("Trying to put params into ssm manager")
client = boto3.client("ssm")
client.put_parameter(
Name=ssm_parameter_name,
Value=runner_registration_token,
Type="SecureString",
Overwrite=True,
)
else:
print(
"Not push token to AWS Parameter Store, just print:",
runner_registration_token,
)
def handler(event, context):
_, _ = event, context
main(get_cached_access_token(), True, "github_runner_registration_token")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Get new token from github to add runners"
)
parser.add_argument(
"-p", "--private-key-path", help="Path to file with private key"
)
parser.add_argument("-k", "--private-key", help="Private key")
parser.add_argument(
"-a", "--app-id", type=int, help="GitHub application ID", required=True
)
parser.add_argument(
"--push-to-ssm",
action="store_true",
help="Store received token in parameter store",
)
parser.add_argument(
"--ssm-parameter-name",
default="github_runner_registration_token",
help="AWS paramater store parameter name",
)
args = parser.parse_args()
if not args.private_key_path and not args.private_key:
print(
"Either --private-key-path or --private-key must be specified",
file=sys.stderr,
)
if args.private_key_path and args.private_key:
print(
"Either --private-key-path or --private-key must be specified",
file=sys.stderr,
)
if args.private_key:
private_key = args.private_key
else:
with open(args.private_key_path, "r", encoding="utf-8") as key_file:
private_key = key_file.read()
token = get_access_token_by_key_app(private_key, args.app_id)
main(token, args.push_to_ssm, args.ssm_parameter_name)

View File

@ -1 +0,0 @@
../team_keys_lambda/build_and_deploy_archive.sh

View File

@ -1 +0,0 @@
../lambda_shared_package/lambda_shared

View File

@ -1 +0,0 @@
../lambda_shared_package[token]

View File

@ -1,323 +0,0 @@
#!/usr/bin/env python3
"""
A trivial stateless slack bot that notifies about new broken tests in ClickHouse CI.
It checks what happened to our CI during the last check_period hours (1 hour) and
notifies us in slack if necessary.
This script should be executed once each check_period hours (1 hour).
It will post duplicate messages if you run it more often; it will lose some messages
if you run it less often.
You can run it locally with no arguments, it will work in a dry-run mode.
Or you can set your own SLACK_URL_DEFAULT.
Feel free to add more checks, more details to messages, or better heuristics.
It's deployed to slack-bot-ci-lambda in CI/CD account
See also: https://aretestsgreenyet.com/
"""
import base64
import json
import os
import random
import requests
DRY_RUN_MARK = "<no url, dry run>"
MAX_FAILURES_DEFAULT = 30
SLACK_URL_DEFAULT = DRY_RUN_MARK
FLAKY_ALERT_PROBABILITY = 0.50
REPORT_NO_FAILURES_PROBABILITY = 0.99
MAX_TESTS_TO_REPORT = 4
# Slack has a stupid limitation on message size, it splits long messages into multiple,
# ones breaking formatting
MESSAGE_LENGTH_LIMIT = 4000
# Find tests that failed in master during the last check_period * 24 hours,
# but did not fail during the last 2 weeks. Assuming these tests were broken recently.
# Counts number of failures in check_period and check_period * 24 time windows
# to distinguish rare flaky tests from completely broken tests
NEW_BROKEN_TESTS_QUERY = """
WITH
1 AS check_period,
check_period * 24 AS extended_check_period,
now() as now
SELECT
test_name,
any(report_url),
countIf((check_start_time + check_duration_ms / 1000) < now - INTERVAL check_period HOUR) AS count_prev_periods,
countIf((check_start_time + check_duration_ms / 1000) >= now - INTERVAL check_period HOUR) AS count
FROM checks
WHERE 1
AND check_start_time BETWEEN now - INTERVAL 1 WEEK AND now
AND (check_start_time + check_duration_ms / 1000) >= now - INTERVAL extended_check_period HOUR
AND pull_request_number = 0
AND test_status LIKE 'F%'
AND check_status != 'success'
AND test_name NOT IN (
SELECT test_name FROM checks WHERE 1
AND check_start_time >= now - INTERVAL 1 MONTH
AND (check_start_time + check_duration_ms / 1000) BETWEEN now - INTERVAL 2 WEEK AND now - INTERVAL extended_check_period HOUR
AND pull_request_number = 0
AND check_status != 'success'
AND test_status LIKE 'F%')
AND test_context_raw NOT LIKE '%CannotSendRequest%' and test_context_raw NOT LIKE '%Server does not respond to health check%'
GROUP BY test_name
ORDER BY (count_prev_periods + count) DESC
"""
# Returns total number of failed checks during the last 24 hours
# and previous value of that metric (check_period hours ago)
COUNT_FAILURES_QUERY = """
WITH
1 AS check_period,
'%' AS check_name_pattern,
now() as now
SELECT
countIf((check_start_time + check_duration_ms / 1000) >= now - INTERVAL 24 HOUR) AS new_val,
countIf((check_start_time + check_duration_ms / 1000) <= now - INTERVAL check_period HOUR) AS prev_val
FROM checks
WHERE 1
AND check_start_time >= now - INTERVAL 1 WEEK
AND (check_start_time + check_duration_ms / 1000) >= now - INTERVAL 24 + check_period HOUR
AND pull_request_number = 0
AND test_status LIKE 'F%'
AND check_status != 'success'
AND check_name ILIKE check_name_pattern
"""
# Returns percentage of failed checks (once per day, at noon)
FAILED_CHECKS_PERCENTAGE_QUERY = """
SELECT if(toHour(now('Europe/Amsterdam')) = 12, v, 0)
FROM
(
SELECT
countDistinctIf((commit_sha, check_name), (test_status LIKE 'F%') AND (check_status != 'success'))
/ countDistinct((commit_sha, check_name)) AS v
FROM checks
WHERE 1
AND (pull_request_number = 0)
AND (test_status != 'SKIPPED')
AND (check_start_time > (now() - toIntervalDay(1)))
)
"""
# It shows all recent failures of the specified test (helps to find when it started)
ALL_RECENT_FAILURES_QUERY = """
WITH
'{}' AS name_substr,
90 AS interval_days,
('Stateless tests (asan)', 'Stateless tests (address)', 'Stateless tests (address, actions)', 'Integration tests (asan) [1/3]', 'Stateless tests (tsan) [1/3]') AS backport_and_release_specific_checks
SELECT
toStartOfDay(check_start_time) AS d,
count(),
groupUniqArray(pull_request_number) AS prs,
any(report_url)
FROM checks
WHERE ((now() - toIntervalDay(interval_days)) <= check_start_time) AND (pull_request_number NOT IN (
SELECT pull_request_number AS prn
FROM checks
WHERE (prn != 0) AND ((now() - toIntervalDay(interval_days)) <= check_start_time) AND (check_name IN (backport_and_release_specific_checks))
)) AND (position(test_name, name_substr) > 0) AND (test_status IN ('FAIL', 'ERROR', 'FLAKY'))
GROUP BY d
ORDER BY d DESC
"""
SLACK_MESSAGE_JSON = {"type": "mrkdwn", "text": None}
def get_play_url(query):
return (
"https://play.clickhouse.com/play?user=play#"
+ base64.b64encode(query.encode()).decode()
)
def run_clickhouse_query(query):
url = "https://play.clickhouse.com/?user=play&query=" + requests.compat.quote(query)
res = requests.get(url, timeout=30)
if res.status_code != 200:
print("Failed to execute query: ", res.status_code, res.content)
res.raise_for_status()
lines = res.text.strip().splitlines()
return [x.split("\t") for x in lines]
def split_broken_and_flaky_tests(failed_tests):
if not failed_tests:
return None
broken_tests = []
flaky_tests = []
for name, report, count_prev_str, count_str in failed_tests:
count_prev, count = int(count_prev_str), int(count_str)
if (count_prev < 2 <= count) or (count_prev == count == 1):
# It failed 2 times or more within extended time window, it's definitely broken.
# 2 <= count means that it was not reported as broken on previous runs
broken_tests.append([name, report])
elif 0 < count and count_prev == 0:
# It failed only once, can be a rare flaky test
flaky_tests.append([name, report])
return broken_tests, flaky_tests
def format_failed_tests_list(failed_tests, failure_type):
if len(failed_tests) == 1:
res = f"There is a new {failure_type} test:\n"
else:
res = f"There are {len(failed_tests)} new {failure_type} tests:\n"
for name, report in failed_tests[:MAX_TESTS_TO_REPORT]:
cidb_url = get_play_url(ALL_RECENT_FAILURES_QUERY.format(name))
res += f"- *{name}* - <{report}|Report> - <{cidb_url}|CI DB> \n"
if MAX_TESTS_TO_REPORT < len(failed_tests):
res += (
f"- and {len(failed_tests) - MAX_TESTS_TO_REPORT} other "
"tests... :this-is-fine-fire:"
)
return res
def get_new_broken_tests_message(failed_tests):
if not failed_tests:
return None
broken_tests, flaky_tests = split_broken_and_flaky_tests(failed_tests)
if len(broken_tests) == 0 and len(flaky_tests) == 0:
return None
msg = ""
if len(broken_tests) > 0:
msg += format_failed_tests_list(broken_tests, "*BROKEN*")
elif random.random() > FLAKY_ALERT_PROBABILITY:
looks_like_fuzzer = [x[0].count(" ") > 2 for x in flaky_tests]
if not any(looks_like_fuzzer):
print("Will not report flaky tests to avoid noise: ", flaky_tests)
return None
if len(flaky_tests) > 0:
if len(msg) > 0:
msg += "\n"
msg += format_failed_tests_list(flaky_tests, "flaky")
return msg
def get_too_many_failures_message_impl(failures_count):
MAX_FAILURES = int(os.environ.get("MAX_FAILURES", MAX_FAILURES_DEFAULT))
curr_failures = int(failures_count[0][0])
prev_failures = int(failures_count[0][1])
if curr_failures == 0 and prev_failures != 0:
if random.random() < REPORT_NO_FAILURES_PROBABILITY:
return None
return "Wow, there are *no failures* at all... 0_o"
return_none = (
curr_failures < MAX_FAILURES
or curr_failures < prev_failures
or (curr_failures - prev_failures) / prev_failures < 0.2
)
if return_none:
return None
if prev_failures < MAX_FAILURES:
return f":alert: *CI is broken: there are {curr_failures} failures during the last 24 hours*"
return "CI is broken and it's getting worse: there are {curr_failures} failures during the last 24 hours"
def get_too_many_failures_message(failures_count):
msg = get_too_many_failures_message_impl(failures_count)
if msg:
msg += "\nSee https://aretestsgreenyet.com/"
return msg
def get_failed_checks_percentage_message(percentage):
p = float(percentage[0][0]) * 100
# Always report more than 1% of failed checks
# For <= 1%: higher percentage of failures == higher probability
if p <= random.random():
return None
msg = ":alert: " if p > 1 else "Only " if p < 0.5 else ""
msg += f"*{p:.2f}%* of all checks in master have failed yesterday"
return msg
def split_slack_message(long_message):
lines = long_message.split("\n")
messages = []
curr_msg = ""
for line in lines:
if len(curr_msg) + len(line) < MESSAGE_LENGTH_LIMIT:
curr_msg += "\n"
curr_msg += line
else:
messages.append(curr_msg)
curr_msg = line
messages.append(curr_msg)
return messages
def send_to_slack_impl(message):
SLACK_URL = os.environ.get("SLACK_URL", SLACK_URL_DEFAULT)
if SLACK_URL == DRY_RUN_MARK:
return
payload = SLACK_MESSAGE_JSON.copy()
payload["text"] = message
res = requests.post(SLACK_URL, json.dumps(payload), timeout=30)
if res.status_code != 200:
print("Failed to send a message to Slack: ", res.status_code, res.content)
res.raise_for_status()
def send_to_slack(message):
messages = split_slack_message(message)
for msg in messages:
send_to_slack_impl(msg)
def query_and_alert_if_needed(query, get_message_func):
query_res = run_clickhouse_query(query)
print("Got result {} for query {}", query_res, query)
msg = get_message_func(query_res)
if msg is None:
return
msg += f"\nCI DB query: <{get_play_url(query)}|link>"
print("Sending message to slack:", msg)
send_to_slack(msg)
def check_and_alert():
query_and_alert_if_needed(NEW_BROKEN_TESTS_QUERY, get_new_broken_tests_message)
query_and_alert_if_needed(COUNT_FAILURES_QUERY, get_too_many_failures_message)
query_and_alert_if_needed(
FAILED_CHECKS_PERCENTAGE_QUERY, get_failed_checks_percentage_message
)
def handler(event, context):
_, _ = event, context
try:
check_and_alert()
return {"statusCode": 200, "body": "OK"}
except Exception as e:
send_to_slack(
"I failed, please help me "
f"(see ClickHouse/ClickHouse/tests/ci/slack_bot_ci_lambda/app.py): {e}"
)
return {"statusCode": 200, "body": "FAIL"}
if __name__ == "__main__":
check_and_alert()

View File

@ -1 +0,0 @@
../team_keys_lambda/build_and_deploy_archive.sh

View File

@ -1 +0,0 @@
../lambda_shared_package

View File

@ -1,136 +0,0 @@
#!/usr/bin/env python3
import argparse
import json
from datetime import datetime
from queue import Queue
from threading import Thread
import boto3 # type: ignore
import requests
class Keys(set):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.updated_at = 0.0
def update_now(self):
self.updated_at = datetime.now().timestamp()
keys = Keys()
class Worker(Thread):
def __init__(self, request_queue):
Thread.__init__(self)
self.queue = request_queue
self.results = set()
def run(self):
while True:
m = self.queue.get()
if m == "":
break
response = requests.get(f"https://github.com/{m}.keys", timeout=30)
self.results.add(f"# {m}\n{response.text}\n")
self.queue.task_done()
def get_org_team_members(token: str, org: str, team_slug: str) -> set:
headers = {
"Authorization": f"token {token}",
"Accept": "application/vnd.github.v3+json",
}
response = requests.get(
f"https://api.github.com/orgs/{org}/teams/{team_slug}/members",
headers=headers,
timeout=30,
)
response.raise_for_status()
data = response.json()
return set(m["login"] for m in data)
def get_cached_members_keys(members: set) -> Keys:
if (datetime.now().timestamp() - 3600) <= keys.updated_at:
return keys
q = Queue() # type: Queue
workers = []
for m in members:
q.put(m)
# Create workers and add to the queue
worker = Worker(q)
worker.start()
workers.append(worker)
# Workers keep working till they receive an empty string
for _ in workers:
q.put("")
# Join workers to wait till they finished
for worker in workers:
worker.join()
keys.clear()
for worker in workers:
keys.update(worker.results)
keys.update_now()
return keys
def get_token_from_aws() -> str:
# We need a separate token, since the clickhouse-ci app does not have
# access to the organization members' endpoint
secret_name = "clickhouse_robot_token"
session = boto3.session.Session()
client = session.client(
service_name="secretsmanager",
)
get_secret_value_response = client.get_secret_value(SecretId=secret_name)
data = json.loads(get_secret_value_response["SecretString"])
return data["clickhouse_robot_token"] # type: ignore
def main(token: str, org: str, team_slug: str) -> str:
members = get_org_team_members(token, org, team_slug)
keys = get_cached_members_keys(members)
return "".join(sorted(keys))
def handler(event, context):
_ = context
_ = event
if keys.updated_at < (datetime.now().timestamp() - 3600):
token = get_token_from_aws()
body = main(token, "ClickHouse", "core")
else:
body = "".join(sorted(keys))
result = {
"statusCode": 200,
"headers": {
"Content-Type": "text/html",
},
"body": body,
}
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Get the public SSH keys for members of given org and team"
)
parser.add_argument("--token", required=True, help="Github PAT")
parser.add_argument(
"--organization", help="GitHub organization name", default="ClickHouse"
)
parser.add_argument("--team", help="GitHub team name", default="core")
args = parser.parse_args()
output = main(args.token, args.organization, args.team)
print(f"# Just showing off the keys:\n{output}")

View File

@ -1,76 +0,0 @@
#!/usr/bin/env bash
set -xeo pipefail
WORKDIR=$(dirname "$0")
WORKDIR=$(readlink -f "${WORKDIR}")
DIR_NAME=$(basename "$WORKDIR")
cd "$WORKDIR"
# Do not deploy the lambda to AWS
DRY_RUN=${DRY_RUN:-}
# Python runtime to install dependencies
PY_VERSION=${PY_VERSION:-3.10}
PY_EXEC="python${PY_VERSION}"
# Image to build the lambda zip package
DOCKER_IMAGE="public.ecr.aws/lambda/python:${PY_VERSION}"
# Rename the_lambda_name directory to the-lambda-name lambda in AWS
LAMBDA_NAME=${DIR_NAME//_/-}
# The name of directory with lambda code
PACKAGE=lambda-package
# Do not rebuild and deploy the archive if it's newer than sources
if [ -e "$PACKAGE.zip" ] && [ -z "$FORCE" ]; then
REBUILD=""
for src in app.py build_and_deploy_archive.sh requirements.txt lambda_shared/*; do
if [ "$src" -nt "$PACKAGE.zip" ]; then
REBUILD=1
fi
done
[ -n "$REBUILD" ] || exit 0
fi
docker_cmd=(
docker run -i --net=host --rm --user="${UID}" -e HOME=/tmp --entrypoint=/bin/bash
--volume="${WORKDIR}/..:/ci" --workdir="/ci/${DIR_NAME}" "${DOCKER_IMAGE}"
)
rm -rf "$PACKAGE" "$PACKAGE".zip
mkdir "$PACKAGE"
cp app.py "$PACKAGE"
if [ -f requirements.txt ]; then
VENV=lambda-venv
rm -rf "$VENV"
"${docker_cmd[@]}" -ex <<EOF
'$PY_EXEC' -m venv '$VENV' &&
source '$VENV/bin/activate' &&
pip install -r requirements.txt &&
# To have consistent pyc files
find '$VENV/lib' -name '*.pyc' -delete
cp -rT '$VENV/lib/$PY_EXEC/site-packages/' '$PACKAGE'
rm -r '$PACKAGE'/{pip,pip-*,setuptools,setuptools-*}
chmod 0777 -R '$PACKAGE'
EOF
fi
# Create zip archive via python zipfile to have it cross-platform
"${docker_cmd[@]}" -ex <<EOF
cd '$PACKAGE'
find ! -type d -exec touch -t 201212121212 {} +
python <<'EOP'
import zipfile
import os
files_path = []
for root, _, files in os.walk('.'):
files_path.extend(os.path.join(root, file) for file in files)
# persistent file order
files_path.sort()
with zipfile.ZipFile('../$PACKAGE.zip', 'w') as zf:
for file in files_path:
zf.write(file)
EOP
EOF
ECHO=()
if [ -n "$DRY_RUN" ]; then
ECHO=(echo Run the following command to push the changes:)
fi
"${ECHO[@]}" aws lambda update-function-code --function-name "$LAMBDA_NAME" --zip-file fileb://"$WORKDIR/$PACKAGE".zip

View File

@ -1 +0,0 @@
../lambda_shared_package/lambda_shared

View File

@ -1 +0,0 @@
../lambda_shared_package

View File

@ -1,278 +0,0 @@
#!/usr/bin/env python3
import argparse
import json
import sys
import time
from dataclasses import dataclass
from typing import Any, Dict, List
import boto3 # type: ignore
from lambda_shared import RunnerDescriptions, cached_value_is_valid, list_runners
from lambda_shared.token import get_access_token_by_key_app, get_cached_access_token
@dataclass
class CachedInstances:
time: float
value: dict
updating: bool = False
cached_instances = CachedInstances(0, {})
def get_cached_instances() -> dict:
"""return cached instances description with updating it once per five minutes"""
if time.time() - 250 < cached_instances.time or cached_instances.updating:
return cached_instances.value
cached_instances.updating = cached_value_is_valid(cached_instances.time, 300)
ec2_client = boto3.client("ec2")
instances_response = ec2_client.describe_instances(
Filters=[{"Name": "instance-state-name", "Values": ["running"]}]
)
cached_instances.time = time.time()
cached_instances.value = {
instance["InstanceId"]: instance
for reservation in instances_response["Reservations"]
for instance in reservation["Instances"]
}
cached_instances.updating = False
return cached_instances.value
@dataclass
class CachedRunners:
time: float
value: RunnerDescriptions
updating: bool = False
cached_runners = CachedRunners(0, [])
def get_cached_runners(access_token: str) -> RunnerDescriptions:
"""From time to time request to GH api costs up to 3 seconds, and
it's a disaster from the termination lambda perspective"""
if time.time() - 5 < cached_runners.time or cached_instances.updating:
return cached_runners.value
cached_runners.updating = cached_value_is_valid(cached_runners.time, 15)
cached_runners.value = list_runners(access_token)
cached_runners.time = time.time()
cached_runners.updating = False
return cached_runners.value
def how_many_instances_to_kill(event_data: dict) -> Dict[str, int]:
data_array = event_data["CapacityToTerminate"]
to_kill_by_zone = {} # type: Dict[str, int]
for av_zone in data_array:
zone_name = av_zone["AvailabilityZone"]
to_kill = av_zone["Capacity"]
if zone_name not in to_kill_by_zone:
to_kill_by_zone[zone_name] = 0
to_kill_by_zone[zone_name] += to_kill
return to_kill_by_zone
def get_candidates_to_be_killed(event_data: dict) -> Dict[str, List[str]]:
data_array = event_data["Instances"]
instances_by_zone = {} # type: Dict[str, List[str]]
for instance in data_array:
zone_name = instance["AvailabilityZone"]
instance_id = instance["InstanceId"] # type: str
if zone_name not in instances_by_zone:
instances_by_zone[zone_name] = []
instances_by_zone[zone_name].append(instance_id)
return instances_by_zone
def main(access_token: str, event: dict) -> Dict[str, List[str]]:
start = time.time()
print("Got event", json.dumps(event, sort_keys=True).replace("\n", ""))
to_kill_by_zone = how_many_instances_to_kill(event)
instances_by_zone = get_candidates_to_be_killed(event)
# Getting ASG and instances' descriptions from the API
# We don't kill instances that alive for less than 10 minutes, since they
# could be not in the GH active runners yet
print(f"Check other hosts from the same ASG {event['AutoScalingGroupName']}")
asg_client = boto3.client("autoscaling")
as_groups_response = asg_client.describe_auto_scaling_groups(
AutoScalingGroupNames=[event["AutoScalingGroupName"]]
)
assert len(as_groups_response["AutoScalingGroups"]) == 1
asg = as_groups_response["AutoScalingGroups"][0]
asg_instance_ids = [instance["InstanceId"] for instance in asg["Instances"]]
instance_descriptions = get_cached_instances()
# The instances launched less than 10 minutes ago
immune_ids = [
instance["InstanceId"]
for instance in instance_descriptions.values()
if start - instance["LaunchTime"].timestamp() < 600
]
# if the ASG's instance ID not in instance_descriptions, it's most probably
# is not cached yet, so we must mark it as immuned
immune_ids.extend(
iid for iid in asg_instance_ids if iid not in instance_descriptions
)
print("Time spent on the requests to AWS: ", time.time() - start)
runners = get_cached_runners(access_token)
runner_ids = set(runner.name for runner in runners)
# We used to delete potential hosts to terminate from GitHub runners pool,
# but the documentation states:
# --- Returning an instance first in the response data does not guarantee its termination
# so they will be cleaned out by ci_runners_metrics_lambda eventually
instances_to_kill = []
total_to_kill = 0
for zone, num_to_kill in to_kill_by_zone.items():
candidates = instances_by_zone[zone]
total_to_kill += num_to_kill
if num_to_kill > len(candidates):
raise RuntimeError(
f"Required to kill {num_to_kill}, but have only {len(candidates)}"
f" candidates in AV {zone}"
)
delete_for_av = [] # type: RunnerDescriptions
for candidate in candidates:
if candidate in immune_ids:
print(
f"Candidate {candidate} started less than 10 minutes ago, won't touch a child"
)
break
if candidate not in runner_ids:
print(
f"Candidate {candidate} was not in runners list, simply delete it"
)
instances_to_kill.append(candidate)
break
if len(delete_for_av) + len(instances_to_kill) == num_to_kill:
break
if candidate in instances_to_kill:
continue
for runner in runners:
if runner.name == candidate:
if not runner.busy:
print(
f"Runner {runner.name} is not busy and can be deleted from AV {zone}"
)
delete_for_av.append(runner)
else:
print(f"Runner {runner.name} is busy, not going to delete it")
break
if len(delete_for_av) < num_to_kill:
print(
f"Checked all candidates for av {zone}, get to delete "
f"{len(delete_for_av)}, but still cannot get required {num_to_kill}"
)
instances_to_kill += [runner.name for runner in delete_for_av]
if len(instances_to_kill) < total_to_kill:
for instance in asg_instance_ids:
if instance in immune_ids:
continue
for runner in runners:
if runner.name == instance and not runner.busy:
print(f"Runner {runner.name} is not busy and can be deleted")
instances_to_kill.append(runner.name)
if total_to_kill <= len(instances_to_kill):
print("Got enough instances to kill")
break
response = {"InstanceIDs": instances_to_kill}
print("Got instances to kill: ", response)
print("Time spent on the request: ", time.time() - start)
return response
def handler(event: dict, context: Any) -> Dict[str, List[str]]:
_ = context
return main(get_cached_access_token(), event)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Get list of runners and their states")
parser.add_argument(
"-p", "--private-key-path", help="Path to file with private key"
)
parser.add_argument("-k", "--private-key", help="Private key")
parser.add_argument(
"-a", "--app-id", type=int, help="GitHub application ID", required=True
)
args = parser.parse_args()
if not args.private_key_path and not args.private_key:
print(
"Either --private-key-path or --private-key must be specified",
file=sys.stderr,
)
if args.private_key_path and args.private_key:
print(
"Either --private-key-path or --private-key must be specified",
file=sys.stderr,
)
if args.private_key:
private_key = args.private_key
else:
with open(args.private_key_path, "r", encoding="utf-8") as key_file:
private_key = key_file.read()
token = get_access_token_by_key_app(private_key, args.app_id)
sample_event = {
"AutoScalingGroupARN": "arn:aws:autoscaling:us-east-1:<account-id>:autoScalingGroup:d4738357-2d40-4038-ae7e-b00ae0227003:autoScalingGroupName/my-asg",
"AutoScalingGroupName": "my-asg",
"CapacityToTerminate": [
{
"AvailabilityZone": "us-east-1b",
"Capacity": 1,
"InstanceMarketOption": "OnDemand",
},
{
"AvailabilityZone": "us-east-1c",
"Capacity": 2,
"InstanceMarketOption": "OnDemand",
},
],
"Instances": [
{
"AvailabilityZone": "us-east-1b",
"InstanceId": "i-08d0b3c1a137e02a5",
"InstanceType": "t2.nano",
"InstanceMarketOption": "OnDemand",
},
{
"AvailabilityZone": "us-east-1c",
"InstanceId": "ip-172-31-45-253.eu-west-1.compute.internal",
"InstanceType": "t2.nano",
"InstanceMarketOption": "OnDemand",
},
{
"AvailabilityZone": "us-east-1c",
"InstanceId": "ip-172-31-27-227.eu-west-1.compute.internal",
"InstanceType": "t2.nano",
"InstanceMarketOption": "OnDemand",
},
{
"AvailabilityZone": "us-east-1c",
"InstanceId": "ip-172-31-45-253.eu-west-1.compute.internal",
"InstanceType": "t2.nano",
"InstanceMarketOption": "OnDemand",
},
],
"Cause": "SCALE_IN",
}
main(token, sample_event)

View File

@ -1 +0,0 @@
../team_keys_lambda/build_and_deploy_archive.sh

Some files were not shown because too many files have changed in this diff Show More