Merge pull request #58640 from Algunenano/argmin_optimization

ArgMin / ArgMax / any / anyLast / anyHeavy optimization
This commit is contained in:
Raúl Marín 2024-03-01 12:45:58 +01:00 committed by GitHub
commit 885952a03b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 3568 additions and 2594 deletions

View File

@ -1,5 +1,5 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/HelpersMinMaxAny.h>
#include <AggregateFunctions/SingleValueData.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <base/defines.h>
@ -11,219 +11,347 @@ struct Settings;
namespace ErrorCodes
{
extern const int INCORRECT_DATA;
extern const int LOGICAL_ERROR;
extern const int NOT_IMPLEMENTED;
}
namespace
{
struct AggregateFunctionAnyRespectNullsData
template <typename Data>
class AggregateFunctionAny final : public IAggregateFunctionDataHelper<Data, AggregateFunctionAny<Data>>
{
enum Status : UInt8
{
NotSet = 1,
SetNull = 2,
SetOther = 3
};
Status status = Status::NotSet;
Field value;
bool isSet() const { return status != Status::NotSet; }
void setNull() { status = Status::SetNull; }
void setOther() { status = Status::SetOther; }
};
template <bool First>
class AggregateFunctionAnyRespectNulls final
: public IAggregateFunctionDataHelper<AggregateFunctionAnyRespectNullsData, AggregateFunctionAnyRespectNulls<First>>
{
public:
using Data = AggregateFunctionAnyRespectNullsData;
private:
SerializationPtr serialization;
const bool returns_nullable_type = false;
explicit AggregateFunctionAnyRespectNulls(const DataTypePtr & type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionAnyRespectNulls<First>>({type}, {}, type)
, serialization(type->getDefaultSerialization())
, returns_nullable_type(type->isNullable())
public:
explicit AggregateFunctionAny(const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionAny<Data>>(argument_types_, {}, argument_types_[0])
, serialization(this->result_type->getDefaultSerialization())
{
}
String getName() const override
{
if constexpr (First)
return "any_respect_nulls";
else
return "anyLast_respect_nulls";
}
String getName() const override { return "any"; }
bool allocatesMemoryInArena() const override { return false; }
void addNull(AggregateDataPtr __restrict place) const
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
chassert(returns_nullable_type);
auto & d = this->data(place);
if (First && d.isSet())
return;
d.setNull();
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{
if (columns[0]->isNullable())
{
if (columns[0]->isNullAt(row_num))
return addNull(place);
}
auto & d = this->data(place);
if (First && d.isSet())
return;
d.setOther();
columns[0]->get(row_num, d.value);
}
void addManyDefaults(AggregateDataPtr __restrict place, const IColumn ** columns, size_t, Arena * arena) const override
{
if (columns[0]->isNullable())
addNull(place);
else
add(place, columns, 0, arena);
if (!this->data(place).has())
this->data(place).set(*columns[0], row_num, arena);
}
void addBatchSinglePlace(
size_t row_begin, size_t row_end, AggregateDataPtr place, const IColumn ** columns, Arena * arena, ssize_t if_argument_pos)
const override
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
Arena * arena,
ssize_t if_argument_pos) const override
{
if (this->data(place).has() || row_begin >= row_end)
return;
if (if_argument_pos >= 0)
{
const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
size_t size = row_end - row_begin;
for (size_t i = 0; i < size; ++i)
const auto & if_map = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
for (size_t i = row_begin; i < row_end; i++)
{
size_t pos = First ? row_begin + i : row_end - 1 - i;
if (flags[pos])
if (if_map.data()[i] != 0)
{
add(place, columns, pos, arena);
break;
this->data(place).set(*columns[0], i, arena);
return;
}
}
}
else if (row_begin < row_end)
else
{
size_t pos = First ? row_begin : row_end - 1;
add(place, columns, pos, arena);
this->data(place).set(*columns[0], row_begin, arena);
}
}
void addBatchSinglePlaceNotNull(
size_t, size_t, AggregateDataPtr __restrict, const IColumn **, const UInt8 *, Arena *, ssize_t) const override
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
const UInt8 * __restrict null_map,
Arena * arena,
ssize_t if_argument_pos) const override
{
/// This should not happen since it means somebody else has preprocessed the data (NULLs or IFs) and might
/// have discarded values that we need (NULLs)
throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "AggregateFunctionAnyRespectNulls::addBatchSinglePlaceNotNull called");
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
{
auto & d = this->data(place);
if (First && d.isSet())
if (this->data(place).has() || row_begin >= row_end)
return;
auto & other = this->data(rhs);
if (other.isSet())
if (if_argument_pos >= 0)
{
d.status = other.status;
d.value = other.value;
const auto & if_map = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
for (size_t i = row_begin; i < row_end; i++)
{
if (if_map.data()[i] != 0 && null_map[i] == 0)
{
this->data(place).set(*columns[0], i, arena);
return;
}
}
}
else
{
for (size_t i = row_begin; i < row_end; i++)
{
if (null_map[i] == 0)
{
this->data(place).set(*columns[0], i, arena);
return;
}
}
}
}
void addManyDefaults(AggregateDataPtr __restrict place, const IColumn ** columns, size_t, Arena * arena) const override
{
if (!this->data(place).has())
this->data(place).set(*columns[0], 0, arena);
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
if (!this->data(place).has())
this->data(place).set(this->data(rhs), arena);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
auto & d = this->data(place);
UInt8 k = d.status;
writeBinaryLittleEndian<UInt8>(k, buf);
if (k == Data::Status::SetOther)
serialization->serializeBinary(d.value, buf, {});
this->data(place).write(buf, *serialization);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
auto & d = this->data(place);
UInt8 k = Data::Status::NotSet;
readBinaryLittleEndian<UInt8>(k, buf);
d.status = static_cast<Data::Status>(k);
if (d.status == Data::Status::NotSet)
return;
else if (d.status == Data::Status::SetNull)
{
if (!returns_nullable_type)
throw Exception(ErrorCodes::INCORRECT_DATA, "Incorrect type (NULL) in non-nullable {}State", getName());
return;
}
else if (d.status == Data::Status::SetOther)
serialization->deserializeBinary(d.value, buf, {});
else
throw Exception(ErrorCodes::INCORRECT_DATA, "Incorrect type ({}) in {}State", static_cast<Int8>(k), getName());
this->data(place).read(buf, *serialization, arena);
}
bool allocatesMemoryInArena() const override { return Data::allocatesMemoryInArena(); }
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
auto & d = this->data(place);
if (d.status == Data::Status::SetOther)
to.insert(d.value);
else
to.insertDefault();
this->data(place).insertResultInto(to);
}
AggregateFunctionPtr getOwnNullAdapter(
const AggregateFunctionPtr & original_function,
const DataTypes & /*arguments*/,
const Array & /*params*/,
const AggregateFunctionProperties & /*properties*/) const override
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
return original_function;
if constexpr (!Data::is_compilable)
return false;
else
return Data::isCompilable(*this->argument_types[0]);
}
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
if constexpr (Data::is_compilable)
Data::compileCreate(builder, aggregate_data_ptr);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override
{
if constexpr (Data::is_compilable)
Data::compileAny(builder, aggregate_data_ptr, arguments[0].value);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
void
compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
{
if constexpr (Data::is_compilable)
Data::compileAnyMerge(builder, aggregate_data_dst_ptr, aggregate_data_src_ptr);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
if constexpr (Data::is_compilable)
return Data::compileGetResult(builder, aggregate_data_ptr);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
#endif
};
template <bool First>
IAggregateFunction * createAggregateFunctionSingleValueRespectNulls(
const String & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
AggregateFunctionPtr
createAggregateFunctionAny(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
return new AggregateFunctionAnyRespectNulls<First>(argument_types[0]);
return AggregateFunctionPtr(
createAggregateFunctionSingleValue<AggregateFunctionAny, /* unary */ true>(name, argument_types, parameters, settings));
}
AggregateFunctionPtr createAggregateFunctionAny(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValue<AggregateFunctionsSingleValue, AggregateFunctionAnyData>(name, argument_types, parameters, settings));
}
AggregateFunctionPtr createAggregateFunctionAnyRespectNulls(
template <typename Data>
class AggregateFunctionAnyLast final : public IAggregateFunctionDataHelper<Data, AggregateFunctionAnyLast<Data>>
{
private:
SerializationPtr serialization;
public:
explicit AggregateFunctionAnyLast(const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionAnyLast<Data>>(argument_types_, {}, argument_types_[0])
, serialization(this->result_type->getDefaultSerialization())
{
}
String getName() const override { return "anyLast"; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
this->data(place).set(*columns[0], row_num, arena);
}
void addBatchSinglePlace(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
Arena * arena,
ssize_t if_argument_pos) const override
{
if (row_begin >= row_end)
return;
size_t batch_size = row_end - row_begin;
if (if_argument_pos >= 0)
{
const auto & if_map = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
for (size_t i = 0; i < batch_size; i++)
{
size_t pos = (row_end - 1) - i;
if (if_map.data()[pos] != 0)
{
this->data(place).set(*columns[0], pos, arena);
return;
}
}
}
else
{
this->data(place).set(*columns[0], row_end - 1, arena);
}
}
void addBatchSinglePlaceNotNull(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
const UInt8 * __restrict null_map,
Arena * arena,
ssize_t if_argument_pos) const override
{
if (row_begin >= row_end)
return;
size_t batch_size = row_end - row_begin;
if (if_argument_pos >= 0)
{
const auto & if_map = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
for (size_t i = 0; i < batch_size; i++)
{
size_t pos = (row_end - 1) - i;
if (if_map.data()[pos] != 0 && null_map[pos] == 0)
{
this->data(place).set(*columns[0], pos, arena);
return;
}
}
}
else
{
for (size_t i = 0; i < batch_size; i++)
{
size_t pos = (row_end - 1) - i;
if (null_map[pos] == 0)
{
this->data(place).set(*columns[0], pos, arena);
return;
}
}
}
}
void addManyDefaults(AggregateDataPtr __restrict place, const IColumn ** columns, size_t, Arena * arena) const override
{
this->data(place).set(*columns[0], 0, arena);
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
this->data(place).set(this->data(rhs), arena);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
this->data(place).write(buf, *serialization);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
this->data(place).read(buf, *serialization, arena);
}
bool allocatesMemoryInArena() const override { return Data::allocatesMemoryInArena(); }
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
this->data(place).insertResultInto(to);
}
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
if constexpr (!Data::is_compilable)
return false;
else
return Data::isCompilable(*this->argument_types[0]);
}
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
if constexpr (Data::is_compilable)
Data::compileCreate(builder, aggregate_data_ptr);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override
{
if constexpr (Data::is_compilable)
Data::compileAnyLast(builder, aggregate_data_ptr, arguments[0].value);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
void
compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
{
if constexpr (Data::is_compilable)
Data::compileAnyLastMerge(builder, aggregate_data_dst_ptr, aggregate_data_src_ptr);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
if constexpr (Data::is_compilable)
return Data::compileGetResult(builder, aggregate_data_ptr);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
#endif
};
AggregateFunctionPtr createAggregateFunctionAnyLast(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValueRespectNulls<true>(name, argument_types, parameters, settings));
}
AggregateFunctionPtr createAggregateFunctionAnyLast(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValue<AggregateFunctionsSingleValue, AggregateFunctionAnyLastData>(name, argument_types, parameters, settings));
}
AggregateFunctionPtr createAggregateFunctionAnyLastRespectNulls(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValueRespectNulls<false>(name, argument_types, parameters, settings));
}
AggregateFunctionPtr createAggregateFunctionAnyHeavy(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValue<AggregateFunctionsSingleValue, AggregateFunctionAnyHeavyData>(name, argument_types, parameters, settings));
return AggregateFunctionPtr(
createAggregateFunctionSingleValue<AggregateFunctionAnyLast, /* unary */ true>(name, argument_types, parameters, settings));
}
}
@ -231,27 +359,11 @@ AggregateFunctionPtr createAggregateFunctionAnyHeavy(const std::string & name, c
void registerAggregateFunctionsAny(AggregateFunctionFactory & factory)
{
AggregateFunctionProperties default_properties = {.returns_default_when_only_null = false, .is_order_dependent = true};
AggregateFunctionProperties default_properties_for_respect_nulls
= {.returns_default_when_only_null = false, .is_order_dependent = true, .is_window_function = true};
factory.registerFunction("any", {createAggregateFunctionAny, default_properties});
factory.registerAlias("any_value", "any", AggregateFunctionFactory::CaseInsensitive);
factory.registerAlias("first_value", "any", AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("any_respect_nulls", {createAggregateFunctionAnyRespectNulls, default_properties_for_respect_nulls});
factory.registerAlias("any_value_respect_nulls", "any_respect_nulls", AggregateFunctionFactory::CaseInsensitive);
factory.registerAlias("first_value_respect_nulls", "any_respect_nulls", AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("anyLast", {createAggregateFunctionAnyLast, default_properties});
factory.registerAlias("last_value", "anyLast", AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("anyLast_respect_nulls", {createAggregateFunctionAnyLastRespectNulls, default_properties_for_respect_nulls});
factory.registerAlias("last_value_respect_nulls", "anyLast_respect_nulls", AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("anyHeavy", {createAggregateFunctionAnyHeavy, default_properties});
factory.registerNullsActionTransformation("any", "any_respect_nulls");
factory.registerNullsActionTransformation("anyLast", "anyLast_respect_nulls");
}
}

View File

@ -0,0 +1,168 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/SingleValueData.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <base/defines.h>
namespace DB
{
struct Settings;
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
namespace
{
/** Implement 'heavy hitters' algorithm.
* Selects most frequent value if its frequency is more than 50% in each thread of execution.
* Otherwise, selects some arbitrary value.
* http://www.cs.umd.edu/~samir/498/karp.pdf
*/
struct AggregateFunctionAnyHeavyData
{
using Self = AggregateFunctionAnyHeavyData;
private:
SingleValueDataBaseMemoryBlock v_data;
UInt64 counter = 0;
public:
[[noreturn]] explicit AggregateFunctionAnyHeavyData()
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "AggregateFunctionAnyHeavyData initialized empty");
}
explicit AggregateFunctionAnyHeavyData(TypeIndex value_type) { generateSingleValueFromTypeIndex(value_type, v_data); }
~AggregateFunctionAnyHeavyData() { data().~SingleValueDataBase(); }
SingleValueDataBase & data() { return v_data.get(); }
const SingleValueDataBase & data() const { return v_data.get(); }
void add(const IColumn & column, size_t row_num, Arena * arena)
{
if (data().isEqualTo(column, row_num))
{
++counter;
}
else if (counter == 0)
{
data().set(column, row_num, arena);
++counter;
}
else
{
--counter;
}
}
void add(const Self & to, Arena * arena)
{
if (!to.data().has())
return;
if (data().isEqualTo(to.data()))
counter += to.counter;
else if (!data().has() || counter < to.counter)
data().set(to.data(), arena);
else
counter -= to.counter;
}
void addManyDefaults(const IColumn & column, size_t length, Arena * arena)
{
for (size_t i = 0; i < length; ++i)
add(column, 0, arena);
}
void write(WriteBuffer & buf, const ISerialization & serialization) const
{
data().write(buf, serialization);
writeBinaryLittleEndian(counter, buf);
}
void read(ReadBuffer & buf, const ISerialization & serialization, Arena * arena)
{
data().read(buf, serialization, arena);
readBinaryLittleEndian(counter, buf);
}
void insertResultInto(IColumn & to) const { data().insertResultInto(to); }
};
class AggregateFunctionAnyHeavy final : public IAggregateFunctionDataHelper<AggregateFunctionAnyHeavyData, AggregateFunctionAnyHeavy>
{
private:
SerializationPtr serialization;
const TypeIndex value_type_index;
public:
explicit AggregateFunctionAnyHeavy(const DataTypePtr & type)
: IAggregateFunctionDataHelper<AggregateFunctionAnyHeavyData, AggregateFunctionAnyHeavy>({type}, {}, type)
, serialization(type->getDefaultSerialization())
, value_type_index(WhichDataType(type).idx)
{
}
void create(AggregateDataPtr __restrict place) const override { new (place) AggregateFunctionAnyHeavyData(value_type_index); }
String getName() const override { return "anyHeavy"; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
this->data(place).add(*columns[0], row_num, arena);
}
void addManyDefaults(AggregateDataPtr __restrict place, const IColumn ** columns, size_t, Arena * arena) const override
{
this->data(place).addManyDefaults(*columns[0], 0, arena);
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
this->data(place).add(this->data(rhs), arena);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
this->data(place).write(buf, *serialization);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
this->data(place).read(buf, *serialization, arena);
}
bool allocatesMemoryInArena() const override { return singleValueTypeAllocatesMemoryInArena(value_type_index); }
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
this->data(place).insertResultInto(to);
}
};
AggregateFunctionPtr
createAggregateFunctionAnyHeavy(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
const DataTypePtr & res_type = argument_types[0];
return AggregateFunctionPtr(new AggregateFunctionAnyHeavy(res_type));
}
}
void registerAggregateFunctionAnyHeavy(AggregateFunctionFactory & factory)
{
AggregateFunctionProperties default_properties = {.returns_default_when_only_null = false, .is_order_dependent = true};
factory.registerFunction("anyHeavy", {createAggregateFunctionAnyHeavy, default_properties});
}
}

View File

@ -0,0 +1,235 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/SingleValueData.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <base/defines.h>
namespace DB
{
struct Settings;
namespace ErrorCodes
{
extern const int INCORRECT_DATA;
extern const int LOGICAL_ERROR;
}
namespace
{
struct AggregateFunctionAnyRespectNullsData
{
enum class Status : UInt8
{
NotSet = 1,
SetNull = 2,
SetOther = 3
};
Status status = Status::NotSet;
Field value;
bool isSet() const { return status != Status::NotSet; }
void setNull() { status = Status::SetNull; }
void setOther() { status = Status::SetOther; }
};
template <bool First>
class AggregateFunctionAnyRespectNulls final
: public IAggregateFunctionDataHelper<AggregateFunctionAnyRespectNullsData, AggregateFunctionAnyRespectNulls<First>>
{
public:
using Data = AggregateFunctionAnyRespectNullsData;
SerializationPtr serialization;
const bool returns_nullable_type = false;
explicit AggregateFunctionAnyRespectNulls(const DataTypePtr & type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionAnyRespectNulls<First>>({type}, {}, type)
, serialization(type->getDefaultSerialization())
, returns_nullable_type(type->isNullable())
{
}
String getName() const override
{
if constexpr (First)
return "any_respect_nulls";
else
return "anyLast_respect_nulls";
}
bool allocatesMemoryInArena() const override { return false; }
void addNull(AggregateDataPtr __restrict place) const
{
chassert(returns_nullable_type);
auto & d = this->data(place);
if (First && d.isSet())
return;
d.setNull();
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{
if (columns[0]->isNullable())
{
if (columns[0]->isNullAt(row_num))
return addNull(place);
}
auto & d = this->data(place);
if (First && d.isSet())
return;
d.setOther();
columns[0]->get(row_num, d.value);
}
void addManyDefaults(AggregateDataPtr __restrict place, const IColumn ** columns, size_t, Arena * arena) const override
{
if (columns[0]->isNullable())
addNull(place);
else
add(place, columns, 0, arena);
}
void addBatchSinglePlace(
size_t row_begin, size_t row_end, AggregateDataPtr place, const IColumn ** columns, Arena * arena, ssize_t if_argument_pos)
const override
{
if (if_argument_pos >= 0)
{
const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
size_t size = row_end - row_begin;
for (size_t i = 0; i < size; ++i)
{
size_t pos = First ? row_begin + i : row_end - 1 - i;
if (flags[pos])
{
add(place, columns, pos, arena);
break;
}
}
}
else if (row_begin < row_end)
{
size_t pos = First ? row_begin : row_end - 1;
add(place, columns, pos, arena);
}
}
void addBatchSinglePlaceNotNull(
size_t, size_t, AggregateDataPtr __restrict, const IColumn **, const UInt8 *, Arena *, ssize_t) const override
{
/// This should not happen since it means somebody else has preprocessed the data (NULLs or IFs) and might
/// have discarded values that we need (NULLs)
throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "AggregateFunctionAnyRespectNulls::addBatchSinglePlaceNotNull called");
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
{
auto & d = this->data(place);
if (First && d.isSet())
return;
auto & other = this->data(rhs);
if (other.isSet())
{
d.status = other.status;
d.value = other.value;
}
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
auto & d = this->data(place);
UInt8 k = static_cast<UInt8>(d.status);
writeBinaryLittleEndian<UInt8>(k, buf);
if (d.status == Data::Status::SetOther)
serialization->serializeBinary(d.value, buf, {});
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
{
auto & d = this->data(place);
UInt8 k = 0;
readBinaryLittleEndian<UInt8>(k, buf);
d.status = static_cast<Data::Status>(k);
if (d.status == Data::Status::NotSet)
return;
else if (d.status == Data::Status::SetNull)
{
if (!returns_nullable_type)
throw Exception(ErrorCodes::INCORRECT_DATA, "Incorrect type (NULL) in non-nullable {}State", getName());
return;
}
else if (d.status == Data::Status::SetOther)
{
serialization->deserializeBinary(d.value, buf, {});
return;
}
throw Exception(ErrorCodes::INCORRECT_DATA, "Incorrect type ({}) in {}State", static_cast<Int8>(k), getName());
}
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
auto & d = this->data(place);
if (d.status == Data::Status::SetOther)
to.insert(d.value);
else
to.insertDefault();
}
AggregateFunctionPtr getOwnNullAdapter(
const AggregateFunctionPtr & original_function,
const DataTypes & /*arguments*/,
const Array & /*params*/,
const AggregateFunctionProperties & /*properties*/) const override
{
return original_function;
}
};
template <bool First>
IAggregateFunction * createAggregateFunctionSingleValueRespectNulls(
const String & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
return new AggregateFunctionAnyRespectNulls<First>(argument_types[0]);
}
AggregateFunctionPtr createAggregateFunctionAnyRespectNulls(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValueRespectNulls<true>(name, argument_types, parameters, settings));
}
AggregateFunctionPtr createAggregateFunctionAnyLastRespectNulls(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValueRespectNulls<false>(name, argument_types, parameters, settings));
}
}
void registerAggregateFunctionsAnyRespectNulls(AggregateFunctionFactory & factory)
{
AggregateFunctionProperties default_properties_for_respect_nulls
= {.returns_default_when_only_null = false, .is_order_dependent = true, .is_window_function = true};
factory.registerFunction("any_respect_nulls", {createAggregateFunctionAnyRespectNulls, default_properties_for_respect_nulls});
factory.registerAlias("any_value_respect_nulls", "any_respect_nulls", AggregateFunctionFactory::CaseInsensitive);
factory.registerAlias("first_value_respect_nulls", "any_respect_nulls", AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("anyLast_respect_nulls", {createAggregateFunctionAnyLastRespectNulls, default_properties_for_respect_nulls});
factory.registerAlias("last_value_respect_nulls", "anyLast_respect_nulls", AggregateFunctionFactory::CaseInsensitive);
/// Must happen after registering any and anyLast
factory.registerNullsActionTransformation("any", "any_respect_nulls");
factory.registerNullsActionTransformation("anyLast", "anyLast_respect_nulls");
}
}

View File

@ -1,107 +0,0 @@
#pragma once
#include <base/StringRef.h>
#include <DataTypes/IDataType.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <AggregateFunctions/AggregateFunctionMinMaxAny.h> // SingleValueDataString used in embedded compiler
namespace DB
{
struct Settings;
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int CORRUPTED_DATA;
}
/// For possible values for template parameters, see 'AggregateFunctionMinMaxAny.h'.
template <typename ResultData, typename ValueData>
struct AggregateFunctionArgMinMaxData
{
using ResultData_t = ResultData;
using ValueData_t = ValueData;
ResultData result; // the argument at which the minimum/maximum value is reached.
ValueData value; // value for which the minimum/maximum is calculated.
static bool allocatesMemoryInArena()
{
return ResultData::allocatesMemoryInArena() || ValueData::allocatesMemoryInArena();
}
};
/// Returns the first arg value found for the minimum/maximum value. Example: argMax(arg, value).
template <typename Data>
class AggregateFunctionArgMinMax final : public IAggregateFunctionDataHelper<Data, AggregateFunctionArgMinMax<Data>>
{
private:
const DataTypePtr & type_val;
const SerializationPtr serialization_res;
const SerializationPtr serialization_val;
using Base = IAggregateFunctionDataHelper<Data, AggregateFunctionArgMinMax<Data>>;
public:
AggregateFunctionArgMinMax(const DataTypePtr & type_res_, const DataTypePtr & type_val_)
: Base({type_res_, type_val_}, {}, type_res_)
, type_val(this->argument_types[1])
, serialization_res(type_res_->getDefaultSerialization())
, serialization_val(type_val->getDefaultSerialization())
{
if (!type_val->isComparable())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of second argument of "
"aggregate function {} because the values of that data type are not comparable",
type_val->getName(), getName());
}
String getName() const override
{
return StringRef(Data::ValueData_t::name()) == StringRef("min") ? "argMin" : "argMax";
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
if (this->data(place).value.changeIfBetter(*columns[1], row_num, arena))
this->data(place).result.change(*columns[0], row_num, arena);
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
if (this->data(place).value.changeIfBetter(this->data(rhs).value, arena))
this->data(place).result.change(this->data(rhs).result, arena);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
this->data(place).result.write(buf, *serialization_res);
this->data(place).value.write(buf, *serialization_val);
}
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
this->data(place).result.read(buf, *serialization_res, arena);
this->data(place).value.read(buf, *serialization_val, arena);
if (unlikely(this->data(place).value.has() != this->data(place).result.has()))
throw Exception(
ErrorCodes::CORRUPTED_DATA,
"Invalid state of the aggregate function {}: has_value ({}) != has_result ({})",
getName(),
this->data(place).value.has(),
this->data(place).result.has());
}
bool allocatesMemoryInArena() const override
{
return Data::allocatesMemoryInArena();
}
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
this->data(place).result.insertResultInto(to);
}
};
}

