Merge remote-tracking branch 'upstream/master' into make-cache-composable

This commit is contained in:
kssenii 2022-07-19 12:32:59 +02:00
commit 21ebf8874e
101 changed files with 947 additions and 1445 deletions

View File

@ -8,7 +8,7 @@ concurrency:
group: cherry-pick
on: # yamllint disable-line rule:truthy
schedule:
- cron: '0 */3 * * *'
- cron: '0 * * * *'
workflow_dispatch:
jobs:

View File

@ -93,7 +93,6 @@
# define NO_SANITIZE_ADDRESS __attribute__((__no_sanitize__("address")))
# define NO_SANITIZE_THREAD __attribute__((__no_sanitize__("thread")))
# define ALWAYS_INLINE_NO_SANITIZE_UNDEFINED __attribute__((__always_inline__, __no_sanitize__("undefined")))
# define DISABLE_SANITIZER_INSTRUMENTATION __attribute__((disable_sanitizer_instrumentation))
#else /// It does not work in GCC. GCC 7 cannot recognize this attribute and GCC 8 simply ignores it.
# define NO_SANITIZE_UNDEFINED
# define NO_SANITIZE_ADDRESS
@ -101,6 +100,13 @@
# define ALWAYS_INLINE_NO_SANITIZE_UNDEFINED ALWAYS_INLINE
#endif
#if defined(__clang__) && defined(__clang_major__) && __clang_major__ >= 14
# define DISABLE_SANITIZER_INSTRUMENTATION __attribute__((disable_sanitizer_instrumentation))
#else
# define DISABLE_SANITIZER_INSTRUMENTATION
#endif
#if !__has_include(<sanitizer/asan_interface.h>) || !defined(ADDRESS_SANITIZER)
# define ASAN_UNPOISON_MEMORY_REGION(a, b)
# define ASAN_POISON_MEMORY_REGION(a, b)

View File

@ -132,7 +132,7 @@ public:
key_ref = assert_cast<const ColumnString &>(key_column).getDataAt(offset + i);
#ifdef __cpp_lib_generic_unordered_lookup
key = static_cast<std::string_view>(key_ref);
key = key_ref.toView();
#else
key = key_ref.toString();
#endif

View File

@ -2,7 +2,6 @@
#include <Interpreters/Context_fwd.h>
#include <Common/ThreadStatus.h>
#include <base/StringRef.h>
#include <memory>
#include <string>
@ -76,7 +75,7 @@ public:
static void finalizePerformanceCounters();
/// Returns a non-empty string if the thread is attached to a query
static StringRef getQueryId()
static std::string_view getQueryId()
{
if (unlikely(!current_thread))
return {};

View File

@ -112,10 +112,10 @@ String FileSegment::getCallerId()
{
if (!CurrentThread::isInitialized()
|| !CurrentThread::get().getQueryContext()
|| CurrentThread::getQueryId().size == 0)
|| CurrentThread::getQueryId().empty())
return "None:" + toString(getThreadId());
return CurrentThread::getQueryId().toString() + ":" + toString(getThreadId());
return std::string(CurrentThread::getQueryId()) + ":" + toString(getThreadId());
}
String FileSegment::getOrSetDownloader()

View File

@ -58,7 +58,7 @@ static bool isQueryInitialized()
{
return CurrentThread::isInitialized()
&& CurrentThread::get().getQueryContext()
&& CurrentThread::getQueryId().size != 0;
&& !CurrentThread::getQueryId().empty();
}
bool IFileCache::isReadOnly()
@ -77,7 +77,7 @@ IFileCache::QueryContextPtr IFileCache::getCurrentQueryContext(std::lock_guard<s
if (!isQueryInitialized())
return nullptr;
return getQueryContext(CurrentThread::getQueryId().toString(), cache_lock);
return getQueryContext(std::string(CurrentThread::getQueryId()), cache_lock);
}
IFileCache::QueryContextPtr IFileCache::getQueryContext(const String & query_id, std::lock_guard<std::mutex> & /* cache_lock */)

View File

@ -210,7 +210,7 @@ public:
return thread_state.load(std::memory_order_relaxed);
}
StringRef getQueryId() const
std::string_view getQueryId() const
{
return query_id;
}

View File

@ -47,7 +47,7 @@ void TraceSender::send(TraceType trace_type, const StackTrace & stack_trace, Int
if (CurrentThread::isInitialized())
{
query_id = CurrentThread::getQueryId();
query_id = StringRef(CurrentThread::getQueryId());
query_id.size = std::min(query_id.size, QUERY_ID_MAX_LEN);
thread_id = CurrentThread::get().thread_id;

View File

@ -298,7 +298,7 @@ private:
/// It will allow client to see failure messages directly.
if (thread_ptr)
{
query_id = thread_ptr->getQueryId().toString();
query_id = std::string(thread_ptr->getQueryId());
if (auto thread_group = thread_ptr->getThreadGroup())
{

View File

@ -126,6 +126,20 @@ DatabasePtr DatabaseFactory::getImpl(const ASTCreateQuery & create, const String
if (!create.attach && !context->getSettingsRef().allow_deprecated_database_ordinary)
throw Exception(ErrorCodes::UNKNOWN_DATABASE_ENGINE,
"Ordinary database engine is deprecated (see also allow_deprecated_database_ordinary setting)");
/// Before 20.7 metadata/db_name.sql file might absent and Ordinary database was attached if there's metadata/db_name/ dir.
/// Between 20.7 and 22.7 metadata/db_name.sql was created in this case as well.
/// Since 20.7 `default` database is created with Atomic engine on the very first server run.
/// The problem is that if server crashed during the very first run and metadata/db_name/ -> store/whatever symlink was created
/// then it's considered as Ordinary database. And it even works somehow
/// until background task tries to remove onused dir from store/...
if (fs::is_symlink(metadata_path))
throw Exception(ErrorCodes::CANNOT_CREATE_DATABASE, "Metadata directory {} for Ordinary database {} is a symbolic link to {}. "
"It may be a result of manual intervention, crash on very first server start or a bug. "
"Database cannot be attached (it's kind of protection from potential data loss). "
"Metadata directory must not be a symlink and must contain tables metadata files itself. "
"You have to resolve this manually.",
metadata_path, database_name, fs::read_symlink(metadata_path).string());
return std::make_shared<DatabaseOrdinary>(database_name, metadata_path, context);
}

View File

@ -79,6 +79,7 @@ void DatabaseMaterializedPostgreSQL::startSynchronization()
}
catch (...)
{
tryLogCurrentException(__PRETTY_FUNCTION__);
LOG_ERROR(log, "Unable to load replicated tables list");
throw;
}
@ -111,7 +112,16 @@ void DatabaseMaterializedPostgreSQL::startSynchronization()
}
LOG_TRACE(log, "Loaded {} tables. Starting synchronization", materialized_tables.size());
try
{
replication_handler->startup(/* delayed */false);
}
catch (...)
{
tryLogCurrentException(__PRETTY_FUNCTION__);
throw;
}
}

View File

@ -387,7 +387,7 @@ void IPAddressDictionary::loadData()
setAttributeValue(attribute, attribute_column[row]);
}
const auto [addr, prefix] = parseIPFromString(std::string_view{key_column_ptr->getDataAt(row)});
const auto [addr, prefix] = parseIPFromString(key_column_ptr->getDataAt(row).toView());
has_ipv6 = has_ipv6 || (addr.family() == Poco::Net::IPAddress::IPv6);
size_t row_number = ip_records.size();

View File

@ -115,7 +115,7 @@ std::unique_ptr<ReadBufferFromFileBase> CachedObjectStorage::readObjects( /// NO
cache,
implementation_buffer_creator,
modified_read_settings,
CurrentThread::isInitialized() && CurrentThread::get().getQueryContext() ? CurrentThread::getQueryId().toString() : "",
CurrentThread::isInitialized() && CurrentThread::get().getQueryContext() ? std::string(CurrentThread::getQueryId()) : "",
file_size.value(),
/* allow_seeks */true,
/* use_external_buffer */false);
@ -155,7 +155,7 @@ std::unique_ptr<ReadBufferFromFileBase> CachedObjectStorage::readObject( /// NOL
cache,
implementation_buffer_creator,
read_settings,
CurrentThread::isInitialized() && CurrentThread::get().getQueryContext() ? CurrentThread::getQueryId().toString() : "",
CurrentThread::isInitialized() && CurrentThread::get().getQueryContext() ? std::string(CurrentThread::getQueryId()) : "",
file_size.value(),
/* allow_seeks */true,
/* use_external_buffer */false);
@ -193,7 +193,7 @@ std::unique_ptr<WriteBufferFromFileBase> CachedObjectStorage::writeObject( /// N
implementation_buffer->getFileName(),
key,
modified_write_settings.is_file_cache_persistent,
CurrentThread::isInitialized() && CurrentThread::get().getQueryContext() ? CurrentThread::getQueryId().toString() : "",
CurrentThread::isInitialized() && CurrentThread::get().getQueryContext() ? std::string(CurrentThread::getQueryId()) : "",
modified_write_settings);
}

View File

@ -52,7 +52,7 @@ capnp::StructSchema CapnProtoSchemaParser::getMessageSchema(const FormatSchemaIn
if (description.find("Parse error") != String::npos)
throw Exception(ErrorCodes::CANNOT_PARSE_CAPN_PROTO_SCHEMA, "Cannot parse CapnProto schema {}:{}", schema_info.schemaPath(), e.getLine());
throw Exception(ErrorCodes::UNKNOWN_EXCEPTION, "Unknown exception while parsing CapnProro schema: {}, schema dir and file: {}, {}", description, schema_info.schemaDirectory(), schema_info.schemaPath());
throw Exception(ErrorCodes::UNKNOWN_EXCEPTION, "Unknown exception while parsing CapnProto schema: {}, schema dir and file: {}, {}", description, schema_info.schemaDirectory(), schema_info.schemaPath());
}
auto message_maybe = schema.findNested(schema_info.messageName());

View File

@ -67,8 +67,8 @@ public:
for (size_t row = 0; row < input_rows_count; ++row)
{
StringRef filename = column_src->getDataAt(row);
fs::path file_path(filename.data, filename.data + filename.size);
std::string_view filename = column_src->getDataAt(row).toView();
fs::path file_path(filename.data(), filename.data() + filename.size());
if (file_path.is_relative())
file_path = user_files_absolute_path / file_path;

View File

@ -182,7 +182,7 @@ private:
const auto mode = arguments[0].column->getDataAt(0);
if (mode.size == 0 || !std::string_view(mode).starts_with("aes-"))
if (mode.size == 0 || !mode.toView().starts_with("aes-"))
throw Exception("Invalid mode: " + mode.toString(), ErrorCodes::BAD_ARGUMENTS);
const auto * evp_cipher = getCipherByName(mode);
@ -453,7 +453,7 @@ private:
using namespace OpenSSLDetails;
const auto mode = arguments[0].column->getDataAt(0);
if (mode.size == 0 || !std::string_view(mode).starts_with("aes-"))
if (mode.size == 0 || !mode.toView().starts_with("aes-"))
throw Exception("Invalid mode: " + mode.toString(), ErrorCodes::BAD_ARGUMENTS);
const auto * evp_cipher = getCipherByName(mode);

View File

@ -251,7 +251,7 @@ private:
}
case MoveType::Key:
{
key = std::string_view{(*arguments[j + 1].column).getDataAt(row)};
key = (*arguments[j + 1].column).getDataAt(row).toView();
if (!moveToElementByKey<JSONParser>(res_element, key))
return false;
break;

View File

@ -12,7 +12,6 @@
#include <Functions/IFunction.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/Context.h>
#include <base/StringRef.h>
namespace DB

View File

@ -8,22 +8,22 @@
namespace DB
{
inline StringRef checkAndReturnHost(const Pos & pos, const Pos & dot_pos, const Pos & start_of_host)
inline std::string_view checkAndReturnHost(const Pos & pos, const Pos & dot_pos, const Pos & start_of_host)
{
if (!dot_pos || start_of_host >= pos || pos - dot_pos == 1)
return StringRef{};
return std::string_view{};
auto after_dot = *(dot_pos + 1);
if (after_dot == ':' || after_dot == '/' || after_dot == '?' || after_dot == '#')
return StringRef{};
return std::string_view{};
return StringRef(start_of_host, pos - start_of_host);
return std::string_view(start_of_host, pos - start_of_host);
}
/// Extracts host from given url.
///
/// @return empty StringRef if the host is not valid (i.e. it does not have dot, or there no symbol after dot).
inline StringRef getURLHost(const char * data, size_t size)
/// @return empty string view if the host is not valid (i.e. it does not have dot, or there no symbol after dot).
inline std::string_view getURLHost(const char * data, size_t size)
{
Pos pos = data;
Pos end = data + size;
@ -61,7 +61,7 @@ inline StringRef getURLHost(const char * data, size_t size)
case ';':
case '=':
case '&':
return StringRef{};
return std::string_view{};
default:
goto exloop;
}
@ -106,7 +106,7 @@ exloop: if ((scheme_end - pos) > 2 && *pos == ':' && *(pos + 1) == '/' && *(pos
case ';':
case '=':
case '&':
return StringRef{};
return std::string_view{};
}
}
@ -120,20 +120,20 @@ struct ExtractDomain
static void execute(Pos data, size_t size, Pos & res_data, size_t & res_size)
{
StringRef host = getURLHost(data, size);
std::string_view host = getURLHost(data, size);
if (host.size == 0)
if (host.empty())
{
res_data = data;
res_size = 0;
}
else
{
if (without_www && host.size > 4 && !strncmp(host.data, "www.", 4))
host = { host.data + 4, host.size - 4 };
if (without_www && host.size() > 4 && !strncmp(host.data(), "www.", 4))
host = { host.data() + 4, host.size() - 4 };
res_data = host.data;
res_size = host.size;
res_data = host.data();
res_size = host.size();
}
}
};

View File

@ -12,7 +12,7 @@ struct ExtractNetloc
/// We use the same as domain function
static size_t getReserveLengthForElement() { return 15; }
static inline StringRef getNetworkLocation(const char * data, size_t size)
static std::string_view getNetworkLocation(const char * data, size_t size)
{
Pos pos = data;
Pos end = data + size;
@ -51,7 +51,7 @@ struct ExtractNetloc
case ';':
case '=':
case '&':
return StringRef{};
return std::string_view();
default:
goto exloop;
}
@ -76,18 +76,18 @@ struct ExtractNetloc
{
case '/':
if (has_identification)
return StringRef(start_of_host, pos - start_of_host);
return std::string_view(start_of_host, pos - start_of_host);
else
slash_pos = pos;
break;
case '?':
if (has_identification)
return StringRef(start_of_host, pos - start_of_host);
return std::string_view(start_of_host, pos - start_of_host);
else
question_mark_pos = pos;
break;
case '#':
return StringRef(start_of_host, pos - start_of_host);
return std::string_view(start_of_host, pos - start_of_host);
case '@': /// foo:bar@example.ru
has_identification = true;
break;
@ -108,23 +108,23 @@ struct ExtractNetloc
case '=':
case '&':
return pos > start_of_host
? StringRef(start_of_host, std::min(std::min(pos - 1, question_mark_pos), slash_pos) - start_of_host)
: StringRef{};
? std::string_view(start_of_host, std::min(std::min(pos - 1, question_mark_pos), slash_pos) - start_of_host)
: std::string_view();
}
}
if (has_identification)
return StringRef(start_of_host, pos - start_of_host);
return std::string_view(start_of_host, pos - start_of_host);
else
return StringRef(start_of_host, std::min(std::min(pos, question_mark_pos), slash_pos) - start_of_host);
return std::string_view(start_of_host, std::min(std::min(pos, question_mark_pos), slash_pos) - start_of_host);
}
static void execute(Pos data, size_t size, Pos & res_data, size_t & res_size)
{
StringRef host = getNetworkLocation(data, size);
std::string_view host = getNetworkLocation(data, size);
res_data = host.data;
res_size = host.size;
res_data = host.data();
res_size = host.size();
}
};

