Merge pull request #8446 from ClickHouse/agg-throw

Added aggregate function `aggThrow`
This commit is contained in:
alexey-milovidov 2019-12-28 18:30:10 +03:00 committed by GitHub
commit 95b43aa5ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 140 additions and 5 deletions

View File

@ -0,0 +1,119 @@
#include <memory>
#include <random>
#include <DataTypes/DataTypesNumber.h>
#include <Common/thread_local_rng.h>
#include <IO/ReadBuffer.h>
#include <IO/WriteBuffer.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
namespace DB
{
namespace ErrorCodes
{
extern const int AGGREGATE_FUNCTION_THROW;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
namespace
{
struct AggregateFunctionThrowData
{
bool allocated;
AggregateFunctionThrowData() : allocated(true) {}
~AggregateFunctionThrowData()
{
volatile bool * allocated_ptr = &allocated;
if (*allocated_ptr)
*allocated_ptr = false;
else
abort();
}
};
/** Throw on creation with probability specified in parameter.
* It will check correct destruction of the state.
* This is intended to check for exception safety.
*/
class AggregateFunctionThrow final : public IAggregateFunctionDataHelper<AggregateFunctionThrowData, AggregateFunctionThrow>
{
private:
Float64 throw_probability;
public:
AggregateFunctionThrow(const DataTypes & argument_types_, const Array & parameters_, Float64 throw_probability_)
: IAggregateFunctionDataHelper(argument_types_, parameters_), throw_probability(throw_probability_) {}
String getName() const override
{
return "aggThrow";
}
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeUInt8>();
}
void create(AggregateDataPtr place) const override
{
if (std::uniform_real_distribution<>(0.0, 1.0)(thread_local_rng) <= throw_probability)
throw Exception("Aggregate function " + getName() + " has thrown exception successfully", ErrorCodes::AGGREGATE_FUNCTION_THROW);
new (place) Data;
}
void destroy(AggregateDataPtr place) const noexcept override
{
data(place).~Data();
}
void add(AggregateDataPtr, const IColumn **, size_t, Arena *) const override
{
}
void merge(AggregateDataPtr, ConstAggregateDataPtr, Arena *) const override
{
}
void serialize(ConstAggregateDataPtr, WriteBuffer & buf) const override
{
char c = 0;
buf.write(c);
}
void deserialize(AggregateDataPtr, ReadBuffer & buf, Arena *) const override
{
char c = 0;
buf.read(c);
}
void insertResultInto(ConstAggregateDataPtr, IColumn & to) const override
{
to.insertDefault();
}
};
}
void registerAggregateFunctionAggThrow(AggregateFunctionFactory & factory)
{
factory.registerFunction("aggThrow", [](const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
Float64 throw_probability = 1.0;
if (parameters.size() == 1)
throw_probability = parameters[0].safeGet<Float64>();
else if (parameters.size() > 1)
throw Exception("Aggregate function " + name + " cannot have more than one parameter", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
return std::make_shared<AggregateFunctionThrow>(argument_types, parameters, throw_probability);
});
}
}

View File

@ -100,7 +100,18 @@ public:
void create(AggregateDataPtr place) const override
{
for (size_t i = 0; i < total; ++i)
nested_function->create(place + i * size_of_data);
{
try
{
nested_function->create(place + i * size_of_data);
}
catch (...)
{
for (size_t j = 0; j < i; ++j)
nested_function->destroy(place + j * size_of_data);
throw;
}
}
}
void destroy(AggregateDataPtr place) const noexcept override

View File

@ -23,13 +23,13 @@ inline void assertNoParameters(const std::string & name, const Array & parameter
inline void assertUnary(const std::string & name, const DataTypes & argument_types)
{
if (argument_types.size() != 1)
throw Exception("Aggregate function " + name + " require single argument", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
throw Exception("Aggregate function " + name + " requires single argument", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
}
inline void assertBinary(const std::string & name, const DataTypes & argument_types)
{
if (argument_types.size() != 2)
throw Exception("Aggregate function " + name + " require two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
throw Exception("Aggregate function " + name + " requires two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
}
template<std::size_t maximal_arity>

View File

@ -213,7 +213,7 @@ protected:
public:
IAggregateFunctionDataHelper(const DataTypes & argument_types_, const Array & parameters_)
: IAggregateFunctionHelper<Derived>(argument_types_, parameters_) {}
: IAggregateFunctionHelper<Derived>(argument_types_, parameters_) {}
void create(AggregateDataPtr place) const override
{

View File

@ -42,6 +42,7 @@ void registerAggregateFunctions()
registerAggregateFunctionSimpleLinearRegression(factory);
registerAggregateFunctionMoving(factory);
registerAggregateFunctionCategoricalIV(factory);
registerAggregateFunctionAggThrow(factory);
}
{

View File

@ -34,6 +34,7 @@ void registerAggregateFunctionEntropy(AggregateFunctionFactory &);
void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory &);
void registerAggregateFunctionMoving(AggregateFunctionFactory &);
void registerAggregateFunctionCategoricalIV(AggregateFunctionFactory &);
void registerAggregateFunctionAggThrow(AggregateFunctionFactory &);
class AggregateFunctionCombinatorFactory;
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);

View File

@ -477,6 +477,7 @@ namespace ErrorCodes
extern const int CANNOT_CREATE_DICTIONARY_FROM_METADATA = 500;
extern const int CANNOT_CREATE_DATABASE = 501;
extern const int CANNOT_SIGQUEUE = 502;
extern const int AGGREGATE_FUNCTION_THROW = 503;
extern const int KEEPER_EXCEPTION = 999;
extern const int POCO_EXCEPTION = 1000;

View File

@ -168,7 +168,8 @@ void FunctionArrayReduce::executeImpl(Block & block, const ColumnNumbers & argum
}
catch (...)
{
agg_func.destroy(places[i]);
for (size_t j = 0; j < i; ++j)
agg_func.destroy(places[j]);
throw;
}
}

View File

@ -0,0 +1 @@
SELECT arrayReduce('aggThrow(0.0001)', range(number % 10)) FROM system.numbers; -- { serverError 503 }