View File

@ -1,238 +0,0 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/HelpersMinMaxAny.h>
#include <Common/Concepts.h>
#include <Common/findExtreme.h>
namespace DB
{
struct Settings;
namespace
{
template <typename Data>
class AggregateFunctionsSingleValueMax final : public AggregateFunctionsSingleValue<Data>
{
using Parent = AggregateFunctionsSingleValue<Data>;
public:
explicit AggregateFunctionsSingleValueMax(const DataTypePtr & type) : Parent(type) { }
/// Specializations for native numeric types
void addBatchSinglePlace(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
Arena * arena,
ssize_t if_argument_pos) const override;
void addBatchSinglePlaceNotNull(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
const UInt8 * __restrict null_map,
Arena * arena,
ssize_t if_argument_pos) const override;
};
// NOLINTBEGIN(bugprone-macro-parentheses)
#define SPECIALIZE(TYPE) \
template <> \
void AggregateFunctionsSingleValueMax<typename DB::AggregateFunctionMaxData<SingleValueDataFixed<TYPE>>>::addBatchSinglePlace( \
size_t row_begin, \
size_t row_end, \
AggregateDataPtr __restrict place, \
const IColumn ** __restrict columns, \
Arena *, \
ssize_t if_argument_pos) const \
{ \
const auto & column = assert_cast<const DB::AggregateFunctionMaxData<SingleValueDataFixed<TYPE>>::ColVecType &>(*columns[0]); \
std::optional<TYPE> opt; \
if (if_argument_pos >= 0) \
{ \
const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData(); \
opt = findExtremeMaxIf(column.getData().data(), flags.data(), row_begin, row_end); \
} \
else \
opt = findExtremeMax(column.getData().data(), row_begin, row_end); \
if (opt.has_value()) \
this->data(place).changeIfGreater(opt.value()); \
}
// NOLINTEND(bugprone-macro-parentheses)
FOR_BASIC_NUMERIC_TYPES(SPECIALIZE)
#undef SPECIALIZE
template <typename Data>
void AggregateFunctionsSingleValueMax<Data>::addBatchSinglePlace(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
Arena * arena,
ssize_t if_argument_pos) const
{
if constexpr (!is_any_of<typename Data::Impl, SingleValueDataString, SingleValueDataGeneric>)
{
/// Leave other numeric types (large integers, decimals, etc) to keep doing the comparison as it's
/// faster than doing a permutation
return Parent::addBatchSinglePlace(row_begin, row_end, place, columns, arena, if_argument_pos);
}
constexpr int nan_null_direction_hint = -1;
auto const & column = *columns[0];
if (if_argument_pos >= 0)
{
size_t index = row_begin;
const auto & if_flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
while (if_flags[index] == 0 && index < row_end)
index++;
if (index >= row_end)
return;
for (size_t i = index + 1; i < row_end; i++)
{
if ((if_flags[i] != 0) && (column.compareAt(i, index, column, nan_null_direction_hint) > 0))
index = i;
}
this->data(place).changeIfGreater(column, index, arena);
}
else
{
if (row_begin >= row_end)
return;
/// TODO: Introduce row_begin and row_end to getPermutation
if (row_begin != 0 || row_end != column.size())
{
size_t index = row_begin;
for (size_t i = index + 1; i < row_end; i++)
{
if (column.compareAt(i, index, column, nan_null_direction_hint) > 0)
index = i;
}
this->data(place).changeIfGreater(column, index, arena);
}
else
{
constexpr IColumn::PermutationSortDirection direction = IColumn::PermutationSortDirection::Descending;
constexpr IColumn::PermutationSortStability stability = IColumn::PermutationSortStability::Unstable;
IColumn::Permutation permutation;
constexpr UInt64 limit = 1;
column.getPermutation(direction, stability, limit, nan_null_direction_hint, permutation);
this->data(place).changeIfGreater(column, permutation[0], arena);
}
}
}
// NOLINTBEGIN(bugprone-macro-parentheses)
#define SPECIALIZE(TYPE) \
template <> \
void AggregateFunctionsSingleValueMax<typename DB::AggregateFunctionMaxData<SingleValueDataFixed<TYPE>>>::addBatchSinglePlaceNotNull( \
size_t row_begin, \
size_t row_end, \
AggregateDataPtr __restrict place, \
const IColumn ** __restrict columns, \
const UInt8 * __restrict null_map, \
Arena *, \
ssize_t if_argument_pos) const \
{ \
const auto & column = assert_cast<const DB::AggregateFunctionMaxData<SingleValueDataFixed<TYPE>>::ColVecType &>(*columns[0]); \
std::optional<TYPE> opt; \
if (if_argument_pos >= 0) \
{ \
const auto * if_flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData().data(); \
auto final_flags = std::make_unique<UInt8[]>(row_end); \
for (size_t i = row_begin; i < row_end; ++i) \
final_flags[i] = (!null_map[i]) & !!if_flags[i]; \
opt = findExtremeMaxIf(column.getData().data(), final_flags.get(), row_begin, row_end); \
} \
else \
opt = findExtremeMaxNotNull(column.getData().data(), null_map, row_begin, row_end); \
if (opt.has_value()) \
this->data(place).changeIfGreater(opt.value()); \
}
// NOLINTEND(bugprone-macro-parentheses)
FOR_BASIC_NUMERIC_TYPES(SPECIALIZE)
#undef SPECIALIZE
template <typename Data>
void AggregateFunctionsSingleValueMax<Data>::addBatchSinglePlaceNotNull(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
const UInt8 * __restrict null_map,
Arena * arena,
ssize_t if_argument_pos) const
{
if constexpr (!is_any_of<typename Data::Impl, SingleValueDataString, SingleValueDataGeneric>)
{
/// Leave other numeric types (large integers, decimals, etc) to keep doing the comparison as it's
/// faster than doing a permutation
return Parent::addBatchSinglePlaceNotNull(row_begin, row_end, place, columns, null_map, arena, if_argument_pos);
}
constexpr int nan_null_direction_hint = -1;
auto const & column = *columns[0];
if (if_argument_pos >= 0)
{
size_t index = row_begin;
const auto & if_flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
while ((if_flags[index] == 0 || null_map[index] != 0) && (index < row_end))
index++;
if (index >= row_end)
return;
for (size_t i = index + 1; i < row_end; i++)
{
if ((if_flags[i] != 0) && (null_map[i] == 0) && (column.compareAt(i, index, column, nan_null_direction_hint) > 0))
index = i;
}
this->data(place).changeIfGreater(column, index, arena);
}
else
{
size_t index = row_begin;
while ((null_map[index] != 0) && (index < row_end))
index++;
if (index >= row_end)
return;
for (size_t i = index + 1; i < row_end; i++)
{
if ((null_map[i] == 0) && (column.compareAt(i, index, column, nan_null_direction_hint) > 0))
index = i;
}
this->data(place).changeIfGreater(column, index, arena);
}
}
AggregateFunctionPtr createAggregateFunctionMax(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValue<AggregateFunctionsSingleValueMax, AggregateFunctionMaxData>(name, argument_types, parameters, settings));
}
AggregateFunctionPtr createAggregateFunctionArgMax(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionArgMinMax<AggregateFunctionMaxData>(name, argument_types, parameters, settings));
}
}
void registerAggregateFunctionsMax(AggregateFunctionFactory & factory)
{
factory.registerFunction("max", createAggregateFunctionMax, AggregateFunctionFactory::CaseInsensitive);
/// The functions below depend on the order of data.
AggregateFunctionProperties properties = { .returns_default_when_only_null = false, .is_order_dependent = true };
factory.registerFunction("argMax", { createAggregateFunctionArgMax, properties });
}
}