View File

@ -94,13 +94,13 @@ private:
const char * p = reinterpret_cast<const char *>(buf.data()) + offset;
const char * end = p + size;
StringRef host = getURLHost(p, size);
if (!host.size)
std::string_view host = getURLHost(p, size);
if (host.empty())
return default_port;
if (host.size == size)
if (host.size() == size)
return default_port;
p = host.data + host.size;
p = host.data() + host.size();
if (*p++ != ':')
return default_port;

View File

@ -8,7 +8,7 @@ namespace DB
{
/// Extracts scheme from given url.
inline StringRef getURLScheme(const char * data, size_t size)
inline std::string_view getURLScheme(const char * data, size_t size)
{
// scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." )
const char * pos = data;
@ -24,7 +24,7 @@ inline StringRef getURLScheme(const char * data, size_t size)
}
}
return StringRef(data, pos - data);
return std::string_view(data, pos - data);
}
return {};
@ -42,10 +42,10 @@ struct ExtractProtocol
res_data = data;
res_size = 0;
StringRef scheme = getURLScheme(data, size);
Pos pos = data + scheme.size;
std::string_view scheme = getURLScheme(data, size);
Pos pos = data + scheme.size();
if (scheme.size == 0 || (data + size) - pos < 4)
if (scheme.empty() || (data + size) - pos < 4)
return;
if (pos[0] == ':')

View File

@ -11,7 +11,7 @@ struct ExtractTopLevelDomain
static void execute(Pos data, size_t size, Pos & res_data, size_t & res_size)
{
StringRef host = getURLHost(data, size);
StringRef host = StringRef(getURLHost(data, size));
res_data = data;
res_size = 0;

View File

@ -79,7 +79,7 @@ public:
current_src_offset = src_offsets[i];
Pos end = reinterpret_cast<Pos>(&src_chars[current_src_offset]) - 1;
StringRef str(pos, end - pos);
std::string_view str(pos, end - pos);
vec_res[i] = countMatches(str, re, matches);
}
@ -87,7 +87,7 @@ public:
}
else if (const ColumnConst * col_const_str = checkAndGetColumnConstStringOrFixedString(column_haystack))
{
StringRef str = col_const_str->getDataColumn().getDataAt(0);
std::string_view str = col_const_str->getDataColumn().getDataAt(0).toView();
uint64_t matches_count = countMatches(str, re, matches);
return result_type->createColumnConst(input_rows_count, matches_count);
}
@ -95,13 +95,13 @@ public:
throw Exception(ErrorCodes::LOGICAL_ERROR, "Error in FunctionCountMatches::getReturnTypeImpl()");
}
static uint64_t countMatches(StringRef src, const Regexps::Regexp & re, OptimizedRegularExpression::MatchVec & matches)
static uint64_t countMatches(std::string_view src, const Regexps::Regexp & re, OptimizedRegularExpression::MatchVec & matches)
{
/// Only one match is required, no need to copy more.
static const unsigned matches_limit = 1;
Pos pos = reinterpret_cast<Pos>(src.data);
Pos end = reinterpret_cast<Pos>(src.data + src.size);
Pos pos = reinterpret_cast<Pos>(src.data());
Pos end = reinterpret_cast<Pos>(src.data() + src.size());
uint64_t match_count = 0;
while (true)

View File

@ -56,7 +56,7 @@ private:
throw Exception{"The argument of function " + String{name} + " should be a constant string with the name of a setting",
ErrorCodes::ILLEGAL_COLUMN};
std::string_view setting_name{column->getDataAt(0)};
std::string_view setting_name{column->getDataAt(0).toView()};
return getContext()->getSettingsRef().get(setting_name);
}
};

View File

@ -75,21 +75,20 @@ struct IPAddressCIDR
UInt8 prefix;
};
IPAddressCIDR parseIPWithCIDR(StringRef cidr_str)
IPAddressCIDR parseIPWithCIDR(std::string_view cidr_str)
{
std::string_view cidr_str_view(cidr_str);
size_t pos_slash = cidr_str_view.find('/');
size_t pos_slash = cidr_str.find('/');
if (pos_slash == 0)
throw DB::Exception("Error parsing IP address with prefix: " + std::string(cidr_str), DB::ErrorCodes::CANNOT_PARSE_TEXT);
if (pos_slash == std::string_view::npos)
throw DB::Exception("The text does not contain '/': " + std::string(cidr_str), DB::ErrorCodes::CANNOT_PARSE_TEXT);
std::string_view addr_str = cidr_str_view.substr(0, pos_slash);
std::string_view addr_str = cidr_str.substr(0, pos_slash);
IPAddressVariant addr(StringRef{addr_str.data(), addr_str.size()});
uint8_t prefix = 0;
auto prefix_str = cidr_str_view.substr(pos_slash+1);
auto prefix_str = cidr_str.substr(pos_slash+1);
const auto * prefix_str_end = prefix_str.data() + prefix_str.size();
auto [parse_end, parse_error] = std::from_chars(prefix_str.data(), prefix_str_end, prefix);
@ -190,7 +189,7 @@ namespace DB
const auto & col_cidr = col_cidr_const.getDataColumn();
const auto addr = IPAddressVariant(col_addr.getDataAt(0));
const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(0));
const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(0).toView());
ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(1);
ColumnUInt8::Container & vec_res = col_res->getData();
@ -212,7 +211,7 @@ namespace DB
for (size_t i = 0; i < input_rows_count; ++i)
{
const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(i));
const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(i).toView());
vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0;
}
return col_res;
@ -223,7 +222,7 @@ namespace DB
{
const auto & col_cidr = col_cidr_const.getDataColumn();
const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(0));
const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(0).toView());
ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(input_rows_count);
ColumnUInt8::Container & vec_res = col_res->getData();
@ -244,7 +243,7 @@ namespace DB
for (size_t i = 0; i < input_rows_count; ++i)
{
const auto addr = IPAddressVariant(col_addr.getDataAt(i));
const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(i));
const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(i).toView());
vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0;
}

View File

@ -288,9 +288,9 @@ private:
ColumnFixedString::Offset offset = 0;
for (size_t i = 0; i < rows; ++i)
{
StringRef data = src.getDataAt(i);
std::string_view data = src.getDataAt(i).toView();
memcpy(&data_to[offset], data.data, std::min(n, data.size));
memcpy(&data_to[offset], data.data(), std::min(n, data.size()));
offset += n;
}
}

View File

@ -84,7 +84,7 @@ private:
auto h3index = h3index_source.getWhole();
// convert to std::string and get the c_str to have the delimiting \0 at the end.
auto h3index_str = StringRef(h3index.data, h3index.size).toString();
auto h3index_str = std::string(reinterpret_cast<const char *>(h3index.data), h3index.size);
res_data[row_num] = stringToH3(h3index_str.c_str());
if (res_data[row_num] == 0)

View File

