diff --git a/dbms/src/AggregateFunctions/AggregateFunctionAggThrow.cpp b/dbms/src/AggregateFunctions/AggregateFunctionAggThrow.cpp new file mode 100644 index 00000000000..2bf00676d77 --- /dev/null +++ b/dbms/src/AggregateFunctions/AggregateFunctionAggThrow.cpp @@ -0,0 +1,119 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int AGGREGATE_FUNCTION_THROW; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} + +namespace +{ + +struct AggregateFunctionThrowData +{ + bool allocated; + + AggregateFunctionThrowData() : allocated(true) {} + ~AggregateFunctionThrowData() + { + volatile bool * allocated_ptr = &allocated; + + if (*allocated_ptr) + *allocated_ptr = false; + else + abort(); + } +}; + +/** Throw on creation with probability specified in parameter. + * It will check correct destruction of the state. + * This is intended to check for exception safety. + */ +class AggregateFunctionThrow final : public IAggregateFunctionDataHelper +{ +private: + Float64 throw_probability; + +public: + AggregateFunctionThrow(const DataTypes & argument_types_, const Array & parameters_, Float64 throw_probability_) + : IAggregateFunctionDataHelper(argument_types_, parameters_), throw_probability(throw_probability_) {} + + String getName() const override + { + return "aggThrow"; + } + + DataTypePtr getReturnType() const override + { + return std::make_shared(); + } + + void create(AggregateDataPtr place) const override + { + if (std::uniform_real_distribution<>(0.0, 1.0)(thread_local_rng) <= throw_probability) + throw Exception("Aggregate function " + getName() + " has thrown exception successfully", ErrorCodes::AGGREGATE_FUNCTION_THROW); + + new (place) Data; + } + + void destroy(AggregateDataPtr place) const noexcept override + { + data(place).~Data(); + } + + void add(AggregateDataPtr, const IColumn **, size_t, Arena *) const override + { + } + + void merge(AggregateDataPtr, ConstAggregateDataPtr, Arena *) const override + { + } + + void serialize(ConstAggregateDataPtr, WriteBuffer & buf) const override + { + char c = 0; + buf.write(c); + } + + void deserialize(AggregateDataPtr, ReadBuffer & buf, Arena *) const override + { + char c = 0; + buf.read(c); + } + + void insertResultInto(ConstAggregateDataPtr, IColumn & to) const override + { + to.insertDefault(); + } +}; + +} + +void registerAggregateFunctionAggThrow(AggregateFunctionFactory & factory) +{ + factory.registerFunction("aggThrow", [](const std::string & name, const DataTypes & argument_types, const Array & parameters) + { + Float64 throw_probability = 1.0; + if (parameters.size() == 1) + throw_probability = parameters[0].safeGet(); + else if (parameters.size() > 1) + throw Exception("Aggregate function " + name + " cannot have more than one parameter", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + return std::make_shared(argument_types, parameters, throw_probability); + }); +} + +} + diff --git a/dbms/src/AggregateFunctions/AggregateFunctionResample.h b/dbms/src/AggregateFunctions/AggregateFunctionResample.h index 33b03fcdee0..0f348899884 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionResample.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionResample.h @@ -100,7 +100,18 @@ public: void create(AggregateDataPtr place) const override { for (size_t i = 0; i < total; ++i) - nested_function->create(place + i * size_of_data); + { + try + { + nested_function->create(place + i * size_of_data); + } + catch (...) + { + for (size_t j = 0; j < i; ++j) + nested_function->destroy(place + j * size_of_data); + throw; + } + } } void destroy(AggregateDataPtr place) const noexcept override diff --git a/dbms/src/AggregateFunctions/FactoryHelpers.h b/dbms/src/AggregateFunctions/FactoryHelpers.h index 183116df54e..aff7ff0ff36 100644 --- a/dbms/src/AggregateFunctions/FactoryHelpers.h +++ b/dbms/src/AggregateFunctions/FactoryHelpers.h @@ -23,13 +23,13 @@ inline void assertNoParameters(const std::string & name, const Array & parameter inline void assertUnary(const std::string & name, const DataTypes & argument_types) { if (argument_types.size() != 1) - throw Exception("Aggregate function " + name + " require single argument", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + throw Exception("Aggregate function " + name + " requires single argument", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); } inline void assertBinary(const std::string & name, const DataTypes & argument_types) { if (argument_types.size() != 2) - throw Exception("Aggregate function " + name + " require two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + throw Exception("Aggregate function " + name + " requires two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); } template diff --git a/dbms/src/AggregateFunctions/IAggregateFunction.h b/dbms/src/AggregateFunctions/IAggregateFunction.h index 94dcf4cbcab..4811937a08b 100644 --- a/dbms/src/AggregateFunctions/IAggregateFunction.h +++ b/dbms/src/AggregateFunctions/IAggregateFunction.h @@ -213,7 +213,7 @@ protected: public: IAggregateFunctionDataHelper(const DataTypes & argument_types_, const Array & parameters_) - : IAggregateFunctionHelper(argument_types_, parameters_) {} + : IAggregateFunctionHelper(argument_types_, parameters_) {} void create(AggregateDataPtr place) const override { diff --git a/dbms/src/AggregateFunctions/registerAggregateFunctions.cpp b/dbms/src/AggregateFunctions/registerAggregateFunctions.cpp index d36603df081..a4fc41e9c06 100644 --- a/dbms/src/AggregateFunctions/registerAggregateFunctions.cpp +++ b/dbms/src/AggregateFunctions/registerAggregateFunctions.cpp @@ -42,6 +42,7 @@ void registerAggregateFunctions() registerAggregateFunctionSimpleLinearRegression(factory); registerAggregateFunctionMoving(factory); registerAggregateFunctionCategoricalIV(factory); + registerAggregateFunctionAggThrow(factory); } { diff --git a/dbms/src/AggregateFunctions/registerAggregateFunctions.h b/dbms/src/AggregateFunctions/registerAggregateFunctions.h index 897e5d52a61..88cdf4a504d 100644 --- a/dbms/src/AggregateFunctions/registerAggregateFunctions.h +++ b/dbms/src/AggregateFunctions/registerAggregateFunctions.h @@ -34,6 +34,7 @@ void registerAggregateFunctionEntropy(AggregateFunctionFactory &); void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory &); void registerAggregateFunctionMoving(AggregateFunctionFactory &); void registerAggregateFunctionCategoricalIV(AggregateFunctionFactory &); +void registerAggregateFunctionAggThrow(AggregateFunctionFactory &); class AggregateFunctionCombinatorFactory; void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &); diff --git a/dbms/src/Common/ErrorCodes.cpp b/dbms/src/Common/ErrorCodes.cpp index 25d1b015a03..7a30c832759 100644 --- a/dbms/src/Common/ErrorCodes.cpp +++ b/dbms/src/Common/ErrorCodes.cpp @@ -477,6 +477,7 @@ namespace ErrorCodes extern const int CANNOT_CREATE_DICTIONARY_FROM_METADATA = 500; extern const int CANNOT_CREATE_DATABASE = 501; extern const int CANNOT_SIGQUEUE = 502; + extern const int AGGREGATE_FUNCTION_THROW = 503; extern const int KEEPER_EXCEPTION = 999; extern const int POCO_EXCEPTION = 1000; diff --git a/dbms/src/Functions/array/arrayReduce.cpp b/dbms/src/Functions/array/arrayReduce.cpp index 103d0fe5fa8..9e8b2ddc3df 100644 --- a/dbms/src/Functions/array/arrayReduce.cpp +++ b/dbms/src/Functions/array/arrayReduce.cpp @@ -168,7 +168,8 @@ void FunctionArrayReduce::executeImpl(Block & block, const ColumnNumbers & argum } catch (...) { - agg_func.destroy(places[i]); + for (size_t j = 0; j < i; ++j) + agg_func.destroy(places[j]); throw; } } diff --git a/dbms/tests/queries/0_stateless/01052_array_reduce_exception.reference b/dbms/tests/queries/0_stateless/01052_array_reduce_exception.reference new file mode 100644 index 00000000000..e69de29bb2d diff --git a/dbms/tests/queries/0_stateless/01052_array_reduce_exception.sql b/dbms/tests/queries/0_stateless/01052_array_reduce_exception.sql new file mode 100644 index 00000000000..71c030a055c --- /dev/null +++ b/dbms/tests/queries/0_stateless/01052_array_reduce_exception.sql @@ -0,0 +1 @@ +SELECT arrayReduce('aggThrow(0.0001)', range(number % 10)) FROM system.numbers; -- { serverError 503 }