View File

@ -1,240 +0,0 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/HelpersMinMaxAny.h>
#include <Common/Concepts.h>
#include <Common/findExtreme.h>
namespace DB
{
struct Settings;
namespace
{
template <typename Data>
class AggregateFunctionsSingleValueMin final : public AggregateFunctionsSingleValue<Data>
{
using Parent = AggregateFunctionsSingleValue<Data>;
public:
explicit AggregateFunctionsSingleValueMin(const DataTypePtr & type) : Parent(type) { }
/// Specializations for native numeric types
void addBatchSinglePlace(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
Arena * arena,
ssize_t if_argument_pos) const override;
void addBatchSinglePlaceNotNull(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
const UInt8 * __restrict null_map,
Arena * arena,
ssize_t if_argument_pos) const override;
};
// NOLINTBEGIN(bugprone-macro-parentheses)
#define SPECIALIZE(TYPE) \
template <> \
void AggregateFunctionsSingleValueMin<typename DB::AggregateFunctionMinData<SingleValueDataFixed<TYPE>>>::addBatchSinglePlace( \
size_t row_begin, \
size_t row_end, \
AggregateDataPtr __restrict place, \
const IColumn ** __restrict columns, \
Arena *, \
ssize_t if_argument_pos) const \
{ \
const auto & column = assert_cast<const DB::AggregateFunctionMinData<SingleValueDataFixed<TYPE>>::ColVecType &>(*columns[0]); \
std::optional<TYPE> opt; \
if (if_argument_pos >= 0) \
{ \
const auto & flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData(); \
opt = findExtremeMinIf(column.getData().data(), flags.data(), row_begin, row_end); \
} \
else \
opt = findExtremeMin(column.getData().data(), row_begin, row_end); \
if (opt.has_value()) \
this->data(place).changeIfLess(opt.value()); \
}
// NOLINTEND(bugprone-macro-parentheses)
FOR_BASIC_NUMERIC_TYPES(SPECIALIZE)
#undef SPECIALIZE
template <typename Data>
void AggregateFunctionsSingleValueMin<Data>::addBatchSinglePlace(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
Arena * arena,
ssize_t if_argument_pos) const
{
if constexpr (!is_any_of<typename Data::Impl, SingleValueDataString, SingleValueDataGeneric>)
{
/// Leave other numeric types (large integers, decimals, etc) to keep doing the comparison as it's
/// faster than doing a permutation
return Parent::addBatchSinglePlace(row_begin, row_end, place, columns, arena, if_argument_pos);
}
constexpr int nan_null_direction_hint = 1;
auto const & column = *columns[0];
if (if_argument_pos >= 0)
{
size_t index = row_begin;
const auto & if_flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
while (if_flags[index] == 0 && index < row_end)
index++;
if (index >= row_end)
return;
for (size_t i = index + 1; i < row_end; i++)
{
if ((if_flags[i] != 0) && (column.compareAt(i, index, column, nan_null_direction_hint) < 0))
index = i;
}
this->data(place).changeIfLess(column, index, arena);
}
else
{
if (row_begin >= row_end)
return;
/// TODO: Introduce row_begin and row_end to getPermutation
if (row_begin != 0 || row_end != column.size())
{
size_t index = row_begin;
for (size_t i = index + 1; i < row_end; i++)
{
if (column.compareAt(i, index, column, nan_null_direction_hint) < 0)
index = i;
}
this->data(place).changeIfLess(column, index, arena);
}
else
{
constexpr IColumn::PermutationSortDirection direction = IColumn::PermutationSortDirection::Ascending;
constexpr IColumn::PermutationSortStability stability = IColumn::PermutationSortStability::Unstable;
IColumn::Permutation permutation;
constexpr UInt64 limit = 1;
column.getPermutation(direction, stability, limit, nan_null_direction_hint, permutation);
this->data(place).changeIfLess(column, permutation[0], arena);
}
}
}
// NOLINTBEGIN(bugprone-macro-parentheses)
#define SPECIALIZE(TYPE) \
template <> \
void AggregateFunctionsSingleValueMin<typename DB::AggregateFunctionMinData<SingleValueDataFixed<TYPE>>>::addBatchSinglePlaceNotNull( \
size_t row_begin, \
size_t row_end, \
AggregateDataPtr __restrict place, \
const IColumn ** __restrict columns, \
const UInt8 * __restrict null_map, \
Arena *, \
ssize_t if_argument_pos) const \
{ \
const auto & column = assert_cast<const DB::AggregateFunctionMinData<SingleValueDataFixed<TYPE>>::ColVecType &>(*columns[0]); \
std::optional<TYPE> opt; \
if (if_argument_pos >= 0) \
{ \
const auto * if_flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData().data(); \
auto final_flags = std::make_unique<UInt8[]>(row_end); \
for (size_t i = row_begin; i < row_end; ++i) \
final_flags[i] = (!null_map[i]) & !!if_flags[i]; \
opt = findExtremeMinIf(column.getData().data(), final_flags.get(), row_begin, row_end); \
} \
else \
opt = findExtremeMinNotNull(column.getData().data(), null_map, row_begin, row_end); \
if (opt.has_value()) \
this->data(place).changeIfLess(opt.value()); \
}
// NOLINTEND(bugprone-macro-parentheses)
FOR_BASIC_NUMERIC_TYPES(SPECIALIZE)
#undef SPECIALIZE
template <typename Data>
void AggregateFunctionsSingleValueMin<Data>::addBatchSinglePlaceNotNull(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
const UInt8 * __restrict null_map,
Arena * arena,
ssize_t if_argument_pos) const
{
if constexpr (!is_any_of<typename Data::Impl, SingleValueDataString, SingleValueDataGeneric>)
{
/// Leave other numeric types (large integers, decimals, etc) to keep doing the comparison as it's
/// faster than doing a permutation
return Parent::addBatchSinglePlaceNotNull(row_begin, row_end, place, columns, null_map, arena, if_argument_pos);
}
constexpr int nan_null_direction_hint = 1;
auto const & column = *columns[0];
if (if_argument_pos >= 0)
{
size_t index = row_begin;
const auto & if_flags = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
while ((if_flags[index] == 0 || null_map[index] != 0) && (index < row_end))
index++;
if (index >= row_end)
return;
for (size_t i = index + 1; i < row_end; i++)
{
if ((if_flags[i] != 0) && (null_map[index] == 0) && (column.compareAt(i, index, column, nan_null_direction_hint) < 0))
index = i;
}
this->data(place).changeIfLess(column, index, arena);
}
else
{
size_t index = row_begin;
while ((null_map[index] != 0) && (index < row_end))
index++;
if (index >= row_end)
return;
for (size_t i = index + 1; i < row_end; i++)
{
if ((null_map[i] == 0) && (column.compareAt(i, index, column, nan_null_direction_hint) < 0))
index = i;
}
this->data(place).changeIfLess(column, index, arena);
}
}
AggregateFunctionPtr createAggregateFunctionMin(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValue<AggregateFunctionsSingleValueMin, AggregateFunctionMinData>(
name, argument_types, parameters, settings));
}
AggregateFunctionPtr createAggregateFunctionArgMin(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionArgMinMax<AggregateFunctionMinData>(name, argument_types, parameters, settings));
}
}
void registerAggregateFunctionsMin(AggregateFunctionFactory & factory)
{
factory.registerFunction("min", createAggregateFunctionMin, AggregateFunctionFactory::CaseInsensitive);
/// The functions below depend on the order of data.
AggregateFunctionProperties properties = { .returns_default_when_only_null = false, .is_order_dependent = true };
factory.registerFunction("argMin", { createAggregateFunctionArgMin, properties });
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,19 +1,193 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/HelpersMinMaxAny.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include "registerAggregateFunctions.h"
#include <AggregateFunctions/SingleValueData.h>
#include <Columns/ColumnNullable.h>
#include <DataTypes/DataTypeNullable.h>
namespace DB
{
struct Settings;
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
namespace
{
/** The aggregate function 'singleValueOrNull' is used to implement subquery operators,
* such as x = ALL (SELECT ...)
* It checks if there is only one unique non-NULL value in the data.
* If there is only one unique value - returns it.
* If there are zero or at least two distinct values - returns NULL.
*/
AggregateFunctionPtr createAggregateFunctionSingleValueOrNull(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
struct AggregateFunctionSingleValueOrNullData
{
return AggregateFunctionPtr(createAggregateFunctionSingleValue<AggregateFunctionsSingleValue, AggregateFunctionSingleValueOrNullData>(name, argument_types, parameters, settings));
using Self = AggregateFunctionSingleValueOrNullData;
private:
SingleValueDataBaseMemoryBlock v_data;
bool first_value = true;
bool is_null = false;
public:
[[noreturn]] explicit AggregateFunctionSingleValueOrNullData()
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "AggregateFunctionSingleValueOrNullData initialized empty");
}
explicit AggregateFunctionSingleValueOrNullData(TypeIndex value_type) { generateSingleValueFromTypeIndex(value_type, v_data); }
~AggregateFunctionSingleValueOrNullData() { data().~SingleValueDataBase(); }
SingleValueDataBase & data() { return v_data.get(); }
const SingleValueDataBase & data() const { return v_data.get(); }
bool isNull() const { return is_null; }
void add(const IColumn & column, size_t row_num, Arena * arena)
{
if (first_value)
{
first_value = false;
data().set(column, row_num, arena);
}
else if (!data().isEqualTo(column, row_num))
{
is_null = true;
}
}
void add(const Self & to, Arena * arena)
{
if (!to.data().has())
return;
if (first_value && !to.first_value)
{
first_value = false;
data().set(to.data(), arena);
}
else if (!data().isEqualTo(to.data()))
{
is_null = true;
}
}
/// TODO: Methods write and read lose data (first_value and is_null)
/// Fixing it requires a breaking change (but it's probably necessary)
void write(WriteBuffer & buf, const ISerialization & serialization) const { data().write(buf, serialization); }
void read(ReadBuffer & buf, const ISerialization & serialization, Arena * arena) { data().read(buf, serialization, arena); }
void insertResultInto(IColumn & to) const
{
if (is_null || first_value)
{
to.insertDefault();
}
else
{
ColumnNullable & col = typeid_cast<ColumnNullable &>(to);
col.getNullMapColumn().insertDefault();
data().insertResultInto(col.getNestedColumn());
}
}
};
class AggregateFunctionSingleValueOrNull final
: public IAggregateFunctionDataHelper<AggregateFunctionSingleValueOrNullData, AggregateFunctionSingleValueOrNull>
{
private:
SerializationPtr serialization;
const TypeIndex value_type_index;
public:
explicit AggregateFunctionSingleValueOrNull(const DataTypePtr & type)
: IAggregateFunctionDataHelper<AggregateFunctionSingleValueOrNullData, AggregateFunctionSingleValueOrNull>(
{type}, {}, makeNullable(type))
, serialization(type->getDefaultSerialization())
, value_type_index(WhichDataType(type).idx)
{
}
void create(AggregateDataPtr __restrict place) const override { new (place) AggregateFunctionSingleValueOrNullData(value_type_index); }
String getName() const override { return "singleValueOrNull"; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
this->data(place).add(*columns[0], row_num, arena);
}
void addBatchSinglePlace(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
Arena * arena,
ssize_t if_argument_pos) const override
{
if (this->data(place).isNull())
return;
IAggregateFunctionDataHelper<Data, AggregateFunctionSingleValueOrNull>::addBatchSinglePlace(
row_begin, row_end, place, columns, arena, if_argument_pos);
}
void addBatchSinglePlaceNotNull(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
const UInt8 * __restrict null_map,
Arena * arena,
ssize_t if_argument_pos) const override
{
if (this->data(place).isNull())
return;
IAggregateFunctionDataHelper<Data, AggregateFunctionSingleValueOrNull>::addBatchSinglePlaceNotNull(
row_begin, row_end, place, columns, null_map, arena, if_argument_pos);
}
void addManyDefaults(AggregateDataPtr __restrict place, const IColumn ** columns, size_t, Arena * arena) const override
{
this->data(place).add(*columns[0], 0, arena);
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
this->data(place).add(this->data(rhs), arena);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
this->data(place).write(buf, *serialization);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
this->data(place).read(buf, *serialization, arena);
}
bool allocatesMemoryInArena() const override { return singleValueTypeAllocatesMemoryInArena(value_type_index); }
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
this->data(place).insertResultInto(to);
}
};
AggregateFunctionPtr createAggregateFunctionSingleValueOrNull(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
const DataTypePtr & res_type = argument_types[0];
return AggregateFunctionPtr(new AggregateFunctionSingleValueOrNull(res_type));
}
}
@ -22,6 +196,4 @@ void registerAggregateFunctionSingleValueOrNull(AggregateFunctionFactory & facto
{
factory.registerFunction("singleValueOrNull", createAggregateFunctionSingleValueOrNull);
}
}

