From afef5c6c70888f45f82f0abea448d6199e350f9b Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sat, 10 Aug 2019 20:51:47 +0300 Subject: [PATCH] Added stack protection; added a test --- dbms/src/Common/ErrorCodes.cpp | 1 + dbms/src/Common/checkStackSize.cpp | 62 +++++++++++++++++++ dbms/src/Common/checkStackSize.h | 7 +++ .../ClusterProxy/SelectStreamFactory.cpp | 3 + .../Interpreters/InterpreterInsertQuery.cpp | 2 + .../Interpreters/InterpreterSelectQuery.cpp | 3 + dbms/src/Storages/StorageMerge.cpp | 2 + .../00985_merge_stack_overflow.reference | 0 .../00985_merge_stack_overflow.sql | 11 ++++ 9 files changed, 91 insertions(+) create mode 100644 dbms/src/Common/checkStackSize.cpp create mode 100644 dbms/src/Common/checkStackSize.h create mode 100644 dbms/tests/queries/0_stateless/00985_merge_stack_overflow.reference create mode 100644 dbms/tests/queries/0_stateless/00985_merge_stack_overflow.sql diff --git a/dbms/src/Common/ErrorCodes.cpp b/dbms/src/Common/ErrorCodes.cpp index cd4601e5b3d..4128ddb8edc 100644 --- a/dbms/src/Common/ErrorCodes.cpp +++ b/dbms/src/Common/ErrorCodes.cpp @@ -442,6 +442,7 @@ namespace ErrorCodes extern const int CANNOT_PARSE_DWARF = 465; extern const int INSECURE_PATH = 466; extern const int CANNOT_PARSE_BOOL = 467; + extern const int CANNOT_PTHREAD_ATTR = 468; extern const int KEEPER_EXCEPTION = 999; extern const int POCO_EXCEPTION = 1000; diff --git a/dbms/src/Common/checkStackSize.cpp b/dbms/src/Common/checkStackSize.cpp new file mode 100644 index 00000000000..e7f91bc3330 --- /dev/null +++ b/dbms/src/Common/checkStackSize.cpp @@ -0,0 +1,62 @@ +#include +#include +#include + +#include +#include +#include + + +namespace DB +{ + namespace ErrorCodes + { + extern const int CANNOT_PTHREAD_ATTR; + extern const int LOGICAL_ERROR; + extern const int TOO_DEEP_RECURSION; + } +} + + +static thread_local void * stack_address = nullptr; +static thread_local size_t max_stack_size = 0; + +void checkStackSize() +{ + using namespace DB; + + if (!stack_address) + { + pthread_attr_t attr; + if (0 != pthread_getattr_np(pthread_self(), &attr)) + throwFromErrno("Cannot pthread_getattr_np", ErrorCodes::CANNOT_PTHREAD_ATTR); + + SCOPE_EXIT({ pthread_attr_destroy(&attr); }); + + if (0 != pthread_attr_getstack(&attr, &stack_address, &max_stack_size)) + throwFromErrno("Cannot pthread_getattr_np", ErrorCodes::CANNOT_PTHREAD_ATTR); + } + + const void * frame_address = __builtin_frame_address(0); + uintptr_t int_frame_address = reinterpret_cast(frame_address); + uintptr_t int_stack_address = reinterpret_cast(stack_address); + + /// We assume that stack grows towards lower addresses. And that it starts to grow from the end of a chunk of memory of max_stack_size. + if (int_frame_address > int_stack_address + max_stack_size) + throw Exception("Logical error: frame address is greater than stack begin address", ErrorCodes::LOGICAL_ERROR); + + size_t stack_size = int_stack_address + max_stack_size - int_frame_address; + + /// Just check if we have already eat more than a half of stack size. It's a bit overkill (a half of stack size is wasted). + /// It's safe to assume that overflow in multiplying by two cannot occur. + if (stack_size * 2 > max_stack_size) + { + std::stringstream message; + message << "Stack size too large" + << ". Stack address: " << stack_address + << ", frame address: " << frame_address + << ", stack size: " << stack_size + << ", maximum stack size: " << max_stack_size; + throw Exception(message.str(), ErrorCodes::TOO_DEEP_RECURSION); + } +} diff --git a/dbms/src/Common/checkStackSize.h b/dbms/src/Common/checkStackSize.h new file mode 100644 index 00000000000..355ceed430b --- /dev/null +++ b/dbms/src/Common/checkStackSize.h @@ -0,0 +1,7 @@ +#pragma once + +/** If the stack is large enough and is near its size, throw an exception. + * You can call this function in "heavy" functions that may be called recursively + * to prevent possible stack overflows. + */ +void checkStackSize(); diff --git a/dbms/src/Interpreters/ClusterProxy/SelectStreamFactory.cpp b/dbms/src/Interpreters/ClusterProxy/SelectStreamFactory.cpp index ba0571d1863..9e49d302100 100644 --- a/dbms/src/Interpreters/ClusterProxy/SelectStreamFactory.cpp +++ b/dbms/src/Interpreters/ClusterProxy/SelectStreamFactory.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -58,6 +59,8 @@ namespace BlockInputStreamPtr createLocalStream(const ASTPtr & query_ast, const Context & context, QueryProcessingStage::Enum processed_stage) { + checkStackSize(); + InterpreterSelectQuery interpreter{query_ast, context, SelectQueryOptions(processed_stage)}; BlockInputStreamPtr stream = interpreter.execute().in; diff --git a/dbms/src/Interpreters/InterpreterInsertQuery.cpp b/dbms/src/Interpreters/InterpreterInsertQuery.cpp index dbb90028316..648f13bec62 100644 --- a/dbms/src/Interpreters/InterpreterInsertQuery.cpp +++ b/dbms/src/Interpreters/InterpreterInsertQuery.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -39,6 +40,7 @@ InterpreterInsertQuery::InterpreterInsertQuery( const ASTPtr & query_ptr_, const Context & context_, bool allow_materialized_) : query_ptr(query_ptr_), context(context_), allow_materialized(allow_materialized_) { + checkStackSize(); } diff --git a/dbms/src/Interpreters/InterpreterSelectQuery.cpp b/dbms/src/Interpreters/InterpreterSelectQuery.cpp index 9682d0e29e4..aea37c7fa36 100644 --- a/dbms/src/Interpreters/InterpreterSelectQuery.cpp +++ b/dbms/src/Interpreters/InterpreterSelectQuery.cpp @@ -58,6 +58,7 @@ #include #include #include +#include #include #include #include @@ -211,6 +212,8 @@ InterpreterSelectQuery::InterpreterSelectQuery( , input(input_) , log(&Logger::get("InterpreterSelectQuery")) { + checkStackSize(); + initSettings(); const Settings & settings = context.getSettingsRef(); diff --git a/dbms/src/Storages/StorageMerge.cpp b/dbms/src/Storages/StorageMerge.cpp index 3487a1becf5..913ab6af4ea 100644 --- a/dbms/src/Storages/StorageMerge.cpp +++ b/dbms/src/Storages/StorageMerge.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -387,6 +388,7 @@ StorageMerge::StorageListWithLocks StorageMerge::getSelectedTables(const ASTPtr DatabaseIteratorPtr StorageMerge::getDatabaseIterator(const Context & context) const { + checkStackSize(); auto database = context.getDatabase(source_database); auto table_name_match = [this](const String & table_name_) { return table_name_regexp.match(table_name_); }; return database->getIterator(global_context, table_name_match); diff --git a/dbms/tests/queries/0_stateless/00985_merge_stack_overflow.reference b/dbms/tests/queries/0_stateless/00985_merge_stack_overflow.reference new file mode 100644 index 00000000000..e69de29bb2d diff --git a/dbms/tests/queries/0_stateless/00985_merge_stack_overflow.sql b/dbms/tests/queries/0_stateless/00985_merge_stack_overflow.sql new file mode 100644 index 00000000000..3a3e5640a38 --- /dev/null +++ b/dbms/tests/queries/0_stateless/00985_merge_stack_overflow.sql @@ -0,0 +1,11 @@ +DROP TABLE IF EXISTS merge1; +DROP TABLE IF EXISTS merge2; + +CREATE TABLE IF NOT EXISTS merge1 (x UInt64) ENGINE = Merge(currentDatabase(), '^merge\\d$'); +CREATE TABLE IF NOT EXISTS merge2 (x UInt64) ENGINE = Merge(currentDatabase(), '^merge\\d$'); + +SELECT * FROM merge1; -- { serverError 306 } +SELECT * FROM merge2; -- { serverError 306 } + +DROP TABLE merge1; +DROP TABLE merge2;