Reimplement any and anyLast

This commit is contained in:
Raúl Marín 2024-01-09 14:55:38 +01:00
parent 067aea6eb2
commit 227baac06a
5 changed files with 265 additions and 271 deletions

View File

@ -8,43 +8,268 @@
namespace DB
{
struct Settings;
//
//AggregateFunctionPtr createAggregateFunctionAny(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
//{
// return AggregateFunctionPtr(createAggregateFunctionSingleValue<AggregateFunctionsSingleValue, AggregateFunctionAnyData>(name, argument_types, parameters, settings));
//}
//
//
//AggregateFunctionPtr createAggregateFunctionAnyLast(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
//{
// return AggregateFunctionPtr(createAggregateFunctionSingleValue<AggregateFunctionsSingleValue, AggregateFunctionAnyLastData>(name, argument_types, parameters, settings));
//}
//
//
//AggregateFunctionPtr createAggregateFunctionAnyHeavy(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
//{
// return AggregateFunctionPtr(createAggregateFunctionSingleValue<AggregateFunctionsSingleValue, AggregateFunctionAnyHeavyData>(name, argument_types, parameters, settings));
//}
//
//}
//
void registerAggregateFunctionsAny(AggregateFunctionFactory &)
namespace
{
// AggregateFunctionProperties default_properties = {.returns_default_when_only_null = false, .is_order_dependent = true};
// AggregateFunctionProperties default_properties_for_respect_nulls
// = {.returns_default_when_only_null = false, .is_order_dependent = true, .is_window_function = true};
//
// factory.registerFunction("any", {createAggregateFunctionAny, default_properties});
// factory.registerAlias("any_value", "any", AggregateFunctionFactory::CaseInsensitive);
// factory.registerAlias("first_value", "any", AggregateFunctionFactory::CaseInsensitive);
////
// factory.registerFunction("anyLast", {createAggregateFunctionAnyLast, default_properties});
// factory.registerAlias("last_value", "anyLast", AggregateFunctionFactory::CaseInsensitive);
////
// factory.registerFunction("anyHeavy", {createAggregateFunctionAnyHeavy, default_properties});
//
// factory.registerNullsActionTransformation("any", "any_respect_nulls");
// factory.registerNullsActionTransformation("anyLast", "anyLast_respect_nulls");
template <typename Data>
class AggregateFunctionAny final : public IAggregateFunctionDataHelper<Data, AggregateFunctionAny<Data>>
{
private:
SerializationPtr serialization;
public:
explicit AggregateFunctionAny(const DataTypePtr & type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionAny<Data>>({type}, {}, type), serialization(type->getDefaultSerialization())
{
}
String getName() const override { return "any"; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
if (!this->data(place).has())
this->data(place).set(*columns[0], row_num, arena);
}
void addBatchSinglePlace(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
Arena * arena,
ssize_t if_argument_pos) const override
{
if (this->data(place).has() || row_begin >= row_end)
return;
if (if_argument_pos >= 0)
{
const auto & if_map = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
for (size_t i = row_begin; i < row_end; i++)
{
if (if_map.data()[i] != 0)
{
this->data(place).set(*columns[0], i, arena);
return;
}
}
}
else
{
this->data(place).set(*columns[0], row_begin, arena);
}
}
void addBatchSinglePlaceNotNull(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
const UInt8 * __restrict null_map,
Arena * arena,
ssize_t if_argument_pos) const override
{
if (this->data(place).has() || row_begin >= row_end)
return;
if (if_argument_pos >= 0)
{
const auto & if_map = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
for (size_t i = row_begin; i < row_end; i++)
{
if (if_map.data()[i] != 0 && null_map[i] == 0)
{
this->data(place).set(*columns[0], i, arena);
return;
}
}
}
else
{
for (size_t i = row_begin; i < row_end; i++)
{
if (null_map[i] == 0)
{
this->data(place).set(*columns[0], i, arena);
return;
}
}
}
}
void addManyDefaults(AggregateDataPtr __restrict place, const IColumn ** columns, size_t, Arena * arena) const override
{
if (!this->data(place).has())
this->data(place).set(*columns[0], 0, arena);
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
if (!this->data(place).has())
this->data(place).set(this->data(rhs), arena);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
this->data(place).write(buf, *serialization);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
this->data(place).read(buf, *serialization, arena);
}
bool allocatesMemoryInArena() const override { return Data::allocatesMemoryInArena(); }
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
this->data(place).insertResultInto(to);
}
};
AggregateFunctionPtr
createAggregateFunctionAny(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValue<AggregateFunctionAny>(name, argument_types, parameters, settings));
}
template <typename Data>
class AggregateFunctionAnyLast final : public IAggregateFunctionDataHelper<Data, AggregateFunctionAnyLast<Data>>
{
private:
SerializationPtr serialization;
public:
explicit AggregateFunctionAnyLast(const DataTypePtr & type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionAnyLast<Data>>({type}, {}, type)
, serialization(type->getDefaultSerialization())
{
}
String getName() const override { return "anyLast"; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
this->data(place).set(*columns[0], row_num, arena);
}
void addBatchSinglePlace(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
Arena * arena,
ssize_t if_argument_pos) const override
{
if (row_begin >= row_end)
return;
size_t batch_size = row_end - row_begin;
if (if_argument_pos >= 0)
{
const auto & if_map = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
for (size_t i = 0; i < batch_size; i++)
{
size_t pos = (row_end - 1) - i;
if (if_map.data()[pos] != 0)
{
this->data(place).set(*columns[0], pos, arena);
return;
}
}
}
else
{
this->data(place).set(*columns[0], row_end - 1, arena);
}
}
void addBatchSinglePlaceNotNull(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
const UInt8 * __restrict null_map,
Arena * arena,
ssize_t if_argument_pos) const override
{
if (row_begin >= row_end)
return;
size_t batch_size = row_end - row_begin;
if (if_argument_pos >= 0)
{
const auto & if_map = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
for (size_t i = 0; i < batch_size; i++)
{
size_t pos = (row_end - 1) - i;
if (if_map.data()[pos] != 0 && null_map[pos] == 0)
{
this->data(place).set(*columns[0], pos, arena);
return;
}
}
}
else
{
for (size_t i = 0; i < batch_size; i++)
{
size_t pos = (row_end - 1) - i;
if (null_map[pos] == 0)
{
this->data(place).set(*columns[0], pos, arena);
return;
}
}
}
}
void addManyDefaults(AggregateDataPtr __restrict place, const IColumn ** columns, size_t, Arena * arena) const override
{
this->data(place).set(*columns[0], 0, arena);
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
this->data(place).set(this->data(rhs), arena);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
this->data(place).write(buf, *serialization);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
this->data(place).read(buf, *serialization, arena);
}
bool allocatesMemoryInArena() const override { return Data::allocatesMemoryInArena(); }
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
this->data(place).insertResultInto(to);
}
};
AggregateFunctionPtr createAggregateFunctionAnyLast(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValue<AggregateFunctionAnyLast>(name, argument_types, parameters, settings));
}
}
void registerAggregateFunctionsAny(AggregateFunctionFactory & factory)
{
AggregateFunctionProperties default_properties = {.returns_default_when_only_null = false, .is_order_dependent = true};
factory.registerFunction("any", {createAggregateFunctionAny, default_properties});
factory.registerAlias("any_value", "any", AggregateFunctionFactory::CaseInsensitive);
factory.registerAlias("first_value", "any", AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("anyLast", {createAggregateFunctionAnyLast, default_properties});
factory.registerAlias("last_value", "anyLast", AggregateFunctionFactory::CaseInsensitive);
}
}