View File

@ -0,0 +1,236 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <AggregateFunctions/SingleValueData.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/IDataType.h>
namespace DB
{
struct Settings;
namespace ErrorCodes
{
extern const int CORRUPTED_DATA;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int LOGICAL_ERROR;
}
namespace
{
template <class ValueType>
struct AggregateFunctionArgMinMaxData
{
private:
SingleValueDataBaseMemoryBlock result_data;
ValueType value_data;
public:
SingleValueDataBase & result() { return result_data.get(); }
const SingleValueDataBase & result() const { return result_data.get(); }
ValueType & value() { return value_data; }
const ValueType & value() const { return value_data; }
[[noreturn]] explicit AggregateFunctionArgMinMaxData()
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "AggregateFunctionArgMinMaxData initialized empty");
}
explicit AggregateFunctionArgMinMaxData(TypeIndex result_type) : value_data()
{
generateSingleValueFromTypeIndex(result_type, result_data);
}
~AggregateFunctionArgMinMaxData() { result().~SingleValueDataBase(); }
};
static_assert(
sizeof(AggregateFunctionArgMinMaxData<Int8>) <= 2 * SingleValueDataBase::MAX_STORAGE_SIZE,
"Incorrect size of AggregateFunctionArgMinMaxData struct");
/// Returns the first arg value found for the minimum/maximum value. Example: argMin(arg, value).
template <typename ValueData, bool isMin>
class AggregateFunctionArgMinMax final
: public IAggregateFunctionDataHelper<AggregateFunctionArgMinMaxData<ValueData>, AggregateFunctionArgMinMax<ValueData, isMin>>
{
private:
const DataTypePtr & type_val;
const SerializationPtr serialization_res;
const SerializationPtr serialization_val;
const TypeIndex result_type_index;
using Base = IAggregateFunctionDataHelper<AggregateFunctionArgMinMaxData<ValueData>, AggregateFunctionArgMinMax<ValueData, isMin>>;
public:
explicit AggregateFunctionArgMinMax(const DataTypes & argument_types_)
: Base(argument_types_, {}, argument_types_[0])
, type_val(this->argument_types[1])
, serialization_res(this->argument_types[0]->getDefaultSerialization())
, serialization_val(this->argument_types[1]->getDefaultSerialization())
, result_type_index(WhichDataType(this->argument_types[0]).idx)
{
if (!type_val->isComparable())
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of second argument of aggregate function {} because the values of that data type are not comparable",
type_val->getName(),
getName());
}
void create(AggregateDataPtr __restrict place) const override /// NOLINT
{
new (place) AggregateFunctionArgMinMaxData<ValueData>(result_type_index);
}
String getName() const override
{
if constexpr (isMin)
return "argMin";
else
return "argMax";
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
if constexpr (isMin)
{
if (this->data(place).value().setIfSmaller(*columns[1], row_num, arena))
this->data(place).result().set(*columns[0], row_num, arena);
}
else
{
if (this->data(place).value().setIfGreater(*columns[1], row_num, arena))
this->data(place).result().set(*columns[0], row_num, arena);
}
}
void addManyDefaults(AggregateDataPtr __restrict place, const IColumn ** columns, size_t, Arena * arena) const override
{
add(place, columns, 0, arena);
}
void addBatchSinglePlace(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
Arena * arena,
ssize_t if_argument_pos) const override
{
std::optional<size_t> idx;
if (if_argument_pos >= 0)
{
const auto & if_map = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
if constexpr (isMin)
idx = this->data(place).value().getSmallestIndexNotNullIf(*columns[1], nullptr, if_map.data(), row_begin, row_end);
else
idx = this->data(place).value().getGreatestIndexNotNullIf(*columns[1], nullptr, if_map.data(), row_begin, row_end);
}
else
{
if constexpr (isMin)
idx = this->data(place).value().getSmallestIndex(*columns[1], row_begin, row_end);
else
idx = this->data(place).value().getGreatestIndex(*columns[1], row_begin, row_end);
}
if (idx)
add(place, columns, *idx, arena);
}
void addBatchSinglePlaceNotNull(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
const UInt8 * __restrict null_map,
Arena * arena,
ssize_t if_argument_pos) const override
{
std::optional<size_t> idx;
if (if_argument_pos >= 0)
{
const auto & if_map = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
if constexpr (isMin)
idx = this->data(place).value().getSmallestIndexNotNullIf(*columns[1], null_map, if_map.data(), row_begin, row_end);
else
idx = this->data(place).value().getGreatestIndexNotNullIf(*columns[1], null_map, if_map.data(), row_begin, row_end);
}
else
{
if constexpr (isMin)
idx = this->data(place).value().getSmallestIndexNotNullIf(*columns[1], null_map, nullptr, row_begin, row_end);
else
idx = this->data(place).value().getGreatestIndexNotNullIf(*columns[1], null_map, nullptr, row_begin, row_end);
}
if (idx)
add(place, columns, *idx, arena);
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
if constexpr (isMin)
{
if (this->data(place).value().setIfSmaller(this->data(rhs).value(), arena))
this->data(place).result().set(this->data(rhs).result(), arena);
}
else
{
if (this->data(place).value().setIfGreater(this->data(rhs).value(), arena))
this->data(place).result().set(this->data(rhs).result(), arena);
}
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
this->data(place).result().write(buf, *serialization_res);
this->data(place).value().write(buf, *serialization_val);
}
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
this->data(place).result().read(buf, *serialization_res, arena);
this->data(place).value().read(buf, *serialization_val, arena);
if (unlikely(this->data(place).value().has() != this->data(place).result().has()))
throw Exception(
ErrorCodes::CORRUPTED_DATA,
"Invalid state of the aggregate function {}: has_value ({}) != has_result ({})",
getName(),
this->data(place).value().has(),
this->data(place).result().has());
}
bool allocatesMemoryInArena() const override
{
return singleValueTypeAllocatesMemoryInArena(result_type_index) || ValueData::allocatesMemoryInArena();
}
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
this->data(place).result().insertResultInto(to);
}
};
template <bool isMin>
AggregateFunctionPtr createAggregateFunctionArgMinMax(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(createAggregateFunctionSingleValue<AggregateFunctionArgMinMax, /* unary */ false, isMin>(
name, argument_types, parameters, settings));
}
}
void registerAggregateFunctionsArgMinArgMax(AggregateFunctionFactory & factory)
{
AggregateFunctionProperties properties = {.returns_default_when_only_null = false, .is_order_dependent = true};
factory.registerFunction("argMin", {createAggregateFunctionArgMinMax<true>, properties});
factory.registerFunction("argMax", {createAggregateFunctionArgMinMax<false>, properties});
}
}

View File

@ -0,0 +1,202 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/SingleValueData.h>
#include <Common/Concepts.h>
#include <Common/findExtreme.h>
namespace DB
{
struct Settings;
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NOT_IMPLEMENTED;
}
namespace
{
template <typename Data, bool isMin>
class AggregateFunctionMinMax final : public IAggregateFunctionDataHelper<Data, AggregateFunctionMinMax<Data, isMin>>
{
private:
SerializationPtr serialization;
public:
explicit AggregateFunctionMinMax(const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionMinMax<Data, isMin>>(argument_types_, {}, argument_types_[0])
, serialization(this->result_type->getDefaultSerialization())
{
if (!this->result_type->isComparable())
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument of aggregate function {} because the values of that data type are not comparable",
this->result_type->getName(),
getName());
}
String getName() const override
{
if constexpr (isMin)
return "min";
else
return "max";
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
if constexpr (isMin)
this->data(place).setIfSmaller(*columns[0], row_num, arena);
else
this->data(place).setIfGreater(*columns[0], row_num, arena);
}
void addManyDefaults(AggregateDataPtr __restrict place, const IColumn ** columns, size_t, Arena * arena) const override
{
add(place, columns, 0, arena);
}
void addBatchSinglePlace(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
Arena * arena,
ssize_t if_argument_pos) const override
{
if (if_argument_pos >= 0)
{
const auto & if_map = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
if constexpr (isMin)
this->data(place).setSmallestNotNullIf(*columns[0], nullptr, if_map.data(), row_begin, row_end, arena);
else
this->data(place).setGreatestNotNullIf(*columns[0], nullptr, if_map.data(), row_begin, row_end, arena);
}
else
{
if constexpr (isMin)
this->data(place).setSmallest(*columns[0], row_begin, row_end, arena);
else
this->data(place).setGreatest(*columns[0], row_begin, row_end, arena);
}
}
void addBatchSinglePlaceNotNull(
size_t row_begin,
size_t row_end,
AggregateDataPtr __restrict place,
const IColumn ** __restrict columns,
const UInt8 * __restrict null_map,
Arena * arena,
ssize_t if_argument_pos) const override
{
if (if_argument_pos >= 0)
{
const auto & if_map = assert_cast<const ColumnUInt8 &>(*columns[if_argument_pos]).getData();
if constexpr (isMin)
this->data(place).setSmallestNotNullIf(*columns[0], null_map, if_map.data(), row_begin, row_end, arena);
else
this->data(place).setGreatestNotNullIf(*columns[0], null_map, if_map.data(), row_begin, row_end, arena);
}
else
{
if constexpr (isMin)
this->data(place).setSmallestNotNullIf(*columns[0], null_map, nullptr, row_begin, row_end, arena);
else
this->data(place).setGreatestNotNullIf(*columns[0], null_map, nullptr, row_begin, row_end, arena);
}
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
if constexpr (isMin)
this->data(place).setIfSmaller(this->data(rhs), arena);
else
this->data(place).setIfGreater(this->data(rhs), arena);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
this->data(place).write(buf, *serialization);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
this->data(place).read(buf, *serialization, arena);
}
bool allocatesMemoryInArena() const override { return Data::allocatesMemoryInArena(); }
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
this->data(place).insertResultInto(to);
}
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
if constexpr (!Data::is_compilable)
return false;
else
return Data::isCompilable(*this->argument_types[0]);
}
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
if constexpr (Data::is_compilable)
Data::compileCreate(builder, aggregate_data_ptr);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override
{
if constexpr (Data::is_compilable)
if constexpr (isMin)
Data::compileMin(builder, aggregate_data_ptr, arguments[0].value);
else
Data::compileMax(builder, aggregate_data_ptr, arguments[0].value);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
void
compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
{
if constexpr (Data::is_compilable)
if constexpr (isMin)
Data::compileMinMerge(builder, aggregate_data_dst_ptr, aggregate_data_src_ptr);
else
Data::compileMaxMerge(builder, aggregate_data_dst_ptr, aggregate_data_src_ptr);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
if constexpr (Data::is_compilable)
return Data::compileGetResult(builder, aggregate_data_ptr);
else
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName());
}
#endif
};
template <bool isMin>
AggregateFunctionPtr createAggregateFunctionMinMax(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
return AggregateFunctionPtr(
createAggregateFunctionSingleValue<AggregateFunctionMinMax, /* unary */ true, isMin>(name, argument_types, parameters, settings));
}
}
void registerAggregateFunctionsMinMax(AggregateFunctionFactory & factory)
{
factory.registerFunction("min", createAggregateFunctionMinMax<true>, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("max", createAggregateFunctionMinMax<false>, AggregateFunctionFactory::CaseInsensitive);
}
}

