mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
Added stack protection; added a test
This commit is contained in:
parent
c80aeb0ef1
commit
afef5c6c70
@ -442,6 +442,7 @@ namespace ErrorCodes
|
||||
extern const int CANNOT_PARSE_DWARF = 465;
|
||||
extern const int INSECURE_PATH = 466;
|
||||
extern const int CANNOT_PARSE_BOOL = 467;
|
||||
extern const int CANNOT_PTHREAD_ATTR = 468;
|
||||
|
||||
extern const int KEEPER_EXCEPTION = 999;
|
||||
extern const int POCO_EXCEPTION = 1000;
|
||||
|
62
dbms/src/Common/checkStackSize.cpp
Normal file
62
dbms/src/Common/checkStackSize.cpp
Normal file
@ -0,0 +1,62 @@
|
||||
#include <Common/checkStackSize.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <ext/scope_guard.h>
|
||||
|
||||
#include <pthread.h>
|
||||
#include <cstdint>
|
||||
#include <sstream>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int CANNOT_PTHREAD_ATTR;
|
||||
extern const int LOGICAL_ERROR;
|
||||
extern const int TOO_DEEP_RECURSION;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static thread_local void * stack_address = nullptr;
|
||||
static thread_local size_t max_stack_size = 0;
|
||||
|
||||
void checkStackSize()
|
||||
{
|
||||
using namespace DB;
|
||||
|
||||
if (!stack_address)
|
||||
{
|
||||
pthread_attr_t attr;
|
||||
if (0 != pthread_getattr_np(pthread_self(), &attr))
|
||||
throwFromErrno("Cannot pthread_getattr_np", ErrorCodes::CANNOT_PTHREAD_ATTR);
|
||||
|
||||
SCOPE_EXIT({ pthread_attr_destroy(&attr); });
|
||||
|
||||
if (0 != pthread_attr_getstack(&attr, &stack_address, &max_stack_size))
|
||||
throwFromErrno("Cannot pthread_getattr_np", ErrorCodes::CANNOT_PTHREAD_ATTR);
|
||||
}
|
||||
|
||||
const void * frame_address = __builtin_frame_address(0);
|
||||
uintptr_t int_frame_address = reinterpret_cast<uintptr_t>(frame_address);
|
||||
uintptr_t int_stack_address = reinterpret_cast<uintptr_t>(stack_address);
|
||||
|
||||
/// We assume that stack grows towards lower addresses. And that it starts to grow from the end of a chunk of memory of max_stack_size.
|
||||
if (int_frame_address > int_stack_address + max_stack_size)
|
||||
throw Exception("Logical error: frame address is greater than stack begin address", ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
size_t stack_size = int_stack_address + max_stack_size - int_frame_address;
|
||||
|
||||
/// Just check if we have already eat more than a half of stack size. It's a bit overkill (a half of stack size is wasted).
|
||||
/// It's safe to assume that overflow in multiplying by two cannot occur.
|
||||
if (stack_size * 2 > max_stack_size)
|
||||
{
|
||||
std::stringstream message;
|
||||
message << "Stack size too large"
|
||||
<< ". Stack address: " << stack_address
|
||||
<< ", frame address: " << frame_address
|
||||
<< ", stack size: " << stack_size
|
||||
<< ", maximum stack size: " << max_stack_size;
|
||||
throw Exception(message.str(), ErrorCodes::TOO_DEEP_RECURSION);
|
||||
}
|
||||
}
|
7
dbms/src/Common/checkStackSize.h
Normal file
7
dbms/src/Common/checkStackSize.h
Normal file
@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
/** If the stack is large enough and is near its size, throw an exception.
|
||||
* You can call this function in "heavy" functions that may be called recursively
|
||||
* to prevent possible stack overflows.
|
||||
*/
|
||||
void checkStackSize();
|
@ -6,6 +6,7 @@
|
||||
#include <Storages/StorageReplicatedMergeTree.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/ProfileEvents.h>
|
||||
#include <Common/checkStackSize.h>
|
||||
#include <TableFunctions/TableFunctionFactory.h>
|
||||
|
||||
#include <common/logger_useful.h>
|
||||
@ -58,6 +59,8 @@ namespace
|
||||
|
||||
BlockInputStreamPtr createLocalStream(const ASTPtr & query_ast, const Context & context, QueryProcessingStage::Enum processed_stage)
|
||||
{
|
||||
checkStackSize();
|
||||
|
||||
InterpreterSelectQuery interpreter{query_ast, context, SelectQueryOptions(processed_stage)};
|
||||
BlockInputStreamPtr stream = interpreter.execute().in;
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include <IO/ReadBufferFromMemory.h>
|
||||
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <Common/checkStackSize.h>
|
||||
|
||||
#include <DataStreams/AddingDefaultBlockOutputStream.h>
|
||||
#include <DataStreams/AddingDefaultsBlockInputStream.h>
|
||||
@ -39,6 +40,7 @@ InterpreterInsertQuery::InterpreterInsertQuery(
|
||||
const ASTPtr & query_ptr_, const Context & context_, bool allow_materialized_)
|
||||
: query_ptr(query_ptr_), context(context_), allow_materialized(allow_materialized_)
|
||||
{
|
||||
checkStackSize();
|
||||
}
|
||||
|
||||
|
||||
|
@ -58,6 +58,7 @@
|
||||
#include <Core/Types.h>
|
||||
#include <Columns/Collator.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <Common/checkStackSize.h>
|
||||
#include <Parsers/queryToString.h>
|
||||
#include <ext/map.h>
|
||||
#include <memory>
|
||||
@ -211,6 +212,8 @@ InterpreterSelectQuery::InterpreterSelectQuery(
|
||||
, input(input_)
|
||||
, log(&Logger::get("InterpreterSelectQuery"))
|
||||
{
|
||||
checkStackSize();
|
||||
|
||||
initSettings();
|
||||
const Settings & settings = context.getSettingsRef();
|
||||
|
||||
|
@ -23,6 +23,7 @@
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <Common/checkStackSize.h>
|
||||
#include <Databases/IDatabase.h>
|
||||
#include <Core/SettingsCommon.h>
|
||||
#include <ext/range.h>
|
||||
@ -387,6 +388,7 @@ StorageMerge::StorageListWithLocks StorageMerge::getSelectedTables(const ASTPtr
|
||||
|
||||
DatabaseIteratorPtr StorageMerge::getDatabaseIterator(const Context & context) const
|
||||
{
|
||||
checkStackSize();
|
||||
auto database = context.getDatabase(source_database);
|
||||
auto table_name_match = [this](const String & table_name_) { return table_name_regexp.match(table_name_); };
|
||||
return database->getIterator(global_context, table_name_match);
|
||||
|
@ -0,0 +1,11 @@
|
||||
DROP TABLE IF EXISTS merge1;
|
||||
DROP TABLE IF EXISTS merge2;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS merge1 (x UInt64) ENGINE = Merge(currentDatabase(), '^merge\\d$');
|
||||
CREATE TABLE IF NOT EXISTS merge2 (x UInt64) ENGINE = Merge(currentDatabase(), '^merge\\d$');
|
||||
|
||||
SELECT * FROM merge1; -- { serverError 306 }
|
||||
SELECT * FROM merge2; -- { serverError 306 }
|
||||
|
||||
DROP TABLE merge1;
|
||||
DROP TABLE merge2;
|
Loading…
Reference in New Issue
Block a user