View File

@ -226,8 +226,8 @@ void registerAggregateFunctionsAnyRespectNulls(AggregateFunctionFactory & factor
factory.registerAlias("last_value_respect_nulls", "anyLast_respect_nulls", AggregateFunctionFactory::CaseInsensitive);
/// Must happen after registering any and anyLast
// factory.registerNullsActionTransformation("any", "any_respect_nulls");
// factory.registerNullsActionTransformation("anyLast", "anyLast_respect_nulls");
factory.registerNullsActionTransformation("any", "any_respect_nulls");
factory.registerNullsActionTransformation("anyLast", "anyLast_respect_nulls");
}
}

View File

@ -1,6 +1,7 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/HelpersMinMaxAny.h>
#include <Columns/ColumnNullable.h>
#include <DataTypes/DataTypeNullable.h>

View File

@ -1,5 +1,5 @@
#include <AggregateFunctions/SingleValueData.h>
#include <Columns/ColumnString.h>
#include <Common/findExtreme.h>
namespace DB

View File

@ -5,12 +5,6 @@
#include <AggregateFunctions/IAggregateFunction.h>
#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnVector.h>
#include <DataTypes/DataTypeNullable.h> /// TODO: Remove
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/IDataType.h>
#include <base/StringRef.h>
#include <Common/Arena.h>
#include <Common/assert_cast.h>
@ -427,230 +421,4 @@ public:
static bool allocatesMemoryInArena() { return false; }
};
/** Implement 'heavy hitters' algorithm.
* Selects most frequent value if its frequency is more than 50% in each thread of execution.
* Otherwise, selects some arbitrary value.
* http://www.cs.umd.edu/~samir/498/karp.pdf
*/
template <typename Data>
struct AggregateFunctionAnyHeavyData : Data
{
UInt64 counter = 0;
using Self = AggregateFunctionAnyHeavyData;
bool changeIfBetter(const IColumn & column, size_t row_num, Arena * arena)
{
if (this->isEqualTo(column, row_num))
{
++counter;
}
else if (counter == 0)
{
this->change(column, row_num, arena);
++counter;
return true;
}
else
--counter;
return false;
}
bool changeIfBetter(const Self & to, Arena * arena)
{
if (!to.has())
return false;
if (this->isEqualTo(to))
{
counter += to.counter;
}
else if ((!this->has() && to.has()) || counter < to.counter)
{
this->change(to, arena);
return true;
}
else
counter -= to.counter;
return false;
}
void addManyDefaults(const IColumn & column, size_t length, Arena * arena)
{
for (size_t i = 0; i < length; ++i)
changeIfBetter(column, 0, arena);
}
void write(WriteBuffer & buf, const ISerialization & serialization) const
{
Data::write(buf, serialization);
writeBinaryLittleEndian(counter, buf);
}
void read(ReadBuffer & buf, const ISerialization & serialization, Arena * arena)
{
Data::read(buf, serialization, arena);
readBinaryLittleEndian(counter, buf);
}
static const char * name() { return "anyHeavy"; }
#if USE_EMBEDDED_COMPILER
static constexpr bool is_compilable = false;
#endif
};
template <typename Data>
class AggregateFunctionsSingleValue : public IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data>>
{
static constexpr bool is_any = Data::is_any;
private:
SerializationPtr serialization;
public:
explicit AggregateFunctionsSingleValue(const DataTypePtr & type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data>>({type}, {}, createResultType(type))
, serialization(type->getDefaultSerialization())
{
if (StringRef(Data::name()) == StringRef("min") || StringRef(Data::name()) == StringRef("max"))
{
if (!type->isComparable())
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument of aggregate function {} because the values of that data type are not comparable",
type->getName(),
Data::name());
}
}
String getName() const override { return Data::name(); }
static DataTypePtr createResultType(const DataTypePtr & type_)
{
if constexpr (Data::result_is_nullable)
return makeNullable(type_);
return type_;
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
this->data(place).changeIfBetter(*columns[0], row_num, arena);
}
void addManyDefaults(AggregateDataPtr __restrict place, const IColumn ** columns, size_t length, Arena * arena) const override
{
this->data(place).addManyDefaults(*columns[0], length, arena);
}
void addBatchSinglePlace(
size_t row_begin, size_t row_end, AggregateDataPtr place, const IColumn ** columns, Arena * arena, ssize_t if_argument_pos)
const override
{
if constexpr (is_any)
if (this->data(place).has())
return;
if (if_argument_pos >= 0)
{
const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
for (size_t i = row_begin; i < row_end; ++i)
{
if (flags[i])
{
this->data(place).changeIfBetter(*columns[0], i, arena);
if constexpr (is_any)
break;
}
}
}
else
{
for (size_t i = row_begin; i < row_end; ++i)
{
this->data(place).changeIfBetter(*columns[0], i, arena);
if constexpr (is_any)
break;
}
}
}
void addBatchSinglePlaceNotNull( /// NOLINT
size_t row_begin,
size_t row_end,
AggregateDataPtr place,
const IColumn ** columns,
const UInt8 * null_map,
Arena * arena,
ssize_t if_argument_pos = -1) const override
{
if constexpr (is_any)
if (this->data(place).has())
return;
if (if_argument_pos >= 0)
{
const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
for (size_t i = row_begin; i < row_end; ++i)
{
if (!null_map[i] && flags[i])
{
this->data(place).changeIfBetter(*columns[0], i, arena);
if constexpr (is_any)
break;
}
}
}
else
{
for (size_t i = row_begin; i < row_end; ++i)
{
if (!null_map[i])
{
this->data(place).changeIfBetter(*columns[0], i, arena);
if constexpr (is_any)
break;
}
}
}
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
this->data(place).changeIfBetter(this->data(rhs), arena);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
this->data(place).write(buf, *serialization);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
this->data(place).read(buf, *serialization, arena);
}
bool allocatesMemoryInArena() const override { return Data::allocatesMemoryInArena(); }
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
this->data(place).insertResultInto(to);
}
AggregateFunctionPtr getOwnNullAdapter(
const AggregateFunctionPtr & original_function,
const DataTypes & /*arguments*/,
const Array & /*params*/,
const AggregateFunctionProperties & /*properties*/) const override
{
if (Data::result_is_nullable && !Data::should_skip_null_arguments)
return original_function;
return nullptr;
}
};
}