View File

@ -1,93 +0,0 @@
#include "AggregateFunctionArgMinMax.h"
#include "AggregateFunctionCombinatorFactory.h"
#include <AggregateFunctions/AggregateFunctionMinMaxAny.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeString.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
namespace
{
template <template <typename> class Data>
class AggregateFunctionCombinatorArgMinMax final : public IAggregateFunctionCombinator
{
public:
String getName() const override { return Data<SingleValueDataGeneric>::name(); }
DataTypes transformArguments(const DataTypes & arguments) const override
{
if (arguments.empty())
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Incorrect number of arguments for aggregate function with {} suffix",
getName());
return DataTypes(arguments.begin(), arguments.end() - 1);
}
AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array & params) const override
{
const DataTypePtr & argument_type = arguments.back();
WhichDataType which(argument_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<TYPE>>>>(nested_function, arguments, params); /// NOLINT
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Date)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<DataTypeDate::FieldType>>>>(
nested_function, arguments, params);
if (which.idx == TypeIndex::DateTime)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<DataTypeDateTime::FieldType>>>>(
nested_function, arguments, params);
if (which.idx == TypeIndex::DateTime64)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<DateTime64>>>>(nested_function, arguments, params);
if (which.idx == TypeIndex::Decimal32)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<Decimal32>>>>(nested_function, arguments, params);
if (which.idx == TypeIndex::Decimal64)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<Decimal64>>>>(nested_function, arguments, params);
if (which.idx == TypeIndex::Decimal128)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<Decimal128>>>>(nested_function, arguments, params);
if (which.idx == TypeIndex::Decimal256)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataFixed<Decimal256>>>>(nested_function, arguments, params);
if (which.idx == TypeIndex::String)
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataString>>>(nested_function, arguments, params);
return std::make_shared<AggregateFunctionArgMinMax<Data<SingleValueDataGeneric>>>(nested_function, arguments, params);
}
};
template <typename Data>
struct AggregateFunctionArgMinDataCapitalized : AggregateFunctionMinData<Data>
{
static const char * name() { return "ArgMin"; }
};
template <typename Data>
struct AggregateFunctionArgMaxDataCapitalized : AggregateFunctionMaxData<Data>
{
static const char * name() { return "ArgMax"; }
};
}
void registerAggregateFunctionCombinatorMinMax(AggregateFunctionCombinatorFactory & factory)
{
factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorArgMinMax<AggregateFunctionArgMinDataCapitalized>>());
factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorArgMinMax<AggregateFunctionArgMaxDataCapitalized>>());
}
}

View File

@ -1,111 +0,0 @@
#pragma once
#include <AggregateFunctions/IAggregateFunction.h>
namespace DB
{
template <typename Key>
class AggregateFunctionArgMinMax final : public IAggregateFunctionHelper<AggregateFunctionArgMinMax<Key>>
{
private:
AggregateFunctionPtr nested_function;
SerializationPtr serialization;
size_t key_col;
size_t key_offset;
Key & key(AggregateDataPtr __restrict place) const { return *reinterpret_cast<Key *>(place + key_offset); }
const Key & key(ConstAggregateDataPtr __restrict place) const { return *reinterpret_cast<const Key *>(place + key_offset); }
public:
AggregateFunctionArgMinMax(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: IAggregateFunctionHelper<AggregateFunctionArgMinMax<Key>>{arguments, params, nested_function_->getResultType()}
, nested_function{nested_function_}
, serialization(arguments.back()->getDefaultSerialization())
, key_col{arguments.size() - 1}
, key_offset{(nested_function->sizeOfData() + alignof(Key) - 1) / alignof(Key) * alignof(Key)}
{
}
String getName() const override { return nested_function->getName() + Key::name(); }
bool isState() const override { return nested_function->isState(); }
bool isVersioned() const override { return nested_function->isVersioned(); }
size_t getVersionFromRevision(size_t revision) const override { return nested_function->getVersionFromRevision(revision); }
size_t getDefaultVersion() const override { return nested_function->getDefaultVersion(); }
bool allocatesMemoryInArena() const override { return nested_function->allocatesMemoryInArena() || Key::allocatesMemoryInArena(); }
bool hasTrivialDestructor() const override { return nested_function->hasTrivialDestructor(); }
size_t sizeOfData() const override { return key_offset + sizeof(Key); }
size_t alignOfData() const override { return nested_function->alignOfData(); }
void create(AggregateDataPtr __restrict place) const override
{
nested_function->create(place);
new (place + key_offset) Key;
}
void destroy(AggregateDataPtr __restrict place) const noexcept override { nested_function->destroy(place); }
void destroyUpToState(AggregateDataPtr __restrict place) const noexcept override { nested_function->destroyUpToState(place); }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
if (key(place).changeIfBetter(*columns[key_col], row_num, arena))
{
nested_function->destroy(place);
nested_function->create(place);
nested_function->add(place, columns, row_num, arena);
}
else if (key(place).isEqualTo(*columns[key_col], row_num))
{
nested_function->add(place, columns, row_num, arena);
}
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
if (key(place).changeIfBetter(key(rhs), arena))
{
nested_function->destroy(place);
nested_function->create(place);
nested_function->merge(place, rhs, arena);
}
else if (key(place).isEqualTo(key(rhs)))
{
nested_function->merge(place, rhs, arena);
}
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> version) const override
{
nested_function->serialize(place, buf, version);
key(place).write(buf, *serialization);
}
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> version, Arena * arena) const override
{
nested_function->deserialize(place, buf, version, arena);
key(place).read(buf, *serialization, arena);
}
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
{
nested_function->insertResultInto(place, to, arena);
}
void insertMergeResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
{
nested_function->insertMergeResultInto(place, to, arena);
}
AggregateFunctionPtr getNestedFunction() const override { return nested_function; }
};
}

View File

@ -0,0 +1,212 @@
#include <AggregateFunctions/Combinators/AggregateFunctionCombinatorFactory.h>
#include <AggregateFunctions/SingleValueData.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
namespace
{
struct AggregateFunctionCombinatorArgMinArgMaxData
{
private:
SingleValueDataBaseMemoryBlock v_data;
public:
explicit AggregateFunctionCombinatorArgMinArgMaxData(TypeIndex value_type) { generateSingleValueFromTypeIndex(value_type, v_data); }
~AggregateFunctionCombinatorArgMinArgMaxData() { data().~SingleValueDataBase(); }
SingleValueDataBase & data() { return v_data.get(); }
const SingleValueDataBase & data() const { return v_data.get(); }
};
template <bool isMin>
class AggregateFunctionCombinatorArgMinArgMax final : public IAggregateFunctionHelper<AggregateFunctionCombinatorArgMinArgMax<isMin>>
{
using Key = AggregateFunctionCombinatorArgMinArgMaxData;
private:
AggregateFunctionPtr nested_function;
SerializationPtr serialization;
const size_t key_col;
const size_t key_offset;
const TypeIndex key_type_index;
AggregateFunctionCombinatorArgMinArgMaxData & data(AggregateDataPtr __restrict place) const /// NOLINT
{
return *reinterpret_cast<Key *>(place + key_offset);
}
const AggregateFunctionCombinatorArgMinArgMaxData & data(ConstAggregateDataPtr __restrict place) const
{
return *reinterpret_cast<const Key *>(place + key_offset);
}
public:
AggregateFunctionCombinatorArgMinArgMax(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: IAggregateFunctionHelper<AggregateFunctionCombinatorArgMinArgMax<isMin>>{arguments, params, nested_function_->getResultType()}
, nested_function{nested_function_}
, serialization(arguments.back()->getDefaultSerialization())
, key_col{arguments.size() - 1}
, key_offset{((nested_function->sizeOfData() + alignof(Key) - 1) / alignof(Key)) * alignof(Key)}
, key_type_index(WhichDataType(arguments[key_col]).idx)
{
if (!arguments[key_col]->isComparable())
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} for combinator {} because the values of that data type are not comparable",
arguments[key_col]->getName(),
getName());
}
String getName() const override
{
if constexpr (isMin)
return "ArgMin";
else
return "ArgMax";
}
bool isState() const override { return nested_function->isState(); }
bool isVersioned() const override { return nested_function->isVersioned(); }
size_t getVersionFromRevision(size_t revision) const override { return nested_function->getVersionFromRevision(revision); }
size_t getDefaultVersion() const override { return nested_function->getDefaultVersion(); }
bool allocatesMemoryInArena() const override
{
return nested_function->allocatesMemoryInArena() || singleValueTypeAllocatesMemoryInArena(key_type_index);
}
bool hasTrivialDestructor() const override
{
return nested_function->hasTrivialDestructor() && /*false*/ std::is_trivially_destructible_v<SingleValueDataBase>;
}
size_t sizeOfData() const override { return key_offset + sizeof(Key); }
size_t alignOfData() const override { return std::max(nested_function->alignOfData(), alignof(SingleValueDataBaseMemoryBlock)); }
void create(AggregateDataPtr __restrict place) const override
{
nested_function->create(place);
new (place + key_offset) Key(key_type_index);
}
void destroy(AggregateDataPtr __restrict place) const noexcept override
{
data(place).~Key();
nested_function->destroy(place);
}
void destroyUpToState(AggregateDataPtr __restrict place) const noexcept override
{
data(place).~Key();
nested_function->destroyUpToState(place);
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
if ((isMin && data(place).data().setIfSmaller(*columns[key_col], row_num, arena))
|| (!isMin && data(place).data().setIfGreater(*columns[key_col], row_num, arena)))
{
nested_function->destroy(place);
nested_function->create(place);
nested_function->add(place, columns, row_num, arena);
}
else if (data(place).data().isEqualTo(*columns[key_col], row_num))
{
nested_function->add(place, columns, row_num, arena);
}
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
if ((isMin && data(place).data().setIfSmaller(data(rhs).data(), arena))
|| (!isMin && data(place).data().setIfGreater(data(rhs).data(), arena)))
{
nested_function->destroy(place);
nested_function->create(place);
nested_function->merge(place, rhs, arena);
}
else if (data(place).data().isEqualTo(data(rhs).data()))
{
nested_function->merge(place, rhs, arena);
}
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> version) const override
{
nested_function->serialize(place, buf, version);
data(place).data().write(buf, *serialization);
}
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> version, Arena * arena) const override
{
nested_function->deserialize(place, buf, version, arena);
data(place).data().read(buf, *serialization, arena);
}
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
{
nested_function->insertResultInto(place, to, arena);
}
void insertMergeResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
{
nested_function->insertMergeResultInto(place, to, arena);
}
AggregateFunctionPtr getNestedFunction() const override { return nested_function; }
};
template <bool isMin>
class CombinatorArgMinArgMax final : public IAggregateFunctionCombinator
{
public:
String getName() const override
{
if constexpr (isMin)
return "ArgMin";
else
return "ArgMax";
}
DataTypes transformArguments(const DataTypes & arguments) const override
{
if (arguments.empty())
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Incorrect number of arguments for aggregate function with {} suffix",
getName());
return DataTypes(arguments.begin(), arguments.end() - 1);
}
AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array & params) const override
{
return std::make_shared<AggregateFunctionCombinatorArgMinArgMax<isMin>>(nested_function, arguments, params);
}
};
}
void registerAggregateFunctionCombinatorsArgMinArgMax(AggregateFunctionCombinatorFactory & factory)
{
factory.registerCombinator(std::make_shared<CombinatorArgMinArgMax<true>>());
factory.registerCombinator(std::make_shared<CombinatorArgMinArgMax<false>>());
}
}

View File