@ -68,7 +68,7 @@ void writeException(const Exception & e, WriteBuffer & buf, bool with_stack_trac
template <typename F>
static inline void writeProbablyQuotedStringImpl(StringRef s, WriteBuffer & buf, F && write_quoted_string)
{
if (isValidIdentifier(std::string_view{s})
if (isValidIdentifier(s.toView())
/// This are valid identifiers but are problematic if present unquoted in SQL query.
&& !(s.size == strlen("distinct") && 0 == strncasecmp(s.data, "distinct", strlen("distinct")))
&& !(s.size == strlen("all") && 0 == strncasecmp(s.data, "all", strlen("all"))))

View File

@ -1191,7 +1191,7 @@ void Context::setSettings(const Settings & settings_)
}
void Context::setSetting(StringRef name, const String & value)
void Context::setSetting(std::string_view name, const String & value)
{
auto lock = getLock();
if (name == "profile")
@ -1199,14 +1199,14 @@ void Context::setSetting(StringRef name, const String & value)
setCurrentProfile(value);
return;
}
settings.set(std::string_view{name}, value);
settings.set(name, value);
if (name == "readonly" || name == "allow_ddl" || name == "allow_introspection_functions")
calculateAccessRights();
}
void Context::setSetting(StringRef name, const Field & value)
void Context::setSetting(std::string_view name, const Field & value)
{
auto lock = getLock();
if (name == "profile")
@ -1214,7 +1214,7 @@ void Context::setSetting(StringRef name, const Field & value)
setCurrentProfile(value.safeGet<String>());
return;
}
settings.set(std::string_view{name}, value);
settings.set(name, value);
if (name == "readonly" || name == "allow_ddl" || name == "allow_introspection_functions")
calculateAccessRights();

View File

@ -607,8 +607,8 @@ public:
void setSettings(const Settings & settings_);
/// Set settings by name.
void setSetting(StringRef name, const String & value);
void setSetting(StringRef name, const Field & value);
void setSetting(std::string_view name, const String & value);
void setSetting(std::string_view name, const Field & value);
void applySettingChange(const SettingChange & change);
void applySettingsChanges(const SettingsChanges & changes);

View File

@ -210,6 +210,7 @@ scope_guard MergeTreeTransaction::beforeCommit()
void MergeTreeTransaction::afterCommit(CSN assigned_csn) noexcept
{
LockMemoryExceptionInThread memory_tracker_lock(VariableContext::Global);
/// Write allocated CSN into version metadata, so we will know CSN without reading it from transaction log
/// and we will be able to remove old entries from transaction log in ZK.
/// It's not a problem if server crash before CSN is written, because we already have TID in data part and entry in the log.
@ -245,6 +246,7 @@ void MergeTreeTransaction::afterCommit(CSN assigned_csn) noexcept
bool MergeTreeTransaction::rollback() noexcept
{
LockMemoryExceptionInThread memory_tracker_lock(VariableContext::Global);
CSN expected = Tx::UnknownCSN;
bool need_rollback = csn.compare_exchange_strong(expected, Tx::RolledBackCSN);

View File

@ -201,8 +201,8 @@ bool PartLog::addNewParts(
{
PartLogElement elem;
if (query_id.data && query_id.size)
elem.query_id.insert(0, query_id.data, query_id.size);
if (!query_id.empty())
elem.query_id.insert(0, query_id.data(), query_id.size());
elem.event_type = PartLogElement::NEW_PART; //-V1048

View File

@ -457,6 +457,7 @@ CSN TransactionLog::commitTransaction(const MergeTreeTransactionPtr & txn, bool
CSN TransactionLog::finalizeCommittedTransaction(MergeTreeTransaction * txn, CSN allocated_csn, scope_guard & state_guard) noexcept
{
LockMemoryExceptionInThread memory_tracker_lock(VariableContext::Global);
chassert(!allocated_csn == txn->isReadOnly());
if (allocated_csn)
{
@ -502,6 +503,7 @@ bool TransactionLog::waitForCSNLoaded(CSN csn) const
void TransactionLog::rollbackTransaction(const MergeTreeTransactionPtr & txn) noexcept
{
LockMemoryExceptionInThread memory_tracker_lock(VariableContext::Global);
LOG_TRACE(log, "Rolling back transaction {}{}", txn->tid,
std::uncaught_exceptions() ? fmt::format(" due to uncaught exception (code: {})", getCurrentExceptionCode()) : "");

View File

@ -55,7 +55,7 @@ void TransactionsInfoLogElement::fillCommonFields(const TransactionInfoContext *
event_time = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
thread_id = getThreadId();
query_id = CurrentThread::getQueryId().toString();
query_id = std::string(CurrentThread::getQueryId());
if (!context)
return;

View File

@ -28,8 +28,8 @@ ExtendedLogMessage ExtendedLogMessage::getFrom(const Poco::Message & base)
if (current_thread)
{
auto query_id_ref = CurrentThread::getQueryId();
if (query_id_ref.size)
msg_ext.query_id.assign(query_id_ref.data, query_id_ref.size);
if (!query_id_ref.empty())
msg_ext.query_id.assign(query_id_ref.data(), query_id_ref.size());
}
msg_ext.thread_id = getThreadId();

View File

@ -1,7 +1,6 @@
#pragma once
#include <Parsers/IAST.h>
#include <base/StringRef.h>
namespace DB

View File

@ -179,15 +179,15 @@ AvroSerializer::SchemaWithSerializeFn AvroSerializer::createSchemaWithSerializeF
if (traits->isStringAsString(column_name))
return {avro::StringSchema(), [](const IColumn & column, size_t row_num, avro::Encoder & encoder)
{
const StringRef & s = assert_cast<const ColumnString &>(column).getDataAt(row_num);
encoder.encodeString(s.toString());
const std::string_view & s = assert_cast<const ColumnString &>(column).getDataAt(row_num).toView();
encoder.encodeString(std::string(s));
}
};
else
return {avro::BytesSchema(), [](const IColumn & column, size_t row_num, avro::Encoder & encoder)
{
const StringRef & s = assert_cast<const ColumnString &>(column).getDataAt(row_num);
encoder.encodeBytes(reinterpret_cast<const uint8_t *>(s.data), s.size);
const std::string_view & s = assert_cast<const ColumnString &>(column).getDataAt(row_num).toView();
encoder.encodeBytes(reinterpret_cast<const uint8_t *>(s.data()), s.size());
}
};
case TypeIndex::FixedString:
@ -196,8 +196,8 @@ AvroSerializer::SchemaWithSerializeFn AvroSerializer::createSchemaWithSerializeF
auto schema = avro::FixedSchema(size, "fixed_" + toString(type_name_increment));
return {schema, [](const IColumn & column, size_t row_num, avro::Encoder & encoder)
{
const StringRef & s = assert_cast<const ColumnFixedString &>(column).getDataAt(row_num);
encoder.encodeFixed(reinterpret_cast<const uint8_t *>(s.data), s.size);
const std::string_view & s = assert_cast<const ColumnFixedString &>(column).getDataAt(row_num).toView();
encoder.encodeFixed(reinterpret_cast<const uint8_t *>(s.data()), s.size());
}};
}
case TypeIndex::Enum8:
@ -343,8 +343,8 @@ AvroSerializer::SchemaWithSerializeFn AvroSerializer::createSchemaWithSerializeF
auto keys_serializer = [](const IColumn & column, size_t row_num, avro::Encoder & encoder)
{
const StringRef & s = column.getDataAt(row_num);
encoder.encodeString(s.toString());
const std::string_view & s = column.getDataAt(row_num).toView();
encoder.encodeString(std::string(s));
};
const auto & values_type = map_type.getValueType();

View File

@ -365,8 +365,8 @@ namespace DB
}
else
{
StringRef string_ref = internal_column.getDataAt(string_i);
status = builder.Append(string_ref.data, string_ref.size);
std::string_view string_ref = internal_column.getDataAt(string_i).toView();
status = builder.Append(string_ref.data(), string_ref.size());
}
checkStatus(status, write_column->getName(), format_name);
}

View File

@ -98,16 +98,16 @@ void MsgPackRowOutputFormat::serializeField(const IColumn & column, DataTypePtr
}
case TypeIndex::String:
{
const StringRef & string = assert_cast<const ColumnString &>(column).getDataAt(row_num);
packer.pack_bin(string.size);
packer.pack_bin_body(string.data, string.size);
const std::string_view & string = assert_cast<const ColumnString &>(column).getDataAt(row_num).toView();
packer.pack_bin(string.size());
packer.pack_bin_body(string.data(), string.size());
return;
}
case TypeIndex::FixedString:
{
const StringRef & string = assert_cast<const ColumnFixedString &>(column).getDataAt(row_num);
packer.pack_bin(string.size);
packer.pack_bin_body(string.data, string.size);
const std::string_view & string = assert_cast<const ColumnFixedString &>(column).getDataAt(row_num).toView();
packer.pack_bin(string.size());
packer.pack_bin_body(string.data(), string.size());
return;
}
case TypeIndex::Array:
@ -178,18 +178,18 @@ void MsgPackRowOutputFormat::serializeField(const IColumn & column, DataTypePtr
{
WriteBufferFromOwnString buf;
writeBinary(uuid_column.getElement(row_num), buf);
StringRef uuid_bin = buf.stringRef();
packer.pack_bin(uuid_bin.size);
packer.pack_bin_body(uuid_bin.data, uuid_bin.size);
std::string_view uuid_bin = buf.stringRef().toView();
packer.pack_bin(uuid_bin.size());
packer.pack_bin_body(uuid_bin.data(), uuid_bin.size());
return;
}
case FormatSettings::MsgPackUUIDRepresentation::STR:
{
WriteBufferFromOwnString buf;
writeText(uuid_column.getElement(row_num), buf);
StringRef uuid_text = buf.stringRef();
packer.pack_str(uuid_text.size);
packer.pack_bin_body(uuid_text.data, uuid_text.size);
std::string_view uuid_text = buf.stringRef().toView();
packer.pack_str(uuid_text.size());
packer.pack_bin_body(uuid_text.data(), uuid_text.size());
return;
}
case FormatSettings::MsgPackUUIDRepresentation::EXT:
@ -198,9 +198,9 @@ void MsgPackRowOutputFormat::serializeField(const IColumn & column, DataTypePtr
UUID value = uuid_column.getElement(row_num);
writeBinaryBigEndian(value.toUnderType().items[0], buf);
writeBinaryBigEndian(value.toUnderType().items[1], buf);
StringRef uuid_ext = buf.stringRef();
std::string_view uuid_ext = buf.stringRef().toView();
packer.pack_ext(sizeof(UUID), int8_t(MsgPackExtensionTypes::UUIDType));
packer.pack_ext_body(uuid_ext.data, uuid_ext.size);
packer.pack_ext_body(uuid_ext.data(), uuid_ext.size());
return;
}
}

View File

@ -225,9 +225,9 @@ void ORCBlockOutputFormat::writeStrings(
}
string_orc_column.notNull[i] = 1;
const StringRef & string = string_column.getDataAt(i);
string_orc_column.data[i] = const_cast<char *>(string.data);
string_orc_column.length[i] = string.size;
const std::string_view & string = string_column.getDataAt(i).toView();
string_orc_column.data[i] = const_cast<char *>(string.data());
string_orc_column.length[i] = string.size();
}
string_orc_column.numElements = string_column.size();
}

View File

@ -17,8 +17,8 @@ RawBLOBRowOutputFormat::RawBLOBRowOutputFormat(
void RawBLOBRowOutputFormat::writeField(const IColumn & column, const ISerialization &, size_t row_num)
{
StringRef value = column.getDataAt(row_num);
out.write(value.data, value.size);
std::string_view value = column.getDataAt(row_num).toView();
out.write(value.data(), value.size());
}

View File

@ -1573,14 +1573,14 @@ namespace
auto & log_entry = *result.add_logs();
log_entry.set_time(column_time.getElement(row));
log_entry.set_time_microseconds(column_time_microseconds.getElement(row));
StringRef query_id = column_query_id.getDataAt(row);
log_entry.set_query_id(query_id.data, query_id.size);
std::string_view query_id = column_query_id.getDataAt(row).toView();
log_entry.set_query_id(query_id.data(), query_id.size());
log_entry.set_thread_id(column_thread_id.getElement(row));
log_entry.set_level(static_cast<::clickhouse::grpc::LogsLevel>(column_level.getElement(row)));
StringRef source = column_source.getDataAt(row);
log_entry.set_source(source.data, source.size);
StringRef text = column_text.getDataAt(row);
log_entry.set_text(text.data, text.size);
std::string_view source = column_source.getDataAt(row).toView();
log_entry.set_source(source.data(), source.size());
std::string_view text = column_text.getDataAt(row).toView();
log_entry.set_text(text.data(), text.size());
}
}
}

View File

@ -3,7 +3,6 @@
#include <Server/HTTP/HTTPServerRequest.h>
#include <Common/Exception.h>
#include <Common/StringUtils/StringUtils.h>
#include <base/StringRef.h>
#include <base/find_symbols.h>
#include <re2/re2.h>
@ -23,16 +22,16 @@ namespace ErrorCodes
using CompiledRegexPtr = std::shared_ptr<const re2::RE2>;
static inline bool checkRegexExpression(StringRef match_str, const CompiledRegexPtr & compiled_regex)
static inline bool checkRegexExpression(std::string_view match_str, const CompiledRegexPtr & compiled_regex)
{
int num_captures = compiled_regex->NumberOfCapturingGroups() + 1;
re2::StringPiece matches[num_captures];
re2::StringPiece match_input(match_str.data, match_str.size);
return compiled_regex->Match(match_input, 0, match_str.size, re2::RE2::Anchor::ANCHOR_BOTH, matches, num_captures);
re2::StringPiece match_input(match_str.data(), match_str.size());
return compiled_regex->Match(match_input, 0, match_str.size(), re2::RE2::Anchor::ANCHOR_BOTH, matches, num_captures);
}
static inline bool checkExpression(StringRef match_str, const std::pair<String, CompiledRegexPtr> & expression)
static inline bool checkExpression(std::string_view match_str, const std::pair<String, CompiledRegexPtr> & expression)
{
if (expression.second)
return checkRegexExpression(match_str, expression.second);
@ -71,7 +70,7 @@ static inline auto urlFilter(Poco::Util::AbstractConfiguration & config, const s
const auto & uri = request.getURI();
const auto & end = find_first_symbols<'?'>(uri.data(), uri.data() + uri.size());
return checkExpression(StringRef(uri.data(), end - uri.data()), expression);
return checkExpression(std::string_view(uri.data(), end - uri.data()), expression);
};
}
@ -93,7 +92,7 @@ static inline auto headersFilter(Poco::Util::AbstractConfiguration & config, con
for (const auto & [header_name, header_expression] : headers_expression)
{
const auto & header_value = request.get(header_name, "");
if (!checkExpression(StringRef(header_value.data(), header_value.size()), header_expression))
if (!checkExpression(std::string_view(header_value.data(), header_value.size()), header_expression))
return false;
}

View File

@ -27,7 +27,7 @@ MemoryTrackerThreadSwitcher::MemoryTrackerThreadSwitcher(MergeListEntry & merge_
prev_untracked_memory = current_thread->untracked_memory;
current_thread->untracked_memory = merge_list_entry->untracked_memory;
prev_query_id = current_thread->getQueryId().toString();
prev_query_id = std::string(current_thread->getQueryId());
current_thread->setQueryId(merge_list_entry->query_id);
}

View File

@ -15,9 +15,21 @@ bool PartitionPruner::canBePruned(const DataPart & part)
{
const auto & partition_value = part.partition.value;
std::vector<FieldRef> index_value(partition_value.begin(), partition_value.end());
for (auto & field : index_value)
{
// NULL_LAST
if (field.isNull())
field = POSITIVE_INFINITY;
}
is_valid = partition_condition.mayBeTrueInRange(
partition_value.size(), index_value.data(), index_value.data(), partition_key.data_types);
partition_filter_map.emplace(partition_id, is_valid);
if (!is_valid)
{
WriteBufferFromOwnString buf;
part.partition.serializeText(part.storage, buf, FormatSettings{});
LOG_TRACE(&Poco::Logger::get("PartitionPruner"), "Partition {} gets pruned", buf.str());
}
}
return !is_valid;
}

View File

@ -31,6 +31,7 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
extern const int BAD_ARGUMENTS;
extern const int POSTGRESQL_REPLICATION_INTERNAL_ERROR;
extern const int QUERY_NOT_ALLOWED;
}
class TemporaryReplicationSlot
@ -188,6 +189,17 @@ void PostgreSQLReplicationHandler::shutdown()
}
void PostgreSQLReplicationHandler::assertInitialized() const
{
if (!replication_handler_initialized)
{
throw Exception(
ErrorCodes::QUERY_NOT_ALLOWED,
"PostgreSQL replication initialization did not finish successfully. Please check logs for error messages");
}
}
void PostgreSQLReplicationHandler::startSynchronization(bool throw_on_error)
{
postgres::Connection replication_connection(connection_info, /* replication */true);
@ -239,7 +251,7 @@ void PostgreSQLReplicationHandler::startSynchronization(bool throw_on_error)
/// Throw in case of single MaterializedPostgreSQL storage, because initial setup is done immediately
/// (unlike database engine where it is done in a separate thread).
if (throw_on_error)
if (throw_on_error && !is_materialized_postgresql_database)
throw;
}
}
@ -314,6 +326,8 @@ void PostgreSQLReplicationHandler::startSynchronization(bool throw_on_error)
/// Do not rely anymore on saved storage pointers.
materialized_storages.clear();
replication_handler_initialized = true;
}
@ -393,12 +407,20 @@ void PostgreSQLReplicationHandler::cleanupFunc()
cleanup_task->scheduleAfter(CLEANUP_RESCHEDULE_MS);
}
PostgreSQLReplicationHandler::ConsumerPtr PostgreSQLReplicationHandler::getConsumer()
{
if (!consumer)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Consumer not initialized");
return consumer;
}
void PostgreSQLReplicationHandler::consumerFunc()
{
assertInitialized();
std::vector<std::pair<Int32, String>> skipped_tables;
bool schedule_now = consumer->consume(skipped_tables);
bool schedule_now = getConsumer()->consume(skipped_tables);
LOG_DEBUG(log, "checking for skipped tables: {}", skipped_tables.size());
if (!skipped_tables.empty())
@ -603,8 +625,10 @@ void PostgreSQLReplicationHandler::removeTableFromPublication(pqxx::nontransacti
void PostgreSQLReplicationHandler::setSetting(const SettingChange & setting)
{
assertInitialized();
consumer_task->deactivate();
consumer->setSetting(setting);
getConsumer()->setSetting(setting);
consumer_task->activateAndSchedule();
}
@ -758,6 +782,15 @@ std::set<String> PostgreSQLReplicationHandler::fetchRequiredTables()
{
pqxx::nontransaction tx(connection.getRef());
result_tables = fetchPostgreSQLTablesList(tx, schema_list.empty() ? postgres_schema : schema_list);
std::string tables_string;
for (const auto & table : result_tables)
{
if (!tables_string.empty())
tables_string += ", ";
tables_string += table;
}
LOG_DEBUG(log, "Tables list was fetched from PostgreSQL directly: {}", tables_string);
}
}
}
@ -824,6 +857,8 @@ PostgreSQLTableStructurePtr PostgreSQLReplicationHandler::fetchTableStructure(
void PostgreSQLReplicationHandler::addTableToReplication(StorageMaterializedPostgreSQL * materialized_storage, const String & postgres_table_name)
{
assertInitialized();
/// Note: we have to ensure that replication consumer task is stopped when we reload table, because otherwise
/// it can read wal beyond start lsn position (from which this table is being loaded), which will result in losing data.
consumer_task->deactivate();
@ -858,7 +893,7 @@ void PostgreSQLReplicationHandler::addTableToReplication(StorageMaterializedPost
}
/// Pass storage to consumer and lsn position, from which to start receiving replication messages for this table.
consumer->addNested(postgres_table_name, nested_storage_info, start_lsn);
getConsumer()->addNested(postgres_table_name, nested_storage_info, start_lsn);
LOG_TRACE(log, "Table `{}` successfully added to replication", postgres_table_name);
}
catch (...)
@ -876,6 +911,8 @@ void PostgreSQLReplicationHandler::addTableToReplication(StorageMaterializedPost
void PostgreSQLReplicationHandler::removeTableFromReplication(const String & postgres_table_name)
{
assertInitialized();
consumer_task->deactivate();
try
{
@ -887,7 +924,7 @@ void PostgreSQLReplicationHandler::removeTableFromReplication(const String & pos
}
/// Pass storage to consumer and lsn position, from which to start receiving replication messages for this table.
consumer->removeNested(postgres_table_name);
getConsumer()->removeNested(postgres_table_name);
}
catch (...)
{
@ -966,7 +1003,7 @@ void PostgreSQLReplicationHandler::reloadFromSnapshot(const std::vector<std::pai
nested_storage->getStorageID().getNameForLogs(), nested_sample_block.dumpStructure());
/// Pass pointer to new nested table into replication consumer, remove current table from skip list and set start lsn position.
consumer->updateNested(table_name, StorageInfo(nested_storage, std::move(table_attributes)), relation_id, start_lsn);
getConsumer()->updateNested(table_name, StorageInfo(nested_storage, std::move(table_attributes)), relation_id, start_lsn);
auto table_to_drop = DatabaseCatalog::instance().getTable(StorageID(temp_table_id.database_name, temp_table_id.table_name, table_id.uuid), nested_context);
auto drop_table_id = table_to_drop->getStorageID();

View File

@ -18,6 +18,8 @@ class PostgreSQLReplicationHandler : WithContext
friend class TemporaryReplicationSlot;
public:
using ConsumerPtr = std::shared_ptr<MaterializedPostgreSQLConsumer>;
PostgreSQLReplicationHandler(
const String & replication_identifier,
const String & postgres_database_,
@ -87,6 +89,8 @@ private:
void consumerFunc();
ConsumerPtr getConsumer();
StorageInfo loadFromSnapshot(postgres::Connection & connection, std::string & snapshot_name, const String & table_name, StorageMaterializedPostgreSQL * materialized_storage);
void reloadFromSnapshot(const std::vector<std::pair<Int32, String>> & relation_data);
@ -97,6 +101,8 @@ private:
std::pair<String, String> getSchemaAndTableName(const String & table_name) const;
void assertInitialized() const;
Poco::Logger * log;
/// If it is not attach, i.e. a create query, then if publication already exists - always drop it.
@ -134,7 +140,7 @@ private:
String replication_slot, publication_name;
/// Replication consumer. Manages decoding of replication stream and syncing into tables.
std::shared_ptr<MaterializedPostgreSQLConsumer> consumer;
ConsumerPtr consumer;
BackgroundSchedulePool::TaskHolder startup_task;
BackgroundSchedulePool::TaskHolder consumer_task;
@ -146,6 +152,8 @@ private:
MaterializedStorages materialized_storages;
UInt64 milliseconds_to_wait;
bool replication_handler_initialized = false;
};
}

View File

@ -90,10 +90,10 @@ namespace
const ucontext_t signal_context = *reinterpret_cast<ucontext_t *>(context);
stack_trace = StackTrace(signal_context);
StringRef query_id = CurrentThread::getQueryId();
query_id_size = std::min(query_id.size, max_query_id_size);
if (query_id.data && query_id.size)
memcpy(query_id_data, query_id.data, query_id_size);
std::string_view query_id = CurrentThread::getQueryId();
query_id_size = std::min(query_id.size(), max_query_id_size);
if (!query_id.empty())
memcpy(query_id_data, query_id.data(), query_id_size);
/// This is unneeded (because we synchronize through pipe) but makes TSan happy.
data_ready_num.store(notification_num, std::memory_order_release);

View File

@ -1,70 +1,495 @@
#!/usr/bin/env python3
"""
A plan:
- TODO: consider receiving GH objects cache from S3, but it's really a few
of requests to API currently
- Get all open release PRs (20.10, 21.8, 22.5, etc.)
- Get all pull-requests between the date of the merge-base for the oldest PR with
labels pr-must-backport and version-specific v21.8-must-backport, but without
pr-backported
- Iterate over gotten PRs:
- for pr-must-backport:
- check if all backport-PRs are created. If yes,
set pr-backported label and finish
- If not, create either cherrypick PRs or merge cherrypick (in the same
stage, if mergable) and create backport-PRs
- If successfull, set pr-backported label on the PR
- for version-specific labels:
- the same, check, cherry-pick, backport, pr-backported
Cherry-pick stage:
- From time to time the cherry-pick fails, if it was done manually. In the
case we check if it's even needed, and mark the release as done somehow.
"""
import argparse
import logging
import os
import subprocess
from contextlib import contextmanager
from datetime import date, timedelta
from subprocess import CalledProcessError
from typing import List, Optional
from env_helper import GITHUB_WORKSPACE, TEMP_PATH
from env_helper import TEMP_PATH
from get_robot_token import get_best_robot_token
from git_helper import git_runner, is_shallow
from github_helper import (
GitHub,
PullRequest,
PullRequests,
Repository,
)
from ssh import SSHKey
from cherry_pick_utils.backport import Backport
from cherry_pick_utils.cherrypick import CherryPick
class Labels:
LABEL_MUST_BACKPORT = "pr-must-backport"
LABEL_BACKPORT = "pr-backport"
LABEL_BACKPORTED = "pr-backported"
LABEL_CHERRYPICK = "pr-cherrypick"
LABEL_DO_NOT_TEST = "do not test"
class ReleaseBranch:
CHERRYPICK_DESCRIPTION = """This pull-request is a first step of an automated \
backporting.
It contains changes like after calling a local command `git cherry-pick`.
If you intend to continue backporting this changes, then resolve all conflicts if any.
Otherwise, if you do not want to backport them, then just close this pull-request.
The check results does not matter at this step - you can safely ignore them.
Also this pull-request will be merged automatically as it reaches the mergeable state, \
but you always can merge it manually.
"""
BACKPORT_DESCRIPTION = """This pull-request is a last step of an automated \
backporting.
Treat it as a standard pull-request: look at the checks and resolve conflicts.
Merge it only if you intend to backport changes to the target branch, otherwise just \
close it.
"""
REMOTE = ""
def __init__(self, name: str, pr: PullRequest):
self.name = name
self.pr = pr
self.cherrypick_branch = f"cherrypick/{name}/{pr.merge_commit_sha}"
self.backport_branch = f"backport/{name}/{pr.number}"
self.cherrypick_pr = None # type: Optional[PullRequest]
self.backport_pr = None # type: Optional[PullRequest]
self._backported = None # type: Optional[bool]
self.git_prefix = ( # All commits to cherrypick are done as robot-clickhouse
"git -c user.email=robot-clickhouse@clickhouse.com "
"-c user.name=robot-clickhouse -c commit.gpgsign=false"
)
self.pre_check()
def pre_check(self):
branch_updated = git_runner(
f"git branch -a --contains={self.pr.merge_commit_sha} "
f"{self.REMOTE}/{self.name}"
)
if branch_updated:
self._backported = True
def pop_prs(self, prs: PullRequests):
to_pop = [] # type: List[int]
for i, pr in enumerate(prs):
if self.name not in pr.head.ref:
continue
if pr.head.ref.startswith(f"cherrypick/{self.name}"):
self.cherrypick_pr = pr
to_pop.append(i)
elif pr.head.ref.startswith(f"backport/{self.name}"):
self.backport_pr = pr
to_pop.append(i)
else:
logging.error(
"PR #%s doesn't head ref starting with known suffix",
pr.number,
)
for i in reversed(to_pop):
# Going from the tail to keep the order and pop greater index first
prs.pop(i)
def process(self, dry_run: bool):
if self.backported:
return
if not self.cherrypick_pr:
if dry_run:
logging.info(
"DRY RUN: Would create cherrypick PR for #%s", self.pr.number
)
return
self.create_cherrypick()
if self.backported:
return
if self.cherrypick_pr is not None:
# Try to merge cherrypick instantly
if self.cherrypick_pr.mergeable and self.cherrypick_pr.state != "closed":
self.cherrypick_pr.merge()
# The PR needs update, since PR.merge doesn't update the object
self.cherrypick_pr.update()
if self.cherrypick_pr.merged:
if dry_run:
logging.info(
"DRY RUN: Would create backport PR for #%s", self.pr.number
)
return
self.create_backport()
return
elif self.cherrypick_pr.state == "closed":
logging.info(
"The cherrypick PR #%s for PR #%s is discarded",
self.cherrypick_pr.number,
self.pr.number,
)
self._backported = True
return
logging.info(
"Cherrypick PR #%s for PR #%s have conflicts and unable to be merged",
self.cherrypick_pr.number,
self.pr.number,
)
def create_cherrypick(self):
# First, create backport branch:
# Checkout release branch with discarding every change
git_runner(f"{self.git_prefix} checkout -f {self.name}")
# Create or reset backport branch
git_runner(f"{self.git_prefix} checkout -B {self.backport_branch}")
# Merge all changes from PR's the first parent commit w/o applying anything
# It will allow to create a merge commit like it would be a cherry-pick
first_parent = git_runner(f"git rev-parse {self.pr.merge_commit_sha}^1")
git_runner(f"{self.git_prefix} merge -s ours --no-edit {first_parent}")
# Second step, create cherrypick branch
git_runner(
f"{self.git_prefix} branch -f "
f"{self.cherrypick_branch} {self.pr.merge_commit_sha}"
)
# Check if there actually any changes between branches. If no, then no
# other actions are required. It's possible when changes are backported
# manually to the release branch already
try:
output = git_runner(
f"{self.git_prefix} merge --no-commit --no-ff {self.cherrypick_branch}"
)
# 'up-to-date', 'up to date', who knows what else (╯°v°)╯ ^┻━┻
if output.startswith("Already up") and output.endswith("date."):
# The changes are already in the release branch, we are done here
logging.info(
"Release branch %s already contain changes from %s",
self.name,
self.pr.number,
)
self._backported = True
return
except CalledProcessError:
# There are most probably conflicts, they'll be resolved in PR
git_runner(f"{self.git_prefix} reset --merge")
else:
# There are changes to apply, so continue
git_runner(f"{self.git_prefix} reset --merge")
# Push, create the cherrypick PR, lable and assign it
for branch in [self.cherrypick_branch, self.backport_branch]:
git_runner(f"{self.git_prefix} push -f {self.REMOTE} {branch}:{branch}")
self.cherrypick_pr = self.pr.base.repo.create_pull(
title=f"Cherry pick #{self.pr.number} to {self.name}: {self.pr.title}",
body=f"Original pull-request #{self.pr.number}\n\n"
f"{self.CHERRYPICK_DESCRIPTION}",
base=self.backport_branch,
head=self.cherrypick_branch,
)
self.cherrypick_pr.add_to_labels(Labels.LABEL_CHERRYPICK)
self.cherrypick_pr.add_to_labels(Labels.LABEL_DO_NOT_TEST)
self.cherrypick_pr.add_to_assignees(self.pr.assignee)
self.cherrypick_pr.add_to_assignees(self.pr.user)
def create_backport(self):
# Checkout the backport branch from the remote and make all changes to
# apply like they are only one cherry-pick commit on top of release
git_runner(f"{self.git_prefix} checkout -f {self.backport_branch}")
git_runner(
f"{self.git_prefix} pull --ff-only {self.REMOTE} {self.backport_branch}"
)
merge_base = git_runner(
f"{self.git_prefix} merge-base "
f"{self.REMOTE}/{self.name} {self.backport_branch}"
)
git_runner(f"{self.git_prefix} reset --soft {merge_base}")
title = f"Backport #{self.pr.number} to {self.name}: {self.pr.title}"
git_runner(f"{self.git_prefix} commit -a --allow-empty -F -", input=title)
# Push with force, create the backport PR, lable and assign it
git_runner(
f"{self.git_prefix} push -f {self.REMOTE} "
f"{self.backport_branch}:{self.backport_branch}"
)
self.backport_pr = self.pr.base.repo.create_pull(
title=title,
body=f"Original pull-request #{self.pr.number}\n"
f"Cherry-pick pull-request #{self.cherrypick_pr.number}\n\n"
f"{self.BACKPORT_DESCRIPTION}",
base=self.name,
head=self.backport_branch,
)
self.backport_pr.add_to_labels(Labels.LABEL_BACKPORT)
self.backport_pr.add_to_assignees(self.pr.assignee)
self.backport_pr.add_to_assignees(self.pr.user)
@property
def backported(self) -> bool:
if self._backported is not None:
return self._backported
return self.backport_pr is not None
def __repr__(self):
return self.name
class Backport:
def __init__(self, gh: GitHub, repo: str, dry_run: bool):
self.gh = gh
self._repo_name = repo
self.dry_run = dry_run
self._query = f"type:pr repo:{repo}"
self._remote = ""
self._repo = None # type: Optional[Repository]
self.release_prs = [] # type: PullRequests
self.release_branches = [] # type: List[str]
self.labels_to_backport = [] # type: List[str]
self.prs_for_backport = [] # type: PullRequests
self.error = None # type: Optional[Exception]
@property
def remote(self) -> str:
if not self._remote:
# lines of "origin git@github.com:ClickHouse/ClickHouse.git (fetch)"
remotes = git_runner("git remote -v").split("\n")
# We need the first word from the first matching result
self._remote = tuple(
remote.split(maxsplit=1)[0]
for remote in remotes
if f"github.com/{self._repo_name}" in remote # https
or f"github.com:{self._repo_name}" in remote # ssh
)[0]
git_runner(f"git fetch {self._remote}")
ReleaseBranch.REMOTE = self._remote
return self._remote
def receive_release_prs(self):
logging.info("Getting release PRs")
self.release_prs = self.gh.get_pulls_from_search(
query=f"{self._query} is:open",
sort="created",
order="asc",
label="release",
)
self.release_branches = [pr.head.ref for pr in self.release_prs]
self.labels_to_backport = [
f"v{branch}-must-backport" for branch in self.release_branches
]
logging.info("Active releases: %s", ", ".join(self.release_branches))
def receive_prs_for_backport(self):
# The commit is the oldest open release branch's merge-base
since_commit = git_runner(
f"git merge-base {self.remote}/{self.release_branches[0]} "
f"{self.remote}/{self.default_branch}"
)
since_date = date.fromisoformat(
git_runner.run(f"git log -1 --format=format:%cs {since_commit}")
)
# To not have a possible TZ issues
tomorrow = date.today() + timedelta(days=1)
logging.info("Receive PRs suppose to be backported")
self.prs_for_backport = self.gh.get_pulls_from_search(
query=f"{self._query} -label:pr-backported",
label=",".join(self.labels_to_backport + [Labels.LABEL_MUST_BACKPORT]),
merged=[since_date, tomorrow],
)
logging.info(
"PRs to be backported:\n %s",
"\n ".join([pr.html_url for pr in self.prs_for_backport]),
)
def process_backports(self):
for pr in self.prs_for_backport:
try:
self.process_pr(pr)
except Exception as e:
logging.error(
"During processing the PR #%s error occured: %s", pr.number, e
)
self.error = e
def process_pr(self, pr: PullRequest):
pr_labels = [label.name for label in pr.labels]
if Labels.LABEL_MUST_BACKPORT in pr_labels:
branches = [
ReleaseBranch(br, pr) for br in self.release_branches
] # type: List[ReleaseBranch]
else:
branches = [
ReleaseBranch(br, pr)
for br in [
label.split("-", 1)[0][1:] # v21.8-must-backport
for label in pr_labels
if label in self.labels_to_backport
]
]
if not branches:
# This is definitely some error. There must be at least one branch
# It also make the whole program exit code non-zero
self.error = Exception(
f"There are no branches to backport PR #{pr.number}, logical error"
)
raise self.error
logging.info(
" PR #%s is suppose to be backported to %s",
pr.number,
", ".join(map(str, branches)),
)
# All PRs for cherrypick and backport branches as heads
query_suffix = " ".join(
[
f"head:{branch.backport_branch} head:{branch.cherrypick_branch}"
for branch in branches
]
)
bp_cp_prs = self.gh.get_pulls_from_search(
query=f"{self._query} {query_suffix}",
)
for br in branches:
br.pop_prs(bp_cp_prs)
if bp_cp_prs:
# This is definitely some error. All prs must be consumed by
# branches with ReleaseBranch.pop_prs. It also make the whole
# program exit code non-zero
self.error = Exception(
"The following PRs are not filtered by release branches:\n"
"\n".join(map(str, bp_cp_prs))
)
raise self.error
if all(br.backported for br in branches):
# Let's check if the PR is already backported
self.mark_pr_backported(pr)
return
for br in branches:
br.process(self.dry_run)
if all(br.backported for br in branches):
# And check it after the running
self.mark_pr_backported(pr)
def mark_pr_backported(self, pr: PullRequest):
if self.dry_run:
logging.info("DRY RUN: would mark PR #%s as done", pr.number)
return
pr.add_to_labels(Labels.LABEL_BACKPORTED)
logging.info(
"PR #%s is successfully labeled with `%s`",
pr.number,
Labels.LABEL_BACKPORTED,
)
@property
def repo(self) -> Repository:
if self._repo is None:
try:
self._repo = self.release_prs[0].base.repo
except IndexError as exc:
raise Exception(
"`repo` is available only after the `receive_release_prs`"
) from exc
return self._repo
@property
def default_branch(self) -> str:
return self.repo.default_branch
def parse_args():
parser = argparse.ArgumentParser("Create cherry-pick and backport PRs")
parser.add_argument("--token", help="github token, if not set, used from smm")
parser.add_argument(
"--repo", default="ClickHouse/ClickHouse", help="repo owner/name"
)
parser.add_argument("--dry-run", action="store_true", help="do not create anything")
parser.add_argument(
"--debug-helpers",
action="store_true",
help="add debug logging for git_helper and github_helper",
)
return parser.parse_args()
@contextmanager
def clear_repo():
orig_ref = git_runner("git branch --show-current") or git_runner(
"git rev-parse HEAD"
)
try:
yield
except (Exception, KeyboardInterrupt):
git_runner(f"git checkout -f {orig_ref}")
raise
else:
git_runner(f"git checkout -f {orig_ref}")
@contextmanager
def stash():
need_stash = bool(git_runner("git diff HEAD"))
if need_stash:
git_runner("git stash push --no-keep-index -m 'running cherry_pick.py'")
try:
with clear_repo():
yield
except (Exception, KeyboardInterrupt):
if need_stash:
git_runner("git stash pop")
raise
else:
if need_stash:
git_runner("git stash pop")
def main():
if not os.path.exists(TEMP_PATH):
os.makedirs(TEMP_PATH)
args = parse_args()
if args.debug_helpers:
logging.getLogger("github_helper").setLevel(logging.DEBUG)
logging.getLogger("git_helper").setLevel(logging.DEBUG)
token = args.token or get_best_robot_token()
bp = Backport(
token,
os.environ.get("REPO_OWNER"),
os.environ.get("REPO_NAME"),
os.environ.get("REPO_TEAM"),
)
cherry_pick = CherryPick(
token,
os.environ.get("REPO_OWNER"),
os.environ.get("REPO_NAME"),
os.environ.get("REPO_TEAM"),
1,
"master",
)
# Use the same _gh in both objects to have a proper cost
# pylint: disable=protected-access
for key in bp._gh.api_costs:
if key in cherry_pick._gh.api_costs:
bp._gh.api_costs[key] += cherry_pick._gh.api_costs[key]
for key in cherry_pick._gh.api_costs:
if key not in bp._gh.api_costs:
bp._gh.api_costs[key] = cherry_pick._gh.api_costs[key]
cherry_pick._gh = bp._gh
# pylint: enable=protected-access
def cherrypick_run(pr_data, branch):
cherry_pick.update_pr_branch(pr_data, branch)
return cherry_pick.execute(GITHUB_WORKSPACE, args.dry_run)
try:
bp.execute(GITHUB_WORKSPACE, "origin", None, cherrypick_run)
except subprocess.CalledProcessError as e:
logging.error(e.output)
gh = GitHub(token, per_page=100)
bp = Backport(gh, args.repo, args.dry_run)
bp.gh.cache_path = str(f"{TEMP_PATH}/gh_cache")
bp.receive_release_prs()
bp.receive_prs_for_backport()
bp.process_backports()
if bp.error is not None:
logging.error("Finished successfully, but errors occured!")
raise bp.error
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
if not os.path.exists(TEMP_PATH):
os.makedirs(TEMP_PATH)
assert not is_shallow()
with stash():
if os.getenv("ROBOT_CLICKHOUSE_SSH_KEY", ""):
with SSHKey("ROBOT_CLICKHOUSE_SSH_KEY"):
main()

View File

@ -1,2 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

View File

@ -1,190 +0,0 @@
# -*- coding: utf-8 -*-
import argparse
import logging
import os
import re
import sys
sys.path.append(os.path.dirname(__file__))
from cherrypick import CherryPick
from query import Query as RemoteRepo
from local import Repository as LocalRepo
class Backport:
def __init__(self, token, owner, name, team):
self._gh = RemoteRepo(
token, owner=owner, name=name, team=team, max_page_size=60, min_page_size=7
)
self._token = token
self.default_branch_name = self._gh.default_branch
self.ssh_url = self._gh.ssh_url
def getPullRequests(self, from_commit):
return self._gh.get_pull_requests(from_commit)
def getBranchesWithRelease(self):
branches = set()
for pull_request in self._gh.find_pull_requests("release"):
branches.add(pull_request["headRefName"])
return branches
def execute(self, repo, upstream, until_commit, run_cherrypick):
repo = LocalRepo(repo, upstream, self.default_branch_name)
all_branches = repo.get_release_branches() # [(branch_name, base_commit)]
release_branches = self.getBranchesWithRelease()
branches = []
# iterate over all branches to preserve their precedence.
for branch in all_branches:
if branch[0] in release_branches:
branches.append(branch)
if not branches:
logging.info("No release branches found!")
return
logging.info(
"Found release branches: %s", ", ".join([br[0] for br in branches])
)
if not until_commit:
until_commit = branches[0][1]
pull_requests = self.getPullRequests(until_commit)
backport_map = {}
pr_map = {pr["number"]: pr for pr in pull_requests}
RE_MUST_BACKPORT = re.compile(r"^v(\d+\.\d+)-must-backport$")
RE_NO_BACKPORT = re.compile(r"^v(\d+\.\d+)-no-backport$")
RE_BACKPORTED = re.compile(r"^v(\d+\.\d+)-backported$")
# pull-requests are sorted by ancestry from the most recent.
for pr in pull_requests:
while repo.comparator(branches[-1][1]) >= repo.comparator(
pr["mergeCommit"]["oid"]
):
logging.info(
"PR #%s is already inside %s. Dropping this branch for further PRs",
pr["number"],
branches[-1][0],
)
branches.pop()
logging.info("Processing PR #%s", pr["number"])
assert len(branches) != 0
branch_set = {branch[0] for branch in branches}
# First pass. Find all must-backports
for label in pr["labels"]["nodes"]:
if label["name"] == "pr-must-backport":
backport_map[pr["number"]] = branch_set.copy()
continue
matched = RE_MUST_BACKPORT.match(label["name"])
if matched:
if pr["number"] not in backport_map:
backport_map[pr["number"]] = set()
backport_map[pr["number"]].add(matched.group(1))
# Second pass. Find all no-backports
for label in pr["labels"]["nodes"]:
if label["name"] == "pr-no-backport" and pr["number"] in backport_map:
del backport_map[pr["number"]]
break
matched_no_backport = RE_NO_BACKPORT.match(label["name"])
matched_backported = RE_BACKPORTED.match(label["name"])
if (
matched_no_backport
and pr["number"] in backport_map
and matched_no_backport.group(1) in backport_map[pr["number"]]
):
backport_map[pr["number"]].remove(matched_no_backport.group(1))
logging.info(
"\tskipping %s because of forced no-backport",
matched_no_backport.group(1),
)
elif (
matched_backported
and pr["number"] in backport_map
and matched_backported.group(1) in backport_map[pr["number"]]
):
backport_map[pr["number"]].remove(matched_backported.group(1))
logging.info(
"\tskipping %s because it's already backported manually",
matched_backported.group(1),
)
for pr, branches in list(backport_map.items()):
statuses = []
for branch in branches:
branch_status = run_cherrypick(pr_map[pr], branch)
statuses.append(f"{branch}, and the status is: {branch_status}")
logging.info(
"PR #%s needs to be backported to:\n\t%s", pr, "\n\t".join(statuses)
)
# print API costs
logging.info("\nGitHub API total costs for backporting per query:")
for name, value in list(self._gh.api_costs.items()):
logging.info("%s : %s", name, value)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--token", type=str, required=True, help="token for Github access"
)
parser.add_argument(
"--repo",
type=str,
required=True,
help="path to full repository",
metavar="PATH",
)
parser.add_argument(
"--til", type=str, help="check PRs from HEAD til this commit", metavar="COMMIT"
)
parser.add_argument(
"--dry-run",
action="store_true",
help="do not create or merge any PRs",
default=False,
)
parser.add_argument(
"--verbose",
"-v",
action="store_true",
help="more verbose output",
default=False,
)
parser.add_argument(
"--upstream",
"-u",
type=str,
help="remote name of upstream in repository",
default="origin",
)
args = parser.parse_args()
if args.verbose:
logging.basicConfig(
format="%(message)s", stream=sys.stdout, level=logging.DEBUG
)
else:
logging.basicConfig(format="%(message)s", stream=sys.stdout, level=logging.INFO)
cherry_pick = CherryPick(
args.token, "ClickHouse", "ClickHouse", "core", 1, "master"
)
def cherrypick_run(pr_data, branch):
cherry_pick.update_pr_branch(pr_data, branch)
return cherry_pick.execute(args.repo, args.dry_run)
bp = Backport(args.token, "ClickHouse", "ClickHouse", "core")
bp.execute(args.repo, args.upstream, args.til, cherrypick_run)

View File

@ -1,319 +0,0 @@
# -*- coding: utf-8 -*-
"""
Backports changes from PR to release branch.
Requires multiple separate runs as part of the implementation.
First run should do the following:
1. Merge release branch with a first parent of merge-commit of PR (using 'ours' strategy). (branch: backport/{branch}/{pr})
2. Create temporary branch over merge-commit to use it for PR creation. (branch: cherrypick/{merge_commit})
3. Create PR from temporary branch to backport branch (emulating cherry-pick).
Second run checks PR from previous run to be merged or at least being mergeable. If it's not merged then try to merge it.
Third run creates PR from backport branch (with merged previous PR) to release branch.
"""
import argparse
from enum import Enum
import logging
import os
import subprocess
import sys
sys.path.append(os.path.dirname(__file__))
from query import Query as RemoteRepo
class CherryPick:
class Status(Enum):
DISCARDED = "discarded"
NOT_INITIATED = "not started"
FIRST_MERGEABLE = "waiting for 1st stage"
FIRST_CONFLICTS = "conflicts on 1st stage"
SECOND_MERGEABLE = "waiting for 2nd stage"
SECOND_CONFLICTS = "conflicts on 2nd stage"
MERGED = "backported"
def _run(self, args):
out = subprocess.check_output(args).rstrip()
logging.debug(out)
return out
def __init__(self, token, owner, name, team, pr_number, target_branch):
self._gh = RemoteRepo(token, owner=owner, name=name, team=team)
self._pr = self._gh.get_pull_request(pr_number)
self.target_branch = target_branch
self.ssh_url = self._gh.ssh_url
# TODO: check if pull-request is merged.
self.update_pr_branch(self._pr, self.target_branch)
def update_pr_branch(self, pr_data, target_branch):
"""The method is here to avoid unnecessary creation of new objects"""
self._pr = pr_data
self.target_branch = target_branch
self.merge_commit_oid = self._pr["mergeCommit"]["oid"]
self.backport_branch = f"backport/{target_branch}/{pr_data['number']}"
self.cherrypick_branch = f"cherrypick/{target_branch}/{self.merge_commit_oid}"
def getCherryPickPullRequest(self):
return self._gh.find_pull_request(
base=self.backport_branch, head=self.cherrypick_branch
)
def createCherryPickPullRequest(self, repo_path):
DESCRIPTION = (
"This pull-request is a first step of an automated backporting.\n"
"It contains changes like after calling a local command `git cherry-pick`.\n"
"If you intend to continue backporting this changes, then resolve all conflicts if any.\n"
"Otherwise, if you do not want to backport them, then just close this pull-request.\n"
"\n"
"The check results does not matter at this step - you can safely ignore them.\n"
"Also this pull-request will be merged automatically as it reaches the mergeable state, but you always can merge it manually.\n"
)
# FIXME: replace with something better than os.system()
git_prefix = [
"git",
"-C",
repo_path,
"-c",
"user.email=robot-clickhouse@yandex-team.ru",
"-c",
"user.name=robot-clickhouse",
]
base_commit_oid = self._pr["mergeCommit"]["parents"]["nodes"][0]["oid"]
# Create separate branch for backporting, and make it look like real cherry-pick.
self._run(git_prefix + ["checkout", "-f", self.target_branch])
self._run(git_prefix + ["checkout", "-B", self.backport_branch])
self._run(git_prefix + ["merge", "-s", "ours", "--no-edit", base_commit_oid])
# Create secondary branch to allow pull request with cherry-picked commit.
self._run(
git_prefix + ["branch", "-f", self.cherrypick_branch, self.merge_commit_oid]
)
self._run(
git_prefix
+ [
"push",
"-f",
"origin",
"{branch}:{branch}".format(branch=self.backport_branch),
]
)
self._run(
git_prefix
+ [
"push",
"-f",
"origin",
"{branch}:{branch}".format(branch=self.cherrypick_branch),
]
)
# Create pull-request like a local cherry-pick
title = self._pr["title"].replace('"', r"\"")
pr = self._gh.create_pull_request(
source=self.cherrypick_branch,
target=self.backport_branch,
title=(
f'Cherry pick #{self._pr["number"]} '
f"to {self.target_branch}: "
f"{title}"
),
description=f'Original pull-request #{self._pr["number"]}\n\n{DESCRIPTION}',
)
# FIXME: use `team` to leave a single eligible assignee.
self._gh.add_assignee(pr, self._pr["author"])
self._gh.add_assignee(pr, self._pr["mergedBy"])
self._gh.set_label(pr, "do not test")
self._gh.set_label(pr, "pr-cherrypick")
return pr
def mergeCherryPickPullRequest(self, cherrypick_pr):
return self._gh.merge_pull_request(cherrypick_pr["id"])
def getBackportPullRequest(self):
return self._gh.find_pull_request(
base=self.target_branch, head=self.backport_branch
)
def createBackportPullRequest(self, cherrypick_pr, repo_path):
DESCRIPTION = (
"This pull-request is a last step of an automated backporting.\n"
"Treat it as a standard pull-request: look at the checks and resolve conflicts.\n"
"Merge it only if you intend to backport changes to the target branch, otherwise just close it.\n"
)
git_prefix = [
"git",
"-C",
repo_path,
"-c",
"user.email=robot-clickhouse@clickhouse.com",
"-c",
"user.name=robot-clickhouse",
]
title = self._pr["title"].replace('"', r"\"")
pr_title = f"Backport #{self._pr['number']} to {self.target_branch}: {title}"
self._run(git_prefix + ["checkout", "-f", self.backport_branch])
self._run(git_prefix + ["pull", "--ff-only", "origin", self.backport_branch])
self._run(
git_prefix
+ [
"reset",
"--soft",
self._run(
git_prefix
+ [
"merge-base",
"origin/" + self.target_branch,
self.backport_branch,
]
),
]
)
self._run(git_prefix + ["commit", "-a", "--allow-empty", "-m", pr_title])
self._run(
git_prefix
+ [
"push",
"-f",
"origin",
"{branch}:{branch}".format(branch=self.backport_branch),
]
)
pr = self._gh.create_pull_request(
source=self.backport_branch,
target=self.target_branch,
title=pr_title,
description=f"Original pull-request #{self._pr['number']}\n"
f"Cherry-pick pull-request #{cherrypick_pr['number']}\n\n{DESCRIPTION}",
)
# FIXME: use `team` to leave a single eligible assignee.
self._gh.add_assignee(pr, self._pr["author"])
self._gh.add_assignee(pr, self._pr["mergedBy"])
self._gh.set_label(pr, "pr-backport")
return pr
def execute(self, repo_path, dry_run=False):
pr1 = self.getCherryPickPullRequest()
if not pr1:
if not dry_run:
pr1 = self.createCherryPickPullRequest(repo_path)
logging.debug(
"Created PR with cherry-pick of %s to %s: %s",
self._pr["number"],
self.target_branch,
pr1["url"],
)
else:
return CherryPick.Status.NOT_INITIATED
else:
logging.debug(
"Found PR with cherry-pick of %s to %s: %s",
self._pr["number"],
self.target_branch,
pr1["url"],
)
if not pr1["merged"] and pr1["mergeable"] == "MERGEABLE" and not pr1["closed"]:
if not dry_run:
pr1 = self.mergeCherryPickPullRequest(pr1)
logging.debug(
"Merged PR with cherry-pick of %s to %s: %s",
self._pr["number"],
self.target_branch,
pr1["url"],
)
if not pr1["merged"]:
logging.debug(
"Waiting for PR with cherry-pick of %s to %s: %s",
self._pr["number"],
self.target_branch,
pr1["url"],
)
if pr1["closed"]:
return CherryPick.Status.DISCARDED
elif pr1["mergeable"] == "CONFLICTING":
return CherryPick.Status.FIRST_CONFLICTS
else:
return CherryPick.Status.FIRST_MERGEABLE
pr2 = self.getBackportPullRequest()
if not pr2:
if not dry_run:
pr2 = self.createBackportPullRequest(pr1, repo_path)
logging.debug(
"Created PR with backport of %s to %s: %s",
self._pr["number"],
self.target_branch,
pr2["url"],
)
else:
return CherryPick.Status.FIRST_MERGEABLE
else:
logging.debug(
"Found PR with backport of %s to %s: %s",
self._pr["number"],
self.target_branch,
pr2["url"],
)
if pr2["merged"]:
return CherryPick.Status.MERGED
elif pr2["closed"]:
return CherryPick.Status.DISCARDED
elif pr2["mergeable"] == "CONFLICTING":
return CherryPick.Status.SECOND_CONFLICTS
else:
return CherryPick.Status.SECOND_MERGEABLE
if __name__ == "__main__":
logging.basicConfig(format="%(message)s", stream=sys.stdout, level=logging.DEBUG)
parser = argparse.ArgumentParser()
parser.add_argument(
"--token", "-t", type=str, required=True, help="token for Github access"
)
parser.add_argument("--pr", type=str, required=True, help="PR# to cherry-pick")
parser.add_argument(
"--branch",
"-b",
type=str,
required=True,
help="target branch name for cherry-pick",
)
parser.add_argument(
"--repo",
"-r",
type=str,
required=True,
help="path to full repository",
metavar="PATH",
)
args = parser.parse_args()
cp = CherryPick(
args.token, "ClickHouse", "ClickHouse", "core", args.pr, args.branch
)
cp.execute(args.repo)

View File

@ -1,109 +0,0 @@
# -*- coding: utf-8 -*-
import functools
import logging
import os
import re
import git
class RepositoryBase:
def __init__(self, repo_path):
self._repo = git.Repo(repo_path, search_parent_directories=(not repo_path))
# comparator of commits
def cmp(x, y):
if str(x) == str(y):
return 0
if self._repo.is_ancestor(x, y):
return -1
else:
return 1
self.comparator = functools.cmp_to_key(cmp)
def iterate(self, begin, end):
rev_range = f"{begin}...{end}"
for commit in self._repo.iter_commits(rev_range, first_parent=True):
yield commit
class Repository(RepositoryBase):
def __init__(self, repo_path, remote_name, default_branch_name):
super().__init__(repo_path)
self._remote = self._repo.remotes[remote_name]
self._remote.fetch()
self._default = self._remote.refs[default_branch_name]
def get_head_commit(self):
return self._repo.commit(self._default)
def get_release_branches(self):
"""
Returns sorted list of tuples:
* remote branch (git.refs.remote.RemoteReference),
* base commit (git.Commit),
* head (git.Commit)).
List is sorted by commits in ascending order.
"""
release_branches = []
RE_RELEASE_BRANCH_REF = re.compile(r"^refs/remotes/.+/\d+\.\d+$")
for branch in [
r for r in self._remote.refs if RE_RELEASE_BRANCH_REF.match(r.path)
]:
base = self._repo.merge_base(self._default, self._repo.commit(branch))
if not base:
logging.info(
"Branch %s is not based on branch %s. Ignoring.",
branch.path,
self._default,
)
elif len(base) > 1:
logging.info(
"Branch %s has more than one base commit. Ignoring.", branch.path
)
else:
release_branches.append((os.path.basename(branch.name), base[0]))
return sorted(release_branches, key=lambda x: self.comparator(x[1]))
class BareRepository(RepositoryBase):
def __init__(self, repo_path, default_branch_name):
super().__init__(repo_path)
self._default = self._repo.branches[default_branch_name]
def get_release_branches(self):
"""
Returns sorted list of tuples:
* branch (git.refs.head?),
* base commit (git.Commit),
* head (git.Commit)).
List is sorted by commits in ascending order.
"""
release_branches = []
RE_RELEASE_BRANCH_REF = re.compile(r"^refs/heads/\d+\.\d+$")
for branch in [
r for r in self._repo.branches if RE_RELEASE_BRANCH_REF.match(r.path)
]:
base = self._repo.merge_base(self._default, self._repo.commit(branch))
if not base:
logging.info(
"Branch %s is not based on branch %s. Ignoring.",
branch.path,
self._default,
)
elif len(base) > 1:
logging.info(
"Branch %s has more than one base commit. Ignoring.", branch.path
)
else:
release_branches.append((os.path.basename(branch.name), base[0]))
return sorted(release_branches, key=lambda x: self.comparator(x[1]))

View File

@ -1,56 +0,0 @@
# -*- coding: utf-8 -*-
class Description:
"""Parsed description representation"""
MAP_CATEGORY_TO_LABEL = {
"New Feature": "pr-feature",
"Bug Fix": "pr-bugfix",
"Improvement": "pr-improvement",
"Performance Improvement": "pr-performance",
# 'Backward Incompatible Change': doesn't match anything
"Build/Testing/Packaging Improvement": "pr-build",
"Non-significant (changelog entry is not needed)": "pr-non-significant",
"Non-significant (changelog entry is not required)": "pr-non-significant",
"Non-significant": "pr-non-significant",
"Documentation (changelog entry is not required)": "pr-documentation",
# 'Other': doesn't match anything
}
def __init__(self, pull_request):
self.label_name = str()
self._parse(pull_request["bodyText"])
def _parse(self, text):
lines = text.splitlines()
next_category = False
category = str()
for line in lines:
stripped = line.strip()
if not stripped:
continue
if next_category:
category = stripped
next_category = False
category_headers = (
"Category (leave one):",
"Changelog category (leave one):",
"Changelog category:",
"Category:",
)
if stripped in category_headers:
next_category = True
if category in Description.MAP_CATEGORY_TO_LABEL:
self.label_name = Description.MAP_CATEGORY_TO_LABEL[category]
else:
if not category:
print("Cannot find category in pr description")
else:
print(("Unknown category: " + category))

View File

@ -1,532 +0,0 @@
# -*- coding: utf-8 -*-
import json
import inspect
import logging
import time
from urllib3.util.retry import Retry # type: ignore
import requests # type: ignore
from requests.adapters import HTTPAdapter # type: ignore
class Query:
"""
Implements queries to the Github API using GraphQL
"""
_PULL_REQUEST = """
author {{
... on User {{
id
login
}}
}}
baseRepository {{
nameWithOwner
}}
mergeCommit {{
oid
parents(first: {min_page_size}) {{
totalCount
nodes {{
oid
}}
}}
}}
mergedBy {{
... on User {{
id
login
}}
}}
baseRefName
closed
headRefName
id
mergeable
merged
number
title
url
"""
def __init__(self, token, owner, name, team, max_page_size=100, min_page_size=10):
self._PULL_REQUEST = Query._PULL_REQUEST.format(min_page_size=min_page_size)
self._token = token
self._owner = owner
self._name = name
self._team = team
self._session = None
self._max_page_size = max_page_size
self._min_page_size = min_page_size
self.api_costs = {}
repo = self.get_repository()
self._id = repo["id"]
self.ssh_url = repo["sshUrl"]
self.default_branch = repo["defaultBranchRef"]["name"]
self.members = set(self.get_members())
def get_repository(self):
_QUERY = """
repository(owner: "{owner}" name: "{name}") {{
defaultBranchRef {{
name
}}
id
sshUrl
}}
"""
query = _QUERY.format(owner=self._owner, name=self._name)
return self._run(query)["repository"]
def get_members(self):
"""Get all team members for organization
Returns:
members: a map of members' logins to ids
"""
_QUERY = """
organization(login: "{organization}") {{
team(slug: "{team}") {{
members(first: {max_page_size} {next}) {{
pageInfo {{
hasNextPage
endCursor
}}
nodes {{
id
login
}}
}}
}}
}}
"""
members = {}
not_end = True
query = _QUERY.format(
organization=self._owner,
team=self._team,
max_page_size=self._max_page_size,
next="",
)
while not_end:
result = self._run(query)["organization"]["team"]
if result is None:
break
result = result["members"]
not_end = result["pageInfo"]["hasNextPage"]
query = _QUERY.format(
organization=self._owner,
team=self._team,
max_page_size=self._max_page_size,
next=f'after: "{result["pageInfo"]["endCursor"]}"',
)
# Update members with new nodes compatible with py3.8-py3.10
members = {
**members,
**{node["login"]: node["id"] for node in result["nodes"]},
}
return members
def get_pull_request(self, number):
_QUERY = """
repository(owner: "{owner}" name: "{name}") {{
pullRequest(number: {number}) {{
{pull_request_data}
}}
}}
"""
query = _QUERY.format(
owner=self._owner,
name=self._name,
number=number,
pull_request_data=self._PULL_REQUEST,
min_page_size=self._min_page_size,
)
return self._run(query)["repository"]["pullRequest"]
def find_pull_request(self, base, head):
_QUERY = """
repository(owner: "{owner}" name: "{name}") {{
pullRequests(
first: {min_page_size} baseRefName: "{base}" headRefName: "{head}"
) {{
nodes {{
{pull_request_data}
}}
totalCount
}}
}}
"""
query = _QUERY.format(
owner=self._owner,
name=self._name,
base=base,
head=head,
pull_request_data=self._PULL_REQUEST,
min_page_size=self._min_page_size,
)
result = self._run(query)["repository"]["pullRequests"]
if result["totalCount"] > 0:
return result["nodes"][0]
else:
return {}
def find_pull_requests(self, label_name):
"""
Get all pull-requests filtered by label name
"""
_QUERY = """
repository(owner: "{owner}" name: "{name}") {{
pullRequests(first: {min_page_size} labels: "{label_name}" states: OPEN) {{
nodes {{
{pull_request_data}
}}
}}
}}
"""
query = _QUERY.format(
owner=self._owner,
name=self._name,
label_name=label_name,
pull_request_data=self._PULL_REQUEST,
min_page_size=self._min_page_size,
)
return self._run(query)["repository"]["pullRequests"]["nodes"]
def get_pull_requests(self, before_commit):
"""
Get all merged pull-requests from the HEAD of default branch to the last commit (excluding)
"""
_QUERY = """
repository(owner: "{owner}" name: "{name}") {{
defaultBranchRef {{
target {{
... on Commit {{
history(first: {max_page_size} {next}) {{
pageInfo {{
hasNextPage
endCursor
}}
nodes {{
oid
associatedPullRequests(first: {min_page_size}) {{
totalCount
nodes {{
... on PullRequest {{
{pull_request_data}
labels(first: {min_page_size}) {{
totalCount
pageInfo {{
hasNextPage
endCursor
}}
nodes {{
name
color
}}
}}
}}
}}
}}
}}
}}
}}
}}
}}
}}
"""
pull_requests = []
not_end = True
query = _QUERY.format(
owner=self._owner,
name=self._name,
max_page_size=self._max_page_size,
min_page_size=self._min_page_size,
pull_request_data=self._PULL_REQUEST,
next="",
)
while not_end:
result = self._run(query)["repository"]["defaultBranchRef"]["target"][
"history"
]
not_end = result["pageInfo"]["hasNextPage"]
query = _QUERY.format(
owner=self._owner,
name=self._name,
max_page_size=self._max_page_size,
min_page_size=self._min_page_size,
pull_request_data=self._PULL_REQUEST,
next=f'after: "{result["pageInfo"]["endCursor"]}"',
)
for commit in result["nodes"]:
# FIXME: maybe include `before_commit`?
if str(commit["oid"]) == str(before_commit):
not_end = False
break
# TODO: fetch all pull-requests that were merged in a single commit.
assert (
commit["associatedPullRequests"]["totalCount"]
<= self._min_page_size
)
for pull_request in commit["associatedPullRequests"]["nodes"]:
if (
pull_request["baseRepository"]["nameWithOwner"]
== f"{self._owner}/{self._name}"
and pull_request["baseRefName"] == self.default_branch
and pull_request["mergeCommit"]["oid"] == commit["oid"]
):
pull_requests.append(pull_request)
return pull_requests
def create_pull_request(
self, source, target, title, description="", draft=False, can_modify=True
):
_QUERY = """
createPullRequest(input: {{
baseRefName: "{target}",
headRefName: "{source}",
repositoryId: "{id}",
title: "{title}",
body: "{body}",
draft: {draft},
maintainerCanModify: {modify}
}}) {{
pullRequest {{
{pull_request_data}
}}
}}
"""
query = _QUERY.format(
target=target,
source=source,
id=self._id,
title=title,
body=description,
draft="true" if draft else "false",
modify="true" if can_modify else "false",
pull_request_data=self._PULL_REQUEST,
)
return self._run(query, is_mutation=True)["createPullRequest"]["pullRequest"]
def merge_pull_request(self, pr_id):
_QUERY = """
mergePullRequest(input: {{
pullRequestId: "{pr_id}"
}}) {{
pullRequest {{
{pull_request_data}
}}
}}
"""
query = _QUERY.format(pr_id=pr_id, pull_request_data=self._PULL_REQUEST)
return self._run(query, is_mutation=True)["mergePullRequest"]["pullRequest"]
# FIXME: figure out how to add more assignees at once
def add_assignee(self, pr, assignee):
_QUERY = """
addAssigneesToAssignable(input: {{
assignableId: "{id1}",
assigneeIds: "{id2}"
}}) {{
clientMutationId
}}
"""
query = _QUERY.format(id1=pr["id"], id2=assignee["id"])
self._run(query, is_mutation=True)
def set_label(self, pull_request, label_name):
"""
Set label by name to the pull request
Args:
pull_request: JSON object returned by `get_pull_requests()`
label_name (string): label name
"""
_GET_LABEL = """
repository(owner: "{owner}" name: "{name}") {{
labels(first: {max_page_size} {next} query: "{label_name}") {{
pageInfo {{
hasNextPage
endCursor
}}
nodes {{
id
name
color
}}
}}
}}
"""
_SET_LABEL = """
addLabelsToLabelable(input: {{
labelableId: "{pr_id}",
labelIds: "{label_id}"
}}) {{
clientMutationId
}}
"""
labels = []
not_end = True
query = _GET_LABEL.format(
owner=self._owner,
name=self._name,
label_name=label_name,
max_page_size=self._max_page_size,
next="",
)
while not_end:
result = self._run(query)["repository"]["labels"]
not_end = result["pageInfo"]["hasNextPage"]
query = _GET_LABEL.format(
owner=self._owner,
name=self._name,
label_name=label_name,
max_page_size=self._max_page_size,
next=f'after: "{result["pageInfo"]["endCursor"]}"',
)
labels += list(result["nodes"])
if not labels:
return
query = _SET_LABEL.format(pr_id=pull_request["id"], label_id=labels[0]["id"])
self._run(query, is_mutation=True)
@property
def session(self):
if self._session is not None:
return self._session
retries = 5
self._session = requests.Session()
retry = Retry(
total=retries,
read=retries,
connect=retries,
backoff_factor=1,
status_forcelist=(403, 500, 502, 504),
)
adapter = HTTPAdapter(max_retries=retry)
self._session.mount("http://", adapter)
self._session.mount("https://", adapter)
return self._session
def _run(self, query, is_mutation=False):
# Get caller and parameters from the stack to track the progress
frame = inspect.getouterframes(inspect.currentframe(), 2)[1]
caller = frame[3]
f_parameters = inspect.signature(getattr(self, caller)).parameters
parameters = ", ".join(str(frame[0].f_locals[p]) for p in f_parameters)
mutation = ""
if is_mutation:
mutation = ", is mutation"
print(f"---GraphQL request for {caller}({parameters}){mutation}---")
headers = {"Authorization": f"bearer {self._token}"}
if is_mutation:
query = f"""
mutation {{
{query}
}}
"""
else:
query = f"""
query {{
{query}
rateLimit {{
cost
remaining
}}
}}
"""
def request_with_retry(retry=0):
max_retries = 5
# From time to time we face some concrete errors, when it worth to
# retry instead of failing competely
# We should sleep progressively
progressive_sleep = 5 * sum(i + 1 for i in range(retry))
if progressive_sleep:
logging.warning(
"Retry GraphQL request %s time, sleep %s seconds",
retry,
progressive_sleep,
)
time.sleep(progressive_sleep)
response = self.session.post(
"https://api.github.com/graphql", json={"query": query}, headers=headers
)
result = response.json()
if response.status_code == 200:
if "errors" in result:
raise Exception(
f"Errors occurred: {result['errors']}\nOriginal query: {query}"
)
if not is_mutation:
if caller not in self.api_costs:
self.api_costs[caller] = 0
self.api_costs[caller] += result["data"]["rateLimit"]["cost"]
return result["data"]
elif (
response.status_code == 403
and "secondary rate limit" in result["message"]
):
if retry <= max_retries:
logging.warning("Secondary rate limit reached")
return request_with_retry(retry + 1)
elif response.status_code == 502 and "errors" in result:
too_many_data = any(
True
for err in result["errors"]
if "message" in err
and "This may be the result of a timeout" in err["message"]
)
if too_many_data:
logging.warning(
"Too many data is requested, decreasing page size %s by 10%%",
self._max_page_size,
)
self._max_page_size = int(self._max_page_size * 0.9)
return request_with_retry(retry)
data = json.dumps(result, indent=4)
raise Exception(f"Query failed with code {response.status_code}:\n{data}")
return request_with_retry()

View File

@ -1,3 +0,0 @@
# Some scripts for backports implementation
TODO: Remove copy from utils/github

View File

@ -1,10 +1,13 @@
#!/usr/bin/env python
import argparse
import logging
import os.path as p
import re
import subprocess
from typing import List, Optional
logger = logging.getLogger(__name__)
# ^ and $ match subline in `multiple\nlines`
# \A and \Z match only start and end of the whole string
RELEASE_BRANCH_REGEXP = r"\A\d+[.]\d+\Z"
@ -55,6 +58,7 @@ class Runner:
def run(self, cmd: str, cwd: Optional[str] = None, **kwargs) -> str:
if cwd is None:
cwd = self.cwd
logger.debug("Running command: %s", cmd)
return subprocess.check_output(
cmd, shell=True, cwd=cwd, encoding="utf-8", **kwargs
).strip()
@ -70,6 +74,9 @@ class Runner:
return
self._cwd = value
def __call__(self, *args, **kwargs):
return self.run(*args, **kwargs)
git_runner = Runner()
# Set cwd to abs path of git root
@ -109,8 +116,8 @@ class Git:
def update(self):
"""Is used to refresh all attributes after updates, e.g. checkout or commit"""
self.branch = self.run("git branch --show-current")
self.sha = self.run("git rev-parse HEAD")
self.branch = self.run("git branch --show-current") or self.sha
self.sha_short = self.sha[:11]
# The following command shows the most recent tag in a graph
# Format should match TAG_REGEXP

172
tests/ci/github_helper.py Normal file
View File

@ -0,0 +1,172 @@
#!/usr/bin/env python
"""Helper for GitHub API requests"""
import logging
from datetime import date, datetime, timedelta
from pathlib import Path
from os import path as p
from time import sleep
from typing import List, Optional
import github
from github.GithubException import RateLimitExceededException
from github.Issue import Issue
from github.PullRequest import PullRequest
from github.Repository import Repository
CACHE_PATH = p.join(p.dirname(p.realpath(__file__)), "gh_cache")
logger = logging.getLogger(__name__)
PullRequests = List[PullRequest]
Issues = List[Issue]
class GitHub(github.Github):
def __init__(self, *args, **kwargs):
# Define meta attribute
self._cache_path = Path(CACHE_PATH)
# And set Path
super().__init__(*args, **kwargs)
self._retries = 0
# pylint: disable=signature-differs
def search_issues(self, *args, **kwargs) -> Issues: # type: ignore
"""Wrapper around search method with throttling and splitting by date.
We split only by the first"""
splittable = False
for arg, value in kwargs.items():
if arg in ["closed", "created", "merged", "updated"]:
if hasattr(value, "__iter__") and not isinstance(value, str):
assert [True for v in value if isinstance(v, (date, datetime))]
assert len(value) == 2
kwargs[arg] = f"{value[0].isoformat()}..{value[1].isoformat()}"
if not splittable:
# We split only by the first met splittable argument
preserved_arg = arg
preserved_value = value
middle_value = value[0] + (value[1] - value[0]) / 2
splittable = middle_value not in value
continue
assert isinstance(value, (date, datetime, str))
inter_result = [] # type: Issues
for i in range(self.retries):
try:
logger.debug("Search issues, args=%s, kwargs=%s", args, kwargs)
result = super().search_issues(*args, **kwargs)
if result.totalCount == 1000 and splittable:
# The hard limit is 1000. If it's splittable, then we make
# two subrequests requests with less time frames
logger.debug(
"The search result contain exactly 1000 results, "
"splitting %s=%s by middle point %s",
preserved_arg,
kwargs[preserved_arg],
middle_value,
)
kwargs[preserved_arg] = [preserved_value[0], middle_value]
inter_result.extend(self.search_issues(*args, **kwargs))
if isinstance(middle_value, date):
# When middle_value is a date, 2022-01-01..2022-01-03
# is split to 2022-01-01..2022-01-02 and
# 2022-01-02..2022-01-03, so we have results for
# 2022-01-02 twicely. We split it to
# 2022-01-01..2022-01-02 and 2022-01-03..2022-01-03.
# 2022-01-01..2022-01-02 aren't split, see splittable
middle_value += timedelta(days=1)
kwargs[preserved_arg] = [middle_value, preserved_value[1]]
inter_result.extend(self.search_issues(*args, **kwargs))
return inter_result
inter_result.extend(result)
return inter_result
except RateLimitExceededException as e:
if i == self.retries - 1:
exception = e
self.sleep_on_rate_limit()
raise exception
# pylint: enable=signature-differs
def get_pulls_from_search(self, *args, **kwargs) -> PullRequests:
"""The search api returns actually issues, so we need to fetch PullRequests"""
issues = self.search_issues(*args, **kwargs)
repos = {}
prs = [] # type: PullRequests
for issue in issues:
# See https://github.com/PyGithub/PyGithub/issues/2202,
# obj._rawData doesn't spend additional API requests
# pylint: disable=protected-access
repo_url = issue._rawData["repository_url"] # type: ignore
if repo_url not in repos:
repos[repo_url] = issue.repository
prs.append(
self.get_pull_cached(repos[repo_url], issue.number, issue.updated_at)
)
return prs
def sleep_on_rate_limit(self):
for limit, data in self.get_rate_limit().raw_data.items():
if data["remaining"] == 0:
sleep_time = data["reset"] - int(datetime.now().timestamp()) + 1
if sleep_time > 0:
logger.warning(
"Faced rate limit for '%s' requests type, sleeping %s",
limit,
sleep_time,
)
sleep(sleep_time)
return
def get_pull_cached(
self, repo: Repository, number: int, updated_at: Optional[datetime] = None
) -> PullRequest:
pr_cache_file = self.cache_path / f"{number}.pickle"
if updated_at is None:
updated_at = datetime.now() - timedelta(hours=-1)
def _get_pr(path: Path) -> PullRequest:
with open(path, "rb") as prfd:
return self.load(prfd) # type: ignore
if pr_cache_file.is_file():
cached_pr = _get_pr(pr_cache_file)
if updated_at <= cached_pr.updated_at:
logger.debug("Getting PR #%s from cache", number)
return cached_pr
logger.debug("Getting PR #%s from API", number)
for i in range(self.retries):
try:
pr = repo.get_pull(number)
break
except RateLimitExceededException:
if i == self.retries - 1:
raise
self.sleep_on_rate_limit()
logger.debug("Caching PR #%s from API in %s", number, pr_cache_file)
with open(pr_cache_file, "wb") as prfd:
self.dump(pr, prfd) # type: ignore
return pr
@property
def cache_path(self):
return self._cache_path
@cache_path.setter
def cache_path(self, value: str):
self._cache_path = Path(value)
if self._cache_path.exists():
assert self._cache_path.is_dir()
else:
self._cache_path.mkdir(parents=True)
@property
def retries(self):
if self._retries == 0:
self._retries = 3
return self._retries
@retries.setter
def retries(self, value: int):
self._retries = value

View File

@ -2099,7 +2099,7 @@ if __name__ == "__main__":
group.add_argument(
"--backward-compatibility-check",
action="store_true",
help="Run tests for further backwoard compatibility testing by ignoring all"
help="Run tests for further backward compatibility testing by ignoring all"
"drop queries in tests for collecting data from new version of server",
)
parser.add_argument(

View File

@ -237,6 +237,18 @@ def enable_consistent_hash_plugin(rabbitmq_id):
return p.returncode == 0
def extract_test_name(base_path):
"""Extracts the name of the test based to a path to its test*.py file
Must be unique in each test directory (because it's used to make instances dir and to stop docker containers from previous run)
"""
name = p.basename(base_path)
if name == "test.py":
name = ""
elif name.startswith("test_") and name.endswith(".py"):
name = name[len("test_") : (len(name) - len(".py"))]
return name
def get_instances_dir():
if (
"INTEGRATION_TESTS_RUN_ID" in os.environ
@ -274,7 +286,7 @@ class ClickHouseCluster:
logging.debug("ENV %40s %s" % (param, os.environ[param]))
self.base_path = base_path
self.base_dir = p.dirname(base_path)
self.name = name if name is not None else ""
self.name = name if name is not None else extract_test_name(base_path)
self.base_config_dir = base_config_dir or os.environ.get(
"CLICKHOUSE_TESTS_BASE_CONFIG_DIR", "/etc/clickhouse-server/"

View File

@ -2,7 +2,7 @@ import pytest
from helpers.cluster import ClickHouseCluster
cluster = ClickHouseCluster(__file__, name="aggregate_fixed_key")
cluster = ClickHouseCluster(__file__)
node1 = cluster.add_instance(
"node1",
with_zookeeper=True,

View File

@ -2,7 +2,7 @@ import pytest
from helpers.cluster import ClickHouseCluster
cluster = ClickHouseCluster(__file__, name="aggregate_state")
cluster = ClickHouseCluster(__file__)
node1 = cluster.add_instance(
"node1",
with_zookeeper=False,

View File

@ -1,7 +1,7 @@
import pytest
from helpers.cluster import ClickHouseCluster
cluster = ClickHouseCluster(__file__, name="convert_ordinary")
cluster = ClickHouseCluster(__file__)
node = cluster.add_instance(
"node",
image="yandex/clickhouse-server",

View File

@ -2,7 +2,7 @@ import pytest
from helpers.cluster import ClickHouseCluster
cluster = ClickHouseCluster(__file__, name="cte_distributed")
cluster = ClickHouseCluster(__file__)
node1 = cluster.add_instance("node1", with_zookeeper=False)
node2 = cluster.add_instance(
"node2",

View File

@ -5,7 +5,7 @@
import pytest
from helpers.cluster import ClickHouseCluster
cluster = ClickHouseCluster(__file__, name="skipping_indices")
cluster = ClickHouseCluster(__file__)
node = cluster.add_instance(
"node",
image="yandex/clickhouse-server",

View File

@ -2,7 +2,7 @@ import pytest
from helpers.cluster import ClickHouseCluster
cluster = ClickHouseCluster(__file__, name="detach")
cluster = ClickHouseCluster(__file__)
# Version 21.6.3.14 has incompatible partition id for tables with UUID in partition key.
node_21_6 = cluster.add_instance(
"node_21_6",

View File

@ -6,7 +6,7 @@ import pytest
from helpers.cluster import ClickHouseCluster
cluster = ClickHouseCluster(__file__, name="insert_profile_events")
cluster = ClickHouseCluster(__file__)
upstream_node = cluster.add_instance("upstream_node")
old_node = cluster.add_instance(
"old_node",

View File

@ -2,7 +2,7 @@ import pytest
from helpers.cluster import ClickHouseCluster
cluster = ClickHouseCluster(__file__, name="aggregate_alias_column")
cluster = ClickHouseCluster(__file__)
node1 = cluster.add_instance("node1", with_zookeeper=False)
node2 = cluster.add_instance(
"node2",

View File

@ -2,7 +2,7 @@ import pytest
from helpers.cluster import ClickHouseCluster
cluster = ClickHouseCluster(__file__, name="short_strings")
cluster = ClickHouseCluster(__file__)
node1 = cluster.add_instance(
"node1",
with_zookeeper=False,

View File

@ -197,7 +197,10 @@ def test_store_cleanup(started_cluster):
node1.exec_in_container(["mkdir", f"{path_to_data}/store/kek"])
node1.exec_in_container(["touch", f"{path_to_data}/store/12"])
try:
node1.exec_in_container(["mkdir", f"{path_to_data}/store/456"])
except Exception as e:
print("Failed to create 456/:", str(e))
node1.exec_in_container(["mkdir", f"{path_to_data}/store/456/testgarbage"])
node1.exec_in_container(
["mkdir", f"{path_to_data}/store/456/30000000-1000-4000-8000-000000000003"]
@ -218,7 +221,7 @@ def test_store_cleanup(started_cluster):
timeout=60,
look_behind_lines=1000,
)
node1.wait_for_log_line("directories from store")
node1.wait_for_log_line("directories from store", look_behind_lines=1000)
store = node1.exec_in_container(["ls", f"{path_to_data}/store"])
assert "100" in store

View File

@ -18,7 +18,7 @@ sys.path.insert(0, os.path.dirname(CURRENT_TEST_DIR))
COPYING_FAIL_PROBABILITY = 0.2
MOVING_FAIL_PROBABILITY = 0.2
cluster = ClickHouseCluster(__file__, name="copier_test")
cluster = ClickHouseCluster(__file__)
def generateRandomString(count):

View File

@ -12,7 +12,7 @@ import docker
CURRENT_TEST_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.dirname(CURRENT_TEST_DIR))
cluster = ClickHouseCluster(__file__, name="copier_test_three_nodes")
cluster = ClickHouseCluster(__file__)
@pytest.fixture(scope="module")

View File

@ -19,7 +19,7 @@ sys.path.insert(0, os.path.dirname(CURRENT_TEST_DIR))
COPYING_FAIL_PROBABILITY = 0.1
MOVING_FAIL_PROBABILITY = 0.1
cluster = ClickHouseCluster(__file__, name="copier_test_trivial")
cluster = ClickHouseCluster(__file__)
def generateRandomString(count):

View File

@ -12,7 +12,7 @@ import docker
CURRENT_TEST_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.dirname(CURRENT_TEST_DIR))
cluster = ClickHouseCluster(__file__, name="copier_test_two_nodes")
cluster = ClickHouseCluster(__file__)
@pytest.fixture(scope="module")

View File

@ -24,7 +24,7 @@ def setup_module(module):
global complex_tester
global ranged_tester
cluster = ClickHouseCluster(__file__, name=test_name)
cluster = ClickHouseCluster(__file__)
SOURCE = SourceCassandra(
"Cassandra",

View File

@ -38,7 +38,7 @@ def setup_module(module):
ranged_tester.create_dictionaries(SOURCE)
# Since that all .xml configs were created
cluster = ClickHouseCluster(__file__, name=test_name)
cluster = ClickHouseCluster(__file__)
main_configs = []
main_configs.append(os.path.join("configs", "disable_ssl_verification.xml"))

View File

@ -38,7 +38,7 @@ def setup_module(module):
ranged_tester.create_dictionaries(SOURCE)
# Since that all .xml configs were created
cluster = ClickHouseCluster(__file__, name=test_name)
cluster = ClickHouseCluster(__file__)
main_configs = []
main_configs.append(os.path.join("configs", "disable_ssl_verification.xml"))

View File

@ -38,7 +38,7 @@ def setup_module(module):
ranged_tester.create_dictionaries(SOURCE)
# Since that all .xml configs were created
cluster = ClickHouseCluster(__file__, name=test_name)
cluster = ClickHouseCluster(__file__)
main_configs = []
main_configs.append(os.path.join("configs", "disable_ssl_verification.xml"))

View File

@ -38,7 +38,7 @@ def setup_module(module):
ranged_tester.create_dictionaries(SOURCE)
# Since that all .xml configs were created
cluster = ClickHouseCluster(__file__, name=test_name)
cluster = ClickHouseCluster(__file__)
main_configs = []
main_configs.append(os.path.join("configs", "disable_ssl_verification.xml"))

View File

@ -36,7 +36,7 @@ def setup_module(module):
ranged_tester.create_dictionaries(SOURCE)
# Since that all .xml configs were created
cluster = ClickHouseCluster(__file__, name=test_name)
cluster = ClickHouseCluster(__file__)
main_configs = []
main_configs.append(os.path.join("configs", "disable_ssl_verification.xml"))

View File

@ -36,7 +36,7 @@ def setup_module(module):
ranged_tester.create_dictionaries(SOURCE)
# Since that all .xml configs were created
cluster = ClickHouseCluster(__file__, name=test_name)
cluster = ClickHouseCluster(__file__)
main_configs = []
main_configs.append(os.path.join("configs", "disable_ssl_verification.xml"))

View File

@ -38,7 +38,7 @@ def setup_module(module):
ranged_tester.create_dictionaries(SOURCE)
# Since that all .xml configs were created
cluster = ClickHouseCluster(__file__, name=test_name)
cluster = ClickHouseCluster(__file__)
main_configs = []
main_configs.append(os.path.join("configs", "disable_ssl_verification.xml"))

View File

@ -24,7 +24,7 @@ def setup_module(module):
global complex_tester
global ranged_tester
cluster = ClickHouseCluster(__file__, name=test_name)
cluster = ClickHouseCluster(__file__)
SOURCE = SourceMongo(
"MongoDB",
"localhost",

View File

@ -24,7 +24,7 @@ def setup_module(module):
global complex_tester
global ranged_tester
cluster = ClickHouseCluster(__file__, name=test_name)
cluster = ClickHouseCluster(__file__)
SOURCE = SourceMongoURI(
"MongoDB",

View File

@ -24,7 +24,7 @@ def setup_module(module):
global complex_tester
global ranged_tester
cluster = ClickHouseCluster(__file__, name=test_name)
cluster = ClickHouseCluster(__file__)
SOURCE = SourceMySQL(
"MySQL",

View File

@ -2,7 +2,7 @@ import pytest
from helpers.cluster import ClickHouseCluster
import redis
cluster = ClickHouseCluster(__file__, name="long")
cluster = ClickHouseCluster(__file__)
node = cluster.add_instance("node", with_redis=True)

View File

@ -5,7 +5,7 @@ from helpers.cluster import ClickHouseCluster
from helpers.cluster import ClickHouseKiller
from helpers.network import PartitionManager
cluster = ClickHouseCluster(__file__, name="reading")
cluster = ClickHouseCluster(__file__)
dictionary_node = cluster.add_instance("dictionary_node", stay_alive=True)
main_node = cluster.add_instance(

View File

@ -7,7 +7,7 @@ import pytest
from helpers.cluster import ClickHouseCluster
from helpers.test_tools import TSV
cluster = ClickHouseCluster(__file__, name="string")
cluster = ClickHouseCluster(__file__)
dictionary_node = cluster.add_instance("dictionary_node", stay_alive=True)
main_node = cluster.add_instance(

View File

@ -5,7 +5,7 @@ from helpers.cluster import ClickHouseCluster
from helpers.cluster import ClickHouseKiller
from helpers.network import PartitionManager
cluster = ClickHouseCluster(__file__, name="default")
cluster = ClickHouseCluster(__file__)
dictionary_node = cluster.add_instance("dictionary_node", stay_alive=True)
main_node = cluster.add_instance(

View File

@ -8,8 +8,6 @@ from helpers.cluster import ClickHouseCluster
from helpers.network import PartitionManager
from helpers.test_tools import TSV
cluster = ClickHouseCluster(__file__)
NODES = {"node" + str(i): None for i in (1, 2)}
IS_DEBUG = False

View File

@ -3,7 +3,7 @@ import pytest
from helpers.cluster import ClickHouseCluster
import time
cluster = ClickHouseCluster(__file__, name="test_keeper_4lw_allow_list")
cluster = ClickHouseCluster(__file__)
node1 = cluster.add_instance(
"node1", main_configs=["configs/keeper_config_with_allow_list.xml"], stay_alive=True
)

View File

@ -3,7 +3,7 @@ import re
from helpers.cluster import ClickHouseCluster
cluster = ClickHouseCluster(__file__, name="log_quries_probability")
cluster = ClickHouseCluster(__file__)
node = cluster.add_instance("node", with_zookeeper=False)
config = """<clickhouse>

View File

@ -2,7 +2,7 @@ import pytest
from helpers.cluster import ClickHouseCluster
cluster = ClickHouseCluster(__file__, name="log_quries_probability")
cluster = ClickHouseCluster(__file__)
node1 = cluster.add_instance("node1", with_zookeeper=False)
node2 = cluster.add_instance("node2", with_zookeeper=False)

View File

@ -2,7 +2,7 @@ import time
import pytest
from helpers.cluster import ClickHouseCluster
cluster = ClickHouseCluster(__file__, name="password")
cluster = ClickHouseCluster(__file__)
# TODO ACL not implemented in Keeper.
node1 = cluster.add_instance(

View File

@ -9,7 +9,6 @@ TEST_DIR = os.path.dirname(__file__)
cluster = ClickHouseCluster(
__file__,
name="secure",
zookeeper_certfile=os.path.join(TEST_DIR, "configs_secure", "client.crt"),
zookeeper_keyfile=os.path.join(TEST_DIR, "configs_secure", "client.key"),
)

View File

@ -36,14 +36,15 @@ function thread_insert_rollback()
function thread_select()
{
while true; do
# Result of `uniq | wc -l` must be 1 if the first and the last queries got the same result
# The first and the last queries must get the same result
$CLICKHOUSE_CLIENT --multiquery --query "
BEGIN TRANSACTION;
SELECT arraySort(groupArray(n)), arraySort(groupArray(m)), arraySort(groupArray(_part)) FROM mt;
SET throw_on_unsupported_query_inside_transaction=0;
CREATE TEMPORARY TABLE tmp AS SELECT arraySort(groupArray(n)), arraySort(groupArray(m)), arraySort(groupArray(_part)) FROM mt FORMAT Null;
SELECT throwIf((SELECT sum(n) FROM mt) != 0) FORMAT Null;
SELECT throwIf((SELECT count() FROM mt) % 2 != 0) FORMAT Null;
SELECT arraySort(groupArray(n)), arraySort(groupArray(m)), arraySort(groupArray(_part)) FROM mt;
COMMIT;" | uniq | wc -l | grep -v "^1$" ||:
select throwIf((SELECT * FROM tmp) != (SELECT arraySort(groupArray(n)), arraySort(groupArray(m)), arraySort(groupArray(_part)) FROM mt)) FORMAT Null;
COMMIT;"
done
}

Some files were not shown because too many files have changed in this diff Show More