@ -43,8 +43,8 @@ template <bool result_is_nullable, bool serialize_flag, typename Derived>
class AggregateFunctionNullBase : public IAggregateFunctionHelper<Derived>
{
protected:
AggregateFunctionPtr nested_function;
size_t prefix_size;
const AggregateFunctionPtr nested_function;
const size_t prefix_size;
/** In addition to data for nested aggregate function, we keep a flag
* indicating - was there at least one non-NULL value accumulated.
@ -55,12 +55,18 @@ protected:
AggregateDataPtr nestedPlace(AggregateDataPtr __restrict place) const noexcept
{
return place + prefix_size;
if constexpr (result_is_nullable)
return place + prefix_size;
else
return place;
}
ConstAggregateDataPtr nestedPlace(ConstAggregateDataPtr __restrict place) const noexcept
{
return place + prefix_size;
if constexpr (result_is_nullable)
return place + prefix_size;
else
return place;
}
static void initFlag(AggregateDataPtr __restrict place) noexcept
@ -87,11 +93,8 @@ public:
AggregateFunctionNullBase(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: IAggregateFunctionHelper<Derived>(arguments, params, createResultType(nested_function_))
, nested_function{nested_function_}
, prefix_size(result_is_nullable ? nested_function->alignOfData() : 0)
{
if constexpr (result_is_nullable)
prefix_size = nested_function->alignOfData();
else
prefix_size = 0;
}
String getName() const override

View File

@ -1,119 +0,0 @@
#pragma once
#include <AggregateFunctions/AggregateFunctionMinMaxAny.h>
#include <AggregateFunctions/AggregateFunctionArgMinMax.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/Helpers.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeString.h>
namespace DB
{
struct Settings;
/// min, max, any, anyLast, anyHeavy, etc...
template <template <typename> class AggregateFunctionTemplate, template <typename, bool...> class Data>
static IAggregateFunction *
createAggregateFunctionSingleValue(const String & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
const DataTypePtr & argument_type = argument_types[0];
WhichDataType which(argument_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<TYPE>>>(argument_type); /// NOLINT
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Date)
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<DataTypeDate::FieldType>>>(argument_type);
if (which.idx == TypeIndex::DateTime)
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<DataTypeDateTime::FieldType>>>(argument_type);
if (which.idx == TypeIndex::DateTime64)
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<DateTime64>>>(argument_type);
if (which.idx == TypeIndex::Decimal32)
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Decimal32>>>(argument_type);
if (which.idx == TypeIndex::Decimal64)
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Decimal64>>>(argument_type);
if (which.idx == TypeIndex::Decimal128)
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Decimal128>>>(argument_type);
if (which.idx == TypeIndex::Decimal256)
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Decimal256>>>(argument_type);
if (which.idx == TypeIndex::String)
return new AggregateFunctionTemplate<Data<SingleValueDataString>>(argument_type);
return new AggregateFunctionTemplate<Data<SingleValueDataGeneric>>(argument_type);
}
/// argMin, argMax
template <template <typename> class MinMaxData, typename ResData>
static IAggregateFunction * createAggregateFunctionArgMinMaxSecond(const DataTypePtr & res_type, const DataTypePtr & val_type)
{
WhichDataType which(val_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<TYPE>>>>(res_type, val_type); /// NOLINT
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Date)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<DataTypeDate::FieldType>>>>(res_type, val_type);
if (which.idx == TypeIndex::DateTime)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<DataTypeDateTime::FieldType>>>>(res_type, val_type);
if (which.idx == TypeIndex::DateTime64)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<DateTime64>>>>(res_type, val_type);
if (which.idx == TypeIndex::Decimal32)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<Decimal32>>>>(res_type, val_type);
if (which.idx == TypeIndex::Decimal64)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<Decimal64>>>>(res_type, val_type);
if (which.idx == TypeIndex::Decimal128)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<Decimal128>>>>(res_type, val_type);
if (which.idx == TypeIndex::Decimal256)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<Decimal256>>>>(res_type, val_type);
if (which.idx == TypeIndex::String)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataString>>>(res_type, val_type);
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataGeneric>>>(res_type, val_type);
}
template <template <typename> class MinMaxData>
static IAggregateFunction * createAggregateFunctionArgMinMax(const String & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
assertNoParameters(name, parameters);
assertBinary(name, argument_types);
const DataTypePtr & res_type = argument_types[0];
const DataTypePtr & val_type = argument_types[1];
WhichDataType which(res_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<TYPE>>(res_type, val_type); /// NOLINT
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Date)
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<DataTypeDate::FieldType>>(res_type, val_type);
if (which.idx == TypeIndex::DateTime)
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<DataTypeDateTime::FieldType>>(res_type, val_type);
if (which.idx == TypeIndex::DateTime64)
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<DateTime64>>(res_type, val_type);
if (which.idx == TypeIndex::Decimal32)
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<Decimal32>>(res_type, val_type);
if (which.idx == TypeIndex::Decimal64)
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<Decimal64>>(res_type, val_type);
if (which.idx == TypeIndex::Decimal128)
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<Decimal128>>(res_type, val_type);
if (which.idx == TypeIndex::Decimal256)
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<Decimal256>>(res_type, val_type);
if (which.idx == TypeIndex::String)
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataString>(res_type, val_type);
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataGeneric>(res_type, val_type);
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,394 @@
#pragma once
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Columns/ColumnDecimal.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <base/StringRef.h>
namespace DB
{
class Arena;
class ReadBuffer;
struct Settings;
class WriteBuffer;
/// Base class for Aggregation data that stores one of passed values: min, any, argMax...
/// It's setup as a virtual class so we can avoid templates when we need to extend them (argMax, SingleValueOrNull)
struct SingleValueDataBase
{
/// Any subclass (numeric, string, generic) must be smaller than MAX_STORAGE_SIZE
/// We use this knowledge to create composite data classes that use them directly by reserving a 'memory_block'
/// For example argMin holds 1 of these (for the result), while keeping a template for the value
static constexpr UInt32 MAX_STORAGE_SIZE = 64;
virtual ~SingleValueDataBase() { }
virtual bool has() const = 0;
virtual void insertResultInto(IColumn &) const = 0;
virtual void write(WriteBuffer &, const ISerialization &) const = 0;
virtual void read(ReadBuffer &, const ISerialization &, Arena *) = 0;
virtual bool isEqualTo(const IColumn & column, size_t row_num) const = 0;
virtual bool isEqualTo(const SingleValueDataBase &) const = 0;
virtual void set(const IColumn &, size_t row_num, Arena *) = 0;
virtual void set(const SingleValueDataBase &, Arena *) = 0;
virtual bool setIfSmaller(const IColumn &, size_t row_num, Arena *) = 0;
virtual bool setIfSmaller(const SingleValueDataBase &, Arena *) = 0;
virtual bool setIfGreater(const IColumn &, size_t row_num, Arena *) = 0;
virtual bool setIfGreater(const SingleValueDataBase &, Arena *) = 0;
/// Given a column, sets the internal value to the smallest or greatest value from the column
/// Used to implement batch min/max
virtual void setSmallest(const IColumn & column, size_t row_begin, size_t row_end, Arena * arena);
virtual void setGreatest(const IColumn & column, size_t row_begin, size_t row_end, Arena * arena);
virtual void setSmallestNotNullIf(const IColumn &, const UInt8 * __restrict, const UInt8 * __restrict, size_t, size_t, Arena *);
virtual void setGreatestNotNullIf(const IColumn &, const UInt8 * __restrict, const UInt8 * __restrict, size_t, size_t, Arena *);
/// Given a column returns the index of the smallest or greatest value in it
/// Doesn't return anything if the column is empty
/// There are used to implement argMin / argMax
virtual std::optional<size_t> getSmallestIndex(const IColumn & column, size_t row_begin, size_t row_end) const;
virtual std::optional<size_t> getGreatestIndex(const IColumn & column, size_t row_begin, size_t row_end) const;
virtual std::optional<size_t> getSmallestIndexNotNullIf(
const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end) const;
virtual std::optional<size_t> getGreatestIndexNotNullIf(
const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end) const;
};
#define FOR_SINGLE_VALUE_NUMERIC_TYPES(M) \
M(UInt8) \
M(UInt16) \
M(UInt32) \
M(UInt64) \
M(UInt128) \
M(UInt256) \
M(Int8) \
M(Int16) \
M(Int32) \
M(Int64) \
M(Int128) \
M(Int256) \
M(Float32) \
M(Float64) \
M(Decimal32) \
M(Decimal64) \
M(Decimal128) \
M(Decimal256) \
M(DateTime64)
/// For numeric values (without inheritance, for performance sensitive functions and JIT)
template <typename T>
struct SingleValueDataFixed
{
static constexpr bool is_compilable = true;
using Self = SingleValueDataFixed;
using ColVecType = ColumnVectorOrDecimal<T>;
T value = T{};
/// We need to remember if at least one value has been passed.
/// This is necessary for AggregateFunctionIf, merging states, JIT (where simple add is used), etc
bool has_value = false;
bool has() const { return has_value; }
void insertResultInto(IColumn & to) const;
void write(WriteBuffer & buf, const ISerialization &) const;
void read(ReadBuffer & buf, const ISerialization &, Arena *);
bool isEqualTo(const IColumn & column, size_t index) const;
bool isEqualTo(const Self & to) const;
void set(const IColumn & column, size_t row_num, Arena *);
void set(const Self & to, Arena *);
bool setIfSmaller(const T & to);
bool setIfGreater(const T & to);
bool setIfSmaller(const Self & to, Arena * arena);
bool setIfGreater(const Self & to, Arena * arena);
bool setIfSmaller(const IColumn & column, size_t row_num, Arena * arena);
bool setIfGreater(const IColumn & column, size_t row_num, Arena * arena);
void setSmallest(const IColumn & column, size_t row_begin, size_t row_end, Arena *);
void setGreatest(const IColumn & column, size_t row_begin, size_t row_end, Arena *);
void setSmallestNotNullIf(
const IColumn & column,
const UInt8 * __restrict null_map,
const UInt8 * __restrict if_map,
size_t row_begin,
size_t row_end,
Arena *);
void setGreatestNotNullIf(
const IColumn & column,
const UInt8 * __restrict null_map,
const UInt8 * __restrict if_map,
size_t row_begin,
size_t row_end,
Arena *);
std::optional<size_t> getSmallestIndex(const IColumn & column, size_t row_begin, size_t row_end) const;
std::optional<size_t> getGreatestIndex(const IColumn & column, size_t row_begin, size_t row_end) const;
std::optional<size_t> getSmallestIndexNotNullIf(
const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end) const;
std::optional<size_t> getGreatestIndexNotNullIf(
const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end) const;
static bool allocatesMemoryInArena() { return false; }
#if USE_EMBEDDED_COMPILER
static constexpr size_t has_value_offset = offsetof(Self, has_value);
static constexpr size_t value_offset = offsetof(Self, value);
static bool isCompilable(const IDataType & type);
static llvm::Value * getValuePtrFromAggregateDataPtr(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr);
static llvm::Value * getValueFromAggregateDataPtr(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr);
static llvm::Value * getHasValuePtrFromAggregateDataPtr(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr);
static llvm::Value * getHasValueFromAggregateDataPtr(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr);
static void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr);
static llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr);
static void compileSetValueFromNumber(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check);
static void
compileSetValueFromAggregation(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * aggregate_data_src_ptr);
static void compileAny(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check);
static void compileAnyMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr);
static void compileAnyLast(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check);
static void
compileAnyLastMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr);
template <bool isMin>
static void compileMinMax(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check);
template <bool isMin>
static void
compileMinMaxMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr);
static void compileMin(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check);
static void compileMinMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr);
static void compileMax(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, llvm::Value * value_to_check);
static void compileMaxMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr);
#endif
};
#define DISPATCH(TYPE) \
extern template struct SingleValueDataFixed<TYPE>; \
static_assert( \
sizeof(SingleValueDataFixed<TYPE>) <= SingleValueDataBase::MAX_STORAGE_SIZE, "Incorrect size of SingleValueDataFixed struct");
FOR_SINGLE_VALUE_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
/// For numeric values inheriting from SingleValueDataBase
template <typename T>
struct SingleValueDataNumeric final : public SingleValueDataBase
{
using Self = SingleValueDataNumeric<T>;
using Base = SingleValueDataFixed<T>;
private:
/// 32 bytes for types of 256 bits, + 8 bytes for the virtual table pointer.
static constexpr size_t base_memory_reserved_size = 40;
struct alignas(alignof(Base)) PrivateMemory
{
char memory[base_memory_reserved_size];
Base & get() { return *reinterpret_cast<Base *>(memory); }
const Base & get() const { return *reinterpret_cast<const Base *>(memory); }
};
static_assert(sizeof(Base) <= base_memory_reserved_size);
PrivateMemory memory;
public:
static constexpr bool is_compilable = false;
SingleValueDataNumeric();
~SingleValueDataNumeric() override;
bool has() const override;
void insertResultInto(IColumn & to) const override;
void write(WriteBuffer & buf, const ISerialization & serialization) const override;
void read(ReadBuffer & buf, const ISerialization & serialization, Arena * arena) override;
bool isEqualTo(const IColumn & column, size_t index) const override;
bool isEqualTo(const SingleValueDataBase & to) const override;
void set(const IColumn & column, size_t row_num, Arena * arena) override;
void set(const SingleValueDataBase & to, Arena * arena) override;
bool setIfSmaller(const SingleValueDataBase & to, Arena * arena) override;
bool setIfGreater(const SingleValueDataBase & to, Arena * arena) override;
bool setIfSmaller(const IColumn & column, size_t row_num, Arena * arena) override;
bool setIfGreater(const IColumn & column, size_t row_num, Arena * arena) override;
void setSmallest(const IColumn & column, size_t row_begin, size_t row_end, Arena * arena) override;
void setGreatest(const IColumn & column, size_t row_begin, size_t row_end, Arena * arena) override;
void setSmallestNotNullIf(
const IColumn & column,
const UInt8 * __restrict null_map,
const UInt8 * __restrict if_map,
size_t row_begin,
size_t row_end,
Arena * arena) override;
void setGreatestNotNullIf(
const IColumn & column,
const UInt8 * __restrict null_map,
const UInt8 * __restrict if_map,
size_t row_begin,
size_t row_end,
Arena * arena) override;
std::optional<size_t> getSmallestIndex(const IColumn & column, size_t row_begin, size_t row_end) const override;
std::optional<size_t> getGreatestIndex(const IColumn & column, size_t row_begin, size_t row_end) const override;
std::optional<size_t> getSmallestIndexNotNullIf(
const IColumn & column,
const UInt8 * __restrict null_map,
const UInt8 * __restrict if_map,
size_t row_begin,
size_t row_end) const override;
std::optional<size_t> getGreatestIndexNotNullIf(
const IColumn & column,
const UInt8 * __restrict null_map,
const UInt8 * __restrict if_map,
size_t row_begin,
size_t row_end) const override;
static bool allocatesMemoryInArena() { return false; }
};
#define DISPATCH(TYPE) \
extern template struct SingleValueDataNumeric<TYPE>; \
static_assert( \
sizeof(SingleValueDataNumeric<TYPE>) <= SingleValueDataBase::MAX_STORAGE_SIZE, "Incorrect size of SingleValueDataNumeric struct");
FOR_SINGLE_VALUE_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
/** For strings. Short strings are stored in the object itself, and long strings are allocated separately.
* NOTE It could also be suitable for arrays of numbers.
// */
struct SingleValueDataString final : public SingleValueDataBase
{
static constexpr bool is_compilable = false;
using Self = SingleValueDataString;
/// 0 size indicates that there is no value. Empty string must have terminating '\0' and, therefore, size of empty string is 1
UInt32 size = 0;
UInt32 capacity = 0; /// power of two or zero
char * large_data; /// Always allocated in an arena
//// TODO: Maybe instead of a virtual class we need to go with a std::variant of the 3 to avoid reserving space for the vtable
static constexpr UInt32 MAX_SMALL_STRING_SIZE
= SingleValueDataBase::MAX_STORAGE_SIZE - sizeof(size) - sizeof(capacity) - sizeof(large_data) - sizeof(SingleValueDataBase);
static constexpr UInt32 MAX_STRING_SIZE = std::numeric_limits<Int32>::max();
private:
char small_data[MAX_SMALL_STRING_SIZE]; /// Including the terminating zero.
char * getDataMutable();
const char * getData() const;
StringRef getStringRef() const;
void allocateLargeDataIfNeeded(UInt32 size_to_reserve, Arena * arena);
void changeImpl(StringRef value, Arena * arena);
public:
bool has() const override { return size != 0; }
void insertResultInto(IColumn & to) const override;
void write(WriteBuffer & buf, const ISerialization & /*serialization*/) const override;
void read(ReadBuffer & buf, const ISerialization & /*serialization*/, Arena * arena) override;
bool isEqualTo(const IColumn & column, size_t row_num) const override;
bool isEqualTo(const SingleValueDataBase &) const override;
void set(const IColumn & column, size_t row_num, Arena * arena) override;
void set(const SingleValueDataBase &, Arena * arena) override;
bool setIfSmaller(const IColumn & column, size_t row_num, Arena * arena) override;
bool setIfSmaller(const SingleValueDataBase &, Arena * arena) override;
bool setIfGreater(const IColumn & column, size_t row_num, Arena * arena) override;
bool setIfGreater(const SingleValueDataBase &, Arena * arena) override;
static bool allocatesMemoryInArena() { return true; }
};
static_assert(sizeof(SingleValueDataString) == SingleValueDataBase::MAX_STORAGE_SIZE, "Incorrect size of SingleValueDataString struct");
/// For any other value types.
struct SingleValueDataGeneric final : public SingleValueDataBase
{
static constexpr bool is_compilable = false;
private:
using Self = SingleValueDataGeneric;
Field value;
public:
bool has() const override { return !value.isNull(); }
void insertResultInto(IColumn & to) const override;
void write(WriteBuffer & buf, const ISerialization & serialization) const override;
void read(ReadBuffer & buf, const ISerialization & serialization, Arena *) override;
bool isEqualTo(const IColumn & column, size_t row_num) const override;
bool isEqualTo(const SingleValueDataBase & other) const override;
void set(const IColumn & column, size_t row_num, Arena *) override;
void set(const SingleValueDataBase & other, Arena *) override;
bool setIfSmaller(const IColumn & column, size_t row_num, Arena * arena) override;
bool setIfSmaller(const SingleValueDataBase & other, Arena *) override;
bool setIfGreater(const IColumn & column, size_t row_num, Arena * arena) override;
bool setIfGreater(const SingleValueDataBase & other, Arena *) override;
static bool allocatesMemoryInArena() { return false; }
};
static_assert(sizeof(SingleValueDataGeneric) <= SingleValueDataBase::MAX_STORAGE_SIZE, "Incorrect size of SingleValueDataGeneric struct");
/// min, max, any, anyLast, anyHeavy, etc...
template <template <typename, bool...> class AggregateFunctionTemplate, bool unary, bool... isMin>
static IAggregateFunction *
createAggregateFunctionSingleValue(const String & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
assertNoParameters(name, parameters);
if constexpr (unary)
assertUnary(name, argument_types);
else
assertBinary(name, argument_types);
const DataTypePtr & value_type = unary ? argument_types[0] : argument_types[1];
WhichDataType which(value_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionTemplate<SingleValueDataFixed<TYPE>, isMin...>(argument_types); /// NOLINT
FOR_SINGLE_VALUE_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Date)
return new AggregateFunctionTemplate<SingleValueDataFixed<DataTypeDate::FieldType>, isMin...>(argument_types);
if (which.idx == TypeIndex::DateTime)
return new AggregateFunctionTemplate<SingleValueDataFixed<DataTypeDateTime::FieldType>, isMin...>(argument_types);
if (which.idx == TypeIndex::String)
return new AggregateFunctionTemplate<SingleValueDataString, isMin...>(argument_types);
return new AggregateFunctionTemplate<SingleValueDataGeneric, isMin...>(argument_types);
}
/// Helper to allocate enough memory to store any derived class
struct SingleValueDataBaseMemoryBlock
{
std::aligned_union_t<
SingleValueDataBase::MAX_STORAGE_SIZE,
SingleValueDataNumeric<Decimal256>, /// We check all types in generateSingleValueFromTypeIndex
SingleValueDataString,
SingleValueDataGeneric>
memory;
SingleValueDataBase & get() { return *reinterpret_cast<SingleValueDataBase *>(&memory); }
const SingleValueDataBase & get() const { return *reinterpret_cast<const SingleValueDataBase *>(&memory); }
};
static_assert(alignof(SingleValueDataBaseMemoryBlock) == 8);
/// For Data classes that want to compose on top of SingleValueDataBase values, like argMax or singleValueOrNull
/// It will build the object based on the type idx on the memory block provided
void generateSingleValueFromTypeIndex(TypeIndex idx, SingleValueDataBaseMemoryBlock & data);
bool singleValueTypeAllocatesMemoryInArena(TypeIndex idx);
}

View File

@ -39,9 +39,11 @@ void registerAggregateFunctionsQuantileApprox(AggregateFunctionFactory &);
void registerAggregateFunctionsSequenceMatch(AggregateFunctionFactory &);
void registerAggregateFunctionWindowFunnel(AggregateFunctionFactory &);
void registerAggregateFunctionRate(AggregateFunctionFactory &);
void registerAggregateFunctionsMin(AggregateFunctionFactory &);
void registerAggregateFunctionsMax(AggregateFunctionFactory &);
void registerAggregateFunctionsMinMax(AggregateFunctionFactory &);
void registerAggregateFunctionsArgMinArgMax(AggregateFunctionFactory &);
void registerAggregateFunctionsAny(AggregateFunctionFactory &);
void registerAggregateFunctionAnyHeavy(AggregateFunctionFactory &);
void registerAggregateFunctionsAnyRespectNulls(AggregateFunctionFactory &);
void registerAggregateFunctionsStatisticsStable(AggregateFunctionFactory &);
void registerAggregateFunctionsStatisticsSecondMoment(AggregateFunctionFactory &);
void registerAggregateFunctionsStatisticsThirdMoment(AggregateFunctionFactory &);
@ -99,7 +101,7 @@ void registerAggregateFunctionCombinatorOrFill(AggregateFunctionCombinatorFactor
void registerAggregateFunctionCombinatorResample(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorDistinct(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorMap(AggregateFunctionCombinatorFactory & factory);
void registerAggregateFunctionCombinatorMinMax(AggregateFunctionCombinatorFactory & factory);
void registerAggregateFunctionCombinatorsArgMinArgMax(AggregateFunctionCombinatorFactory & factory);
void registerWindowFunctions(AggregateFunctionFactory & factory);
@ -138,9 +140,11 @@ void registerAggregateFunctions()
registerAggregateFunctionsSequenceMatch(factory);
registerAggregateFunctionWindowFunnel(factory);
registerAggregateFunctionRate(factory);
registerAggregateFunctionsMin(factory);
registerAggregateFunctionsMax(factory);
registerAggregateFunctionsMinMax(factory);
registerAggregateFunctionsArgMinArgMax(factory);
registerAggregateFunctionsAny(factory);
registerAggregateFunctionAnyHeavy(factory);
registerAggregateFunctionsAnyRespectNulls(factory);
registerAggregateFunctionsStatisticsStable(factory);
registerAggregateFunctionsStatisticsSecondMoment(factory);
registerAggregateFunctionsStatisticsThirdMoment(factory);
@ -203,7 +207,7 @@ void registerAggregateFunctions()
registerAggregateFunctionCombinatorResample(factory);
registerAggregateFunctionCombinatorDistinct(factory);
registerAggregateFunctionCombinatorMap(factory);
registerAggregateFunctionCombinatorMinMax(factory);
registerAggregateFunctionCombinatorsArgMinArgMax(factory);
}
}

View File

@ -20,6 +20,7 @@
#include <Common/TargetSpecific.h>
#include <Common/WeakHash.h>
#include <Common/assert_cast.h>
#include <Common/findExtreme.h>
#include <Common/iota.h>
#include <bit>
@ -248,6 +249,26 @@ void ColumnVector<T>::getPermutation(IColumn::PermutationSortDirection direction
iota(res.data(), data_size, IColumn::Permutation::value_type(0));
if constexpr (has_find_extreme_implementation<T> && !std::is_floating_point_v<T>)
{
/// Disabled for:floating point
/// * floating point: We don't deal with nan_direction_hint
/// * stability::Stable: We might return any value, not the first
if ((limit == 1) && (stability == IColumn::PermutationSortStability::Unstable))
{
std::optional<size_t> index;
if (direction == IColumn::PermutationSortDirection::Ascending)
index = findExtremeMinIndex(data.data(), 0, data.size());
else
index = findExtremeMaxIndex(data.data(), 0, data.size());
if (index)
{
res.data()[0] = *index;
return;
}
}
}
if constexpr (is_arithmetic_v<T> && !is_big_int_v<T>)
{
if (!limit)

View File

@ -2,23 +2,28 @@
#include <Common/TargetSpecific.h>
#include <Common/findExtreme.h>
#include <limits>
#include <type_traits>
namespace DB
{
template <is_any_native_number T>
template <has_find_extreme_implementation T>
struct MinComparator
{
static ALWAYS_INLINE inline const T & cmp(const T & a, const T & b) { return std::min(a, b); }
};
template <is_any_native_number T>
template <has_find_extreme_implementation T>
struct MaxComparator
{
static ALWAYS_INLINE inline const T & cmp(const T & a, const T & b) { return std::max(a, b); }
};
MULTITARGET_FUNCTION_AVX2_SSE42(
MULTITARGET_FUNCTION_HEADER(template <is_any_native_number T, typename ComparatorClass, bool add_all_elements, bool add_if_cond_zero> static std::optional<T> NO_INLINE),
MULTITARGET_FUNCTION_HEADER(
template <has_find_extreme_implementation T, typename ComparatorClass, bool add_all_elements, bool add_if_cond_zero>
static std::optional<T> NO_INLINE),
findExtremeImpl,
MULTITARGET_FUNCTION_BODY((const T * __restrict ptr, const UInt8 * __restrict condition_map [[maybe_unused]], size_t row_begin, size_t row_end) /// NOLINT
{
@ -65,24 +70,57 @@ MULTITARGET_FUNCTION_AVX2_SSE42(
for (size_t unroll_it = 0; unroll_it < unroll_block; unroll_it++)
ret = ComparatorClass::cmp(ret, partial_min[unroll_it]);
}
}
for (; i < count; i++)
for (; i < count; i++)
{
if (add_all_elements || !condition_map[i] == add_if_cond_zero)
ret = ComparatorClass::cmp(ret, ptr[i]);
}
return ret;
}
else
{
if (add_all_elements || !condition_map[i] == add_if_cond_zero)
ret = ComparatorClass::cmp(ret, ptr[i]);
/// Only native integers
for (; i < count; i++)
{
constexpr bool is_min = std::same_as<ComparatorClass, MinComparator<T>>;
if constexpr (add_all_elements)
{
ret = ComparatorClass::cmp(ret, ptr[i]);
}
else if constexpr (is_min)
{
/// keep_number will be 0 or 1
bool keep_number = !condition_map[i] == add_if_cond_zero;
/// If keep_number = ptr[i] * 1 + 0 * max = ptr[i]
/// If not keep_number = ptr[i] * 0 + 1 * max = max
T final = ptr[i] * T{keep_number} + T{!keep_number} * std::numeric_limits<T>::max();
ret = ComparatorClass::cmp(ret, final);
}
else
{
static_assert(std::same_as<ComparatorClass, MaxComparator<T>>);
/// keep_number will be 0 or 1
bool keep_number = !condition_map[i] == add_if_cond_zero;
/// If keep_number = ptr[i] * 1 + 0 * lowest = ptr[i]
/// If not keep_number = ptr[i] * 0 + 1 * lowest = lowest
T final = ptr[i] * T{keep_number} + T{!keep_number} * std::numeric_limits<T>::lowest();
ret = ComparatorClass::cmp(ret, final);
}
}
return ret;
}
return ret;
}
))
/// Given a vector of T finds the extreme (MIN or MAX) value
template <is_any_native_number T, class ComparatorClass, bool add_all_elements, bool add_if_cond_zero>
template <has_find_extreme_implementation T, class ComparatorClass, bool add_all_elements, bool add_if_cond_zero>
static std::optional<T>
findExtreme(const T * __restrict ptr, const UInt8 * __restrict condition_map [[maybe_unused]], size_t start, size_t end)
{
#if USE_MULTITARGET_CODE
/// In some cases the compiler if able to apply the condition and still generate SIMD, so we still build both
/// conditional and unconditional functions with multiple architectures
/// We see no benefit from using AVX512BW or AVX512F (over AVX2), so we only declare SSE and AVX2
if (isArchSupported(TargetArch::AVX2))
return findExtremeImplAVX2<T, ComparatorClass, add_all_elements, add_if_cond_zero>(ptr, condition_map, start, end);
@ -93,50 +131,90 @@ findExtreme(const T * __restrict ptr, const UInt8 * __restrict condition_map [[m
return findExtremeImpl<T, ComparatorClass, add_all_elements, add_if_cond_zero>(ptr, condition_map, start, end);
}
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMin(const T * __restrict ptr, size_t start, size_t end)
{
return findExtreme<T, MinComparator<T>, true, false>(ptr, nullptr, start, end);
}
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMinNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end)
{
return findExtreme<T, MinComparator<T>, false, true>(ptr, condition_map, start, end);
}
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMinIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end)
{
return findExtreme<T, MinComparator<T>, false, false>(ptr, condition_map, start, end);
}
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMax(const T * __restrict ptr, size_t start, size_t end)
{
return findExtreme<T, MaxComparator<T>, true, false>(ptr, nullptr, start, end);
}
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMaxNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end)
{
return findExtreme<T, MaxComparator<T>, false, true>(ptr, condition_map, start, end);
}
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMaxIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end)
{
return findExtreme<T, MaxComparator<T>, false, false>(ptr, condition_map, start, end);
}
template <has_find_extreme_implementation T>
std::optional<size_t> findExtremeMinIndex(const T * __restrict ptr, size_t start, size_t end)
{
/// This is implemented based on findNumericExtreme and not the other way around (or independently) because getting
/// the MIN or MAX value of an array is possible with SIMD, but getting the index isn't.
/// So what we do is use SIMD to find the lowest value and then iterate again over the array to find its position
std::optional<T> opt = findExtremeMin(ptr, start, end);
if (!opt)
return std::nullopt;
/// Some minimal heuristics for the case the input is sorted
if (*opt == ptr[start])
return {start};
for (size_t i = end - 1; i > start; i--)
if (ptr[i] == *opt)
return {i};
return std::nullopt;
}
template <has_find_extreme_implementation T>
std::optional<size_t> findExtremeMaxIndex(const T * __restrict ptr, size_t start, size_t end)
{
std::optional<T> opt = findExtremeMax(ptr, start, end);
if (!opt)
return std::nullopt;
/// Some minimal heuristics for the case the input is sorted
if (*opt == ptr[start])
return {start};
for (size_t i = end - 1; i > start; i--)
if (ptr[i] == *opt)
return {i};
return std::nullopt;
}
#define INSTANTIATION(T) \
template std::optional<T> findExtremeMin(const T * __restrict ptr, size_t start, size_t end); \
template std::optional<T> findExtremeMinNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
template std::optional<T> findExtremeMinIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
template std::optional<T> findExtremeMinNotNull( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
template std::optional<T> findExtremeMinIf( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
template std::optional<T> findExtremeMax(const T * __restrict ptr, size_t start, size_t end); \
template std::optional<T> findExtremeMaxNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
template std::optional<T> findExtremeMaxIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end);
template std::optional<T> findExtremeMaxNotNull( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
template std::optional<T> findExtremeMaxIf( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
template std::optional<size_t> findExtremeMinIndex(const T * __restrict ptr, size_t start, size_t end); \
template std::optional<size_t> findExtremeMaxIndex(const T * __restrict ptr, size_t start, size_t end);
FOR_BASIC_NUMERIC_TYPES(INSTANTIATION)
#undef INSTANTIATION

View File

@ -11,35 +11,47 @@
namespace DB
{
template <typename T>
concept is_any_native_number = (is_any_of<T, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64>);
concept has_find_extreme_implementation = (is_any_of<T, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64>);
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMin(const T * __restrict ptr, size_t start, size_t end);
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMinNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end);
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMinIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end);
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMax(const T * __restrict ptr, size_t start, size_t end);
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMaxNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end);
template <is_any_native_number T>
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMaxIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end);
template <has_find_extreme_implementation T>
std::optional<size_t> findExtremeMinIndex(const T * __restrict ptr, size_t start, size_t end);
template <has_find_extreme_implementation T>
std::optional<size_t> findExtremeMaxIndex(const T * __restrict ptr, size_t start, size_t end);
#define EXTERN_INSTANTIATION(T) \
extern template std::optional<T> findExtremeMin(const T * __restrict ptr, size_t start, size_t end); \
extern template std::optional<T> findExtremeMinNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
extern template std::optional<T> findExtremeMinIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
extern template std::optional<T> findExtremeMinNotNull( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
extern template std::optional<T> findExtremeMinIf( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
extern template std::optional<T> findExtremeMax(const T * __restrict ptr, size_t start, size_t end); \
extern template std::optional<T> findExtremeMaxNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
extern template std::optional<T> findExtremeMaxIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end);
extern template std::optional<T> findExtremeMaxNotNull( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
extern template std::optional<T> findExtremeMaxIf( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
extern template std::optional<size_t> findExtremeMinIndex(const T * __restrict ptr, size_t start, size_t end); \
extern template std::optional<size_t> findExtremeMaxIndex(const T * __restrict ptr, size_t start, size_t end);
FOR_BASIC_NUMERIC_TYPES(EXTERN_INSTANTIATION)
FOR_BASIC_NUMERIC_TYPES(EXTERN_INSTANTIATION)
#undef EXTERN_INSTANTIATION
}

View File

@ -0,0 +1,24 @@
<test>
<substitutions>
<substitution>
<name>group_scale</name>
<values>
<value>1000000</value>
</values>
</substitution>
</substitutions>
<query>select argMin(Title, EventTime) from hits_100m_single where Title != '' group by intHash32(UserID) % {group_scale} FORMAT Null</query>
<query>select argMinIf(Title, EventTime, Title != '') from hits_100m_single group by intHash32(UserID) % {group_scale} FORMAT Null</query>
<query>select argMinIf(Title::Nullable(String), EventTime::Nullable(DateTime), Title::Nullable(String) != '') from hits_100m_single group by intHash32(UserID) % {group_scale} FORMAT Null</query>
<query>select argMin(RegionID, EventTime) from hits_100m_single where Title != '' group by intHash32(UserID) % {group_scale} FORMAT Null</query>
<query>select argMin((Title, RegionID), EventTime) from hits_100m_single where Title != '' group by intHash32(UserID) % {group_scale} FORMAT Null</query>
<query>select argMinIf(Title, EventTime, Title != '') from hits_100m_single group by intHash32(UserID) % {group_scale} FORMAT Null</query>
<query>select argMax(WatchID, Age) from hits_100m_single FORMAT Null</query>
<query>select argMax(WatchID, Age::Nullable(UInt8)) from hits_100m_single FORMAT Null</query>
<query>select argMax(WatchID, (EventDate, EventTime)) from hits_100m_single where Title != '' group by intHash32(UserID) % {group_scale} FORMAT Null</query>
<query>select argMax(MobilePhone, MobilePhoneModel) from hits_100m_single</query>
</test>

View File

@ -30,7 +30,7 @@
</create_query>
<create_query>
CREATE TABLE jit_test_merge_tree_nullable (
CREATE TABLE jit_test_memory_nullable (
key UInt64,
value_1 Nullable(UInt64),
value_2 Nullable(UInt64),
@ -42,7 +42,7 @@
</create_query>
<create_query>
CREATE TABLE jit_test_memory_nullable (
CREATE TABLE jit_test_merge_tree_nullable (
key UInt64,
value_1 Nullable(UInt64),
value_2 Nullable(UInt64),

View File

@ -1,4 +1,5 @@
<test>
<query>SELECT number AS n FROM numbers_mt(200000000) ORDER BY n DESC LIMIT 1 FORMAT Null</query>
<query>SELECT number AS n FROM numbers_mt(200000000) ORDER BY n DESC LIMIT 10 FORMAT Null</query>
<query>SELECT number AS n FROM numbers_mt(200000000) ORDER BY n DESC LIMIT 100 FORMAT Null</query>
<query>SELECT number AS n FROM numbers_mt(200000000) ORDER BY n DESC LIMIT 1500 FORMAT Null</query>
@ -7,6 +8,7 @@
<query>SELECT number AS n FROM numbers_mt(200000000) ORDER BY n DESC LIMIT 10000 FORMAT Null</query>
<query>SELECT number AS n FROM numbers_mt(200000000) ORDER BY n DESC LIMIT 65535 FORMAT Null</query>
<query>SELECT intHash64(number) AS n FROM numbers_mt(500000000) ORDER BY n LIMIT 1 FORMAT Null</query>
<query>SELECT intHash64(number) AS n FROM numbers_mt(500000000) ORDER BY n LIMIT 10 FORMAT Null</query>
<query>SELECT intHash64(number) AS n FROM numbers_mt(200000000) ORDER BY n LIMIT 100 FORMAT Null</query>
<query>SELECT intHash64(number) AS n FROM numbers_mt(200000000) ORDER BY n LIMIT 1500 FORMAT Null</query>
@ -15,6 +17,7 @@
<query>SELECT intHash64(number) AS n FROM numbers_mt(200000000) ORDER BY n LIMIT 10000 FORMAT Null</query>
<query>SELECT intHash64(number) AS n FROM numbers_mt(100000000) ORDER BY n LIMIT 65535 FORMAT Null</query>
<query>SELECT intHash64(number) AS n FROM numbers_mt(200000000) ORDER BY n, n + 1, n + 2 LIMIT 1 FORMAT Null</query>
<query>SELECT intHash64(number) AS n FROM numbers_mt(200000000) ORDER BY n, n + 1, n + 2 LIMIT 10 FORMAT Null</query>
<query>SELECT intHash64(number) AS n FROM numbers_mt(200000000) ORDER BY n, n + 1, n + 2 LIMIT 100 FORMAT Null</query>
<query>SELECT intHash64(number) AS n FROM numbers_mt(200000000) ORDER BY n, n + 1, n + 2 LIMIT 1500 FORMAT Null</query>

View File

@ -5,4 +5,12 @@ select argMin(x.1, x.2), argMax(x.1, x.2) from (select (toDate(number, 'UTC'), t
select argMin(x.1, x.2), argMax(x.1, x.2) from (select (toDecimal32(number, 2), toDecimal64(number, 2) + 1) as x from numbers(10));
-- array
SELECT argMinArray(id, num), argMaxArray(id, num) FROM (SELECT arrayJoin([[10, 4, 3], [7, 5, 6], [8, 8, 2]]) AS num, arrayJoin([[1, 2, 4], [2, 3, 3]]) AS id);
SELECT
argMinArray(id, num),
argMaxArray(id, num)
FROM
(
SELECT
arrayJoin([[10, 4, 3], [7, 5, 6], [8, 8, 2]]) AS num,
arrayJoin([[1, 2, 4]]) AS id
)

View File

@ -1 +1 @@
SELECT argMinArray(id, num), argMaxArray(id, num) FROM (SELECT arrayJoin([[10, 4, 3], [7, 5, 6], [8, 8, 2]]) AS num, arrayJoin([[1, 2, 4], [2, 3, 3]]) AS id)
SELECT argMinArray(id, num), argMaxArray(id, num) FROM (SELECT arrayJoin([[10, 4, 3], [7, 5, 6], [8, 8, 2]]) AS num, arrayJoin([[1, 2, 4]]) AS id)

View File

@ -56,6 +56,10 @@ SELECT min(n::Nullable(String)) from (Select if(number < 15 and number % 2 == 1,
22
SELECT max(n::Nullable(String)) from (Select if(number < 15 and number % 2 == 1, number * 2, NULL) as n from numbers(10, 20));
26
SELECT max(number) from (Select if(number % 2 == 1, NULL, -number::Int8) as number FROM numbers(128));
0
SELECT min(number) from (Select if(number % 2 == 1, NULL, -number::Int8) as number FROM numbers(128));
-126
SELECT argMax(number, now()) FROM (Select number as number from numbers(10, 10000)) settings max_threads=1, max_block_size=100;
10
SELECT argMax(number, now()) FROM (Select number as number from numbers(10, 10000)) settings max_threads=1, max_block_size=20000;
@ -190,3 +194,7 @@ SELECT min(n::Nullable(String)) from (Select if(number < 15 and number % 2 == 1,
22
SELECT max(n::Nullable(String)) from (Select if(number < 15 and number % 2 == 1, number * 2, NULL) as n from numbers(10, 20));
26
SELECT max(number::Nullable(Decimal64(3))) from numbers(11) settings max_block_size=10;
10
SELECT min(-number::Nullable(Decimal64(3))) from numbers(11) settings max_block_size=10;
-10

View File

@ -48,6 +48,9 @@ SELECT maxIf(number::Nullable(String), number < 10) as number from numbers(10, 1
SELECT min(n::Nullable(String)) from (Select if(number < 15 and number % 2 == 1, number * 2, NULL) as n from numbers(10, 20));
SELECT max(n::Nullable(String)) from (Select if(number < 15 and number % 2 == 1, number * 2, NULL) as n from numbers(10, 20));
SELECT max(number) from (Select if(number % 2 == 1, NULL, -number::Int8) as number FROM numbers(128));
SELECT min(number) from (Select if(number % 2 == 1, NULL, -number::Int8) as number FROM numbers(128));
SELECT argMax(number, now()) FROM (Select number as number from numbers(10, 10000)) settings max_threads=1, max_block_size=100;
SELECT argMax(number, now()) FROM (Select number as number from numbers(10, 10000)) settings max_threads=1, max_block_size=20000;
SELECT argMax(number, 1) FROM (Select number as number from numbers(10, 10000)) settings max_threads=1, max_block_size=100;
@ -138,3 +141,6 @@ SELECT maxIf(number::Nullable(String), number < 10) as number from numbers(10, 1
SELECT min(n::Nullable(String)) from (Select if(number < 15 and number % 2 == 1, number * 2, NULL) as n from numbers(10, 20));
SELECT max(n::Nullable(String)) from (Select if(number < 15 and number % 2 == 1, number * 2, NULL) as n from numbers(10, 20));
SELECT max(number::Nullable(Decimal64(3))) from numbers(11) settings max_block_size=10;
SELECT min(-number::Nullable(Decimal64(3))) from numbers(11) settings max_block_size=10;

View File

@ -42,12 +42,12 @@ ORDER BY event_time_microseconds;
-- 1 * 8 + AggregateFunction(argMax, String, DateTime)
--
-- Size of AggregateFunction(argMax, String, DateTime):
-- SingleValueDataString() + SingleValueDataFixed(DateTime)
-- SingleValueDataString = 64B for small strings, 64B + string size + 1 for larger
-- SingleValueDataFixed(DateTime) = 1 + 4. With padding = 8
-- SingleValueDataString Total: 72B
-- 1 Base class + 1 specific/value class:
-- Base class: MAX(sizeOf(SingleValueDataFixed<T>), sizeOf(SingleValueDataString), sizeOf(SingleValueDataGeneric)) = 64
-- Specific class: SingleValueDataFixed(DateTime) = 4 + 1. With padding = 8
-- Total: 8 + 64 + 8 = 80
--
-- ColumnAggregateFunction total: 8 + 72 = 80
-- ColumnAggregateFunction total: 8 + 2 * 64 = 136
SELECT 'AggregateFunction(argMax, String, DateTime)',
read_rows,
read_bytes

View File

@ -1,3 +1,4 @@
200 295
200 245
200 290
999

View File

@ -1,3 +1,10 @@
select sumArgMin(number, number % 20), sumArgMax(number, number % 20) from numbers(100);
select sumArgMin(number, toString(number % 20)), sumArgMax(number, toString(number % 20)) from numbers(100);
select sumArgMinIf(number, number % 20, number % 2 = 0), sumArgMaxIf(number, number % 20, number % 2 = 0) from numbers(100);
select sumArgMin() from numbers(100); -- {serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH}
select sumArgMin(number) from numbers(100); -- {serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH}
-- Try passing a non comparable type, for example an AggregationState
select sumArgMin(number, unhex('0000000000000000')::AggregateFunction(sum, UInt64)) from numbers(100); -- {serverError ILLEGAL_TYPE_OF_ARGUMENT}
-- ASAN (data leak)
SELECT sumArgMax(number, tuple(number, repeat('a', (10 * (number % 100))::Int32))) FROM numbers(1000);

View File

@ -0,0 +1,8 @@
-- When we use SingleValueDataBaseMemoryBlock we must ensure we call the class destructor on destroy
Select argMax((number, number), (number, number)) FROM numbers(100000) format Null;
Select argMin((number, number), (number, number)) FROM numbers(100000) format Null;
Select anyHeavy((number, number)) FROM numbers(100000) format Null;
Select singleValueOrNull(number::Date32) FROM numbers(100000) format Null;
Select anyArgMax(number, (number, number)) FROM numbers(100000) format Null;
Select anyArgMin(number, (number, number)) FROM numbers(100000) format Null;

View File

@ -4,8 +4,6 @@
TU_EXCLUDES=(
CastOverloadResolver
AggregateFunctionMax
AggregateFunctionMin
AggregateFunctionUniq
FunctionsConversion