This commit is contained in:
Pavel Kartavyy 2015-12-04 11:54:09 +03:00
commit 703b0f285a
316 changed files with 18350 additions and 4708 deletions

View File

@ -24,17 +24,17 @@ private:
public: public:
AggregateFunctionArray(AggregateFunctionPtr nested_) : nested_func_owner(nested_), nested_func(nested_func_owner.get()) {} AggregateFunctionArray(AggregateFunctionPtr nested_) : nested_func_owner(nested_), nested_func(nested_func_owner.get()) {}
String getName() const String getName() const override
{ {
return nested_func->getName() + "Array"; return nested_func->getName() + "Array";
} }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return nested_func->getReturnType(); return nested_func->getReturnType();
} }
void setArguments(const DataTypes & arguments) void setArguments(const DataTypes & arguments) override
{ {
num_agruments = arguments.size(); num_agruments = arguments.size();
@ -49,37 +49,37 @@ public:
nested_func->setArguments(nested_arguments); nested_func->setArguments(nested_arguments);
} }
void setParameters(const Array & params) void setParameters(const Array & params) override
{ {
nested_func->setParameters(params); nested_func->setParameters(params);
} }
void create(AggregateDataPtr place) const void create(AggregateDataPtr place) const override
{ {
nested_func->create(place); nested_func->create(place);
} }
void destroy(AggregateDataPtr place) const noexcept void destroy(AggregateDataPtr place) const noexcept override
{ {
nested_func->destroy(place); nested_func->destroy(place);
} }
bool hasTrivialDestructor() const bool hasTrivialDestructor() const override
{ {
return nested_func->hasTrivialDestructor(); return nested_func->hasTrivialDestructor();
} }
size_t sizeOfData() const size_t sizeOfData() const override
{ {
return nested_func->sizeOfData(); return nested_func->sizeOfData();
} }
size_t alignOfData() const size_t alignOfData() const override
{ {
return nested_func->alignOfData(); return nested_func->alignOfData();
} }
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num) const void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num) const override
{ {
const IColumn * nested[num_agruments]; const IColumn * nested[num_agruments];
@ -96,25 +96,32 @@ public:
nested_func->add(place, nested, i); nested_func->add(place, nested, i);
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
nested_func->merge(place, rhs); nested_func->merge(place, rhs);
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
nested_func->serialize(place, buf); nested_func->serialize(place, buf);
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
nested_func->deserializeMerge(place, buf); nested_func->deserializeMerge(place, buf);
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
nested_func->insertResultInto(place, to); nested_func->insertResultInto(place, to);
} }
static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num)
{
return static_cast<const AggregateFunctionArray &>(*that).add(place, columns, row_num);
}
IAggregateFunction::AddFunc getAddressOfAddFunction() const override final { return &addFree; }
}; };
} }

View File

@ -27,9 +27,9 @@ template <typename T>
class AggregateFunctionAvg final : public IUnaryAggregateFunction<AggregateFunctionAvgData<typename NearestFieldType<T>::Type>, AggregateFunctionAvg<T> > class AggregateFunctionAvg final : public IUnaryAggregateFunction<AggregateFunctionAvgData<typename NearestFieldType<T>::Type>, AggregateFunctionAvg<T> >
{ {
public: public:
String getName() const { return "avg"; } String getName() const override { return "avg"; }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return new DataTypeFloat64; return new DataTypeFloat64;
} }
@ -42,25 +42,25 @@ public:
} }
void addOne(AggregateDataPtr place, const IColumn & column, size_t row_num) const void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num) const
{ {
this->data(place).sum += static_cast<const ColumnVector<T> &>(column).getData()[row_num]; this->data(place).sum += static_cast<const ColumnVector<T> &>(column).getData()[row_num];
++this->data(place).count; ++this->data(place).count;
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
this->data(place).sum += this->data(rhs).sum; this->data(place).sum += this->data(rhs).sum;
this->data(place).count += this->data(rhs).count; this->data(place).count += this->data(rhs).count;
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
writeBinary(this->data(place).sum, buf); writeBinary(this->data(place).sum, buf);
writeVarUInt(this->data(place).count, buf); writeVarUInt(this->data(place).count, buf);
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
typename NearestFieldType<T>::Type tmp_sum = 0; typename NearestFieldType<T>::Type tmp_sum = 0;
UInt64 tmp_count = 0; UInt64 tmp_count = 0;
@ -70,7 +70,7 @@ public:
this->data(place).count += tmp_count; this->data(place).count += tmp_count;
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
static_cast<ColumnFloat64 &>(to).getData().push_back( static_cast<ColumnFloat64 &>(to).getData().push_back(
static_cast<Float64>(this->data(place).sum) / this->data(place).count); static_cast<Float64>(this->data(place).sum) / this->data(place).count);

View File

@ -23,37 +23,37 @@ struct AggregateFunctionCountData
class AggregateFunctionCount final : public INullaryAggregateFunction<AggregateFunctionCountData, AggregateFunctionCount> class AggregateFunctionCount final : public INullaryAggregateFunction<AggregateFunctionCountData, AggregateFunctionCount>
{ {
public: public:
String getName() const { return "count"; } String getName() const override { return "count"; }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return new DataTypeUInt64; return new DataTypeUInt64;
} }
void addZero(AggregateDataPtr place) const void addImpl(AggregateDataPtr place) const
{ {
++data(place).count; ++data(place).count;
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
data(place).count += data(rhs).count; data(place).count += data(rhs).count;
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
writeVarUInt(data(place).count, buf); writeVarUInt(data(place).count, buf);
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
UInt64 tmp; UInt64 tmp;
readVarUInt(tmp, buf); readVarUInt(tmp, buf);
data(place).count += tmp; data(place).count += tmp;
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
static_cast<ColumnUInt64 &>(to).getData().push_back(data(place).count); static_cast<ColumnUInt64 &>(to).getData().push_back(data(place).count);
} }

View File

@ -4,6 +4,10 @@
#include <DB/IO/ReadHelpers.h> #include <DB/IO/ReadHelpers.h>
#include <DB/DataTypes/DataTypeArray.h> #include <DB/DataTypes/DataTypeArray.h>
#include <DB/DataTypes/DataTypesNumberFixed.h>
#include <DB/Columns/ColumnVector.h>
#include <DB/Columns/ColumnArray.h>
#include <DB/AggregateFunctions/IUnaryAggregateFunction.h> #include <DB/AggregateFunctions/IUnaryAggregateFunction.h>
@ -13,22 +17,102 @@
namespace DB namespace DB
{ {
struct AggregateFunctionGroupArrayData
/// Частный случай - реализация для числовых типов.
template <typename T>
struct AggregateFunctionGroupArrayDataNumeric
{
/// Сразу будет выделена память на несколько элементов так, чтобы состояние занимало 64 байта.
static constexpr size_t bytes_in_arena = 64 - sizeof(PODArray<T>);
using Array = PODArray<T, bytes_in_arena / sizeof(T), AllocatorWithStackMemory<Allocator<false>, bytes_in_arena>>;
Array value;
};
/// Общий случай (неэффективно). NOTE Можно ещё реализовать частный случай для строк.
struct AggregateFunctionGroupArrayDataGeneric
{ {
Array value; /// TODO Добавить MemoryTracker Array value; /// TODO Добавить MemoryTracker
}; };
/// Складывает все значения в массив. Реализовано неэффективно. template <typename T>
class AggregateFunctionGroupArray final : public IUnaryAggregateFunction<AggregateFunctionGroupArrayData, AggregateFunctionGroupArray> class AggregateFunctionGroupArrayNumeric final
: public IUnaryAggregateFunction<AggregateFunctionGroupArrayDataNumeric<T>, AggregateFunctionGroupArrayNumeric<T>>
{
public:
String getName() const override { return "groupArray"; }
DataTypePtr getReturnType() const override
{
return new DataTypeArray(new typename DataTypeFromFieldType<T>::Type);
}
void setArgument(const DataTypePtr & argument)
{
}
void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num) const
{
this->data(place).value.push_back(static_cast<const ColumnVector<T> &>(column).getData()[row_num]);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{
this->data(place).value.insert(this->data(rhs).value.begin(), this->data(rhs).value.end());
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
const auto & value = this->data(place).value;
size_t size = value.size();
writeVarUInt(size, buf);
buf.write(reinterpret_cast<const char *>(&value[0]), size * sizeof(value[0]));
}
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{
size_t size = 0;
readVarUInt(size, buf);
if (size > AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE)
throw Exception("Too large array size", ErrorCodes::TOO_LARGE_ARRAY_SIZE);
auto & value = this->data(place).value;
size_t old_size = value.size();
value.resize(old_size + size);
buf.read(reinterpret_cast<char *>(&value[old_size]), size * sizeof(value[0]));
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
const auto & value = this->data(place).value;
size_t size = value.size();
ColumnArray & arr_to = static_cast<ColumnArray &>(to);
ColumnArray::Offsets_t & offsets_to = arr_to.getOffsets();
offsets_to.push_back((offsets_to.size() == 0 ? 0 : offsets_to.back()) + size);
typename ColumnVector<T>::Container_t & data_to = static_cast<ColumnVector<T> &>(arr_to.getData()).getData();
data_to.insert(this->data(place).value.begin(), this->data(place).value.end());
}
};
/// Складывает все значения в массив, общий случай. Реализовано неэффективно.
class AggregateFunctionGroupArrayGeneric final
: public IUnaryAggregateFunction<AggregateFunctionGroupArrayDataGeneric, AggregateFunctionGroupArrayGeneric>
{ {
private: private:
DataTypePtr type; DataTypePtr type;
public: public:
String getName() const { return "groupArray"; } String getName() const override { return "groupArray"; }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return new DataTypeArray(type); return new DataTypeArray(type);
} }
@ -39,18 +123,18 @@ public:
} }
void addOne(AggregateDataPtr place, const IColumn & column, size_t row_num) const void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num) const
{ {
data(place).value.push_back(Array::value_type()); data(place).value.push_back(Array::value_type());
column.get(row_num, data(place).value.back()); column.get(row_num, data(place).value.back());
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
data(place).value.insert(data(place).value.end(), data(rhs).value.begin(), data(rhs).value.end()); data(place).value.insert(data(place).value.end(), data(rhs).value.begin(), data(rhs).value.end());
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
const Array & value = data(place).value; const Array & value = data(place).value;
size_t size = value.size(); size_t size = value.size();
@ -59,7 +143,7 @@ public:
type->serializeBinary(value[i], buf); type->serializeBinary(value[i], buf);
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
size_t size = 0; size_t size = 0;
readVarUInt(size, buf); readVarUInt(size, buf);
@ -75,7 +159,7 @@ public:
type->deserializeBinary(value[old_size + i], buf); type->deserializeBinary(value[old_size + i], buf);
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
to.insert(data(place).value); to.insert(data(place).value);
} }

View File

@ -42,9 +42,9 @@ private:
typedef AggregateFunctionGroupUniqArrayData<T> State; typedef AggregateFunctionGroupUniqArrayData<T> State;
public: public:
String getName() const { return "groupUniqArray"; } String getName() const override { return "groupUniqArray"; }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return new DataTypeArray(new typename DataTypeFromFieldType<T>::Type); return new DataTypeArray(new typename DataTypeFromFieldType<T>::Type);
} }
@ -54,17 +54,17 @@ public:
} }
void addOne(AggregateDataPtr place, const IColumn & column, size_t row_num) const void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num) const
{ {
this->data(place).value.insert(static_cast<const ColumnVector<T> &>(column).getData()[row_num]); this->data(place).value.insert(static_cast<const ColumnVector<T> &>(column).getData()[row_num]);
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
this->data(place).value.merge(this->data(rhs).value); this->data(place).value.merge(this->data(rhs).value);
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
const typename State::Set & set = this->data(place).value; const typename State::Set & set = this->data(place).value;
size_t size = set.size(); size_t size = set.size();
@ -73,12 +73,12 @@ public:
writeIntBinary(*it, buf); writeIntBinary(*it, buf);
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
this->data(place).value.readAndMerge(buf); this->data(place).value.readAndMerge(buf);
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
ColumnArray & arr_to = static_cast<ColumnArray &>(to); ColumnArray & arr_to = static_cast<ColumnArray &>(to);
ColumnArray::Offsets_t & offsets_to = arr_to.getOffsets(); ColumnArray::Offsets_t & offsets_to = arr_to.getOffsets();
@ -106,7 +106,7 @@ template <typename T>
class AggregateFunctionGroupUniqArrays final : public AggregateFunctionGroupUniqArray<T> class AggregateFunctionGroupUniqArrays final : public AggregateFunctionGroupUniqArray<T>
{ {
public: public:
void addOne(AggregateDataPtr place, const IColumn & column, size_t row_num) const void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num) const
{ {
const ColumnArray & arr = static_cast<const ColumnArray &>(column); const ColumnArray & arr = static_cast<const ColumnArray &>(column);
const ColumnArray::Offsets_t & offsets = arr.getOffsets(); const ColumnArray::Offsets_t & offsets = arr.getOffsets();

View File

@ -23,17 +23,17 @@ private:
public: public:
AggregateFunctionIf(AggregateFunctionPtr nested_) : nested_func_owner(nested_), nested_func(nested_func_owner.get()) {} AggregateFunctionIf(AggregateFunctionPtr nested_) : nested_func_owner(nested_), nested_func(nested_func_owner.get()) {}
String getName() const String getName() const override
{ {
return nested_func->getName() + "If"; return nested_func->getName() + "If";
} }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return nested_func->getReturnType(); return nested_func->getReturnType();
} }
void setArguments(const DataTypes & arguments) void setArguments(const DataTypes & arguments) override
{ {
num_agruments = arguments.size(); num_agruments = arguments.size();
@ -47,61 +47,68 @@ public:
nested_func->setArguments(nested_arguments); nested_func->setArguments(nested_arguments);
} }
void setParameters(const Array & params) void setParameters(const Array & params) override
{ {
nested_func->setParameters(params); nested_func->setParameters(params);
} }
void create(AggregateDataPtr place) const void create(AggregateDataPtr place) const override
{ {
nested_func->create(place); nested_func->create(place);
} }
void destroy(AggregateDataPtr place) const noexcept void destroy(AggregateDataPtr place) const noexcept override
{ {
nested_func->destroy(place); nested_func->destroy(place);
} }
bool hasTrivialDestructor() const bool hasTrivialDestructor() const override
{ {
return nested_func->hasTrivialDestructor(); return nested_func->hasTrivialDestructor();
} }
size_t sizeOfData() const size_t sizeOfData() const override
{ {
return nested_func->sizeOfData(); return nested_func->sizeOfData();
} }
size_t alignOfData() const size_t alignOfData() const override
{ {
return nested_func->alignOfData(); return nested_func->alignOfData();
} }
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num) const void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num) const override
{ {
if (static_cast<const ColumnUInt8 &>(*columns[num_agruments - 1]).getData()[row_num]) if (static_cast<const ColumnUInt8 &>(*columns[num_agruments - 1]).getData()[row_num])
nested_func->add(place, columns, row_num); nested_func->add(place, columns, row_num);
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
nested_func->merge(place, rhs); nested_func->merge(place, rhs);
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
nested_func->serialize(place, buf); nested_func->serialize(place, buf);
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
nested_func->deserializeMerge(place, buf); nested_func->deserializeMerge(place, buf);
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
nested_func->insertResultInto(place, to); nested_func->insertResultInto(place, to);
} }
static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num)
{
return static_cast<const AggregateFunctionIf &>(*that).add(place, columns, row_num);
}
IAggregateFunction::AddFunc getAddressOfAddFunction() const override final { return &addFree; }
}; };
} }

View File

@ -24,17 +24,17 @@ private:
public: public:
AggregateFunctionMerge(AggregateFunctionPtr nested_) : nested_func_owner(nested_), nested_func(nested_func_owner.get()) {} AggregateFunctionMerge(AggregateFunctionPtr nested_) : nested_func_owner(nested_), nested_func(nested_func_owner.get()) {}
String getName() const String getName() const override
{ {
return nested_func->getName() + "Merge"; return nested_func->getName() + "Merge";
} }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return nested_func->getReturnType(); return nested_func->getReturnType();
} }
void setArguments(const DataTypes & arguments) void setArguments(const DataTypes & arguments) override
{ {
if (arguments.size() != 1) if (arguments.size() != 1)
throw Exception("Passed " + toString(arguments.size()) + " arguments to unary aggregate function " + this->getName(), throw Exception("Passed " + toString(arguments.size()) + " arguments to unary aggregate function " + this->getName(),
@ -49,60 +49,67 @@ public:
nested_func->setArguments(data_type->getArgumentsDataTypes()); nested_func->setArguments(data_type->getArgumentsDataTypes());
} }
void setParameters(const Array & params) void setParameters(const Array & params) override
{ {
nested_func->setParameters(params); nested_func->setParameters(params);
} }
void create(AggregateDataPtr place) const void create(AggregateDataPtr place) const override
{ {
nested_func->create(place); nested_func->create(place);
} }
void destroy(AggregateDataPtr place) const noexcept void destroy(AggregateDataPtr place) const noexcept override
{ {
nested_func->destroy(place); nested_func->destroy(place);
} }
bool hasTrivialDestructor() const bool hasTrivialDestructor() const override
{ {
return nested_func->hasTrivialDestructor(); return nested_func->hasTrivialDestructor();
} }
size_t sizeOfData() const size_t sizeOfData() const override
{ {
return nested_func->sizeOfData(); return nested_func->sizeOfData();
} }
size_t alignOfData() const size_t alignOfData() const override
{ {
return nested_func->alignOfData(); return nested_func->alignOfData();
} }
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num) const void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num) const override
{ {
nested_func->merge(place, static_cast<const ColumnAggregateFunction &>(*columns[0]).getData()[row_num]); nested_func->merge(place, static_cast<const ColumnAggregateFunction &>(*columns[0]).getData()[row_num]);
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
nested_func->merge(place, rhs); nested_func->merge(place, rhs);
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
nested_func->serialize(place, buf); nested_func->serialize(place, buf);
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
nested_func->deserializeMerge(place, buf); nested_func->deserializeMerge(place, buf);
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
nested_func->insertResultInto(place, to); nested_func->insertResultInto(place, to);
} }
static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num)
{
return static_cast<const AggregateFunctionMerge &>(*that).add(place, columns, row_num);
}
IAggregateFunction::AddFunc getAddressOfAddFunction() const override final { return &addFree; }
}; };
} }

View File

@ -1,6 +1,6 @@
#pragma once #pragma once
#include <stats/ReservoirSampler.h> #include <DB/AggregateFunctions/ReservoirSampler.h>
#include <DB/Core/FieldVisitors.h> #include <DB/Core/FieldVisitors.h>
@ -32,7 +32,8 @@ struct AggregateFunctionQuantileData
* Для дат и дат-с-временем returns_float следует задавать равным false. * Для дат и дат-с-временем returns_float следует задавать равным false.
*/ */
template <typename ArgumentFieldType, bool returns_float = true> template <typename ArgumentFieldType, bool returns_float = true>
class AggregateFunctionQuantile final : public IUnaryAggregateFunction<AggregateFunctionQuantileData<ArgumentFieldType>, AggregateFunctionQuantile<ArgumentFieldType, returns_float> > class AggregateFunctionQuantile final
: public IUnaryAggregateFunction<AggregateFunctionQuantileData<ArgumentFieldType>, AggregateFunctionQuantile<ArgumentFieldType, returns_float> >
{ {
private: private:
using Sample = typename AggregateFunctionQuantileData<ArgumentFieldType>::Sample; using Sample = typename AggregateFunctionQuantileData<ArgumentFieldType>::Sample;
@ -43,9 +44,9 @@ private:
public: public:
AggregateFunctionQuantile(double level_ = 0.5) : level(level_) {} AggregateFunctionQuantile(double level_ = 0.5) : level(level_) {}
String getName() const { return "quantile"; } String getName() const override { return "quantile"; }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return type; return type;
} }
@ -58,7 +59,7 @@ public:
type = argument; type = argument;
} }
void setParameters(const Array & params) void setParameters(const Array & params) override
{ {
if (params.size() != 1) if (params.size() != 1)
throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
@ -67,29 +68,29 @@ public:
} }
void addOne(AggregateDataPtr place, const IColumn & column, size_t row_num) const void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num) const
{ {
this->data(place).sample.insert(static_cast<const ColumnVector<ArgumentFieldType> &>(column).getData()[row_num]); this->data(place).sample.insert(static_cast<const ColumnVector<ArgumentFieldType> &>(column).getData()[row_num]);
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
this->data(place).sample.merge(this->data(rhs).sample); this->data(place).sample.merge(this->data(rhs).sample);
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
this->data(place).sample.write(buf); this->data(place).sample.write(buf);
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
Sample tmp_sample; Sample tmp_sample;
tmp_sample.read(buf); tmp_sample.read(buf);
this->data(place).sample.merge(tmp_sample); this->data(place).sample.merge(tmp_sample);
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
/// Sample может отсортироваться при получении квантиля, но в этом контексте можно не считать это нарушением константности. /// Sample может отсортироваться при получении квантиля, но в этом контексте можно не считать это нарушением константности.
Sample & sample = const_cast<Sample &>(this->data(place).sample); Sample & sample = const_cast<Sample &>(this->data(place).sample);
@ -107,19 +108,20 @@ public:
* Возвращает массив результатов. * Возвращает массив результатов.
*/ */
template <typename ArgumentFieldType, bool returns_float = true> template <typename ArgumentFieldType, bool returns_float = true>
class AggregateFunctionQuantiles final : public IUnaryAggregateFunction<AggregateFunctionQuantileData<ArgumentFieldType>, AggregateFunctionQuantiles<ArgumentFieldType, returns_float> > class AggregateFunctionQuantiles final
: public IUnaryAggregateFunction<AggregateFunctionQuantileData<ArgumentFieldType>, AggregateFunctionQuantiles<ArgumentFieldType, returns_float> >
{ {
private: private:
using Sample = typename AggregateFunctionQuantileData<ArgumentFieldType>::Sample; using Sample = typename AggregateFunctionQuantileData<ArgumentFieldType>::Sample;
typedef std::vector<double> Levels; using Levels = std::vector<double>;
Levels levels; Levels levels;
DataTypePtr type; DataTypePtr type;
public: public:
String getName() const { return "quantiles"; } String getName() const override { return "quantiles"; }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return new DataTypeArray(type); return new DataTypeArray(type);
} }
@ -132,7 +134,7 @@ public:
type = argument; type = argument;
} }
void setParameters(const Array & params) void setParameters(const Array & params) override
{ {
if (params.empty()) if (params.empty())
throw Exception("Aggregate function " + getName() + " requires at least one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); throw Exception("Aggregate function " + getName() + " requires at least one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
@ -145,29 +147,29 @@ public:
} }
void addOne(AggregateDataPtr place, const IColumn & column, size_t row_num) const void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num) const
{ {
this->data(place).sample.insert(static_cast<const ColumnVector<ArgumentFieldType> &>(column).getData()[row_num]); this->data(place).sample.insert(static_cast<const ColumnVector<ArgumentFieldType> &>(column).getData()[row_num]);
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
this->data(place).sample.merge(this->data(rhs).sample); this->data(place).sample.merge(this->data(rhs).sample);
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
this->data(place).sample.write(buf); this->data(place).sample.write(buf);
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
Sample tmp_sample; Sample tmp_sample;
tmp_sample.read(buf); tmp_sample.read(buf);
this->data(place).sample.merge(tmp_sample); this->data(place).sample.merge(tmp_sample);
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
/// Sample может отсортироваться при получении квантиля, но в этом контексте можно не считать это нарушением константности. /// Sample может отсортироваться при получении квантиля, но в этом контексте можно не считать это нарушением константности.
Sample & sample = const_cast<Sample &>(this->data(place).sample); Sample & sample = const_cast<Sample &>(this->data(place).sample);

View File

@ -46,9 +46,9 @@ private:
public: public:
AggregateFunctionQuantileDeterministic(double level_ = 0.5) : level(level_) {} AggregateFunctionQuantileDeterministic(double level_ = 0.5) : level(level_) {}
String getName() const { return "quantileDeterministic"; } String getName() const override { return "quantileDeterministic"; }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return type; return type;
} }
@ -65,7 +65,7 @@ public:
}; };
} }
void setParameters(const Array & params) void setParameters(const Array & params) override
{ {
if (params.size() != 1) if (params.size() != 1)
throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
@ -74,30 +74,30 @@ public:
} }
void addOne(AggregateDataPtr place, const IColumn & column, const IColumn & determinator, size_t row_num) const void addImpl(AggregateDataPtr place, const IColumn & column, const IColumn & determinator, size_t row_num) const
{ {
this->data(place).sample.insert(static_cast<const ColumnVector<ArgumentFieldType> &>(column).getData()[row_num], this->data(place).sample.insert(static_cast<const ColumnVector<ArgumentFieldType> &>(column).getData()[row_num],
determinator.get64(row_num)); determinator.get64(row_num));
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
this->data(place).sample.merge(this->data(rhs).sample); this->data(place).sample.merge(this->data(rhs).sample);
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
this->data(place).sample.write(buf); this->data(place).sample.write(buf);
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
Sample tmp_sample; Sample tmp_sample;
tmp_sample.read(buf); tmp_sample.read(buf);
this->data(place).sample.merge(tmp_sample); this->data(place).sample.merge(tmp_sample);
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
/// Sample может отсортироваться при получении квантиля, но в этом контексте можно не считать это нарушением константности. /// Sample может отсортироваться при получении квантиля, но в этом контексте можно не считать это нарушением константности.
Sample & sample = const_cast<Sample &>(this->data(place).sample); Sample & sample = const_cast<Sample &>(this->data(place).sample);
@ -123,14 +123,14 @@ class AggregateFunctionQuantilesDeterministic final
private: private:
using Sample = typename AggregateFunctionQuantileDeterministicData<ArgumentFieldType>::Sample; using Sample = typename AggregateFunctionQuantileDeterministicData<ArgumentFieldType>::Sample;
typedef std::vector<double> Levels; using Levels = std::vector<double>;
Levels levels; Levels levels;
DataTypePtr type; DataTypePtr type;
public: public:
String getName() const { return "quantilesDeterministic"; } String getName() const override { return "quantilesDeterministic"; }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return new DataTypeArray(type); return new DataTypeArray(type);
} }
@ -147,7 +147,7 @@ public:
}; };
} }
void setParameters(const Array & params) void setParameters(const Array & params) override
{ {
if (params.empty()) if (params.empty())
throw Exception("Aggregate function " + getName() + " requires at least one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); throw Exception("Aggregate function " + getName() + " requires at least one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
@ -160,30 +160,30 @@ public:
} }
void addOne(AggregateDataPtr place, const IColumn & column, const IColumn & determinator, size_t row_num) const void addImpl(AggregateDataPtr place, const IColumn & column, const IColumn & determinator, size_t row_num) const
{ {
this->data(place).sample.insert(static_cast<const ColumnVector<ArgumentFieldType> &>(column).getData()[row_num], this->data(place).sample.insert(static_cast<const ColumnVector<ArgumentFieldType> &>(column).getData()[row_num],
determinator.get64(row_num)); determinator.get64(row_num));
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
this->data(place).sample.merge(this->data(rhs).sample); this->data(place).sample.merge(this->data(rhs).sample);
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
this->data(place).sample.write(buf); this->data(place).sample.write(buf);
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
Sample tmp_sample; Sample tmp_sample;
tmp_sample.read(buf); tmp_sample.read(buf);
this->data(place).sample.merge(tmp_sample); this->data(place).sample.merge(tmp_sample);
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
/// Sample может отсортироваться при получении квантиля, но в этом контексте можно не считать это нарушением константности. /// Sample может отсортироваться при получении квантиля, но в этом контексте можно не считать это нарушением константности.
Sample & sample = const_cast<Sample &>(this->data(place).sample); Sample & sample = const_cast<Sample &>(this->data(place).sample);

View File

@ -0,0 +1,224 @@
#pragma once
#include <DB/Common/PODArray.h>
#include <DB/IO/WriteHelpers.h>
#include <DB/IO/ReadHelpers.h>
#include <DB/DataTypes/DataTypesNumberFixed.h>
#include <DB/DataTypes/DataTypeArray.h>
#include <DB/AggregateFunctions/IUnaryAggregateFunction.h>
#include <DB/AggregateFunctions/QuantilesCommon.h>
#include <DB/Columns/ColumnArray.h>
namespace DB
{
/** В качестве состояния используется массив, в который складываются все значения.
* NOTE Если различных значений мало, то это не оптимально.
* Для 8 и 16-битных значений возможно, было бы лучше использовать lookup-таблицу.
*/
template <typename T>
struct AggregateFunctionQuantileExactData
{
/// Сразу будет выделена память на несколько элементов так, чтобы состояние занимало 64 байта.
static constexpr size_t bytes_in_arena = 64 - sizeof(PODArray<T>);
using Array = PODArray<T, bytes_in_arena / sizeof(T), AllocatorWithStackMemory<Allocator<false>, bytes_in_arena>>;
Array array;
};
/** Точно вычисляет квантиль.
* В качестве типа аргумента может быть только числовой тип (в том числе, дата и дата-с-временем).
* Тип результата совпадает с типом аргумента.
*/
template <typename T>
class AggregateFunctionQuantileExact final
: public IUnaryAggregateFunction<AggregateFunctionQuantileExactData<T>, AggregateFunctionQuantileExact<T>>
{
private:
double level;
DataTypePtr type;
public:
AggregateFunctionQuantileExact(double level_ = 0.5) : level(level_) {}
String getName() const override { return "quantileExact"; }
DataTypePtr getReturnType() const override
{
return type;
}
void setArgument(const DataTypePtr & argument)
{
type = argument;
}
void setParameters(const Array & params) override
{
if (params.size() != 1)
throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
level = apply_visitor(FieldVisitorConvertToNumber<Float64>(), params[0]);
}
void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num) const
{
this->data(place).array.push_back(static_cast<const ColumnVector<T> &>(column).getData()[row_num]);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{
this->data(place).array.insert(this->data(rhs).array.begin(), this->data(rhs).array.end());
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
const auto & array = this->data(place).array;
size_t size = array.size();
writeVarUInt(size, buf);
buf.write(reinterpret_cast<const char *>(&array[0]), size * sizeof(array[0]));
}
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{
auto & array = this->data(place).array;
size_t size = 0;
readVarUInt(size, buf);
size_t old_size = array.size();
array.resize(old_size + size);
buf.read(reinterpret_cast<char *>(&array[old_size]), size * sizeof(array[0]));
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
/// Сортировка массива не будет считаться нарушением константности.
auto & array = const_cast<typename AggregateFunctionQuantileExactData<T>::Array &>(this->data(place).array);
T quantile = T();
if (!array.empty())
{
size_t n = level < 1
? level * array.size()
: (array.size() - 1);
std::nth_element(array.begin(), array.begin() + n, array.end()); /// NOTE Можно придумать алгоритм radix-select.
quantile = array[n];
}
static_cast<ColumnVector<T> &>(to).getData().push_back(quantile);
}
};
/** То же самое, но позволяет вычислить сразу несколько квантилей.
* Для этого, принимает в качестве параметров несколько уровней. Пример: quantilesExact(0.5, 0.8, 0.9, 0.95)(ConnectTiming).
* Возвращает массив результатов.
*/
template <typename T>
class AggregateFunctionQuantilesExact final
: public IUnaryAggregateFunction<AggregateFunctionQuantileExactData<T>, AggregateFunctionQuantilesExact<T>>
{
private:
QuantileLevels<double> levels;
DataTypePtr type;
public:
String getName() const override { return "quantilesExact"; }
DataTypePtr getReturnType() const override
{
return new DataTypeArray(type);
}
void setArgument(const DataTypePtr & argument)
{
type = argument;
}
void setParameters(const Array & params) override
{
levels.set(params);
}
void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num) const
{
this->data(place).array.push_back(static_cast<const ColumnVector<T> &>(column).getData()[row_num]);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{
this->data(place).array.insert(this->data(rhs).array.begin(), this->data(rhs).array.end());
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
const auto & array = this->data(place).array;
size_t size = array.size();
writeVarUInt(size, buf);
buf.write(reinterpret_cast<const char *>(&array[0]), size * sizeof(array[0]));
}
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{
auto & array = this->data(place).array;
size_t size = 0;
readVarUInt(size, buf);
size_t old_size = array.size();
array.resize(old_size + size);
buf.read(reinterpret_cast<char *>(&array[old_size]), size * sizeof(array[0]));
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
/// Сортировка массива не будет считаться нарушением константности.
auto & array = const_cast<typename AggregateFunctionQuantileExactData<T>::Array &>(this->data(place).array);
ColumnArray & arr_to = static_cast<ColumnArray &>(to);
ColumnArray::Offsets_t & offsets_to = arr_to.getOffsets();
size_t num_levels = levels.size();
offsets_to.push_back((offsets_to.size() == 0 ? 0 : offsets_to.back()) + num_levels);
typename ColumnVector<T>::Container_t & data_to = static_cast<ColumnVector<T> &>(arr_to.getData()).getData();
size_t old_size = data_to.size();
data_to.resize(old_size + num_levels);
if (!array.empty())
{
size_t prev_n = 0;
for (auto level_index : levels.permutation)
{
auto level = levels.levels[level_index];
size_t n = level < 1
? level * array.size()
: (array.size() - 1);
std::nth_element(array.begin() + prev_n, array.begin() + n, array.end());
data_to[old_size + level_index] = array[n];
prev_n = n;
}
}
else
{
for (size_t i = 0; i < num_levels; ++i)
data_to[old_size + i] = T();
}
}
};
}

View File

@ -0,0 +1,280 @@
#pragma once
#include <DB/Common/HashTable/HashMap.h>
#include <DB/DataTypes/DataTypesNumberFixed.h>
#include <DB/DataTypes/DataTypeArray.h>
#include <DB/AggregateFunctions/IBinaryAggregateFunction.h>
#include <DB/AggregateFunctions/QuantilesCommon.h>
#include <DB/Columns/ColumnArray.h>
namespace DB
{
/** В качестве состояния используется хэш-таблица вида: значение -> сколько раз встретилось.
*/
template <typename T>
struct AggregateFunctionQuantileExactWeightedData
{
using Key = T;
using Weight = UInt64;
/// При создании, хэш-таблица должна быть небольшой.
using Map = HashMap<
Key, Weight,
HashCRC32<Key>,
HashTableGrower<4>,
HashTableAllocatorWithStackMemory<sizeof(std::pair<Key, Weight>) * (1 << 3)>
>;
Map map;
};
/** Точно вычисляет квантиль по множеству значений, для каждого из которых задан вес - сколько раз значение встречалось.
* Можно рассматривать набор пар value, weight - как набор гистограмм,
* в которых value - значение, округлённое до середины столбика, а weight - высота столбика.
* В качестве типа аргумента может быть только числовой тип (в том числе, дата и дата-с-временем).
* Тип результата совпадает с типом аргумента.
*/
template <typename ValueType, typename WeightType>
class AggregateFunctionQuantileExactWeighted final
: public IBinaryAggregateFunction<
AggregateFunctionQuantileExactWeightedData<ValueType>,
AggregateFunctionQuantileExactWeighted<ValueType, WeightType>>
{
private:
double level;
DataTypePtr type;
public:
AggregateFunctionQuantileExactWeighted(double level_ = 0.5) : level(level_) {}
String getName() const override { return "quantileExactWeighted"; }
DataTypePtr getReturnType() const override
{
return type;
}
void setArgumentsImpl(const DataTypes & arguments)
{
type = arguments[0];
}
void setParameters(const Array & params) override
{
if (params.size() != 1)
throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
level = apply_visitor(FieldVisitorConvertToNumber<Float64>(), params[0]);
}
void addImpl(AggregateDataPtr place, const IColumn & column_value, const IColumn & column_weight, size_t row_num) const
{
this->data(place)
.map[static_cast<const ColumnVector<ValueType> &>(column_value).getData()[row_num]]
+= static_cast<const ColumnVector<WeightType> &>(column_weight).getData()[row_num];
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{
auto & map = this->data(place).map;
const auto & rhs_map = this->data(rhs).map;
for (const auto & pair : rhs_map)
map[pair.first] += pair.second;
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
this->data(place).map.write(buf);
}
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{
typename AggregateFunctionQuantileExactWeightedData<ValueType>::Map::Reader reader(buf);
auto & map = this->data(place).map;
while (reader.next())
{
const auto & pair = reader.get();
map[pair.first] += pair.second;
}
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
auto & map = this->data(place).map;
size_t size = map.size();
if (0 == size)
{
static_cast<ColumnVector<ValueType> &>(to).getData().push_back(ValueType());
return;
}
/// Копируем данные во временный массив, чтобы получить нужный по порядку элемент.
using Pair = typename AggregateFunctionQuantileExactWeightedData<ValueType>::Map::value_type;
std::unique_ptr<Pair[]> array_holder(new Pair[size]);
Pair * array = array_holder.get();
size_t i = 0;
UInt64 sum_weight = 0;
for (const auto & pair : map)
{
sum_weight += pair.second;
array[i] = pair;
++i;
}
std::sort(array, array + size, [](const Pair & a, const Pair & b) { return a.first < b.first; });
UInt64 threshold = sum_weight * level;
UInt64 accumulated = 0;
const Pair * it = array;
const Pair * end = array + size;
while (it < end && accumulated < threshold)
{
accumulated += it->second;
++it;
}
if (it == end)
--it;
static_cast<ColumnVector<ValueType> &>(to).getData().push_back(it->first);
}
};
/** То же самое, но позволяет вычислить сразу несколько квантилей.
* Для этого, принимает в качестве параметров несколько уровней. Пример: quantilesExactWeighted(0.5, 0.8, 0.9, 0.95)(ConnectTiming, Weight).
* Возвращает массив результатов.
*/
template <typename ValueType, typename WeightType>
class AggregateFunctionQuantilesExactWeighted final
: public IBinaryAggregateFunction<
AggregateFunctionQuantileExactWeightedData<ValueType>,
AggregateFunctionQuantilesExactWeighted<ValueType, WeightType>>
{
private:
QuantileLevels<double> levels;
DataTypePtr type;
public:
String getName() const override { return "quantilesExactWeighted"; }
DataTypePtr getReturnType() const override
{
return new DataTypeArray(type);
}
void setArgumentsImpl(const DataTypes & arguments)
{
type = arguments[0];
}
void setParameters(const Array & params) override
{
levels.set(params);
}
void addImpl(AggregateDataPtr place, const IColumn & column_value, const IColumn & column_weight, size_t row_num) const
{
this->data(place)
.map[static_cast<const ColumnVector<ValueType> &>(column_value).getData()[row_num]]
+= static_cast<const ColumnVector<WeightType> &>(column_weight).getData()[row_num];
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{
auto & map = this->data(place).map;
const auto & rhs_map = this->data(rhs).map;
for (const auto & pair : rhs_map)
map[pair.first] += pair.second;
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
this->data(place).map.write(buf);
}
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{
typename AggregateFunctionQuantileExactWeightedData<ValueType>::Map::Reader reader(buf);
auto & map = this->data(place).map;
while (reader.next())
{
const auto & pair = reader.get();
map[pair.first] += pair.second;
}
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
auto & map = this->data(place).map;
size_t size = map.size();
ColumnArray & arr_to = static_cast<ColumnArray &>(to);
ColumnArray::Offsets_t & offsets_to = arr_to.getOffsets();
size_t num_levels = levels.size();
offsets_to.push_back((offsets_to.size() == 0 ? 0 : offsets_to.back()) + num_levels);
typename ColumnVector<ValueType>::Container_t & data_to = static_cast<ColumnVector<ValueType> &>(arr_to.getData()).getData();
size_t old_size = data_to.size();
data_to.resize(old_size + num_levels);
if (0 == size)
{
for (size_t i = 0; i < num_levels; ++i)
data_to[old_size + i] = ValueType();
return;
}
/// Копируем данные во временный массив, чтобы получить нужный по порядку элемент.
using Pair = typename AggregateFunctionQuantileExactWeightedData<ValueType>::Map::value_type;
std::unique_ptr<Pair[]> array_holder(new Pair[size]);
Pair * array = array_holder.get();
size_t i = 0;
UInt64 sum_weight = 0;
for (const auto & pair : map)
{
sum_weight += pair.second;
array[i] = pair;
++i;
}
std::sort(array, array + size, [](const Pair & a, const Pair & b) { return a.first < b.first; });
UInt64 accumulated = 0;
const Pair * it = array;
const Pair * end = array + size;
for (auto level_index : levels.permutation)
{
UInt64 threshold = sum_weight * levels.levels[level_index];
while (it < end && accumulated < threshold)
{
accumulated += it->second;
++it;
}
data_to[old_size + level_index] = it < end ? it->first : it[-1].first;
}
}
};
}

View File

@ -0,0 +1,643 @@
#pragma once
#include <cmath>
#include <cstdint>
#include <cassert>
#include <vector>
#include <algorithm>
#include <DB/Common/RadixSort.h>
#include <DB/Common/PODArray.h>
#include <DB/Columns/ColumnArray.h>
#include <DB/AggregateFunctions/IUnaryAggregateFunction.h>
#include <DB/AggregateFunctions/IBinaryAggregateFunction.h>
#include <DB/AggregateFunctions/QuantilesCommon.h>
#include <DB/DataTypes/DataTypesNumberFixed.h>
#include <DB/DataTypes/DataTypeArray.h>
/** Алгоритм реализовал Алексей Борзенков https://███████████.yandex-team.ru/snaury
* Ему принадлежит авторство кода и половины комментариев в данном namespace,
* за исключением слияния, сериализации и сортировки, а также выбора типов и других изменений.
* Мы благодарим Алексея Борзенкова за написание изначального кода.
*/
namespace tdigest
{
/**
* Центроид хранит вес точек вокруг их среднего значения
*/
template <typename Value, typename Count>
struct Centroid
{
Value mean;
Count count;
Centroid() = default;
explicit Centroid(Value mean, Count count = 1)
: mean(mean)
, count(count)
{}
Centroid & operator+=(const Centroid & other)
{
count += other.count;
mean += other.count * (other.mean - mean) / count;
return *this;
}
bool operator<(const Centroid & other) const
{
return mean < other.mean;
}
};
/** :param epsilon: значение \delta из статьи - погрешность в районе
* квантиля 0.5 (по-умолчанию 0.01, т.е. 1%)
* :param max_unmerged: при накоплении кол-ва новых точек сверх этого
* значения запускается компрессия центроидов
* (по-умолчанию 2048, чем выше значение - тем
* больше требуется памяти, но повышается
* амортизация времени выполнения)
*/
template <typename Value>
struct Params
{
Value epsilon = 0.01;
size_t max_unmerged = 2048;
};
/** Реализация алгоритма t-digest (https://github.com/tdunning/t-digest).
* Этот вариант очень похож на MergingDigest на java, однако решение об
* объединении принимается на основе оригинального условия из статьи
* (через ограничение на размер, используя апроксимацию квантиля каждого
* центроида, а не расстояние на кривой положения их границ). MergingDigest
* на java даёт значительно меньше центроидов, чем данный вариант, что
* негативно влияет на точность при том же факторе компрессии, но даёт
* гарантии размера. Сам автор на предложение об этом варианте сказал, что
* размер дайжеста растёт как O(log(n)), в то время как вариант на java
* не зависит от предполагаемого кол-ва точек. Кроме того вариант на java
* использует asin, чем немного замедляет алгоритм.
*/
template <typename Value, typename CentroidCount, typename TotalCount>
class MergingDigest
{
using Params = tdigest::Params<Value>;
using Centroid = tdigest::Centroid<Value, CentroidCount>;
/// Сразу будет выделена память на несколько элементов так, чтобы состояние занимало 64 байта.
static constexpr size_t bytes_in_arena = 64 - sizeof(DB::PODArray<Centroid>) - sizeof(TotalCount) - sizeof(uint32_t);
using Summary = DB::PODArray<Centroid, bytes_in_arena / sizeof(Centroid), AllocatorWithStackMemory<Allocator<false>, bytes_in_arena>>;
Summary summary;
TotalCount count = 0;
uint32_t unmerged = 0;
/** Линейная интерполяция в точке x на прямой (x1, y1)..(x2, y2)
*/
static Value interpolate(Value x, Value x1, Value y1, Value x2, Value y2)
{
double k = (x - x1) / (x2 - x1);
return y1 + k * (y2 - y1);
}
struct RadixSortTraits
{
using Element = Centroid;
using Key = Value;
using CountType = uint32_t;
using KeyBits = uint32_t;
static constexpr size_t PART_SIZE_BITS = 8;
using Transform = RadixSortFloatTransform<KeyBits>;
using Allocator = RadixSortMallocAllocator;
/// Функция получения ключа из элемента массива.
static Key & extractKey(Element & elem) { return elem.mean; }
};
public:
/** Добавляет к дайджесту изменение x с весом cnt (по-умолчанию 1)
*/
void add(const Params & params, Value x, CentroidCount cnt = 1)
{
add(params, Centroid(x, cnt));
}
/** Добавляет к дайджесту центроид c
*/
void add(const Params & params, const Centroid & c)
{
summary.push_back(c);
count += c.count;
++unmerged;
if (unmerged >= params.max_unmerged)
compress(params);
}
/** Выполняет компрессию накопленных центроидов
* При объединении сохраняется инвариант на максимальный размер каждого
* центроида, не превышающий 4 q (1 - q) \delta N.
*/
void compress(const Params & params)
{
if (unmerged > 0)
{
RadixSort<RadixSortTraits>::execute(&summary[0], summary.size());
if (summary.size() > 3)
{
/// Пара подряд идущих столбиков гистограммы.
auto l = summary.begin();
auto r = std::next(l);
TotalCount sum = 0;
while (r != summary.end())
{
// we use quantile which gives us the smallest error
/// Отношение части гистограммы до l, включая половинку l ко всей гистограмме. То есть, какого уровня квантиль в позиции l.
Value ql = (sum + l->count * 0.5) / count;
Value err = ql * (1 - ql);
/// Отношение части гистограммы до l, включая l и половинку r ко всей гистограмме. То есть, какого уровня квантиль в позиции r.
Value qr = (sum + l->count + r->count * 0.5) / count;
Value err2 = qr * (1 - qr);
if (err > err2)
err = err2;
Value k = 4 * count * err * params.epsilon;
/** Отношение веса склеенной пары столбиков ко всем значениям не больше,
* чем epsilon умножить на некий квадратичный коэффициент, который в медиане равен 1 (4 * 1/2 * 1/2),
* а по краям убывает и примерно равен расстоянию до края * 4.
*/
if (l->count + r->count <= k)
{
// it is possible to merge left and right
/// Левый столбик "съедает" правый.
*l += *r;
}
else
{
// not enough capacity, check the next pair
sum += l->count;
++l;
/// Пропускаем все "съеденные" ранее значения.
if (l != r)
*l = *r;
}
++r;
}
/// По окончании цикла, все значения правее l были "съедены".
summary.resize(l - summary.begin() + 1);
}
unmerged = 0;
}
}
/** Вычисляет квантиль q [0, 1] на основе дайджеста
* Для пустого дайджеста возвращает NaN.
*/
Value getQuantile(const Params & params, Value q)
{
if (summary.empty())
return NAN;
compress(params);
if (summary.size() == 1)
return summary.front().mean;
Value x = q * count;
TotalCount sum = 0;
Value prev_mean = summary.front().mean;
Value prev_x = 0;
for (const auto & c : summary)
{
Value current_x = sum + c.count * 0.5;
if (current_x >= x)
return interpolate(x, prev_x, prev_mean, current_x, c.mean);
sum += c.count;
prev_mean = c.mean;
prev_x = current_x;
}
return summary.back().mean;
}
/** Получить несколько квантилей (size штук).
* levels - массив уровней нужных квантилей. Они идут в произвольном порядке.
* levels_permutation - массив-перестановка уровней. На i-ой позиции будет лежать индекс i-го по возрастанию уровня в массиве levels.
* result - массив, куда сложить результаты, в порядке levels,
*/
template <typename ResultType>
void getManyQuantiles(const Params & params, const Value * levels, const size_t * levels_permutation, size_t size, ResultType * result)
{
if (summary.empty())
{
for (size_t result_num = 0; result_num < size; ++result_num)
result[result_num] = std::is_floating_point<ResultType>::value ? NAN : 0;
return;
}
compress(params);
if (summary.size() == 1)
{
for (size_t result_num = 0; result_num < size; ++result_num)
result[result_num] = summary.front().mean;
return;
}
Value x = levels[levels_permutation[0]] * count;
TotalCount sum = 0;
Value prev_mean = summary.front().mean;
Value prev_x = 0;
size_t result_num = 0;
for (const auto & c : summary)
{
Value current_x = sum + c.count * 0.5;
while (current_x >= x)
{
result[levels_permutation[result_num]] = interpolate(x, prev_x, prev_mean, current_x, c.mean);
++result_num;
if (result_num >= size)
return;
x = levels[levels_permutation[result_num]] * count;
}
sum += c.count;
prev_mean = c.mean;
prev_x = current_x;
}
auto rest_of_results = summary.back().mean;
for (; result_num < size; ++result_num)
result[levels_permutation[result_num]] = rest_of_results;
}
/** Объединить с другим состоянием.
*/
void merge(const Params & params, const MergingDigest & other)
{
for (const auto & c : other.summary)
add(params, c);
}
/** Записать в поток.
*/
void write(const Params & params, DB::WriteBuffer & buf)
{
compress(params);
DB::writeVarUInt(summary.size(), buf);
buf.write(reinterpret_cast<const char *>(&summary[0]), summary.size() * sizeof(summary[0]));
}
/** Прочитать из потока и объединить с текущим состоянием.
*/
void readAndMerge(const Params & params, DB::ReadBuffer & buf)
{
size_t size = 0;
DB::readVarUInt(size, buf);
if (size > params.max_unmerged)
throw DB::Exception("Too large t-digest summary size", DB::ErrorCodes::TOO_LARGE_ARRAY_SIZE);
for (size_t i = 0; i < size; ++i)
{
Centroid c;
DB::readPODBinary(c, buf);
add(params, c);
}
}
};
}
namespace DB
{
struct AggregateFunctionQuantileTDigestData
{
tdigest::MergingDigest<Float32, Float32, Float32> digest;
};
template <typename T, bool returns_float = true>
class AggregateFunctionQuantileTDigest final
: public IUnaryAggregateFunction<AggregateFunctionQuantileTDigestData, AggregateFunctionQuantileTDigest<T>>
{
private:
Float32 level;
tdigest::Params<Float32> params;
DataTypePtr type;
public:
AggregateFunctionQuantileTDigest(double level_ = 0.5) : level(level_) {}
String getName() const override { return "quantileTDigest"; }
DataTypePtr getReturnType() const override
{
return type;
}
void setArgument(const DataTypePtr & argument)
{
if (returns_float)
type = new DataTypeFloat32;
else
type = argument;
}
void setParameters(const Array & params) override
{
if (params.size() != 1)
throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
level = apply_visitor(FieldVisitorConvertToNumber<Float32>(), params[0]);
}
void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num) const
{
this->data(place).digest.add(params, static_cast<const ColumnVector<T> &>(column).getData()[row_num]);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{
this->data(place).digest.merge(params, this->data(rhs).digest);
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
this->data(const_cast<AggregateDataPtr>(place)).digest.write(params, buf);
}
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{
this->data(place).digest.readAndMerge(params, buf);
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
auto quantile = this->data(const_cast<AggregateDataPtr>(place)).digest.getQuantile(params, level);
if (returns_float)
static_cast<ColumnFloat32 &>(to).getData().push_back(quantile);
else
static_cast<ColumnVector<T> &>(to).getData().push_back(quantile);
}
};
template <typename T, typename Weight, bool returns_float = true>
class AggregateFunctionQuantileTDigestWeighted final
: public IBinaryAggregateFunction<AggregateFunctionQuantileTDigestData, AggregateFunctionQuantileTDigestWeighted<T, Weight, returns_float>>
{
private:
Float32 level;
tdigest::Params<Float32> params;
DataTypePtr type;
public:
AggregateFunctionQuantileTDigestWeighted(double level_ = 0.5) : level(level_) {}
String getName() const override { return "quantileTDigestWeighted"; }
DataTypePtr getReturnType() const override
{
return type;
}
void setArgumentsImpl(const DataTypes & arguments)
{
if (returns_float)
type = new DataTypeFloat32;
else
type = arguments.at(0);
}
void setParameters(const Array & params) override
{
if (params.size() != 1)
throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
level = apply_visitor(FieldVisitorConvertToNumber<Float32>(), params[0]);
}
void addImpl(AggregateDataPtr place, const IColumn & column_value, const IColumn & column_weight, size_t row_num) const
{
this->data(place).digest.add(params,
static_cast<const ColumnVector<T> &>(column_value).getData()[row_num],
static_cast<const ColumnVector<Weight> &>(column_weight).getData()[row_num]);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{
this->data(place).digest.merge(params, this->data(rhs).digest);
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
this->data(const_cast<AggregateDataPtr>(place)).digest.write(params, buf);
}
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{
this->data(place).digest.readAndMerge(params, buf);
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
auto quantile = this->data(const_cast<AggregateDataPtr>(place)).digest.getQuantile(params, level);
if (returns_float)
static_cast<ColumnFloat32 &>(to).getData().push_back(quantile);
else
static_cast<ColumnVector<T> &>(to).getData().push_back(quantile);
}
};
template <typename T, bool returns_float = true>
class AggregateFunctionQuantilesTDigest final
: public IUnaryAggregateFunction<AggregateFunctionQuantileTDigestData, AggregateFunctionQuantilesTDigest<T>>
{
private:
QuantileLevels<Float32> levels;
tdigest::Params<Float32> params;
DataTypePtr type;
public:
String getName() const override { return "quantilesTDigest"; }
DataTypePtr getReturnType() const override
{
return new DataTypeArray(type);
}
void setArgument(const DataTypePtr & argument)
{
if (returns_float)
type = new DataTypeFloat32;
else
type = argument;
}
void setParameters(const Array & params) override
{
levels.set(params);
}
void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num) const
{
this->data(place).digest.add(params, static_cast<const ColumnVector<T> &>(column).getData()[row_num]);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{
this->data(place).digest.merge(params, this->data(rhs).digest);
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
this->data(const_cast<AggregateDataPtr>(place)).digest.write(params, buf);
}
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{
this->data(place).digest.readAndMerge(params, buf);
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
ColumnArray & arr_to = static_cast<ColumnArray &>(to);
ColumnArray::Offsets_t & offsets_to = arr_to.getOffsets();
size_t size = levels.size();
offsets_to.push_back((offsets_to.size() == 0 ? 0 : offsets_to.back()) + size);
if (returns_float)
{
typename ColumnFloat32::Container_t & data_to = static_cast<ColumnFloat32 &>(arr_to.getData()).getData();
size_t old_size = data_to.size();
data_to.resize(data_to.size() + size);
this->data(const_cast<AggregateDataPtr>(place)).digest.getManyQuantiles(
params, &levels.levels[0], &levels.permutation[0], size, &data_to[old_size]);
}
else
{
typename ColumnVector<T>::Container_t & data_to = static_cast<ColumnVector<T> &>(arr_to.getData()).getData();
size_t old_size = data_to.size();
data_to.resize(data_to.size() + size);
this->data(const_cast<AggregateDataPtr>(place)).digest.getManyQuantiles(
params, &levels.levels[0], &levels.permutation[0], size, &data_to[old_size]);
}
}
};
template <typename T, typename Weight, bool returns_float = true>
class AggregateFunctionQuantilesTDigestWeighted final
: public IBinaryAggregateFunction<AggregateFunctionQuantileTDigestData, AggregateFunctionQuantilesTDigestWeighted<T, Weight, returns_float>>
{
private:
QuantileLevels<Float32> levels;
tdigest::Params<Float32> params;
DataTypePtr type;
public:
String getName() const override { return "quantilesTDigest"; }
DataTypePtr getReturnType() const override
{
return new DataTypeArray(type);
}
void setArgumentsImpl(const DataTypes & arguments)
{
if (returns_float)
type = new DataTypeFloat32;
else
type = arguments.at(0);
}
void setParameters(const Array & params) override
{
levels.set(params);
}
void addImpl(AggregateDataPtr place, const IColumn & column_value, const IColumn & column_weight, size_t row_num) const
{
this->data(place).digest.add(params,
static_cast<const ColumnVector<T> &>(column_value).getData()[row_num],
static_cast<const ColumnVector<Weight> &>(column_weight).getData()[row_num]);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{
this->data(place).digest.merge(params, this->data(rhs).digest);
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
this->data(const_cast<AggregateDataPtr>(place)).digest.write(params, buf);
}
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{
this->data(place).digest.readAndMerge(params, buf);
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
ColumnArray & arr_to = static_cast<ColumnArray &>(to);
ColumnArray::Offsets_t & offsets_to = arr_to.getOffsets();
size_t size = levels.size();
offsets_to.push_back((offsets_to.size() == 0 ? 0 : offsets_to.back()) + size);
if (returns_float)
{
typename ColumnFloat32::Container_t & data_to = static_cast<ColumnFloat32 &>(arr_to.getData()).getData();
size_t old_size = data_to.size();
data_to.resize(data_to.size() + size);
this->data(const_cast<AggregateDataPtr>(place)).digest.getManyQuantiles(
params, &levels.levels[0], &levels.permutation[0], size, &data_to[old_size]);
}
else
{
typename ColumnVector<T>::Container_t & data_to = static_cast<ColumnVector<T> &>(arr_to.getData()).getData();
size_t old_size = data_to.size();
data_to.resize(data_to.size() + size);
this->data(const_cast<AggregateDataPtr>(place)).digest.getManyQuantiles(
params, &levels.levels[0], &levels.permutation[0], size, &data_to[old_size]);
}
}
};
}

View File

@ -3,8 +3,7 @@
#include <limits> #include <limits>
#include <DB/Common/MemoryTracker.h> #include <DB/Common/MemoryTracker.h>
#include <DB/Common/HashTable/Hash.h>
#include <DB/Core/FieldVisitors.h>
#include <DB/IO/WriteHelpers.h> #include <DB/IO/WriteHelpers.h>
#include <DB/IO/ReadHelpers.h> #include <DB/IO/ReadHelpers.h>
@ -13,10 +12,11 @@
#include <DB/DataTypes/DataTypeArray.h> #include <DB/DataTypes/DataTypeArray.h>
#include <DB/AggregateFunctions/IUnaryAggregateFunction.h> #include <DB/AggregateFunctions/IUnaryAggregateFunction.h>
#include <DB/AggregateFunctions/IBinaryAggregateFunction.h>
#include <DB/AggregateFunctions/QuantilesCommon.h>
#include <DB/Columns/ColumnArray.h> #include <DB/Columns/ColumnArray.h>
#include <stats/IntHash.h>
#include <ext/range.hpp> #include <ext/range.hpp>
@ -234,15 +234,10 @@ namespace detail
} }
/// Получить значения size квантилей уровней levels. Записать size результатов начиная с адреса result. /// Получить значения size квантилей уровней levels. Записать size результатов начиная с адреса result.
/// indices - массив индексов levels такой, что соответствующие элементы будут идти в порядке по возрастанию.
template <typename ResultType> template <typename ResultType>
void getMany(const double * levels, size_t size, ResultType * result) const void getMany(const double * levels, const size_t * indices, size_t size, ResultType * result) const
{ {
std::size_t indices[size];
std::copy(ext::range_iterator<size_t>{}, ext::make_range_iterator(size), indices);
std::sort(indices, indices + size, [levels] (auto i1, auto i2) {
return levels[i1] < levels[i2];
});
const auto indices_end = indices + size; const auto indices_end = indices + size;
auto index = indices; auto index = indices;
@ -310,10 +305,10 @@ namespace detail
: std::numeric_limits<float>::quiet_NaN(); : std::numeric_limits<float>::quiet_NaN();
} }
void getManyFloat(const double * levels, size_t size, float * result) const void getManyFloat(const double * levels, const size_t * levels_permutation, size_t size, float * result) const
{ {
if (count) if (count)
getMany(levels, size, result); getMany(levels, levels_permutation, size, result);
else else
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
result[i] = std::numeric_limits<float>::quiet_NaN(); result[i] = std::numeric_limits<float>::quiet_NaN();
@ -502,11 +497,11 @@ public:
/// Получить значения size квантилей уровней levels. Записать size результатов начиная с адреса result. /// Получить значения size квантилей уровней levels. Записать size результатов начиная с адреса result.
template <typename ResultType> template <typename ResultType>
void getMany(const double * levels, size_t size, ResultType * result) const void getMany(const double * levels, const size_t * levels_permutation, size_t size, ResultType * result) const
{ {
if (isLarge()) if (isLarge())
{ {
return large->getMany(levels, size, result); return large->getMany(levels, levels_permutation, size, result);
} }
else else
{ {
@ -523,10 +518,10 @@ public:
: std::numeric_limits<float>::quiet_NaN(); : std::numeric_limits<float>::quiet_NaN();
} }
void getManyFloat(const double * levels, size_t size, float * result) const void getManyFloat(const double * levels, const size_t * levels_permutation, size_t size, float * result) const
{ {
if (tiny.count) if (tiny.count)
getMany(levels, size, result); getMany(levels, levels_permutation, size, result);
else else
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
result[i] = std::numeric_limits<float>::quiet_NaN(); result[i] = std::numeric_limits<float>::quiet_NaN();
@ -549,9 +544,9 @@ private:
public: public:
AggregateFunctionQuantileTiming(double level_ = 0.5) : level(level_) {} AggregateFunctionQuantileTiming(double level_ = 0.5) : level(level_) {}
String getName() const { return "quantileTiming"; } String getName() const override { return "quantileTiming"; }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return new DataTypeFloat32; return new DataTypeFloat32;
} }
@ -560,7 +555,7 @@ public:
{ {
} }
void setParameters(const Array & params) void setParameters(const Array & params) override
{ {
if (params.size() != 1) if (params.size() != 1)
throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
@ -569,27 +564,27 @@ public:
} }
void addOne(AggregateDataPtr place, const IColumn & column, size_t row_num) const void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num) const
{ {
this->data(place).insert(static_cast<const ColumnVector<ArgumentFieldType> &>(column).getData()[row_num]); this->data(place).insert(static_cast<const ColumnVector<ArgumentFieldType> &>(column).getData()[row_num]);
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
this->data(place).merge(this->data(rhs)); this->data(place).merge(this->data(rhs));
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
this->data(place).serialize(buf); this->data(place).serialize(buf);
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
this->data(place).deserializeMerge(buf); this->data(place).deserializeMerge(buf);
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
static_cast<ColumnFloat32 &>(to).getData().push_back(this->data(place).getFloat(level)); static_cast<ColumnFloat32 &>(to).getData().push_back(this->data(place).getFloat(level));
} }
@ -599,7 +594,8 @@ public:
/** То же самое, но с двумя аргументами. Второй аргумент - "вес" (целое число) - сколько раз учитывать значение. /** То же самое, но с двумя аргументами. Второй аргумент - "вес" (целое число) - сколько раз учитывать значение.
*/ */
template <typename ArgumentFieldType, typename WeightFieldType> template <typename ArgumentFieldType, typename WeightFieldType>
class AggregateFunctionQuantileTimingWeighted final : public IAggregateFunctionHelper<QuantileTiming> class AggregateFunctionQuantileTimingWeighted final
: public IBinaryAggregateFunction<QuantileTiming, AggregateFunctionQuantileTimingWeighted<ArgumentFieldType, WeightFieldType>>
{ {
private: private:
double level; double level;
@ -607,18 +603,18 @@ private:
public: public:
AggregateFunctionQuantileTimingWeighted(double level_ = 0.5) : level(level_) {} AggregateFunctionQuantileTimingWeighted(double level_ = 0.5) : level(level_) {}
String getName() const { return "quantileTimingWeighted"; } String getName() const override { return "quantileTimingWeighted"; }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return new DataTypeFloat32; return new DataTypeFloat32;
} }
void setArguments(const DataTypes & arguments) void setArgumentsImpl(const DataTypes & arguments)
{ {
} }
void setParameters(const Array & params) void setParameters(const Array & params) override
{ {
if (params.size() != 1) if (params.size() != 1)
throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
@ -626,30 +622,29 @@ public:
level = apply_visitor(FieldVisitorConvertToNumber<Float64>(), params[0]); level = apply_visitor(FieldVisitorConvertToNumber<Float64>(), params[0]);
} }
void addImpl(AggregateDataPtr place, const IColumn & column_value, const IColumn & column_weight, size_t row_num) const
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num) const
{ {
this->data(place).insertWeighted( this->data(place).insertWeighted(
static_cast<const ColumnVector<ArgumentFieldType> &>(*columns[0]).getData()[row_num], static_cast<const ColumnVector<ArgumentFieldType> &>(column_value).getData()[row_num],
static_cast<const ColumnVector<WeightFieldType> &>(*columns[1]).getData()[row_num]); static_cast<const ColumnVector<WeightFieldType> &>(column_weight).getData()[row_num]);
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
this->data(place).merge(this->data(rhs)); this->data(place).merge(this->data(rhs));
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
this->data(place).serialize(buf); this->data(place).serialize(buf);
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
this->data(place).deserializeMerge(buf); this->data(place).deserializeMerge(buf);
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
static_cast<ColumnFloat32 &>(to).getData().push_back(this->data(place).getFloat(level)); static_cast<ColumnFloat32 &>(to).getData().push_back(this->data(place).getFloat(level));
} }
@ -664,13 +659,12 @@ template <typename ArgumentFieldType>
class AggregateFunctionQuantilesTiming final : public IUnaryAggregateFunction<QuantileTiming, AggregateFunctionQuantilesTiming<ArgumentFieldType> > class AggregateFunctionQuantilesTiming final : public IUnaryAggregateFunction<QuantileTiming, AggregateFunctionQuantilesTiming<ArgumentFieldType> >
{ {
private: private:
typedef std::vector<double> Levels; QuantileLevels<double> levels;
Levels levels;
public: public:
String getName() const { return "quantilesTiming"; } String getName() const override { return "quantilesTiming"; }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return new DataTypeArray(new DataTypeFloat32); return new DataTypeArray(new DataTypeFloat32);
} }
@ -679,40 +673,33 @@ public:
{ {
} }
void setParameters(const Array & params) void setParameters(const Array & params) override
{ {
if (params.empty()) levels.set(params);
throw Exception("Aggregate function " + getName() + " requires at least one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
size_t size = params.size();
levels.resize(size);
for (size_t i = 0; i < size; ++i)
levels[i] = apply_visitor(FieldVisitorConvertToNumber<Float64>(), params[i]);
} }
void addOne(AggregateDataPtr place, const IColumn & column, size_t row_num) const void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num) const
{ {
this->data(place).insert(static_cast<const ColumnVector<ArgumentFieldType> &>(column).getData()[row_num]); this->data(place).insert(static_cast<const ColumnVector<ArgumentFieldType> &>(column).getData()[row_num]);
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
this->data(place).merge(this->data(rhs)); this->data(place).merge(this->data(rhs));
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
this->data(place).serialize(buf); this->data(place).serialize(buf);
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
this->data(place).deserializeMerge(buf); this->data(place).deserializeMerge(buf);
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
ColumnArray & arr_to = static_cast<ColumnArray &>(to); ColumnArray & arr_to = static_cast<ColumnArray &>(to);
ColumnArray::Offsets_t & offsets_to = arr_to.getOffsets(); ColumnArray::Offsets_t & offsets_to = arr_to.getOffsets();
@ -724,65 +711,58 @@ public:
size_t old_size = data_to.size(); size_t old_size = data_to.size();
data_to.resize(data_to.size() + size); data_to.resize(data_to.size() + size);
this->data(place).getManyFloat(&levels[0], size, &data_to[old_size]); this->data(place).getManyFloat(&levels.levels[0], &levels.permutation[0], size, &data_to[old_size]);
} }
}; };
template <typename ArgumentFieldType, typename WeightFieldType> template <typename ArgumentFieldType, typename WeightFieldType>
class AggregateFunctionQuantilesTimingWeighted final : public IAggregateFunctionHelper<QuantileTiming> class AggregateFunctionQuantilesTimingWeighted final
: public IBinaryAggregateFunction<QuantileTiming, AggregateFunctionQuantilesTimingWeighted<ArgumentFieldType, WeightFieldType>>
{ {
private: private:
typedef std::vector<double> Levels; QuantileLevels<double> levels;
Levels levels;
public: public:
String getName() const { return "quantilesTimingWeighted"; } String getName() const override { return "quantilesTimingWeighted"; }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return new DataTypeArray(new DataTypeFloat32); return new DataTypeArray(new DataTypeFloat32);
} }
void setArguments(const DataTypes & arguments) void setArgumentsImpl(const DataTypes & arguments)
{ {
} }
void setParameters(const Array & params) void setParameters(const Array & params) override
{ {
if (params.empty()) levels.set(params);
throw Exception("Aggregate function " + getName() + " requires at least one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
size_t size = params.size();
levels.resize(size);
for (size_t i = 0; i < size; ++i)
levels[i] = apply_visitor(FieldVisitorConvertToNumber<Float64>(), params[i]);
} }
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num) const void addImpl(AggregateDataPtr place, const IColumn & column_value, const IColumn & column_weight, size_t row_num) const
{ {
this->data(place).insertWeighted( this->data(place).insertWeighted(
static_cast<const ColumnVector<ArgumentFieldType> &>(*columns[0]).getData()[row_num], static_cast<const ColumnVector<ArgumentFieldType> &>(column_value).getData()[row_num],
static_cast<const ColumnVector<WeightFieldType> &>(*columns[1]).getData()[row_num]); static_cast<const ColumnVector<WeightFieldType> &>(column_weight).getData()[row_num]);
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
this->data(place).merge(this->data(rhs)); this->data(place).merge(this->data(rhs));
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
this->data(place).serialize(buf); this->data(place).serialize(buf);
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
this->data(place).deserializeMerge(buf); this->data(place).deserializeMerge(buf);
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
ColumnArray & arr_to = static_cast<ColumnArray &>(to); ColumnArray & arr_to = static_cast<ColumnArray &>(to);
ColumnArray::Offsets_t & offsets_to = arr_to.getOffsets(); ColumnArray::Offsets_t & offsets_to = arr_to.getOffsets();
@ -794,7 +774,7 @@ public:
size_t old_size = data_to.size(); size_t old_size = data_to.size();
data_to.resize(data_to.size() + size); data_to.resize(data_to.size() + size);
this->data(place).getManyFloat(&levels[0], size, &data_to[old_size]); this->data(place).getManyFloat(&levels.levels[0], &levels.permutation[0], size, &data_to[old_size]);
} }
}; };

View File

@ -229,6 +229,13 @@ public:
static_cast<ColumnUInt8 &>(to).getData().push_back(match(events_it, events_end)); static_cast<ColumnUInt8 &>(to).getData().push_back(match(events_it, events_end));
} }
static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num)
{
return static_cast<const AggregateFunctionSequenceMatch &>(*that).add(place, columns, row_num);
}
IAggregateFunction::AddFunc getAddressOfAddFunction() const override final { return &addFree; }
private: private:
enum class PatternActionType enum class PatternActionType
{ {

View File

@ -25,81 +25,89 @@ private:
public: public:
AggregateFunctionState(AggregateFunctionPtr nested_) : nested_func_owner(nested_), nested_func(nested_func_owner.get()) {} AggregateFunctionState(AggregateFunctionPtr nested_) : nested_func_owner(nested_), nested_func(nested_func_owner.get()) {}
String getName() const String getName() const override
{ {
return nested_func->getName() + "State"; return nested_func->getName() + "State";
} }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return new DataTypeAggregateFunction(nested_func_owner, arguments, params); return new DataTypeAggregateFunction(nested_func_owner, arguments, params);
} }
void setArguments(const DataTypes & arguments_) void setArguments(const DataTypes & arguments_) override
{ {
arguments = arguments_; arguments = arguments_;
nested_func->setArguments(arguments); nested_func->setArguments(arguments);
} }
void setParameters(const Array & params_) void setParameters(const Array & params_) override
{ {
params = params_; params = params_;
nested_func->setParameters(params); nested_func->setParameters(params);
} }
void create(AggregateDataPtr place) const void create(AggregateDataPtr place) const override
{ {
nested_func->create(place); nested_func->create(place);
} }
void destroy(AggregateDataPtr place) const noexcept void destroy(AggregateDataPtr place) const noexcept override
{ {
nested_func->destroy(place); nested_func->destroy(place);
} }
bool hasTrivialDestructor() const bool hasTrivialDestructor() const override
{ {
return nested_func->hasTrivialDestructor(); return nested_func->hasTrivialDestructor();
} }
size_t sizeOfData() const size_t sizeOfData() const override
{ {
return nested_func->sizeOfData(); return nested_func->sizeOfData();
} }
size_t alignOfData() const size_t alignOfData() const override
{ {
return nested_func->alignOfData(); return nested_func->alignOfData();
} }
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num) const void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num) const override
{ {
nested_func->add(place, columns, row_num); nested_func->add(place, columns, row_num);
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
nested_func->merge(place, rhs); nested_func->merge(place, rhs);
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
nested_func->serialize(place, buf); nested_func->serialize(place, buf);
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
nested_func->deserializeMerge(place, buf); nested_func->deserializeMerge(place, buf);
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
static_cast<ColumnAggregateFunction &>(to).getData().push_back(const_cast<AggregateDataPtr>(place)); static_cast<ColumnAggregateFunction &>(to).getData().push_back(const_cast<AggregateDataPtr>(place));
} }
/// Аггрегатная функция или состояние аггрегатной функции. /// Аггрегатная функция или состояние аггрегатной функции.
bool isState() const { return true; } bool isState() const override { return true; }
AggregateFunctionPtr getNestedFunction() const { return nested_func_owner; }
static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num)
{
return static_cast<const AggregateFunctionState &>(*that).add(place, columns, row_num);
}
IAggregateFunction::AddFunc getAddressOfAddFunction() const override final { return &addFree; }
}; };
} }

View File

@ -25,9 +25,9 @@ template <typename T>
class AggregateFunctionSum final : public IUnaryAggregateFunction<AggregateFunctionSumData<typename NearestFieldType<T>::Type>, AggregateFunctionSum<T> > class AggregateFunctionSum final : public IUnaryAggregateFunction<AggregateFunctionSumData<typename NearestFieldType<T>::Type>, AggregateFunctionSum<T> >
{ {
public: public:
String getName() const { return "sum"; } String getName() const override { return "sum"; }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return new typename DataTypeFromFieldType<typename NearestFieldType<T>::Type>::Type; return new typename DataTypeFromFieldType<typename NearestFieldType<T>::Type>::Type;
} }
@ -40,29 +40,29 @@ public:
} }
void addOne(AggregateDataPtr place, const IColumn & column, size_t row_num) const void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num) const
{ {
this->data(place).sum += static_cast<const ColumnVector<T> &>(column).getData()[row_num]; this->data(place).sum += static_cast<const ColumnVector<T> &>(column).getData()[row_num];
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
this->data(place).sum += this->data(rhs).sum; this->data(place).sum += this->data(rhs).sum;
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
writeBinary(this->data(place).sum, buf); writeBinary(this->data(place).sum, buf);
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
typename NearestFieldType<T>::Type tmp; typename NearestFieldType<T>::Type tmp;
readBinary(tmp, buf); readBinary(tmp, buf);
this->data(place).sum += tmp; this->data(place).sum += tmp;
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
static_cast<ColumnVector<typename NearestFieldType<T>::Type> &>(to).getData().push_back(this->data(place).sum); static_cast<ColumnVector<typename NearestFieldType<T>::Type> &>(to).getData().push_back(this->data(place).sum);
} }

View File

@ -3,7 +3,7 @@
#include <city.h> #include <city.h>
#include <type_traits> #include <type_traits>
#include <stats/UniquesHashSet.h> #include <DB/AggregateFunctions/UniquesHashSet.h>
#include <DB/IO/WriteHelpers.h> #include <DB/IO/WriteHelpers.h>
#include <DB/IO/ReadHelpers.h> #include <DB/IO/ReadHelpers.h>
@ -256,7 +256,7 @@ struct OneAdder<T, Data, typename std::enable_if<
std::is_same<Data, AggregateFunctionUniqHLL12Data<T> >::value>::type> std::is_same<Data, AggregateFunctionUniqHLL12Data<T> >::value>::type>
{ {
template <typename T2 = T> template <typename T2 = T>
static void addOne(Data & data, const IColumn & column, size_t row_num, static void addImpl(Data & data, const IColumn & column, size_t row_num,
typename std::enable_if<!std::is_same<T2, String>::value>::type * = nullptr) typename std::enable_if<!std::is_same<T2, String>::value>::type * = nullptr)
{ {
const auto & value = static_cast<const ColumnVector<T2> &>(column).getData()[row_num]; const auto & value = static_cast<const ColumnVector<T2> &>(column).getData()[row_num];
@ -264,7 +264,7 @@ struct OneAdder<T, Data, typename std::enable_if<
} }
template <typename T2 = T> template <typename T2 = T>
static void addOne(Data & data, const IColumn & column, size_t row_num, static void addImpl(Data & data, const IColumn & column, size_t row_num,
typename std::enable_if<std::is_same<T2, String>::value>::type * = nullptr) typename std::enable_if<std::is_same<T2, String>::value>::type * = nullptr)
{ {
StringRef value = column.getDataAt(row_num); StringRef value = column.getDataAt(row_num);
@ -280,7 +280,7 @@ struct OneAdder<T, Data, typename std::enable_if<
std::is_same<Data, AggregateFunctionUniqCombinedData<T> >::value>::type> std::is_same<Data, AggregateFunctionUniqCombinedData<T> >::value>::type>
{ {
template <typename T2 = T> template <typename T2 = T>
static void addOne(Data & data, const IColumn & column, size_t row_num, static void addImpl(Data & data, const IColumn & column, size_t row_num,
typename std::enable_if<!std::is_same<T2, String>::value>::type * = nullptr) typename std::enable_if<!std::is_same<T2, String>::value>::type * = nullptr)
{ {
const auto & value = static_cast<const ColumnVector<T2> &>(column).getData()[row_num]; const auto & value = static_cast<const ColumnVector<T2> &>(column).getData()[row_num];
@ -288,7 +288,7 @@ struct OneAdder<T, Data, typename std::enable_if<
} }
template <typename T2 = T> template <typename T2 = T>
static void addOne(Data & data, const IColumn & column, size_t row_num, static void addImpl(Data & data, const IColumn & column, size_t row_num,
typename std::enable_if<std::is_same<T2, String>::value>::type * = nullptr) typename std::enable_if<std::is_same<T2, String>::value>::type * = nullptr)
{ {
StringRef value = column.getDataAt(row_num); StringRef value = column.getDataAt(row_num);
@ -301,14 +301,14 @@ struct OneAdder<T, Data, typename std::enable_if<
std::is_same<Data, AggregateFunctionUniqExactData<T> >::value>::type> std::is_same<Data, AggregateFunctionUniqExactData<T> >::value>::type>
{ {
template <typename T2 = T> template <typename T2 = T>
static void addOne(Data & data, const IColumn & column, size_t row_num, static void addImpl(Data & data, const IColumn & column, size_t row_num,
typename std::enable_if<!std::is_same<T2, String>::value>::type * = nullptr) typename std::enable_if<!std::is_same<T2, String>::value>::type * = nullptr)
{ {
data.set.insert(static_cast<const ColumnVector<T2> &>(column).getData()[row_num]); data.set.insert(static_cast<const ColumnVector<T2> &>(column).getData()[row_num]);
} }
template <typename T2 = T> template <typename T2 = T>
static void addOne(Data & data, const IColumn & column, size_t row_num, static void addImpl(Data & data, const IColumn & column, size_t row_num,
typename std::enable_if<std::is_same<T2, String>::value>::type * = nullptr) typename std::enable_if<std::is_same<T2, String>::value>::type * = nullptr)
{ {
StringRef value = column.getDataAt(row_num); StringRef value = column.getDataAt(row_num);
@ -330,9 +330,9 @@ template <typename T, typename Data>
class AggregateFunctionUniq final : public IUnaryAggregateFunction<Data, AggregateFunctionUniq<T, Data> > class AggregateFunctionUniq final : public IUnaryAggregateFunction<Data, AggregateFunctionUniq<T, Data> >
{ {
public: public:
String getName() const { return Data::getName(); } String getName() const override { return Data::getName(); }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return new DataTypeUInt64; return new DataTypeUInt64;
} }
@ -341,27 +341,27 @@ public:
{ {
} }
void addOne(AggregateDataPtr place, const IColumn & column, size_t row_num) const void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num) const
{ {
detail::OneAdder<T, Data>::addOne(this->data(place), column, row_num); detail::OneAdder<T, Data>::addImpl(this->data(place), column, row_num);
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
this->data(place).set.merge(this->data(rhs).set); this->data(place).set.merge(this->data(rhs).set);
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
this->data(place).set.write(buf); this->data(place).set.write(buf);
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
this->data(place).set.readAndMerge(buf); this->data(place).set.readAndMerge(buf);
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
static_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size()); static_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size());
} }
@ -381,14 +381,14 @@ private:
size_t num_args = 0; size_t num_args = 0;
public: public:
String getName() const { return Data::getName(); } String getName() const override { return Data::getName(); }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return new DataTypeUInt64; return new DataTypeUInt64;
} }
void setArguments(const DataTypes & arguments) void setArguments(const DataTypes & arguments) override
{ {
if (argument_is_tuple) if (argument_is_tuple)
num_args = typeid_cast<const DataTypeTuple &>(*arguments[0]).getElements().size(); num_args = typeid_cast<const DataTypeTuple &>(*arguments[0]).getElements().size();
@ -396,30 +396,37 @@ public:
num_args = arguments.size(); num_args = arguments.size();
} }
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num) const void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num) const override
{ {
this->data(place).set.insert(UniqVariadicHash<is_exact, argument_is_tuple>::apply(num_args, columns, row_num)); this->data(place).set.insert(UniqVariadicHash<is_exact, argument_is_tuple>::apply(num_args, columns, row_num));
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
this->data(place).set.merge(this->data(rhs).set); this->data(place).set.merge(this->data(rhs).set);
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
this->data(place).set.write(buf); this->data(place).set.write(buf);
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
this->data(place).set.readAndMerge(buf); this->data(place).set.readAndMerge(buf);
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
static_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size()); static_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size());
} }
static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num)
{
return static_cast<const AggregateFunctionUniqVariadic &>(*that).add(place, columns, row_num);
}
IAggregateFunction::AddFunc getAddressOfAddFunction() const override final { return &addFree; }
}; };

View File

@ -104,7 +104,7 @@ struct __attribute__((__packed__)) AggregateFunctionUniqUpToData
} }
void addOne(const IColumn & column, size_t row_num, UInt8 threshold) void addImpl(const IColumn & column, size_t row_num, UInt8 threshold)
{ {
insert(static_cast<const ColumnVector<T> &>(column).getData()[row_num], threshold); insert(static_cast<const ColumnVector<T> &>(column).getData()[row_num], threshold);
} }
@ -115,7 +115,7 @@ struct __attribute__((__packed__)) AggregateFunctionUniqUpToData
template <> template <>
struct AggregateFunctionUniqUpToData<String> : AggregateFunctionUniqUpToData<UInt64> struct AggregateFunctionUniqUpToData<String> : AggregateFunctionUniqUpToData<UInt64>
{ {
void addOne(const IColumn & column, size_t row_num, UInt8 threshold) void addImpl(const IColumn & column, size_t row_num, UInt8 threshold)
{ {
/// Имейте ввиду, что вычисление приближённое. /// Имейте ввиду, что вычисление приближённое.
StringRef value = column.getDataAt(row_num); StringRef value = column.getDataAt(row_num);
@ -133,14 +133,14 @@ private:
UInt8 threshold = 5; /// Значение по-умолчанию, если параметр не указан. UInt8 threshold = 5; /// Значение по-умолчанию, если параметр не указан.
public: public:
size_t sizeOfData() const size_t sizeOfData() const override
{ {
return sizeof(AggregateFunctionUniqUpToData<T>) + sizeof(T) * threshold; return sizeof(AggregateFunctionUniqUpToData<T>) + sizeof(T) * threshold;
} }
String getName() const { return "uniqUpTo"; } String getName() const override { return "uniqUpTo"; }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return new DataTypeUInt64; return new DataTypeUInt64;
} }
@ -149,7 +149,7 @@ public:
{ {
} }
void setParameters(const Array & params) void setParameters(const Array & params) override
{ {
if (params.size() != 1) if (params.size() != 1)
throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
@ -163,27 +163,27 @@ public:
threshold = threshold_param; threshold = threshold_param;
} }
void addOne(AggregateDataPtr place, const IColumn & column, size_t row_num) const void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num) const
{ {
this->data(place).addOne(column, row_num, threshold); this->data(place).addImpl(column, row_num, threshold);
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
this->data(place).merge(this->data(rhs), threshold); this->data(place).merge(this->data(rhs), threshold);
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
this->data(place).write(buf, threshold); this->data(place).write(buf, threshold);
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
this->data(place).readAndMerge(buf, threshold); this->data(place).readAndMerge(buf, threshold);
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
static_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).size()); static_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).size());
} }
@ -202,14 +202,19 @@ private:
UInt8 threshold = 5; /// Значение по-умолчанию, если параметр не указан. UInt8 threshold = 5; /// Значение по-умолчанию, если параметр не указан.
public: public:
String getName() const { return "uniqUpTo"; } size_t sizeOfData() const override
{
return sizeof(AggregateFunctionUniqUpToData<UInt64>) + sizeof(UInt64) * threshold;
}
DataTypePtr getReturnType() const String getName() const override { return "uniqUpTo"; }
DataTypePtr getReturnType() const override
{ {
return new DataTypeUInt64; return new DataTypeUInt64;
} }
void setArguments(const DataTypes & arguments) void setArguments(const DataTypes & arguments) override
{ {
if (argument_is_tuple) if (argument_is_tuple)
num_args = typeid_cast<const DataTypeTuple &>(*arguments[0]).getElements().size(); num_args = typeid_cast<const DataTypeTuple &>(*arguments[0]).getElements().size();
@ -217,7 +222,7 @@ public:
num_args = arguments.size(); num_args = arguments.size();
} }
void setParameters(const Array & params) void setParameters(const Array & params) override
{ {
if (params.size() != 1) if (params.size() != 1)
throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
@ -231,30 +236,37 @@ public:
threshold = threshold_param; threshold = threshold_param;
} }
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num) const void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num) const override
{ {
this->data(place).insert(UniqVariadicHash<false, argument_is_tuple>::apply(num_args, columns, row_num), threshold); this->data(place).insert(UniqVariadicHash<false, argument_is_tuple>::apply(num_args, columns, row_num), threshold);
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
this->data(place).merge(this->data(rhs), threshold); this->data(place).merge(this->data(rhs), threshold);
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
this->data(place).write(buf, threshold); this->data(place).write(buf, threshold);
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
this->data(place).readAndMerge(buf, threshold); this->data(place).readAndMerge(buf, threshold);
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
static_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).size()); static_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).size());
} }
static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num)
{
return static_cast<const AggregateFunctionUniqUpToVariadic &>(*that).add(place, columns, row_num);
}
IAggregateFunction::AddFunc getAddressOfAddFunction() const override final { return &addFree; }
}; };

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include <DB/AggregateFunctions/AggregateFunctionsMinMaxAny.h> #include <DB/AggregateFunctions/AggregateFunctionsMinMaxAny.h>
#include <DB/AggregateFunctions/IBinaryAggregateFunction.h>
namespace DB namespace DB
@ -20,48 +21,45 @@ struct AggregateFunctionsArgMinMaxData
/// Возвращает первое попавшееся значение arg для минимального/максимального value. Пример: argMax(arg, value). /// Возвращает первое попавшееся значение arg для минимального/максимального value. Пример: argMax(arg, value).
template <typename Data> template <typename Data>
class AggregateFunctionsArgMinMax final : public IAggregateFunctionHelper<Data> class AggregateFunctionsArgMinMax final : public IBinaryAggregateFunction<Data, AggregateFunctionsArgMinMax<Data>>
{ {
private: private:
DataTypePtr type_res; DataTypePtr type_res;
DataTypePtr type_val; DataTypePtr type_val;
public: public:
String getName() const { return (0 == strcmp(Data::ValueData_t::name(), "min")) ? "argMin" : "argMax"; } String getName() const override { return (0 == strcmp(Data::ValueData_t::name(), "min")) ? "argMin" : "argMax"; }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return type_res; return type_res;
} }
void setArguments(const DataTypes & arguments) void setArgumentsImpl(const DataTypes & arguments)
{ {
if (arguments.size() != 2)
throw Exception("Aggregate function " + getName() + " requires exactly two arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
type_res = arguments[0]; type_res = arguments[0];
type_val = arguments[1]; type_val = arguments[1];
} }
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num) const void addImpl(AggregateDataPtr place, const IColumn & column_arg, const IColumn & column_max, size_t row_num) const
{ {
if (this->data(place).value.changeIfBetter(*columns[1], row_num)) if (this->data(place).value.changeIfBetter(column_max, row_num))
this->data(place).result.change(*columns[0], row_num); this->data(place).result.change(column_arg, row_num);
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
if (this->data(place).value.changeIfBetter(this->data(rhs).value)) if (this->data(place).value.changeIfBetter(this->data(rhs).value))
this->data(place).result.change(this->data(rhs).result); this->data(place).result.change(this->data(rhs).result);
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
this->data(place).result.write(buf, *type_res.get()); this->data(place).result.write(buf, *type_res.get());
this->data(place).value.write(buf, *type_val.get()); this->data(place).value.write(buf, *type_val.get());
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
Data rhs; /// Для строчек не очень оптимально, так как может делаться одна лишняя аллокация. Data rhs; /// Для строчек не очень оптимально, так как может делаться одна лишняя аллокация.
@ -72,7 +70,7 @@ public:
this->data(place).result.change(rhs.result); this->data(place).result.change(rhs.result);
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
this->data(place).result.insertResultInto(to); this->data(place).result.insertResultInto(to);
} }

View File

@ -5,6 +5,7 @@
#include <DB/Columns/ColumnVector.h> #include <DB/Columns/ColumnVector.h>
#include <DB/Columns/ColumnString.h> #include <DB/Columns/ColumnString.h>
#include <DB/DataTypes/DataTypeAggregateFunction.h>
#include <DB/AggregateFunctions/IUnaryAggregateFunction.h> #include <DB/AggregateFunctions/IUnaryAggregateFunction.h>
@ -138,7 +139,7 @@ struct SingleValueDataFixed
/** Для строк. Короткие строки хранятся в самой структуре, а длинные выделяются отдельно. /** Для строк. Короткие строки хранятся в самой структуре, а длинные выделяются отдельно.
* NOTE Могло бы подойти также для массивов чисел. * NOTE Могло бы подойти также для массивов чисел.
*/ */
struct __attribute__((__packed__)) SingleValueDataString struct __attribute__((__packed__, __aligned__(1))) SingleValueDataString
{ {
typedef SingleValueDataString Self; typedef SingleValueDataString Self;
@ -147,10 +148,10 @@ struct __attribute__((__packed__)) SingleValueDataString
static constexpr Int32 AUTOMATIC_STORAGE_SIZE = 64; static constexpr Int32 AUTOMATIC_STORAGE_SIZE = 64;
static constexpr Int32 MAX_SMALL_STRING_SIZE = AUTOMATIC_STORAGE_SIZE - sizeof(size); static constexpr Int32 MAX_SMALL_STRING_SIZE = AUTOMATIC_STORAGE_SIZE - sizeof(size);
union __attribute__((__aligned__(1))) union __attribute__((__packed__, __aligned__(1)))
{ {
char small_data[MAX_SMALL_STRING_SIZE]; /// Включая завершающий ноль. char small_data[MAX_SMALL_STRING_SIZE]; /// Включая завершающий ноль.
char * __attribute__((__aligned__(1))) large_data; char * __attribute__((__packed__, __aligned__(1))) large_data;
}; };
~SingleValueDataString() ~SingleValueDataString()
@ -335,6 +336,10 @@ struct __attribute__((__packed__)) SingleValueDataString
} }
}; };
static_assert(
sizeof(SingleValueDataString) == SingleValueDataString::AUTOMATIC_STORAGE_SIZE,
"Incorrect size of SingleValueDataString struct");
/// Для любых других типов значений. /// Для любых других типов значений.
struct SingleValueDataGeneric struct SingleValueDataGeneric
@ -531,9 +536,9 @@ private:
DataTypePtr type; DataTypePtr type;
public: public:
String getName() const { return Data::name(); } String getName() const override { return Data::name(); }
DataTypePtr getReturnType() const DataTypePtr getReturnType() const override
{ {
return type; return type;
} }
@ -541,25 +546,28 @@ public:
void setArgument(const DataTypePtr & argument) void setArgument(const DataTypePtr & argument)
{ {
type = argument; type = argument;
if (typeid_cast<const DataTypeAggregateFunction *>(type.get()))
throw Exception("Illegal type " + type->getName() + " of argument of aggregate function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
} }
void addOne(AggregateDataPtr place, const IColumn & column, size_t row_num) const void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num) const
{ {
this->data(place).changeIfBetter(column, row_num); this->data(place).changeIfBetter(column, row_num);
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override
{ {
this->data(place).changeIfBetter(this->data(rhs)); this->data(place).changeIfBetter(this->data(rhs));
} }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
this->data(place).write(buf, *type.get()); this->data(place).write(buf, *type.get());
} }
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
{ {
Data rhs; /// Для строчек не очень оптимально, так как может делаться одна лишняя аллокация. Data rhs; /// Для строчек не очень оптимально, так как может делаться одна лишняя аллокация.
rhs.read(buf, *type.get()); rhs.read(buf, *type.get());
@ -567,7 +575,7 @@ public:
this->data(place).changeIfBetter(rhs); this->data(place).changeIfBetter(rhs);
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
this->data(place).insertResultInto(to); this->data(place).insertResultInto(to);
} }

View File

@ -122,14 +122,14 @@ public:
return new DataTypeFloat64; return new DataTypeFloat64;
} }
void setArgument(const DataTypePtr & argument) override void setArgument(const DataTypePtr & argument)
{ {
if (!argument->behavesAsNumber()) if (!argument->behavesAsNumber())
throw Exception("Illegal type " + argument->getName() + " of argument for aggregate function " + getName(), throw Exception("Illegal type " + argument->getName() + " of argument for aggregate function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
} }
void addOne(AggregateDataPtr place, const IColumn & column, size_t row_num) const void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num) const
{ {
this->data(place).update(column, row_num); this->data(place).update(column, row_num);
} }
@ -400,7 +400,7 @@ public:
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
} }
void addOne(AggregateDataPtr place, const IColumn & column_left, const IColumn & column_right, size_t row_num) const void addImpl(AggregateDataPtr place, const IColumn & column_left, const IColumn & column_right, size_t row_num) const
{ {
this->data(place).update(column_left, column_right, row_num); this->data(place).update(column_left, column_right, row_num);
} }

View File

@ -90,6 +90,16 @@ public:
* Они выполняются как другие агрегатные функции, но не финализируются (возвращают состояние агрегации, которое может быть объединено с другим). * Они выполняются как другие агрегатные функции, но не финализируются (возвращают состояние агрегации, которое может быть объединено с другим).
*/ */
virtual bool isState() const { return false; } virtual bool isState() const { return false; }
/** Внутренний цикл, использующий указатель на функцию, получается лучше, чем использующий виртуальную функцию.
* Причина в том, что в случае виртуальных функций, GCC 5.1.2 генерирует код,
* который на каждой итерации цикла заново грузит из памяти в регистр адрес функции (значение по смещению в таблице виртуальных функций).
* Это даёт падение производительности на простых запросах в районе 12%.
* После появления более хороших компиляторов, код можно будет убрать.
*/
using AddFunc = void (*)(const IAggregateFunction *, AggregateDataPtr, const IColumn **, size_t);
virtual AddFunc getAddressOfAddFunction() const = 0;
}; };
@ -104,28 +114,28 @@ protected:
static const Data & data(ConstAggregateDataPtr place) { return *reinterpret_cast<const Data*>(place); } static const Data & data(ConstAggregateDataPtr place) { return *reinterpret_cast<const Data*>(place); }
public: public:
void create(AggregateDataPtr place) const void create(AggregateDataPtr place) const override
{ {
new (place) Data; new (place) Data;
} }
void destroy(AggregateDataPtr place) const noexcept void destroy(AggregateDataPtr place) const noexcept override
{ {
data(place).~Data(); data(place).~Data();
} }
bool hasTrivialDestructor() const bool hasTrivialDestructor() const override
{ {
return __has_trivial_destructor(Data); return __has_trivial_destructor(Data);
} }
size_t sizeOfData() const size_t sizeOfData() const override
{ {
return sizeof(Data); return sizeof(Data);
} }
/// NOTE: Сейчас не используется (структуры с состоянием агрегации кладутся без выравнивания). /// NOTE: Сейчас не используется (структуры с состоянием агрегации кладутся без выравнивания).
size_t alignOfData() const size_t alignOfData() const override
{ {
return __alignof__(Data); return __alignof__(Data);
} }

View File

@ -8,11 +8,12 @@ namespace DB
template <typename T, typename Derived> template <typename T, typename Derived>
class IBinaryAggregateFunction : public IAggregateFunctionHelper<T> class IBinaryAggregateFunction : public IAggregateFunctionHelper<T>
{ {
private:
Derived & getDerived() { return static_cast<Derived &>(*this); } Derived & getDerived() { return static_cast<Derived &>(*this); }
const Derived & getDerived() const { return static_cast<const Derived &>(*this); } const Derived & getDerived() const { return static_cast<const Derived &>(*this); }
public: public:
void setArguments(const DataTypes & arguments) void setArguments(const DataTypes & arguments) override final
{ {
if (arguments.size() != 2) if (arguments.size() != 2)
throw Exception{ throw Exception{
@ -23,10 +24,17 @@ public:
getDerived().setArgumentsImpl(arguments); getDerived().setArgumentsImpl(arguments);
} }
void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num) const void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num) const override final
{ {
getDerived().addOne(place, *columns[0], *columns[1], row_num); getDerived().addImpl(place, *columns[0], *columns[1], row_num);
} }
static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num)
{
return static_cast<const Derived &>(*that).addImpl(place, *columns[0], *columns[1], row_num);
}
IAggregateFunction::AddFunc getAddressOfAddFunction() const override final { return &addFree; }
}; };
} }

View File

@ -11,9 +11,13 @@ namespace DB
template <typename T, typename Derived> template <typename T, typename Derived>
class INullaryAggregateFunction : public IAggregateFunctionHelper<T> class INullaryAggregateFunction : public IAggregateFunctionHelper<T>
{ {
private:
Derived & getDerived() { return static_cast<Derived &>(*this); }
const Derived & getDerived() const { return static_cast<const Derived &>(*this); }
public: public:
/// Получить тип результата по типам аргументов. Если функция неприменима для данных аргументов - кинуть исключение. /// Получить тип результата по типам аргументов. Если функция неприменима для данных аргументов - кинуть исключение.
void setArguments(const DataTypes & arguments) void setArguments(const DataTypes & arguments) override final
{ {
if (arguments.size() != 0) if (arguments.size() != 0)
throw Exception("Passed " + toString(arguments.size()) + " arguments to nullary aggregate function " + this->getName(), throw Exception("Passed " + toString(arguments.size()) + " arguments to nullary aggregate function " + this->getName(),
@ -21,13 +25,20 @@ public:
} }
/// Добавить значение. /// Добавить значение.
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num) const void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num) const override final
{ {
static_cast<const Derived &>(*this).addZero(place); getDerived().addImpl(place);
} }
static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num)
{
return static_cast<const Derived &>(*that).addImpl(place);
}
IAggregateFunction::AddFunc getAddressOfAddFunction() const override final { return &addFree; }
/** Реализуйте это в классе-наследнике: /** Реализуйте это в классе-наследнике:
* void addZero(AggregateDataPtr place) const; * void addImpl(AggregateDataPtr place) const;
*/ */
}; };

View File

@ -11,25 +11,36 @@ namespace DB
template <typename T, typename Derived> template <typename T, typename Derived>
class IUnaryAggregateFunction : public IAggregateFunctionHelper<T> class IUnaryAggregateFunction : public IAggregateFunctionHelper<T>
{ {
private:
Derived & getDerived() { return static_cast<Derived &>(*this); }
const Derived & getDerived() const { return static_cast<const Derived &>(*this); }
public: public:
void setArguments(const DataTypes & arguments) void setArguments(const DataTypes & arguments) override final
{ {
if (arguments.size() != 1) if (arguments.size() != 1)
throw Exception("Passed " + toString(arguments.size()) + " arguments to unary aggregate function " + this->getName(), throw Exception("Passed " + toString(arguments.size()) + " arguments to unary aggregate function " + this->getName(),
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
setArgument(arguments[0]);
}
virtual void setArgument(const DataTypePtr & argument) = 0; getDerived().setArgument(arguments[0]);
}
/// Добавить значение. /// Добавить значение.
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num) const void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num) const override final
{ {
static_cast<const Derived &>(*this).addOne(place, *columns[0], row_num); getDerived().addImpl(place, *columns[0], row_num);
} }
static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num)
{
return static_cast<const Derived &>(*that).addImpl(place, *columns[0], row_num);
}
IAggregateFunction::AddFunc getAddressOfAddFunction() const override final { return &addFree; }
/** Реализуйте это в классе-наследнике: /** Реализуйте это в классе-наследнике:
* void addOne(AggregateDataPtr place, const IColumn & column, size_t row_num) const; * void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num) const;
* void setArgument(const DataTypePtr & argument);
*/ */
}; };

View File

@ -0,0 +1,52 @@
#pragma once
#include <vector>
#include <DB/Core/Field.h>
#include <DB/Core/FieldVisitors.h>
namespace DB
{
/** Параметры разных функций quantilesSomething.
* - список уровней квантилей.
* Также необходимо вычислить массив индексов уровней, идущих по возрастанию.
*
* Пример: quantiles(0.5, 0.99, 0.95)(x).
* levels: 0.5, 0.99, 0.95
* levels_permutation: 0, 2, 1
*/
template <typename T> /// float или double
struct QuantileLevels
{
using Levels = std::vector<T>;
using Permutation = std::vector<size_t>;
Levels levels;
Permutation permutation; /// Индекс i-го по величине уровня в массиве levels.
size_t size() const { return levels.size(); }
void set(const Array & params)
{
if (params.empty())
throw Exception("Aggregate function quantiles requires at least one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
size_t size = params.size();
levels.resize(size);
permutation.resize(size);
for (size_t i = 0; i < size; ++i)
{
levels[i] = apply_visitor(FieldVisitorConvertToNumber<Float64>(), params[i]);
permutation[i] = i;
}
std::sort(permutation.begin(), permutation.end(), [this] (size_t a, size_t b) { return levels[a] < levels[b]; });
}
};
}

View File

@ -0,0 +1,237 @@
#pragma once
#include <limits>
#include <algorithm>
#include <climits>
#include <sstream>
#include <common/Common.h>
#include <DB/IO/ReadBuffer.h>
#include <DB/IO/ReadHelpers.h>
#include <DB/IO/WriteHelpers.h>
#include <DB/Common/PODArray.h>
#include <Poco/Exception.h>
#include <boost/random.hpp>
/// Реализация алгоритма Reservoir Sampling. Инкрементально выбирает из добавленных объектов случайное подмножество размера sample_count.
/// Умеет приближенно получать квантили.
/// Вызов quantile занимает O(sample_count log sample_count), если после предыдущего вызова quantile был хотя бы один вызов insert. Иначе, O(1).
/// То есть, имеет смысл сначала добавлять, потом получать квантили, не добавляя.
const size_t DEFAULT_SAMPLE_COUNT = 8192;
/// Что делать, если нет ни одного значения - кинуть исключение, или вернуть 0 или NaN в случае double?
namespace ReservoirSamplerOnEmpty
{
enum Enum
{
THROW,
RETURN_NAN_OR_ZERO,
};
}
template<typename ResultType, bool IsFloatingPoint>
struct NanLikeValueConstructor
{
static ResultType getValue()
{
return std::numeric_limits<ResultType>::quiet_NaN();
}
};
template<typename ResultType>
struct NanLikeValueConstructor<ResultType, false>
{
static ResultType getValue()
{
return ResultType();
}
};
template<typename T, ReservoirSamplerOnEmpty::Enum OnEmpty = ReservoirSamplerOnEmpty::THROW, typename Comparer = std::less<T> >
class ReservoirSampler
{
public:
ReservoirSampler(size_t sample_count_ = DEFAULT_SAMPLE_COUNT)
: sample_count(sample_count_)
{
rng.seed(123456);
}
void clear()
{
samples.clear();
sorted = false;
total_values = 0;
rng.seed(123456);
}
void insert(const T & v)
{
sorted = false;
++total_values;
if (samples.size() < sample_count)
{
samples.push_back(v);
}
else
{
UInt64 rnd = genRandom(total_values);
if (rnd < sample_count)
samples[rnd] = v;
}
}
size_t size() const
{
return total_values;
}
T quantileNearest(double level)
{
if (samples.empty())
return onEmpty<T>();
sortIfNeeded();
double index = level * (samples.size() - 1);
size_t int_index = static_cast<size_t>(index + 0.5);
int_index = std::max(0LU, std::min(samples.size() - 1, int_index));
return samples[int_index];
}
/** Если T не числовой тип, использование этого метода вызывает ошибку компиляции,
* но использование класса ошибки не вызывает. SFINAE.
*/
double quantileInterpolated(double level)
{
if (samples.empty())
return onEmpty<double>();
sortIfNeeded();
double index = std::max(0., std::min(samples.size() - 1., level * (samples.size() - 1)));
/// Чтобы получить значение по дробному индексу линейно интерполируем между соседними значениями.
size_t left_index = static_cast<size_t>(index);
size_t right_index = left_index + 1;
if (right_index == samples.size())
return samples[left_index];
double left_coef = right_index - index;
double right_coef = index - left_index;
return samples[left_index] * left_coef + samples[right_index] * right_coef;
}
void merge(const ReservoirSampler<T, OnEmpty> & b)
{
if (sample_count != b.sample_count)
throw Poco::Exception("Cannot merge ReservoirSampler's with different sample_count");
sorted = false;
if (b.total_values <= sample_count)
{
for (size_t i = 0; i < b.samples.size(); ++i)
insert(b.samples[i]);
}
else if (total_values <= sample_count)
{
Array from = std::move(samples);
samples.assign(b.samples.begin(), b.samples.end());
total_values = b.total_values;
for (size_t i = 0; i < from.size(); ++i)
insert(from[i]);
}
else
{
randomShuffle(samples);
total_values += b.total_values;
for (size_t i = 0; i < sample_count; ++i)
{
UInt64 rnd = genRandom(total_values);
if (rnd < b.total_values)
samples[i] = b.samples[i];
}
}
}
void read(DB::ReadBuffer & buf)
{
DB::readIntBinary<size_t>(sample_count, buf);
DB::readIntBinary<size_t>(total_values, buf);
samples.resize(std::min(total_values, sample_count));
std::string rng_string;
DB::readStringBinary(rng_string, buf);
std::istringstream rng_stream(rng_string);
rng_stream >> rng;
for (size_t i = 0; i < samples.size(); ++i)
DB::readBinary(samples[i], buf);
sorted = false;
}
void write(DB::WriteBuffer & buf) const
{
DB::writeIntBinary<size_t>(sample_count, buf);
DB::writeIntBinary<size_t>(total_values, buf);
std::ostringstream rng_stream;
rng_stream << rng;
DB::writeStringBinary(rng_stream.str(), buf);
for (size_t i = 0; i < std::min(sample_count, total_values); ++i)
DB::writeBinary(samples[i], buf);
}
private:
friend void qdigest_test(int normal_size, UInt64 value_limit, const std::vector<UInt64> & values, int queries_count, bool verbose);
friend void rs_perf_test();
/// Будем выделять немного памяти на стеке - чтобы избежать аллокаций, когда есть много объектов с маленьким количеством элементов.
static constexpr size_t bytes_on_stack = 64;
using Array = DB::PODArray<T, bytes_on_stack / sizeof(T), AllocatorWithStackMemory<Allocator<false>, bytes_on_stack>>;
size_t sample_count;
size_t total_values = 0;
Array samples;
boost::taus88 rng;
bool sorted = false;
UInt64 genRandom(size_t lim)
{
/// При большом количестве значений будем генерировать случайные числа в несколько раз медленнее.
if (lim <= static_cast<UInt64>(rng.max()))
return static_cast<UInt32>(rng()) % static_cast<UInt32>(lim);
else
return (static_cast<UInt64>(rng()) * (static_cast<UInt64>(rng.max()) + 1ULL) + static_cast<UInt64>(rng())) % lim;
}
void randomShuffle(Array & v)
{
for (size_t i = 1; i < v.size(); ++i)
{
size_t j = genRandom(i + 1);
std::swap(v[i], v[j]);
}
}
void sortIfNeeded()
{
if (sorted)
return;
sorted = true;
std::sort(samples.begin(), samples.end(), Comparer());
}
template <typename ResultType>
ResultType onEmpty() const
{
if (OnEmpty == ReservoirSamplerOnEmpty::THROW)
throw Poco::Exception("Quantile of empty ReservoirSampler");
else
return NanLikeValueConstructor<ResultType, std::is_floating_point<ResultType>::value>::getValue();
}
};

View File

@ -1,16 +1,16 @@
#pragma once #pragma once
#include <limits> #include <limits>
#include <vector>
#include <algorithm> #include <algorithm>
#include <climits> #include <climits>
#include <sstream> #include <sstream>
#include <stats/ReservoirSampler.h> #include <DB/AggregateFunctions/ReservoirSampler.h>
#include <common/Common.h> #include <common/Common.h>
#include <DB/Common/HashTable/Hash.h> #include <DB/Common/HashTable/Hash.h>
#include <DB/IO/ReadBuffer.h> #include <DB/IO/ReadBuffer.h>
#include <DB/IO/ReadHelpers.h> #include <DB/IO/ReadHelpers.h>
#include <DB/IO/WriteHelpers.h> #include <DB/IO/WriteHelpers.h>
#include <DB/Common/PODArray.h>
#include <Poco/Exception.h> #include <Poco/Exception.h>
#include <boost/random.hpp> #include <boost/random.hpp>
@ -150,13 +150,15 @@ public:
} }
private: private:
friend void rs_perf_test(); /// Будем выделять немного памяти на стеке - чтобы избежать аллокаций, когда есть много объектов с маленьким количеством элементов.
friend void qdigest_test(int, UInt64, const std::vector<UInt64> &, int, bool); static constexpr size_t bytes_on_stack = 64;
using Element = std::pair<T, UInt32>;
using Array = DB::PODArray<Element, bytes_on_stack / sizeof(Element), AllocatorWithStackMemory<Allocator<false>, bytes_on_stack>>;
size_t sample_count; size_t sample_count;
size_t total_values{}; size_t total_values{};
bool sorted{}; bool sorted{};
std::vector<std::pair<T, UInt32>> samples; Array samples;
UInt8 skip_degree{}; UInt8 skip_degree{};
void insertImpl(const T & v, const UInt32 hash) void insertImpl(const T & v, const UInt32 hash)

View File

@ -0,0 +1,540 @@
#pragma once
#include <math.h>
#include <common/Common.h>
#include <DB/IO/WriteBuffer.h>
#include <DB/IO/WriteHelpers.h>
#include <DB/IO/ReadBuffer.h>
#include <DB/IO/ReadHelpers.h>
#include <DB/IO/VarInt.h>
#include <DB/Common/HashTable/HashTableAllocator.h>
#include <DB/Common/HashTable/Hash.h>
/** Приближённый рассчёт чего-угодно, как правило, построен по следующей схеме:
* - для рассчёта значения X используется некоторая структура данных;
* - в структуру данных добавляются не все значения, а только избранные (согласно некоторому критерию избранности);
* - после обработки всех элементов, структура данных находится в некотором состоянии S;
* - в качестве приближённого значения X возвращается значание, посчитанное по принципу максимального правдоподобия:
* при каком реальном значении X, вероятность нахождения структуры данных в полученном состоянии S максимальна.
*/
/** В частности, то, что описано ниже, можно найти по названию BJKST algorithm.
*/
/** Очень простое хэш-множество для приближённого подсчёта количества уникальных значений.
* Работает так:
* - вставлять можно UInt64;
* - перед вставкой, сначала вычисляется хэш-функция UInt64 -> UInt32;
* - исходное значение не сохраняется (теряется);
* - далее все операции производятся с этими хэшами;
* - хэш таблица построена по схеме:
* - open addressing (один буфер, позиция в буфере вычисляется взятием остатка от деления на его размер);
* - linear probing (если в ячейке уже есть значение, то берётся ячейка, следующая за ней и т. д.);
* - отсутствующее значение кодируется нулём; чтобы запомнить наличие в множестве нуля, используется отдельная переменная типа bool;
* - рост буфера в 2 раза при заполнении более чем на 50%;
* - если в множестве больше UNIQUES_HASH_MAX_SIZE элементов, то из множества удаляются все элементы,
* не делящиеся на 2, и затем все элементы, которые не делятся на 2, не вставляются в множество;
* - если ситуация повторяется, то берутся только элементы делящиеся на 4 и т. п.
* - метод size() возвращает приблизительное количество элементов, которые были вставлены в множество;
* - есть методы для быстрого чтения и записи в бинарный и текстовый вид.
*/
/// Максимальная степень размера буфера перед тем, как значения будут выкидываться
#define UNIQUES_HASH_MAX_SIZE_DEGREE 17
/// Максимальное количество элементов перед тем, как значения будут выкидываться
#define UNIQUES_HASH_MAX_SIZE (1 << (UNIQUES_HASH_MAX_SIZE_DEGREE - 1))
/** Количество младших бит, использующихся для прореживания. Оставшиеся старшие биты используются для определения позиции в хэш-таблице.
* (старшие биты берутся потому что младшие будут постоянными после выкидывания части значений)
*/
#define UNIQUES_HASH_BITS_FOR_SKIP (32 - UNIQUES_HASH_MAX_SIZE_DEGREE)
/// Начальная степень размера буфера
#define UNIQUES_HASH_SET_INITIAL_SIZE_DEGREE 4
/** Эта хэш-функция не самая оптимальная, но состояния UniquesHashSet, посчитанные с ней,
* хранятся много где на дисках (в Метраже), поэтому она продолжает использоваться.
*/
struct UniquesHashSetDefaultHash
{
size_t operator() (UInt64 x) const
{
return intHash32<0>(x);
}
};
template <typename Hash = UniquesHashSetDefaultHash>
class UniquesHashSet : private HashTableAllocatorWithStackMemory<(1 << UNIQUES_HASH_SET_INITIAL_SIZE_DEGREE) * sizeof(UInt32)>
{
private:
typedef UInt64 Value_t;
typedef UInt32 HashValue_t;
typedef HashTableAllocatorWithStackMemory<(1 << UNIQUES_HASH_SET_INITIAL_SIZE_DEGREE) * sizeof(UInt32)> Allocator;
UInt32 m_size; /// Количество элементов
UInt8 size_degree; /// Размер таблицы в виде степени двух
UInt8 skip_degree; /// Пропускать элементы не делящиеся на 2 ^ skip_degree
bool has_zero; /// Хэш-таблица содержит элемент со значением хэш-функции = 0.
HashValue_t * buf;
#ifdef UNIQUES_HASH_SET_COUNT_COLLISIONS
/// Для профилирования.
mutable size_t collisions;
#endif
void alloc(UInt8 new_size_degree)
{
buf = reinterpret_cast<HashValue_t *>(Allocator::alloc((1 << new_size_degree) * sizeof(buf[0])));
size_degree = new_size_degree;
}
void free()
{
if (buf)
{
Allocator::free(buf, buf_size() * sizeof(buf[0]));
buf = nullptr;
}
}
inline size_t buf_size() const { return 1 << size_degree; }
inline size_t max_fill() const { return 1 << (size_degree - 1); }
inline size_t mask() const { return buf_size() - 1; }
inline size_t place(HashValue_t x) const { return (x >> UNIQUES_HASH_BITS_FOR_SKIP) & mask(); }
/// Значение делится на 2 ^ skip_degree
inline bool good(HashValue_t hash) const
{
return hash == ((hash >> skip_degree) << skip_degree);
}
HashValue_t hash(Value_t key) const
{
return Hash()(key);
}
/// Удалить все значения, хэши которых не делятся на 2 ^ skip_degree
void rehash()
{
for (size_t i = 0; i < buf_size(); ++i)
{
if (buf[i] && !good(buf[i]))
{
buf[i] = 0;
--m_size;
}
}
/** После удаления элементов, возможно, освободилось место для элементов,
* которые были помещены дальше, чем нужно, из-за коллизии.
* Надо переместить их.
*/
for (size_t i = 0; i < buf_size(); ++i)
{
if (unlikely(buf[i] && i != place(buf[i])))
{
HashValue_t x = buf[i];
buf[i] = 0;
reinsertImpl(x);
}
}
}
/// Увеличить размер буфера в 2 раза или до new_size_degree, если указана ненулевая.
void resize(size_t new_size_degree = 0)
{
size_t old_size = buf_size();
if (!new_size_degree)
new_size_degree = size_degree + 1;
/// Расширим пространство.
buf = reinterpret_cast<HashValue_t *>(Allocator::realloc(buf, old_size * sizeof(buf[0]), (1 << new_size_degree) * sizeof(buf[0])));
size_degree = new_size_degree;
/** Теперь некоторые элементы может потребоваться переместить на новое место.
* Элемент может остаться на месте, или переместиться в новое место "справа",
* или переместиться левее по цепочке разрешения коллизий, из-за того, что элементы левее него были перемещены в новое место "справа".
* Также имеется особый случай:
* если элемент должен был быть в конце старого буфера, [ x]
* но находится в начале из-за цепочки разрешения коллизий, [o x]
* то после ресайза, он сначала снова окажется не на своём месте, [ xo ]
* и для того, чтобы перенести его куда надо,
* надо будет после переноса всех элементов из старой половинки [ o x ]
* обработать ещё хвостик из цепочки разрешения коллизий сразу после неё [ o x ]
* Именно для этого написано || buf[i] ниже.
*/
for (size_t i = 0; i < old_size || buf[i]; ++i)
{
HashValue_t x = buf[i];
if (!x)
continue;
size_t place_value = place(x);
/// Элемент на своём месте.
if (place_value == i)
continue;
while (buf[place_value] && buf[place_value] != x)
{
++place_value;
place_value &= mask();
#ifdef UNIQUES_HASH_SET_COUNT_COLLISIONS
++collisions;
#endif
}
/// Элемент остался на своём месте.
if (buf[place_value] == x)
continue;
buf[place_value] = x;
buf[i] = 0;
}
}
/// Вставить значение.
void insertImpl(HashValue_t x)
{
if (x == 0)
{
m_size += !has_zero;
has_zero = true;
return;
}
size_t place_value = place(x);
while (buf[place_value] && buf[place_value] != x)
{
++place_value;
place_value &= mask();
#ifdef UNIQUES_HASH_SET_COUNT_COLLISIONS
++collisions;
#endif
}
if (buf[place_value] == x)
return;
buf[place_value] = x;
++m_size;
}
/** Вставить в новый буфер значение, которое было в старом буфере.
* Используется при увеличении размера буфера, а также при чтении из файла.
*/
void reinsertImpl(HashValue_t x)
{
size_t place_value = place(x);
while (buf[place_value])
{
++place_value;
place_value &= mask();
#ifdef UNIQUES_HASH_SET_COUNT_COLLISIONS
++collisions;
#endif
}
buf[place_value] = x;
}
/** Если хэш-таблица достаточно заполнена, то сделать resize.
* Если элементов слишком много - то выкидывать половину, пока их не станет достаточно мало.
*/
void shrinkIfNeed()
{
if (unlikely(m_size > max_fill()))
{
if (m_size > UNIQUES_HASH_MAX_SIZE)
{
while (m_size > UNIQUES_HASH_MAX_SIZE)
{
++skip_degree;
rehash();
}
}
else
resize();
}
}
public:
UniquesHashSet() :
m_size(0),
skip_degree(0),
has_zero(false)
{
alloc(UNIQUES_HASH_SET_INITIAL_SIZE_DEGREE);
#ifdef UNIQUES_HASH_SET_COUNT_COLLISIONS
collisions = 0;
#endif
}
UniquesHashSet(const UniquesHashSet & rhs)
: m_size(rhs.m_size), skip_degree(rhs.skip_degree), has_zero(rhs.has_zero)
{
alloc(rhs.size_degree);
memcpy(buf, rhs.buf, buf_size() * sizeof(buf[0]));
}
UniquesHashSet & operator= (const UniquesHashSet & rhs)
{
if (size_degree != rhs.size_degree)
{
free();
alloc(rhs.size_degree);
}
m_size = rhs.m_size;
skip_degree = rhs.skip_degree;
has_zero = rhs.has_zero;
memcpy(buf, rhs.buf, buf_size() * sizeof(buf[0]));
return *this;
}
~UniquesHashSet()
{
free();
}
void insert(Value_t x)
{
HashValue_t hash_value = hash(x);
if (!good(hash_value))
return;
insertImpl(hash_value);
shrinkIfNeed();
}
size_t size() const
{
if (0 == skip_degree)
return m_size;
size_t res = m_size * (1 << skip_degree);
/** Псевдослучайный остаток - для того, чтобы не было видно,
* что количество делится на степень двух.
*/
res += (intHashCRC32(m_size) & ((1 << skip_degree) - 1));
/** Коррекция систематической погрешности из-за коллизий при хэшировании в UInt32.
* Формула fixed_res(res)
* - при каком количестве разных элементов fixed_res,
* при их случайном разбрасывании по 2^32 корзинам,
* получается в среднем res заполненных корзин.
*/
size_t p32 = 1ULL << 32;
size_t fixed_res = round(p32 * (log(p32) - log(p32 - res)));
return fixed_res;
}
void merge(const UniquesHashSet & rhs)
{
if (rhs.skip_degree > skip_degree)
{
skip_degree = rhs.skip_degree;
rehash();
}
if (!has_zero && rhs.has_zero)
{
has_zero = true;
++m_size;
shrinkIfNeed();
}
for (size_t i = 0; i < rhs.buf_size(); ++i)
{
if (rhs.buf[i] && good(rhs.buf[i]))
{
insertImpl(rhs.buf[i]);
shrinkIfNeed();
}
}
}
void write(DB::WriteBuffer & wb) const
{
if (m_size > UNIQUES_HASH_MAX_SIZE)
throw Poco::Exception("Cannot write UniquesHashSet: too large size_degree.");
DB::writeIntBinary(skip_degree, wb);
DB::writeVarUInt(m_size, wb);
if (has_zero)
{
HashValue_t x = 0;
DB::writeIntBinary(x, wb);
}
for (size_t i = 0; i < buf_size(); ++i)
if (buf[i])
DB::writeIntBinary(buf[i], wb);
}
void read(DB::ReadBuffer & rb)
{
has_zero = false;
DB::readIntBinary(skip_degree, rb);
DB::readVarUInt(m_size, rb);
if (m_size > UNIQUES_HASH_MAX_SIZE)
throw Poco::Exception("Cannot read UniquesHashSet: too large size_degree.");
free();
UInt8 new_size_degree = m_size <= 1
? UNIQUES_HASH_SET_INITIAL_SIZE_DEGREE
: std::max(UNIQUES_HASH_SET_INITIAL_SIZE_DEGREE, static_cast<int>(log2(m_size - 1)) + 2);
alloc(new_size_degree);
for (size_t i = 0; i < m_size; ++i)
{
HashValue_t x = 0;
DB::readIntBinary(x, rb);
if (x == 0)
has_zero = true;
else
reinsertImpl(x);
}
}
void readAndMerge(DB::ReadBuffer & rb)
{
UInt8 rhs_skip_degree = 0;
DB::readIntBinary(rhs_skip_degree, rb);
if (rhs_skip_degree > skip_degree)
{
skip_degree = rhs_skip_degree;
rehash();
}
size_t rhs_size = 0;
DB::readVarUInt(rhs_size, rb);
if (rhs_size > UNIQUES_HASH_MAX_SIZE)
throw Poco::Exception("Cannot read UniquesHashSet: too large size_degree.");
if ((1U << size_degree) < rhs_size)
{
UInt8 new_size_degree = std::max(UNIQUES_HASH_SET_INITIAL_SIZE_DEGREE, static_cast<int>(log2(rhs_size - 1)) + 2);
resize(new_size_degree);
}
for (size_t i = 0; i < rhs_size; ++i)
{
HashValue_t x = 0;
DB::readIntBinary(x, rb);
insertHash(x);
}
}
static void skip(DB::ReadBuffer & rb)
{
size_t size = 0;
rb.ignore();
DB::readVarUInt(size, rb);
if (size > UNIQUES_HASH_MAX_SIZE)
throw Poco::Exception("Cannot read UniquesHashSet: too large size_degree.");
rb.ignore(sizeof(HashValue_t) * size);
}
void writeText(DB::WriteBuffer & wb) const
{
if (m_size > UNIQUES_HASH_MAX_SIZE)
throw Poco::Exception("Cannot write UniquesHashSet: too large size_degree.");
DB::writeIntText(skip_degree, wb);
wb.write(",", 1);
DB::writeIntText(m_size, wb);
if (has_zero)
wb.write(",0", 2);
for (size_t i = 0; i < buf_size(); ++i)
{
if (buf[i])
{
wb.write(",", 1);
DB::writeIntText(buf[i], wb);
}
}
}
void readText(DB::ReadBuffer & rb)
{
has_zero = false;
DB::readIntText(skip_degree, rb);
DB::assertString(",", rb);
DB::readIntText(m_size, rb);
if (m_size > UNIQUES_HASH_MAX_SIZE)
throw Poco::Exception("Cannot read UniquesHashSet: too large size_degree.");
free();
UInt8 new_size_degree = m_size <= 1
? UNIQUES_HASH_SET_INITIAL_SIZE_DEGREE
: std::max(UNIQUES_HASH_SET_INITIAL_SIZE_DEGREE, static_cast<int>(log2(m_size - 1)) + 2);
alloc(new_size_degree);
for (size_t i = 0; i < m_size; ++i)
{
HashValue_t x = 0;
DB::assertString(",", rb);
DB::readIntText(x, rb);
if (x == 0)
has_zero = true;
else
reinsertImpl(x);
}
}
void insertHash(HashValue_t hash_value)
{
if (!good(hash_value))
return;
insertImpl(hash_value);
shrinkIfNeed();
}
#ifdef UNIQUES_HASH_SET_COUNT_COLLISIONS
size_t getCollisions() const
{
return collisions;
}
#endif
};
#undef UNIQUES_HASH_MAX_SIZE_DEGREE
#undef UNIQUES_HASH_MAX_SIZE
#undef UNIQUES_HASH_BITS_FOR_SKIP
#undef UNIQUES_HASH_SET_INITIAL_SIZE_DEGREE

View File

@ -26,8 +26,6 @@ namespace DB
using Poco::SharedPtr; using Poco::SharedPtr;
class ParallelReplicas;
/// Поток блоков читающих из таблицы и ее имя /// Поток блоков читающих из таблицы и ее имя
typedef std::pair<BlockInputStreamPtr, std::string> ExternalTableData; typedef std::pair<BlockInputStreamPtr, std::string> ExternalTableData;
/// Вектор пар, описывающих таблицы /// Вектор пар, описывающих таблицы
@ -44,6 +42,7 @@ typedef std::vector<ExternalTableData> ExternalTablesData;
class Connection : private boost::noncopyable class Connection : private boost::noncopyable
{ {
friend class ParallelReplicas; friend class ParallelReplicas;
friend class MultiplexedConnections;
public: public:
Connection(const String & host_, UInt16 port_, const String & default_database_, Connection(const String & host_, UInt16 port_, const String & default_database_,

View File

@ -57,7 +57,7 @@ protected:
typedef SharedPtr<IConnectionPool> ConnectionPoolPtr; typedef SharedPtr<IConnectionPool> ConnectionPoolPtr;
typedef std::vector<ConnectionPoolPtr> ConnectionPools; typedef std::vector<ConnectionPoolPtr> ConnectionPools;
typedef SharedPtr<ConnectionPools> ConnectionPoolsPtr;
/** Обычный пул соединений, без отказоустойчивости. /** Обычный пул соединений, без отказоустойчивости.

View File

@ -6,30 +6,38 @@
#include <Poco/ScopedLock.h> #include <Poco/ScopedLock.h>
#include <Poco/Mutex.h> #include <Poco/Mutex.h>
namespace DB namespace DB
{ {
/** Для получения данных сразу из нескольких реплик (соединений) в рамках одного потока. /** Для получения данных сразу из нескольких реплик (соединений) из одного или нексольких шардов
* В качестве вырожденного случая, может также работать с одним соединением. * в рамках одного потока. В качестве вырожденного случая, может также работать с одним соединением.
* Предполагается, что все функции кроме sendCancel всегда выполняются в одном потоке. * Предполагается, что все функции кроме sendCancel всегда выполняются в одном потоке.
* *
* Интерфейс почти совпадает с Connection. * Интерфейс почти совпадает с Connection.
*/ */
class ParallelReplicas final : private boost::noncopyable class MultiplexedConnections final : private boost::noncopyable
{ {
public: public:
/// Принимает готовое соединение. /// Принимает готовое соединение.
ParallelReplicas(Connection * connection_, const Settings * settings_, ThrottlerPtr throttler_); MultiplexedConnections(Connection * connection_, const Settings * settings_, ThrottlerPtr throttler_);
/** Принимает пул, из которого нужно будет достать одно или несколько соединений. /** Принимает пул, из которого нужно будет достать одно или несколько соединений.
* Если флаг append_extra_info установлен, к каждому полученному блоку прилагается * Если флаг append_extra_info установлен, к каждому полученному блоку прилагается
* дополнительная информация. * дополнительная информация.
* Если флаг get_all_replicas установлен, достаются все соединения. * Если флаг get_all_replicas установлен, достаются все соединения.
*/ */
ParallelReplicas(IConnectionPool * pool_, const Settings * settings_, ThrottlerPtr throttler_, MultiplexedConnections(IConnectionPool * pool_, const Settings * settings_, ThrottlerPtr throttler_,
bool append_extra_info = false, bool get_all_replicas = false); bool append_extra_info = false, bool do_broadcast = false);
/** Принимает пулы, один для каждого шарда, из которих нужно будет достать одно или несколько
* соединений.
* Если флаг append_extra_info установлен, к каждому полученному блоку прилагается
* дополнительная информация.
* Если флаг do_broadcast установлен, достаются все соединения.
*/
MultiplexedConnections(ConnectionPools & pools_, const Settings * settings_, ThrottlerPtr throttler_,
bool append_extra_info = false, bool do_broadcast = false);
/// Отправить на реплики всё содержимое внешних таблиц. /// Отправить на реплики всё содержимое внешних таблиц.
void sendExternalTablesData(std::vector<ExternalTablesData> & data); void sendExternalTablesData(std::vector<ExternalTablesData> & data);
@ -65,15 +73,44 @@ public:
/// Проверить, есть ли действительные реплики. /// Проверить, есть ли действительные реплики.
/// Без блокировки, потому что sendCancel() не меняет состояние реплик. /// Без блокировки, потому что sendCancel() не меняет состояние реплик.
bool hasActiveReplicas() const { return active_replica_count > 0; } bool hasActiveConnections() const { return active_connection_total_count > 0; }
private: private:
/// Реплики хэшированные по id сокета /// Соединения 1-го шарда, затем соединения 2-го шарда, и т.д.
using ReplicaMap = std::unordered_map<int, Connection *>; using Connections = std::vector<Connection *>;
/// Состояние соединений одного шарда.
struct ShardState
{
/// Количество выделенных соединений, т.е. реплик, для этого шарда.
size_t allocated_connection_count;
/// Текущее количество действительных соединений к репликам этого шарда.
size_t active_connection_count;
};
/// Описание одной реплики.
struct ReplicaState
{
/// Индекс соединения.
size_t connection_index;
/// Владелец этой реплики.
ShardState * shard_state;
};
/// Реплики хэшированные по id сокета.
using ReplicaMap = std::unordered_map<int, ReplicaState>;
/// Состояние каждого шарда.
using ShardStates = std::vector<ShardState>;
private: private:
/// Зарегистрировать реплику. void initFromShard(IConnectionPool * pool);
void registerReplica(Connection * connection);
/// Зарегистрировать шарды.
void registerShards();
/// Зарегистрировать реплики одного шарда.
void registerReplicas(size_t index_begin, size_t index_end, ShardState & shard_state);
/// Внутренняя версия функции receivePacket без блокировки. /// Внутренняя версия функции receivePacket без блокировки.
Connection::Packet receivePacketUnlocked(); Connection::Packet receivePacketUnlocked();
@ -94,13 +131,15 @@ private:
private: private:
const Settings * settings; const Settings * settings;
Connections connections;
ReplicaMap replica_map; ReplicaMap replica_map;
ShardStates shard_states;
/// Если не nullptr, то используется, чтобы ограничить сетевой трафик. /// Если не nullptr, то используется, чтобы ограничить сетевой трафик.
ThrottlerPtr throttler; ThrottlerPtr throttler;
std::vector<ConnectionPool::Entry> pool_entries; std::vector<ConnectionPool::Entry> pool_entries;
ConnectionPool::Entry pool_entry;
/// Соединение, c которого был получен последний блок. /// Соединение, c которого был получен последний блок.
Connection * current_connection; Connection * current_connection;
@ -108,7 +147,7 @@ private:
std::unique_ptr<BlockExtraInfo> block_extra_info; std::unique_ptr<BlockExtraInfo> block_extra_info;
/// Текущее количество действительных соединений к репликам. /// Текущее количество действительных соединений к репликам.
size_t active_replica_count; size_t active_connection_total_count = 0;
/// Запрос выполняется параллельно на нескольких репликах. /// Запрос выполняется параллельно на нескольких репликах.
bool supports_parallel_execution; bool supports_parallel_execution;
/// Отправили запрос /// Отправили запрос
@ -116,6 +155,8 @@ private:
/// Отменили запрос /// Отменили запрос
bool cancelled = false; bool cancelled = false;
bool do_broadcast = false;
/// Мьютекс для того, чтобы функция sendCancel могла выполняться безопасно /// Мьютекс для того, чтобы функция sendCancel могла выполняться безопасно
/// в отдельном потоке. /// в отдельном потоке.
mutable Poco::FastMutex cancel_mutex; mutable Poco::FastMutex cancel_mutex;

View File

@ -103,18 +103,9 @@ public:
arenas.push_back(arena_); arenas.push_back(arena_);
} }
ColumnPtr convertToValues() const /** Преобразовать столбец состояний агрегатной функции в столбец с готовыми значениями результатов.
{ */
const IAggregateFunction * function = holder->func; ColumnPtr convertToValues() const;
ColumnPtr res = function->getReturnType()->createColumn();
IColumn & column = *res;
res->reserve(getData().size());
for (auto val : getData())
function->insertResultInto(val, column);
return res;
}
std::string getName() const override { return "ColumnAggregateFunction"; } std::string getName() const override { return "ColumnAggregateFunction"; }
@ -174,6 +165,9 @@ public:
{ {
IAggregateFunction * function = holder.get()->func; IAggregateFunction * function = holder.get()->func;
if (unlikely(arenas.empty()))
arenas.emplace_back(new Arena);
getData().push_back(arenas.back().get()->alloc(function->sizeOfData())); getData().push_back(arenas.back().get()->alloc(function->sizeOfData()));
function->create(getData().back()); function->create(getData().back());
ReadBufferFromString read_buffer(x.get<const String &>()); ReadBufferFromString read_buffer(x.get<const String &>());
@ -200,22 +194,21 @@ public:
return getData().size() * sizeof(getData()[0]); return getData().size() * sizeof(getData()[0]);
} }
ColumnPtr cut(size_t start, size_t length) const override void insertRangeFrom(const IColumn & src, size_t start, size_t length) override
{ {
if (start + length > getData().size()) const ColumnAggregateFunction & src_concrete = static_cast<const ColumnAggregateFunction &>(src);
if (start + length > src_concrete.getData().size())
throw Exception("Parameters start = " throw Exception("Parameters start = "
+ toString(start) + ", length = " + toString(start) + ", length = "
+ toString(length) + " are out of bound in ColumnAggregateFunction::cut() method" + toString(length) + " are out of bound in ColumnAggregateFunction::insertRangeFrom method"
" (data.size() = " + toString(getData().size()) + ").", " (data.size() = " + toString(src_concrete.getData().size()) + ").",
ErrorCodes::PARAMETER_OUT_OF_BOUND); ErrorCodes::PARAMETER_OUT_OF_BOUND);
ColumnAggregateFunction * res_ = new ColumnAggregateFunction(*this); auto & data = getData();
ColumnPtr res = res_; size_t old_size = data.size();
data.resize(old_size + length);
res_->getData().resize(length); memcpy(&data[old_size], &src_concrete.getData()[start], length * sizeof(data[0]));
for (size_t i = 0; i < length; ++i)
res_->getData()[i] = getData()[start + i];
return res;
} }
ColumnPtr filter(const Filter & filter) const override ColumnPtr filter(const Filter & filter) const override

View File

@ -147,7 +147,7 @@ public:
return pos; return pos;
} }
ColumnPtr cut(size_t start, size_t length) const override; void insertRangeFrom(const IColumn & src, size_t start, size_t length) override;
void insert(const Field & x) override void insert(const Field & x) override
{ {
@ -293,9 +293,9 @@ private:
size_t ALWAYS_INLINE sizeAt(size_t i) const { return i == 0 ? getOffsets()[0] : (getOffsets()[i] - getOffsets()[i - 1]); } size_t ALWAYS_INLINE sizeAt(size_t i) const { return i == 0 ? getOffsets()[0] : (getOffsets()[i] - getOffsets()[i - 1]); }
/// Размножить значения, если вложенный столбец - ColumnArray<T>. /// Размножить значения, если вложенный столбец - ColumnVector<T>.
template <typename T> template <typename T>
ColumnPtr replicate(const Offsets_t & replicate_offsets) const; ColumnPtr replicateNumber(const Offsets_t & replicate_offsets) const;
/// Размножить значения, если вложенный столбец - ColumnString. Код слишком сложный. /// Размножить значения, если вложенный столбец - ColumnString. Код слишком сложный.
ColumnPtr replicateString(const Offsets_t & replicate_offsets) const; ColumnPtr replicateString(const Offsets_t & replicate_offsets) const;
@ -306,6 +306,14 @@ private:
* Только ради неё сделана реализация метода replicate для ColumnArray(ColumnConst). * Только ради неё сделана реализация метода replicate для ColumnArray(ColumnConst).
*/ */
ColumnPtr replicateConst(const Offsets_t & replicate_offsets) const; ColumnPtr replicateConst(const Offsets_t & replicate_offsets) const;
/// Специализации для функции filter.
template <typename T>
ColumnPtr filterNumber(const Filter & filt) const;
ColumnPtr filterString(const Filter & filt) const;
ColumnPtr filterGeneric(const Filter & filt) const;
}; };

View File

@ -7,6 +7,7 @@
#include <DB/Core/ErrorCodes.h> #include <DB/Core/ErrorCodes.h>
#include <DB/Columns/ColumnVector.h> #include <DB/Columns/ColumnVector.h>
#include <DB/Columns/IColumn.h> #include <DB/Columns/IColumn.h>
#include <DB/Columns/ColumnsCommon.h>
#include <DB/DataTypes/IDataType.h> #include <DB/DataTypes/IDataType.h>
@ -62,9 +63,13 @@ public:
Field operator[](size_t n) const override { return FieldType(getDataFromHolder()); } Field operator[](size_t n) const override { return FieldType(getDataFromHolder()); }
void get(size_t n, Field & res) const override { res = FieldType(getDataFromHolder()); } void get(size_t n, Field & res) const override { res = FieldType(getDataFromHolder()); }
ColumnPtr cut(size_t start, size_t length) const override void insertRangeFrom(const IColumn & src, size_t start, size_t length) override
{ {
return new Derived(length, data, data_type); if (getDataFromHolder() != static_cast<const Derived &>(src).getDataFromHolder())
throw Exception("Cannot insert different element into constant column " + getName(),
ErrorCodes::CANNOT_INSERT_ELEMENT_INTO_CONSTANT_COLUMN);
s += length;
} }
void insert(const Field & x) override void insert(const Field & x) override

View File

@ -5,6 +5,7 @@
#include <DB/Common/PODArray.h> #include <DB/Common/PODArray.h>
#include <DB/Common/Arena.h> #include <DB/Common/Arena.h>
#include <DB/Columns/IColumn.h> #include <DB/Columns/IColumn.h>
#include <DB/IO/ReadHelpers.h>
namespace DB namespace DB
@ -172,13 +173,20 @@ public:
} }
} }
ColumnPtr cut(size_t start, size_t length) const override void insertRangeFrom(const IColumn & src, size_t start, size_t length) override
{ {
ColumnFixedString * res_ = new ColumnFixedString(n); const ColumnFixedString & src_concrete = static_cast<const ColumnFixedString &>(src);
ColumnPtr res = res_;
res_->chars.resize(n * length); if (start + length > src_concrete.size())
memcpy(&res_->chars[0], &chars[n * start], n * length); throw Exception("Parameters start = "
return res; + toString(start) + ", length = "
+ toString(length) + " are out of bound in ColumnFixedString::insertRangeFrom method"
" (size() = " + toString(src_concrete.size()) + ").",
ErrorCodes::PARAMETER_OUT_OF_BOUND);
size_t old_size = chars.size();
chars.resize(old_size + length * n);
memcpy(&chars[old_size], &src_concrete.chars[start * n], length * n);
} }
ColumnPtr filter(const IColumn::Filter & filt) const override ColumnPtr filter(const IColumn::Filter & filt) const override

View File

@ -14,7 +14,7 @@ namespace DB
class ColumnSet final : public IColumnDummy class ColumnSet final : public IColumnDummy
{ {
public: public:
ColumnSet(size_t s_, SetPtr data_) : IColumnDummy(s_), data(data_) {} ColumnSet(size_t s_, ConstSetPtr data_) : IColumnDummy(s_), data(data_) {}
/// Столбец не константный. Иначе столбец будет использоваться в вычислениях в ExpressionActions::prepare, когда множество из подзапроса ещё не готово. /// Столбец не константный. Иначе столбец будет использоваться в вычислениях в ExpressionActions::prepare, когда множество из подзапроса ещё не готово.
bool isConst() const override { return false; } bool isConst() const override { return false; }
@ -22,11 +22,10 @@ public:
std::string getName() const override { return "ColumnSet"; } std::string getName() const override { return "ColumnSet"; }
ColumnPtr cloneDummy(size_t s_) const override { return new ColumnSet(s_, data); } ColumnPtr cloneDummy(size_t s_) const override { return new ColumnSet(s_, data); }
SetPtr & getData() { return data; } ConstSetPtr getData() const { return data; }
const SetPtr & getData() const { return data; }
private: private:
SetPtr data; ConstSetPtr data;
}; };
} }

View File

@ -5,6 +5,7 @@
#include <DB/Core/Defines.h> #include <DB/Core/Defines.h>
#include <DB/Columns/IColumn.h> #include <DB/Columns/IColumn.h>
#include <DB/Columns/ColumnsCommon.h>
#include <DB/Common/Collator.h> #include <DB/Common/Collator.h>
#include <DB/Common/PODArray.h> #include <DB/Common/PODArray.h>
#include <DB/Common/Arena.h> #include <DB/Common/Arena.h>
@ -144,48 +145,42 @@ public:
return pos + string_size; return pos + string_size;
} }
ColumnPtr cut(size_t start, size_t length) const override void insertRangeFrom(const IColumn & src, size_t start, size_t length) override
{ {
if (length == 0) if (length == 0)
return new ColumnString; return;
if (start + length > offsets.size()) const ColumnString & src_concrete = static_cast<const ColumnString &>(src);
throw Exception("Parameter out of bound in IColumnString::cut() method.",
if (start + length > src_concrete.offsets.size())
throw Exception("Parameter out of bound in IColumnString::insertRangeFrom method.",
ErrorCodes::PARAMETER_OUT_OF_BOUND); ErrorCodes::PARAMETER_OUT_OF_BOUND);
size_t nested_offset = offsetAt(start); size_t nested_offset = src_concrete.offsetAt(start);
size_t nested_length = offsets[start + length - 1] - nested_offset; size_t nested_length = src_concrete.offsets[start + length - 1] - nested_offset;
ColumnString * res_ = new ColumnString; size_t old_chars_size = chars.size();
ColumnPtr res = res_; chars.resize(old_chars_size + nested_length);
memcpy(&chars[old_chars_size], &src_concrete.chars[nested_offset], nested_length);
res_->chars.resize(nested_length); if (start == 0 && offsets.empty())
memcpy(&res_->chars[0], &chars[nested_offset], nested_length);
Offsets_t & res_offsets = res_->offsets;
if (start == 0)
{ {
res_offsets.assign(offsets.begin(), offsets.begin() + length); offsets.assign(src_concrete.offsets.begin(), src_concrete.offsets.begin() + length);
} }
else else
{ {
res_offsets.resize(length); size_t old_size = offsets.size();
size_t prev_max_offset = old_size ? offsets.back() : 0;
offsets.resize(old_size + length);
for (size_t i = 0; i < length; ++i) for (size_t i = 0; i < length; ++i)
res_offsets[i] = offsets[start + i] - nested_offset; offsets[old_size + i] = src_concrete.offsets[start + i] - nested_offset + prev_max_offset;
} }
return res;
} }
ColumnPtr filter(const Filter & filt) const override ColumnPtr filter(const Filter & filt) const override
{ {
const size_t size = offsets.size(); if (offsets.size() == 0)
if (size != filt.size())
throw Exception("Size of filter doesn't match size of column.", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
if (size == 0)
return new ColumnString; return new ColumnString;
auto res = new ColumnString; auto res = new ColumnString;
@ -193,96 +188,8 @@ public:
Chars_t & res_chars = res->chars; Chars_t & res_chars = res->chars;
Offsets_t & res_offsets = res->offsets; Offsets_t & res_offsets = res->offsets;
res_chars.reserve(chars.size());
res_offsets.reserve(size);
Offset_t current_offset = 0;
const UInt8 * filt_pos = &filt[0];
const auto filt_end = filt_pos + size;
const auto filt_end_aligned = filt_pos + size / 16 * 16;
auto offsets_pos = &offsets[0];
const auto offsets_begin = offsets_pos;
const __m128i zero16 = _mm_setzero_si128();
/// copy string ending at *end_offset_ptr
const auto copy_string = [&] (const Offset_t * offset_ptr) {
const auto offset = offset_ptr == offsets_begin ? 0 : offset_ptr[-1];
const auto size = *offset_ptr - offset;
current_offset += size;
res_offsets.push_back(current_offset);
const auto chars_size_old = res_chars.size();
res_chars.resize_assume_reserved(chars_size_old + size);
memcpy(&res_chars[chars_size_old], &chars[offset], size);
};
while (filt_pos < filt_end_aligned)
{
const auto mask = _mm_movemask_epi8(_mm_cmpgt_epi8(
_mm_loadu_si128(reinterpret_cast<const __m128i *>(filt_pos)),
zero16));
if (mask == 0)
{
/// 16 consecutive rows do not pass the filter
}
else if (mask == 0xffff)
{
/// 16 consecutive rows pass the filter
const auto first = offsets_pos == offsets_begin;
const auto chunk_offset = first ? 0 : offsets_pos[-1];
const auto chunk_size = offsets_pos[16 - 1] - chunk_offset;
const auto offsets_size_old = res_offsets.size();
res_offsets.resize(offsets_size_old + 16);
memcpy(&res_offsets[offsets_size_old], offsets_pos, 16 * sizeof(Offset_t));
if (!first)
{
/// difference between current and actual offset
const auto diff_offset = chunk_offset - current_offset;
if (diff_offset > 0)
{
const auto res_offsets_pos = &res_offsets[offsets_size_old];
/// adjust offsets
for (size_t i = 0; i < 16; ++i)
res_offsets_pos[i] -= diff_offset;
}
}
current_offset += chunk_size;
/// copy characters for 16 strings at once
const auto chars_size_old = res_chars.size();
res_chars.resize(chars_size_old + chunk_size);
memcpy(&res_chars[chars_size_old], &chars[chunk_offset], chunk_size);
}
else
{
for (size_t i = 0; i < 16; ++i)
if (filt_pos[i])
copy_string(offsets_pos + i);
}
filt_pos += 16;
offsets_pos += 16;
}
while (filt_pos < filt_end)
{
if (*filt_pos)
copy_string(offsets_pos);
++filt_pos;
++offsets_pos;
}
filterArraysImpl<UInt8>(chars, offsets, res_chars, res_offsets, filt);
return res_; return res_;
} }

View File

@ -23,7 +23,7 @@ public:
size_t size = data.columns(); size_t size = data.columns();
columns.resize(size); columns.resize(size);
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
columns[i] = data.getByPosition(i).column; columns[i] = data.unsafeGetByPosition(i).column;
} }
std::string getName() const override { return "Tuple"; } std::string getName() const override { return "Tuple"; }
@ -115,15 +115,12 @@ public:
return pos; return pos;
} }
void insertRangeFrom(const IColumn & src, size_t start, size_t length) override
ColumnPtr cut(size_t start, size_t length) const override
{ {
Block res_block = data.cloneEmpty();
for (size_t i = 0; i < columns.size(); ++i) for (size_t i = 0; i < columns.size(); ++i)
res_block.getByPosition(i).column = data.getByPosition(i).column->cut(start, length); data.unsafeGetByPosition(i).column->insertRangeFrom(
*static_cast<const ColumnTuple &>(src).data.unsafeGetByPosition(i).column.get(),
return new ColumnTuple(res_block); start, length);
} }
ColumnPtr filter(const Filter & filt) const override ColumnPtr filter(const Filter & filt) const override
@ -131,7 +128,7 @@ public:
Block res_block = data.cloneEmpty(); Block res_block = data.cloneEmpty();
for (size_t i = 0; i < columns.size(); ++i) for (size_t i = 0; i < columns.size(); ++i)
res_block.getByPosition(i).column = data.getByPosition(i).column->filter(filt); res_block.unsafeGetByPosition(i).column = data.unsafeGetByPosition(i).column->filter(filt);
return new ColumnTuple(res_block); return new ColumnTuple(res_block);
} }
@ -141,7 +138,7 @@ public:
Block res_block = data.cloneEmpty(); Block res_block = data.cloneEmpty();
for (size_t i = 0; i < columns.size(); ++i) for (size_t i = 0; i < columns.size(); ++i)
res_block.getByPosition(i).column = data.getByPosition(i).column->permute(perm, limit); res_block.unsafeGetByPosition(i).column = data.unsafeGetByPosition(i).column->permute(perm, limit);
return new ColumnTuple(res_block); return new ColumnTuple(res_block);
} }
@ -151,7 +148,7 @@ public:
Block res_block = data.cloneEmpty(); Block res_block = data.cloneEmpty();
for (size_t i = 0; i < columns.size(); ++i) for (size_t i = 0; i < columns.size(); ++i)
res_block.getByPosition(i).column = data.getByPosition(i).column->replicate(offsets); res_block.unsafeGetByPosition(i).column = data.unsafeGetByPosition(i).column->replicate(offsets);
return new ColumnTuple(res_block); return new ColumnTuple(res_block);
} }

View File

@ -254,18 +254,20 @@ public:
data.push_back(DB::get<typename NearestFieldType<T>::Type>(x)); data.push_back(DB::get<typename NearestFieldType<T>::Type>(x));
} }
ColumnPtr cut(size_t start, size_t length) const override void insertRangeFrom(const IColumn & src, size_t start, size_t length) override
{ {
if (start + length > data.size()) const ColumnVector & src_vec = static_cast<const ColumnVector &>(src);
if (start + length > src_vec.data.size())
throw Exception("Parameters start = " throw Exception("Parameters start = "
+ toString(start) + ", length = " + toString(start) + ", length = "
+ toString(length) + " are out of bound in IColumnVector<T>::cut() method" + toString(length) + " are out of bound in ColumnVector::insertRangeFrom method"
" (data.size() = " + toString(data.size()) + ").", " (data.size() = " + toString(src_vec.data.size()) + ").",
ErrorCodes::PARAMETER_OUT_OF_BOUND); ErrorCodes::PARAMETER_OUT_OF_BOUND);
Self * res = new Self(length); size_t old_size = data.size();
memcpy(&res->getData()[0], &data[start], length * sizeof(data[0])); data.resize(old_size + length);
return res; memcpy(&data[old_size], &src_vec.data[start], length * sizeof(data[0]));
} }
ColumnPtr filter(const IColumn::Filter & filt) const override ColumnPtr filter(const IColumn::Filter & filt) const override

View File

@ -0,0 +1,22 @@
#pragma once
#include <DB/Columns/IColumn.h>
/// Общие вспомогательные методы для реализации разных столбцов.
namespace DB
{
/// Считает, сколько байт в filt больше нуля.
size_t countBytesInFilter(const IColumn::Filter & filt);
/// Общая реализация функции filter для ColumnArray и ColumnString.
template <typename T>
void filterArraysImpl(
const PODArray<T> & src_elems, const IColumn::Offsets_t & src_offsets,
PODArray<T> & res_elems, IColumn::Offsets_t & res_offsets,
const IColumn::Filter & filt);
}

View File

@ -111,7 +111,12 @@ public:
/** Удалить всё кроме диапазона элементов. /** Удалить всё кроме диапазона элементов.
* Используется, например, для операции LIMIT. * Используется, например, для операции LIMIT.
*/ */
virtual SharedPtr<IColumn> cut(size_t start, size_t length) const = 0; virtual SharedPtr<IColumn> cut(size_t start, size_t length) const
{
SharedPtr<IColumn> res = cloneEmpty();
res.get()->insertRangeFrom(*this, start, length);
return res;
}
/** Вставить значение в конец столбца (количество значений увеличится на 1). /** Вставить значение в конец столбца (количество значений увеличится на 1).
* Используется для преобразования из строк в блоки (например, при чтении значений из текстового дампа) * Используется для преобразования из строк в блоки (например, при чтении значений из текстового дампа)
@ -123,6 +128,11 @@ public:
*/ */
virtual void insertFrom(const IColumn & src, size_t n) { insert(src[n]); } virtual void insertFrom(const IColumn & src, size_t n) { insert(src[n]); }
/** Вставить в конец столбца диапазон элементов из другого столбца.
* Может использоваться для склейки столбцов.
*/
virtual void insertRangeFrom(const IColumn & src, size_t start, size_t length) = 0;
/** Вставить данные, расположенные в указанном куске памяти, если возможно. /** Вставить данные, расположенные в указанном куске памяти, если возможно.
* (если не реализуемо - кидает исключение) * (если не реализуемо - кидает исключение)
* Используется для оптимизации некоторых вычислений (например, агрегации). * Используется для оптимизации некоторых вычислений (например, агрегации).
@ -227,8 +237,4 @@ public:
}; };
/// Считает, сколько байт в filt больше нуля.
size_t countBytesInFilter(const IColumn::Filter & filt);
} }

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include <DB/Columns/IColumn.h> #include <DB/Columns/IColumn.h>
#include <DB/Columns/ColumnsCommon.h>
namespace DB namespace DB
@ -44,9 +45,9 @@ public:
throw Exception("Method getExtremes is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); throw Exception("Method getExtremes is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
} }
ColumnPtr cut(size_t start, size_t length) const override void insertRangeFrom(const IColumn & src, size_t start, size_t length) override
{ {
return cloneDummy(length); s += length;
} }
ColumnPtr filter(const Filter & filt) const override ColumnPtr filter(const Filter & filt) const override

View File

@ -1,12 +1,19 @@
#pragma once #pragma once
#include <linux/aio_abi.h> #include <DB/Core/ErrorCodes.h>
#include <unistd.h>
#include <sys/syscall.h>
#include <boost/noncopyable.hpp>
#include <DB/Common/Exception.h> #include <DB/Common/Exception.h>
#include <common/logger_useful.h>
#include <common/singleton.h>
#include <Poco/Logger.h>
#include <boost/range/iterator_range.hpp>
#include <boost/noncopyable.hpp>
#include <condition_variable>
#include <future>
#include <mutex>
#include <map>
#include <linux/aio_abi.h>
#include <sys/syscall.h>
#include <unistd.h>
/** Небольшие обёртки для асинхронного ввода-вывода. /** Небольшие обёртки для асинхронного ввода-вывода.
@ -23,7 +30,8 @@ inline int io_destroy(aio_context_t ctx)
return syscall(__NR_io_destroy, ctx); return syscall(__NR_io_destroy, ctx);
} }
inline int io_submit(aio_context_t ctx, long nr, struct iocb **iocbpp) /// last argument is an array of pointers technically speaking
inline int io_submit(aio_context_t ctx, long nr, struct iocb * iocbpp[])
{ {
return syscall(__NR_io_submit, ctx, nr, iocbpp); return syscall(__NR_io_submit, ctx, nr, iocbpp);
} }
@ -50,3 +58,159 @@ struct AIOContext : private boost::noncopyable
io_destroy(ctx); io_destroy(ctx);
} }
}; };
namespace DB
{
class AIOContextPool : public Singleton<AIOContextPool>
{
friend class Singleton<AIOContextPool>;
static const auto max_concurrent_events = 128;
static const auto timeout_sec = 1;
AIOContext aio_context{max_concurrent_events};
using id_t = size_t;
using bytes_read_t = ssize_t;
/// Autoincremental id used to identify completed requests
id_t id{};
mutable std::mutex mutex;
mutable std::condition_variable have_resources;
std::map<id_t, std::promise<bytes_read_t>> promises;
std::atomic<bool> cancelled{false};
std::thread io_completion_monitor{&AIOContextPool::doMonitor, this};
~AIOContextPool()
{
cancelled.store(true, std::memory_order_relaxed);
io_completion_monitor.join();
}
void doMonitor()
{
/// continue checking for events unless cancelled
while (!cancelled.load(std::memory_order_relaxed))
waitForCompletion();
/// wait until all requests have been completed
while (!promises.empty())
waitForCompletion();
}
void waitForCompletion()
{
/// array to hold completion events
io_event events[max_concurrent_events];
try
{
const auto num_events = getCompletionEvents(events, max_concurrent_events);
fulfillPromises(events, num_events);
notifyProducers(num_events);
}
catch (...)
{
/// there was an error, log it, return to any producer and continue
reportExceptionToAnyProducer();
tryLogCurrentException("AIOContextPool::waitForCompletion()");
}
}
int getCompletionEvents(io_event events[], const int max_events)
{
timespec timeout{timeout_sec};
auto num_events = 0;
/// request 1 to `max_events` events
while ((num_events = io_getevents(aio_context.ctx, 1, max_events, events, &timeout)) < 0)
if (errno != EINTR)
throwFromErrno("io_getevents: Failed to wait for asynchronous IO completion",
ErrorCodes::AIO_COMPLETION_ERROR, errno);
return num_events;
}
void fulfillPromises(const io_event events[], const int num_events)
{
if (num_events == 0)
return;
const std::lock_guard<std::mutex> lock{mutex};
/// look at returned events and find corresponding promise, set result and erase promise from map
for (const auto & event : boost::make_iterator_range(events, events + num_events))
{
/// get id from event
const auto id = event.data;
/// set value via promise and release it
const auto it = promises.find(id);
if (it == std::end(promises))
{
LOG_CRITICAL(&Poco::Logger::get("AIOcontextPool"), "Found io_event with unknown id " << id);
continue;
}
it->second.set_value(event.res);
promises.erase(it);
}
}
void notifyProducers(const int num_producers) const
{
if (num_producers == 0)
return;
if (num_producers > 1)
have_resources.notify_all();
else
have_resources.notify_one();
}
void reportExceptionToAnyProducer()
{
const std::lock_guard<std::mutex> lock{mutex};
const auto any_promise_it = std::begin(promises);
any_promise_it->second.set_exception(std::current_exception());
}
public:
/// Request AIO read operation for iocb, returns a future with number of bytes read
std::future<bytes_read_t> post(struct iocb & iocb)
{
std::unique_lock<std::mutex> lock{mutex};
/// get current id and increment it by one
const auto request_id = id++;
/// create a promise and put request in "queue"
promises.emplace(request_id, std::promise<bytes_read_t>{});
/// store id in AIO request for further identification
iocb.aio_data = request_id;
auto num_requests = 0;
struct iocb * requests[] { &iocb };
/// submit a request
while ((num_requests = io_submit(aio_context.ctx, 1, requests)) < 0)
{
if (errno == EAGAIN)
/// wait until at least one event has been completed (or a spurious wakeup) and try again
have_resources.wait(lock);
else if (errno != EINTR)
throwFromErrno("io_submit: Failed to submit a request for asynchronous IO",
ErrorCodes::AIO_SUBMIT_ERROR, errno);
}
return promises[request_id].get_future();
}
};
}

View File

@ -9,19 +9,42 @@
#include <DB/Core/ErrorCodes.h> #include <DB/Core/ErrorCodes.h>
/** При использовании AllocatorWithStackMemory, размещённом на стеке,
* GCC 4.9 ошибочно делает предположение, что мы можем вызывать free от указателя на стек.
* На самом деле, комбинация условий внутри AllocatorWithStackMemory этого не допускает.
*/
#if !__clang__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wfree-nonheap-object"
#endif
/** Отвечает за выделение/освобождение памяти. Используется, например, в PODArray, Arena. /** Отвечает за выделение/освобождение памяти. Используется, например, в PODArray, Arena.
* Также используется в хэш-таблицах.
* Интерфейс отличается от std::allocator * Интерфейс отличается от std::allocator
* - наличием метода realloc, который для больших кусков памяти использует mremap; * - наличием метода realloc, который для больших кусков памяти использует mremap;
* - передачей размера в метод free; * - передачей размера в метод free;
* - наличием аргумента alignment; * - наличием аргумента alignment;
* - возможностью зануления памяти (используется в хэш-таблицах);
*/ */
template <bool clear_memory_>
class Allocator class Allocator
{ {
protected:
static constexpr bool clear_memory = clear_memory_;
private: private:
/** См. комментарий в HashTableAllocator.h /** Многие современные аллокаторы (например, tcmalloc) не умеют делать mremap для realloc,
* даже в случае достаточно больших кусков памяти.
* Хотя это позволяет увеличить производительность и уменьшить потребление памяти во время realloc-а.
* Чтобы это исправить, делаем mremap самостоятельно, если кусок памяти достаточно большой.
* Порог (64 МБ) выбран достаточно большим, так как изменение адресного пространства
* довольно сильно тормозит, особенно в случае наличия большого количества потоков.
* Рассчитываем, что набор операций mmap/что-то сделать/mremap может выполняться всего лишь около 1000 раз в секунду.
*
* PS. Также это требуется, потому что tcmalloc не может выделить кусок памяти больше 16 GB.
*/ */
static constexpr size_t MMAP_THRESHOLD = 64 * (1 << 20); static constexpr size_t MMAP_THRESHOLD = 64 * (1 << 20);
static constexpr size_t HUGE_PAGE_SIZE = 2 * (1 << 20);
static constexpr size_t MMAP_MIN_ALIGNMENT = 4096; static constexpr size_t MMAP_MIN_ALIGNMENT = 4096;
static constexpr size_t MALLOC_MIN_ALIGNMENT = 8; static constexpr size_t MALLOC_MIN_ALIGNMENT = 8;
@ -43,14 +66,15 @@ public:
if (MAP_FAILED == buf) if (MAP_FAILED == buf)
DB::throwFromErrno("Allocator: Cannot mmap.", DB::ErrorCodes::CANNOT_ALLOCATE_MEMORY); DB::throwFromErrno("Allocator: Cannot mmap.", DB::ErrorCodes::CANNOT_ALLOCATE_MEMORY);
/// См. комментарий в HashTableAllocator.h /// Заполнение нулями не нужно - mmap сам это делает.
if (size >= HUGE_PAGE_SIZE && 0 != madvise(buf, size, MADV_HUGEPAGE))
DB::throwFromErrno("HashTableAllocator: Cannot madvise with MADV_HUGEPAGE.", DB::ErrorCodes::CANNOT_ALLOCATE_MEMORY);
} }
else else
{ {
if (alignment <= MALLOC_MIN_ALIGNMENT) if (alignment <= MALLOC_MIN_ALIGNMENT)
{ {
if (clear_memory)
buf = ::calloc(size, 1);
else
buf = ::malloc(size); buf = ::malloc(size);
if (nullptr == buf) if (nullptr == buf)
@ -63,6 +87,9 @@ public:
if (0 != res) if (0 != res)
DB::throwFromErrno("Cannot allocate memory (posix_memalign)", DB::ErrorCodes::CANNOT_ALLOCATE_MEMORY, res); DB::throwFromErrno("Cannot allocate memory (posix_memalign)", DB::ErrorCodes::CANNOT_ALLOCATE_MEMORY, res);
if (clear_memory)
memset(buf, 0, size);
} }
} }
@ -101,6 +128,9 @@ public:
if (nullptr == buf) if (nullptr == buf)
DB::throwFromErrno("Allocator: Cannot realloc.", DB::ErrorCodes::CANNOT_ALLOCATE_MEMORY); DB::throwFromErrno("Allocator: Cannot realloc.", DB::ErrorCodes::CANNOT_ALLOCATE_MEMORY);
if (clear_memory)
memset(reinterpret_cast<char *>(buf) + old_size, 0, new_size - old_size);
} }
else if (old_size >= MMAP_THRESHOLD && new_size >= MMAP_THRESHOLD) else if (old_size >= MMAP_THRESHOLD && new_size >= MMAP_THRESHOLD)
{ {
@ -110,6 +140,8 @@ public:
buf = mremap(buf, old_size, new_size, MREMAP_MAYMOVE); buf = mremap(buf, old_size, new_size, MREMAP_MAYMOVE);
if (MAP_FAILED == buf) if (MAP_FAILED == buf)
DB::throwFromErrno("Allocator: Cannot mremap.", DB::ErrorCodes::CANNOT_MREMAP); DB::throwFromErrno("Allocator: Cannot mremap.", DB::ErrorCodes::CANNOT_MREMAP);
/// Заполнение нулями не нужно.
} }
else else
{ {
@ -122,3 +154,53 @@ public:
return buf; return buf;
} }
}; };
/** Аллокатор с оптимизацией для маленьких кусков памяти.
*/
template <typename Base, size_t N = 64>
class AllocatorWithStackMemory : private Base
{
private:
char stack_memory[N];
public:
void * alloc(size_t size)
{
if (size <= N)
{
if (Base::clear_memory)
memset(stack_memory, 0, N);
return stack_memory;
}
return Base::alloc(size);
}
void free(void * buf, size_t size)
{
if (size > N)
Base::free(buf, size);
}
void * realloc(void * buf, size_t old_size, size_t new_size)
{
/// Было в stack_memory, там и останется.
if (new_size <= N)
return buf;
/// Уже не помещалось в stack_memory.
if (old_size > N)
return Base::realloc(buf, old_size, new_size);
/// Было в stack_memory, но теперь не помещается.
void * new_buf = Base::alloc(new_size);
memcpy(new_buf, buf, old_size);
return new_buf;
}
};
#if !__clang__
#pragma GCC diagnostic pop
#endif

View File

@ -26,7 +26,7 @@ class Arena
{ {
private: private:
/// Непрерывный кусок памяти и указатель на свободное место в нём. Односвязный список. /// Непрерывный кусок памяти и указатель на свободное место в нём. Односвязный список.
struct Chunk : private Allocator /// empty base optimization struct Chunk : private Allocator<false> /// empty base optimization
{ {
char * begin; char * begin;
char * pos; char * pos;

View File

@ -0,0 +1,100 @@
#pragma once
#include <DB/Common/Arena.h>
#include <ext/bit_cast.hpp>
#include <ext/size.hpp>
#include <cstdlib>
#include <memory>
#include <array>
namespace DB
{
class ArenaWithFreeLists : private Allocator<false>
{
private:
struct Block { Block * next; };
static const std::array<std::size_t, 14> & getSizes()
{
static constexpr std::array<std::size_t, 14> sizes{
8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536
};
static_assert(sizes.front() >= sizeof(Block), "Can't make allocations smaller than sizeof(Block)");
return sizes;
}
static auto sizeToPreviousPowerOfTwo(const int size) { return _bit_scan_reverse(size - 1); }
static auto getMinBucketNum()
{
static const auto val = sizeToPreviousPowerOfTwo(getSizes().front());
return val;
}
static auto getMaxFixedBlockSize() { return getSizes().back(); }
Arena pool;
const std::unique_ptr<Block * []> free_lists = std::make_unique<Block * []>(ext::size(getSizes()));
static std::size_t findFreeListIndex(const std::size_t size)
{
/// shift powers of two into previous bucket by subtracting 1
const auto bucket_num = sizeToPreviousPowerOfTwo(size);
return std::max(bucket_num, getMinBucketNum()) - getMinBucketNum();
}
public:
ArenaWithFreeLists(
const std::size_t initial_size = 4096, const std::size_t growth_factor = 2,
const std::size_t linear_growth_threshold = 128 * 1024 * 1024)
: pool{initial_size, growth_factor, linear_growth_threshold}
{
}
char * alloc(const std::size_t size)
{
if (size > getMaxFixedBlockSize())
return static_cast<char *>(Allocator::alloc(size));
/// find list of required size
const auto list_idx = findFreeListIndex(size);
if (auto & block = free_lists[list_idx])
{
const auto res = ext::bit_cast<char *>(block);
block = block->next;
return res;
}
/// no block of corresponding size, allocate a new one
return pool.alloc(getSizes()[list_idx]);
}
void free(const void * ptr, const std::size_t size)
{
if (size > getMaxFixedBlockSize())
return Allocator::free(const_cast<void *>(ptr), size);
/// find list of required size
const auto list_idx = findFreeListIndex(size);
auto & block = free_lists[list_idx];
const auto old = block;
block = ext::bit_cast<Block *>(ptr);
block->next = old;
}
/// Размер выделенного пула в байтах
size_t size() const
{
return pool.size();
}
};
}

View File

@ -10,11 +10,11 @@
namespace DB namespace DB
{ {
/** Компактный массив для хранения данных, размер L, в битах, которых составляет /** Компактный массив для хранения данных, размер content_width, в битах, которых составляет
* меньше одного байта. Вместо того, чтобы хранить каждое значение в отдельный * меньше одного байта. Вместо того, чтобы хранить каждое значение в отдельный
* байт, что приводит к растрате 37.5% пространства для L=5, CompactArray хранит * байт, что приводит к растрате 37.5% пространства для content_width=5, CompactArray хранит
* смежные L-битные значения в массиве байтов, т.е. фактически CompactArray * смежные content_width-битные значения в массиве байтов, т.е. фактически CompactArray
* симулирует массив L-битных значений. * симулирует массив content_width-битных значений.
*/ */
template <typename BucketIndex, UInt8 content_width, size_t bucket_count> template <typename BucketIndex, UInt8 content_width, size_t bucket_count>
class __attribute__ ((packed)) CompactArray final class __attribute__ ((packed)) CompactArray final

View File

@ -121,3 +121,42 @@ struct TrivialHash
return key; return key;
} }
}; };
/** Сравнительно неплохая некриптографическая хэш функция из UInt64 в UInt32.
* Но хуже (и по качеству и по скорости), чем просто срезка intHash64.
* Взята отсюда: http://www.concentric.net/~ttwang/tech/inthash.htm
*
* Немного изменена по сравнению с функцией по ссылке: сдвиги вправо случайно заменены на цикличесвие сдвиги вправо.
* Это изменение никак не повлияло на результаты тестов smhasher.
*
* Рекомендуется для разных задач использовать разные salt.
* А то был случай, что в БД значения сортировались по хэшу (для некачественного псевдослучайного разбрасывания),
* а в другом месте, в агрегатной функции, в хэш таблице использовался такой же хэш,
* в результате чего, эта агрегатная функция чудовищно тормозила из-за коллизий.
*/
template <DB::UInt64 salt>
inline DB::UInt32 intHash32(DB::UInt64 key)
{
key ^= salt;
key = (~key) + (key << 18);
key = key ^ ((key >> 31) | (key << 33));
key = key * 21;
key = key ^ ((key >> 11) | (key << 53));
key = key + (key << 6);
key = key ^ ((key >> 22) | (key << 42));
return key;
}
/// Для контейнеров.
template <typename T, DB::UInt64 salt = 0>
struct IntHash32
{
size_t operator() (const T & key) const
{
return intHash32<salt>(key);
}
};

View File

@ -11,8 +11,6 @@
#include <common/likely.h> #include <common/likely.h>
#include <stats/IntHash.h>
#include <DB/Core/Defines.h> #include <DB/Core/Defines.h>
#include <DB/Core/Types.h> #include <DB/Core/Types.h>
#include <DB/Common/Exception.h> #include <DB/Common/Exception.h>

View File

@ -1,184 +1,9 @@
#pragma once #pragma once
#include <malloc.h> #include <DB/Common/Allocator.h>
#include <string.h>
#include <sys/mman.h>
#include <DB/Common/MemoryTracker.h>
#include <DB/Common/Exception.h>
#include <DB/Core/ErrorCodes.h>
/** При использовании HashTableAllocatorWithStackMemory, размещённом на стеке,
* GCC 4.9 ошибочно делает предположение, что мы можем вызывать free от указателя на стек.
* На самом деле, комбинация условий внутри HashTableAllocatorWithStackMemory этого не допускает.
*/
#if !__clang__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wfree-nonheap-object"
#endif
/** Общая часть разных хэш-таблиц, отвечающая за выделение/освобождение памяти. using HashTableAllocator = Allocator<true>;
* Отличается от Allocator тем, что зануляет память.
* Используется в качестве параметра шаблона (есть несколько реализаций с таким же интерфейсом).
*/
class HashTableAllocator
{
private:
/** Многие современные аллокаторы (например, tcmalloc) не умеют делать mremap для realloc,
* даже в случае достаточно больших кусков памяти.
* Хотя это позволяет увеличить производительность и уменьшить потребление памяти во время realloc-а.
* Чтобы это исправить, делаем mremap самостоятельно, если кусок памяти достаточно большой.
* Порог (64 МБ) выбран достаточно большим, так как изменение адресного пространства
* довольно сильно тормозит, особенно в случае наличия большого количества потоков.
* Рассчитываем, что набор операций mmap/что-то сделать/mremap может выполняться всего лишь около 1000 раз в секунду.
*
* PS. Также это требуется, потому что tcmalloc не может выделить кусок памяти больше 16 GB.
*/
static constexpr size_t MMAP_THRESHOLD = 64 * (1 << 20);
static constexpr size_t HUGE_PAGE_SIZE = 2 * (1 << 20);
public:
/// Выделить кусок памяти и заполнить его нулями.
void * alloc(size_t size)
{
if (current_memory_tracker)
current_memory_tracker->alloc(size);
void * buf;
if (size >= MMAP_THRESHOLD)
{
buf = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
if (MAP_FAILED == buf)
DB::throwFromErrno("HashTableAllocator: Cannot mmap.", DB::ErrorCodes::CANNOT_ALLOCATE_MEMORY);
/** Использование huge pages позволяет увеличить производительность более чем в три раза
* в запросе SELECT number % 1000000 AS k, count() FROM system.numbers GROUP BY k,
* (хэш-таблица на 1 000 000 элементов)
* и примерно на 15% в случае хэш-таблицы на 100 000 000 элементов.
*/
if (size >= HUGE_PAGE_SIZE && 0 != madvise(buf, size, MADV_HUGEPAGE))
DB::throwFromErrno("HashTableAllocator: Cannot madvise with MADV_HUGEPAGE.", DB::ErrorCodes::CANNOT_ALLOCATE_MEMORY);
/// Заполнение нулями не нужно - mmap сам это делает.
}
else
{
buf = ::calloc(size, 1);
if (nullptr == buf)
DB::throwFromErrno("HashTableAllocator: Cannot calloc.", DB::ErrorCodes::CANNOT_ALLOCATE_MEMORY);
}
return buf;
}
/// Освободить память.
void free(void * buf, size_t size)
{
if (size >= MMAP_THRESHOLD)
{
if (0 != munmap(buf, size))
DB::throwFromErrno("HashTableAllocator: Cannot munmap.", DB::ErrorCodes::CANNOT_MUNMAP);
}
else
{
::free(buf);
}
if (current_memory_tracker)
current_memory_tracker->free(size);
}
/** Увеличить размер куска памяти.
* Содержимое старого куска памяти переезжает в начало нового.
* Оставшаяся часть заполняется нулями.
* Положение куска памяти может измениться.
*/
void * realloc(void * buf, size_t old_size, size_t new_size)
{
if (old_size < MMAP_THRESHOLD && new_size < MMAP_THRESHOLD)
{
if (current_memory_tracker)
current_memory_tracker->realloc(old_size, new_size);
buf = ::realloc(buf, new_size);
if (nullptr == buf)
DB::throwFromErrno("HashTableAllocator: Cannot realloc.", DB::ErrorCodes::CANNOT_ALLOCATE_MEMORY);
memset(reinterpret_cast<char *>(buf) + old_size, 0, new_size - old_size);
}
else if (old_size >= MMAP_THRESHOLD && new_size >= MMAP_THRESHOLD)
{
if (current_memory_tracker)
current_memory_tracker->realloc(old_size, new_size);
buf = mremap(buf, old_size, new_size, MREMAP_MAYMOVE);
if (MAP_FAILED == buf)
DB::throwFromErrno("HashTableAllocator: Cannot mremap.", DB::ErrorCodes::CANNOT_MREMAP);
/** Здесь не получается сделать madvise с MADV_HUGEPAGE.
* Похоже, что при mremap, huge pages сами расширяются на новую область.
*/
/// Заполнение нулями не нужно.
}
else
{
void * new_buf = alloc(new_size);
memcpy(new_buf, buf, old_size);
free(buf, old_size);
buf = new_buf;
}
return buf;
}
};
/** Аллокатор с оптимизацией для маленьких кусков памяти.
*/
template <size_t N = 64> template <size_t N = 64>
class HashTableAllocatorWithStackMemory : private HashTableAllocator using HashTableAllocatorWithStackMemory = AllocatorWithStackMemory<HashTableAllocator, N>;
{
private:
char stack_memory[N];
public:
void * alloc(size_t size)
{
if (size <= N)
{
memset(stack_memory, 0, N);
return stack_memory;
}
return HashTableAllocator::alloc(size);
}
void free(void * buf, size_t size)
{
if (size > N)
HashTableAllocator::free(buf, size);
}
void * realloc(void * buf, size_t old_size, size_t new_size)
{
/// Было в stack_memory, там и останется.
if (new_size <= N)
return buf;
/// Уже не помещалось в stack_memory.
if (old_size > N)
return HashTableAllocator::realloc(buf, old_size, new_size);
/// Было в stack_memory, но теперь не помещается.
void * new_buf = HashTableAllocator::alloc(new_size);
memcpy(new_buf, buf, old_size);
return new_buf;
}
};
#if !__clang__
#pragma GCC diagnostic pop
#endif

View File

@ -1,9 +1,9 @@
#pragma once #pragma once
#include <common/Common.h> #include <common/Common.h>
#include <stats/IntHash.h>
#include <DB/Common/HyperLogLogBiasEstimator.h> #include <DB/Common/HyperLogLogBiasEstimator.h>
#include <DB/Common/CompactArray.h> #include <DB/Common/CompactArray.h>
#include <DB/Common/HashTable/Hash.h>
#include <DB/IO/ReadBuffer.h> #include <DB/IO/ReadBuffer.h>
#include <DB/IO/WriteBuffer.h> #include <DB/IO/WriteBuffer.h>

View File

@ -29,14 +29,13 @@ namespace DB
* Поддерживается только часть интерфейса std::vector. * Поддерживается только часть интерфейса std::vector.
* *
* Конструктор по-умолчанию создаёт пустой объект, который не выделяет память. * Конструктор по-умолчанию создаёт пустой объект, который не выделяет память.
* Затем выделяется память минимум под POD_ARRAY_INITIAL_SIZE элементов. * Затем выделяется память минимум под INITIAL_SIZE элементов.
* *
* Если вставлять элементы push_back-ом, не делая reserve, то PODArray примерно в 2.5 раза быстрее std::vector. * Если вставлять элементы push_back-ом, не делая reserve, то PODArray примерно в 2.5 раза быстрее std::vector.
*/ */
#define POD_ARRAY_INITIAL_SIZE 4096UL
template <typename T> template <typename T, size_t INITIAL_SIZE = 4096, typename TAllocator = Allocator<false>>
class PODArray : private boost::noncopyable, private Allocator /// empty base optimization class PODArray : private boost::noncopyable, private TAllocator /// empty base optimization
{ {
private: private:
char * c_start; char * c_start;
@ -79,7 +78,7 @@ private:
size_t bytes_to_alloc = to_size(n); size_t bytes_to_alloc = to_size(n);
c_start = c_end = reinterpret_cast<char *>(Allocator::alloc(bytes_to_alloc)); c_start = c_end = reinterpret_cast<char *>(TAllocator::alloc(bytes_to_alloc));
c_end_of_storage = c_start + bytes_to_alloc; c_end_of_storage = c_start + bytes_to_alloc;
} }
@ -88,7 +87,7 @@ private:
if (c_start == nullptr) if (c_start == nullptr)
return; return;
Allocator::free(c_start, storage_size()); TAllocator::free(c_start, storage_size());
} }
void realloc(size_t n) void realloc(size_t n)
@ -102,7 +101,7 @@ private:
ptrdiff_t end_diff = c_end - c_start; ptrdiff_t end_diff = c_end - c_start;
size_t bytes_to_alloc = to_size(n); size_t bytes_to_alloc = to_size(n);
c_start = reinterpret_cast<char *>(Allocator::realloc(c_start, storage_size(), bytes_to_alloc)); c_start = reinterpret_cast<char *>(TAllocator::realloc(c_start, storage_size(), bytes_to_alloc));
c_end = c_start + end_diff; c_end = c_start + end_diff;
c_end_of_storage = c_start + bytes_to_alloc; c_end_of_storage = c_start + bytes_to_alloc;
@ -133,7 +132,17 @@ public:
PODArray(const_iterator from_begin, const_iterator from_end) { alloc(from_end - from_begin); insert(from_begin, from_end); } PODArray(const_iterator from_begin, const_iterator from_end) { alloc(from_end - from_begin); insert(from_begin, from_end); }
~PODArray() { dealloc(); } ~PODArray() { dealloc(); }
PODArray(PODArray && other) { *this = std::move(other); } PODArray(PODArray && other)
{
c_start = other.c_start;
c_end = other.c_end;
c_end_of_storage = other.c_end_of_storage;
other.c_start = nullptr;
other.c_end = nullptr;
other.c_end_of_storage = nullptr;
}
PODArray & operator=(PODArray && other) PODArray & operator=(PODArray && other)
{ {
std::swap(c_start, other.c_start); std::swap(c_start, other.c_start);
@ -174,7 +183,7 @@ public:
void reserve() void reserve()
{ {
if (size() == 0) if (size() == 0)
realloc(POD_ARRAY_INITIAL_SIZE); realloc(INITIAL_SIZE);
else else
realloc(size() * 2); realloc(size() * 2);
} }
@ -227,6 +236,16 @@ public:
c_end += byte_size(1); c_end += byte_size(1);
} }
template <typename... Args>
void emplace_back(Args &&... args)
{
if (unlikely(c_end == c_end_of_storage))
reserve();
new (t_end()) T(std::forward<Args>(args)...);
c_end += byte_size(1);
}
/// Не вставляйте в массив кусок самого себя. Потому что при ресайзе, итераторы на самого себя могут инвалидироваться. /// Не вставляйте в массив кусок самого себя. Потому что при ресайзе, итераторы на самого себя могут инвалидироваться.
template <typename It1, typename It2> template <typename It1, typename It2>
void insert(It1 from_begin, It2 from_end) void insert(It1 from_begin, It2 from_end)
@ -246,7 +265,7 @@ public:
c_end += bytes_to_copy; c_end += bytes_to_copy;
} }
void swap(PODArray<T> & rhs) void swap(PODArray & rhs)
{ {
std::swap(c_start, rhs.c_start); std::swap(c_start, rhs.c_start);
std::swap(c_end, rhs.c_end); std::swap(c_end, rhs.c_end);
@ -271,13 +290,13 @@ public:
c_end = c_start + bytes_to_copy; c_end = c_start + bytes_to_copy;
} }
void assign(const PODArray<T> & from) void assign(const PODArray & from)
{ {
assign(from.begin(), from.end()); assign(from.begin(), from.end());
} }
bool operator== (const PODArray<T> & other) const bool operator== (const PODArray & other) const
{ {
if (size() != other.size()) if (size() != other.size())
return false; return false;
@ -297,7 +316,7 @@ public:
return true; return true;
} }
bool operator!= (const PODArray<T> & other) const bool operator!= (const PODArray & other) const
{ {
return !operator==(other); return !operator==(other);
} }

View File

@ -70,6 +70,13 @@
M(CompileAttempt) \ M(CompileAttempt) \
M(CompileSuccess) \ M(CompileSuccess) \
\ \
M(ExternalSortWritePart) \
M(ExternalSortMerge) \
M(ExternalAggregationWritePart) \
M(ExternalAggregationMerge) \
M(ExternalAggregationCompressedBytes) \
M(ExternalAggregationUncompressedBytes) \
\
M(END) M(END)
namespace ProfileEvents namespace ProfileEvents

View File

@ -0,0 +1,256 @@
#pragma once
#include <string.h>
#include <malloc.h>
#include <cstdint>
#include <type_traits>
#include <ext/bit_cast.hpp>
#include <DB/Core/Defines.h>
/** Поразрядная сортировка, обладает следующей функциональностью:
* Может сортировать unsigned, signed числа, а также float-ы.
* Может сортировать массив элементов фиксированной длины, которые содержат что-то ещё кроме ключа.
* Настраиваемый размер разряда.
*
* LSB, stable.
* NOTE Для некоторых приложений имеет смысл добавить MSB-radix-sort,
* а также алгоритмы radix-select, radix-partial-sort, radix-get-permutation на его основе.
*/
/** Используется в качестве параметра шаблона. См. ниже.
*/
struct RadixSortMallocAllocator
{
void * allocate(size_t size)
{
return malloc(size);
}
void deallocate(void * ptr, size_t size)
{
return free(ptr);
}
};
/** Преобразование, которое переводит битовое представление ключа в такое целое беззнаковое число,
* что отношение порядка над ключами будет соответствовать отношению порядка над полученными беззнаковыми числами.
* Для float-ов это преобразование делает следующее:
* если выставлен знаковый бит, то переворачивает все остальные биты.
*/
template <typename KeyBits>
struct RadixSortFloatTransform
{
/// Стоит ли записывать результат в память, или лучше делать его каждый раз заново?
static constexpr bool transform_is_simple = false;
static KeyBits forward(KeyBits x)
{
return x ^ (-((x >> (sizeof(KeyBits) * 8 - 1) | (KeyBits(1) << (sizeof(KeyBits) * 8 - 1)))));
}
static KeyBits backward(KeyBits x)
{
return x ^ (((x >> (sizeof(KeyBits) * 8 - 1)) - 1) | (KeyBits(1) << (sizeof(KeyBits) * 8 - 1)));
}
};
template <typename Float>
struct RadixSortFloatTraits
{
using Element = Float; /// Тип элемента. Это может быть структура с ключём и ещё каким-то payload-ом. Либо просто ключ.
using Key = Float; /// Ключ, по которому нужно сортировать.
using CountType = uint32_t; /// Тип для подсчёта гистограмм. В случае заведомо маленького количества элементов, может быть меньше чем size_t.
/// Тип, в который переводится ключ, чтобы делать битовые операции. Это UInt такого же размера, как ключ.
using KeyBits = typename std::conditional<sizeof(Float) == 8, uint64_t, uint32_t>::type;
static constexpr size_t PART_SIZE_BITS = 8; /// Какими кусочками ключа в количестве бит делать один проход - перестановку массива.
/// Преобразования ключа в KeyBits такое, что отношение порядка над ключём соответствует отношению порядка над KeyBits.
using Transform = RadixSortFloatTransform<KeyBits>;
/// Объект с функциями allocate и deallocate.
/// Может быть использован, например, чтобы выделить память для временного массива на стеке.
/// Для этого сам аллокатор создаётся на стеке.
using Allocator = RadixSortMallocAllocator;
/// Функция получения ключа из элемента массива.
static Key & extractKey(Element & elem) { return elem; }
};
template <typename KeyBits>
struct RadixSortIdentityTransform
{
static constexpr bool transform_is_simple = true;
static KeyBits forward(KeyBits x) { return x; }
static KeyBits backward(KeyBits x) { return x; }
};
template <typename KeyBits>
struct RadixSortSignedTransform
{
static constexpr bool transform_is_simple = true;
static KeyBits forward(KeyBits x) { return x ^ (KeyBits(1) << (sizeof(KeyBits) * 8 - 1)); }
static KeyBits backward(KeyBits x) { return x ^ (KeyBits(1) << (sizeof(KeyBits) * 8 - 1)); }
};
template <typename UInt>
struct RadixSortUIntTraits
{
using Element = UInt;
using Key = UInt;
using CountType = uint32_t;
using KeyBits = UInt;
static constexpr size_t PART_SIZE_BITS = 8;
using Transform = RadixSortIdentityTransform<KeyBits>;
using Allocator = RadixSortMallocAllocator;
/// Функция получения ключа из элемента массива.
static Key & extractKey(Element & elem) { return elem; }
};
template <typename Int>
struct RadixSortIntTraits
{
using Element = Int;
using Key = Int;
using CountType = uint32_t;
using KeyBits = typename std::make_unsigned<Int>::type;
static constexpr size_t PART_SIZE_BITS = 8;
using Transform = RadixSortSignedTransform<KeyBits>;
using Allocator = RadixSortMallocAllocator;
/// Функция получения ключа из элемента массива.
static Key & extractKey(Element & elem) { return elem; }
};
template <typename Traits>
struct RadixSort
{
private:
using Element = typename Traits::Element;
using Key = typename Traits::Key;
using CountType = typename Traits::CountType;
using KeyBits = typename Traits::KeyBits;
static constexpr size_t HISTOGRAM_SIZE = 1 << Traits::PART_SIZE_BITS;
static constexpr size_t PART_BITMASK = HISTOGRAM_SIZE - 1;
static constexpr size_t KEY_BITS = sizeof(Key) * 8;
static constexpr size_t NUM_PASSES = (KEY_BITS + (Traits::PART_SIZE_BITS - 1)) / Traits::PART_SIZE_BITS;
static ALWAYS_INLINE KeyBits getPart(size_t N, KeyBits x)
{
if (Traits::Transform::transform_is_simple)
x = Traits::Transform::forward(x);
return (x >> (N * Traits::PART_SIZE_BITS)) & PART_BITMASK;
}
static KeyBits keyToBits(Key x) { return ext::bit_cast<KeyBits>(x); }
static Key bitsToKey(KeyBits x) { return ext::bit_cast<Key>(x); }
public:
static void execute(Element * arr, size_t size)
{
/// Если массив имеет размер меньше 256, то лучше использовать другой алгоритм.
/// Здесь есть циклы по NUM_PASSES. Очень важно, что они разворачиваются в compile-time.
/// Для каждого из NUM_PASSES кусков бит ключа, считаем, сколько раз каждое значение этого куска встретилось.
CountType histograms[HISTOGRAM_SIZE * NUM_PASSES] = {0};
typename Traits::Allocator allocator;
/// Будем делать несколько проходов по массиву. На каждом проходе, данные перекладываются в другой массив. Выделим этот временный массив.
Element * swap_buffer = reinterpret_cast<Element *>(allocator.allocate(size * sizeof(Element)));
/// Трансформируем массив и вычисляем гистограмму.
for (size_t i = 0; i < size; ++i)
{
if (!Traits::Transform::transform_is_simple)
Traits::extractKey(arr[i]) = bitsToKey(Traits::Transform::forward(keyToBits(Traits::extractKey(arr[i]))));
for (size_t j = 0; j < NUM_PASSES; ++j)
++histograms[j * HISTOGRAM_SIZE + getPart(j, keyToBits(Traits::extractKey(arr[i])))];
}
{
/// Заменяем гистограммы на суммы с накоплением: значение в позиции i равно сумме в предыдущих позициях минус один.
size_t sums[NUM_PASSES] = {0};
for (size_t i = 0; i < HISTOGRAM_SIZE; ++i)
{
for (size_t j = 0; j < NUM_PASSES; ++j)
{
size_t tmp = histograms[j * HISTOGRAM_SIZE + i] + sums[j];
histograms[j * HISTOGRAM_SIZE + i] = sums[j] - 1;
sums[j] = tmp;
}
}
}
/// Перекладываем элементы в порядке начиная от младшего куска бит, и далее делаем несколько проходов по количеству кусков.
for (size_t j = 0; j < NUM_PASSES; ++j)
{
Element * writer = j % 2 ? arr : swap_buffer;
Element * reader = j % 2 ? swap_buffer : arr;
for (size_t i = 0; i < size; ++i)
{
size_t pos = getPart(j, keyToBits(Traits::extractKey(reader[i])));
/// Размещаем элемент на следующей свободной позиции.
auto & dest = writer[++histograms[j * HISTOGRAM_SIZE + pos]];
dest = reader[i];
/// На последнем перекладывании, делаем обратную трансформацию.
if (!Traits::Transform::transform_is_simple && j == NUM_PASSES - 1)
Traits::extractKey(dest) = bitsToKey(Traits::Transform::backward(keyToBits(Traits::extractKey(reader[i]))));
}
}
/// Если число проходов нечётное, то результирующий массив находится во временном буфере. Скопируем его на место исходного массива.
if (NUM_PASSES % 2)
memcpy(arr, swap_buffer, size * sizeof(Element));
allocator.deallocate(swap_buffer, size * sizeof(Element));
}
};
template <typename T>
typename std::enable_if<std::is_unsigned<T>::value && std::is_integral<T>::value, void>::type
radixSort(T * arr, size_t size)
{
return RadixSort<RadixSortUIntTraits<T>>::execute(arr, size);
}
template <typename T>
typename std::enable_if<std::is_signed<T>::value && std::is_integral<T>::value, void>::type
radixSort(T * arr, size_t size)
{
return RadixSort<RadixSortIntTraits<T>>::execute(arr, size);
}
template <typename T>
typename std::enable_if<std::is_floating_point<T>::value, void>::type
radixSort(T * arr, size_t size)
{
return RadixSort<RadixSortFloatTraits<T>>::execute(arr, size);
}

View File

@ -0,0 +1,88 @@
#pragma once
#include <DB/Core/ErrorCodes.h>
#include <DB/Common/Arena.h>
#include <common/likely.h>
#include <ext/range.hpp>
#include <ext/size.hpp>
#include <ext/bit_cast.hpp>
#include <cstdlib>
#include <memory>
namespace DB
{
class SmallObjectPool
{
private:
struct Block { Block * next; };
const std::size_t object_size;
Arena pool;
Block * free_list{};
public:
SmallObjectPool(
const std::size_t object_size, const std::size_t initial_size = 4096, const std::size_t growth_factor = 2,
const std::size_t linear_growth_threshold = 128 * 1024 * 1024)
: object_size{object_size}, pool{initial_size, growth_factor, linear_growth_threshold}
{
if (object_size < sizeof(Block))
throw Exception{
"Can't make allocations smaller than sizeof(Block) = " + std::to_string(sizeof(Block)),
ErrorCodes::LOGICAL_ERROR
};
if (pool.size() < object_size)
return;
const auto num_objects = pool.size() / object_size;
auto head = free_list = ext::bit_cast<Block *>(pool.alloc(num_objects * object_size));
for (const auto i : ext::range(0, num_objects - 1))
{
(void) i;
head->next = ext::bit_cast<Block *>(ext::bit_cast<char *>(head) + object_size);
head = head->next;
}
head->next = nullptr;
}
char * alloc()
{
if (free_list)
{
const auto res = reinterpret_cast<char *>(free_list);
free_list = free_list->next;
return res;
}
return pool.alloc(object_size);
}
void free(const void * ptr)
{
union {
const void * p_v;
Block * block;
};
p_v = ptr;
block->next = free_list;
free_list = block;
}
/// Размер выделенного пула в байтах
size_t size() const
{
return pool.size();
}
};
}

View File

@ -1,8 +1,12 @@
#pragma once #pragma once
#include <string> #include <string>
#include <DB/IO/ReadHelpers.h>
namespace DB
{
inline std::string escapeForFileName(const std::string & s) inline std::string escapeForFileName(const std::string & s)
{ {
std::string res; std::string res;
@ -30,18 +34,6 @@ inline std::string escapeForFileName(const std::string & s)
return res; return res;
} }
inline char unhex(char c)
{
switch (c)
{
case '0' ... '9':
return c - '0';
case 'A' ... 'F':
return c - 'A' + 10;
default:
return 0;
}
}
inline std::string unescapeForFileName(const std::string & s) inline std::string unescapeForFileName(const std::string & s)
{ {
@ -71,3 +63,5 @@ inline std::string unescapeForFileName(const std::string & s)
} }
return res; return res;
} }
}

View File

@ -104,6 +104,7 @@ struct SortCursorImpl
rows = all_columns[0]->size(); rows = all_columns[0]->size();
} }
bool isFirst() const { return pos == 0; }
bool isLast() const { return pos + 1 >= rows; } bool isLast() const { return pos + 1 >= rows; }
void next() { ++pos; } void next() { ++pos; }
}; };
@ -118,13 +119,13 @@ struct SortCursor
SortCursorImpl * operator-> () { return impl; } SortCursorImpl * operator-> () { return impl; }
const SortCursorImpl * operator-> () const { return impl; } const SortCursorImpl * operator-> () const { return impl; }
/// Инвертировано, чтобы из priority queue элементы вынимались в нужном порядке. /// Указанная строка данного курсора больше указанной строки другого курсора.
bool operator< (const SortCursor & rhs) const bool greaterAt(const SortCursor & rhs, size_t lhs_pos, size_t rhs_pos) const
{ {
for (size_t i = 0; i < impl->sort_columns_size; ++i) for (size_t i = 0; i < impl->sort_columns_size; ++i)
{ {
int direction = impl->desc[i].direction; int direction = impl->desc[i].direction;
int res = direction * impl->sort_columns[i]->compareAt(impl->pos, rhs.impl->pos, *(rhs.impl->sort_columns[i]), direction); int res = direction * impl->sort_columns[i]->compareAt(lhs_pos, rhs_pos, *(rhs.impl->sort_columns[i]), direction);
if (res > 0) if (res > 0)
return true; return true;
if (res < 0) if (res < 0)
@ -132,6 +133,27 @@ struct SortCursor
} }
return impl->order > rhs.impl->order; return impl->order > rhs.impl->order;
} }
/// Проверяет, что все строки в текущем блоке данного курсора меньше или равны, чем все строки текущего блока другого курсора.
bool totallyLessOrEquals(const SortCursor & rhs) const
{
if (impl->rows == 0 || rhs.impl->rows == 0)
return false;
/// Последняя строка данного курсора не больше первой строки другого.
return !greaterAt(rhs, impl->rows - 1, 0);
}
bool greater(const SortCursor & rhs) const
{
return greaterAt(rhs, impl->pos, rhs.impl->pos);
}
/// Инвертировано, чтобы из priority queue элементы вынимались в порядке по возрастанию.
bool operator< (const SortCursor & rhs) const
{
return greater(rhs);
}
}; };
@ -144,8 +166,7 @@ struct SortCursorWithCollation
SortCursorImpl * operator-> () { return impl; } SortCursorImpl * operator-> () { return impl; }
const SortCursorImpl * operator-> () const { return impl; } const SortCursorImpl * operator-> () const { return impl; }
/// Инвертировано, чтобы из priority queue элементы вынимались в нужном порядке. bool greaterAt(const SortCursorWithCollation & rhs, size_t lhs_pos, size_t rhs_pos) const
bool operator< (const SortCursorWithCollation & rhs) const
{ {
for (size_t i = 0; i < impl->sort_columns_size; ++i) for (size_t i = 0; i < impl->sort_columns_size; ++i)
{ {
@ -154,10 +175,10 @@ struct SortCursorWithCollation
if (impl->need_collation[i]) if (impl->need_collation[i])
{ {
const ColumnString & column_string = typeid_cast<const ColumnString &>(*impl->sort_columns[i]); const ColumnString & column_string = typeid_cast<const ColumnString &>(*impl->sort_columns[i]);
res = column_string.compareAtWithCollation(impl->pos, rhs.impl->pos, *(rhs.impl->sort_columns[i]), *impl->desc[i].collator); res = column_string.compareAtWithCollation(lhs_pos, rhs_pos, *(rhs.impl->sort_columns[i]), *impl->desc[i].collator);
} }
else else
res = impl->sort_columns[i]->compareAt(impl->pos, rhs.impl->pos, *(rhs.impl->sort_columns[i]), direction); res = impl->sort_columns[i]->compareAt(lhs_pos, rhs_pos, *(rhs.impl->sort_columns[i]), direction);
res *= direction; res *= direction;
if (res > 0) if (res > 0)
@ -167,6 +188,25 @@ struct SortCursorWithCollation
} }
return impl->order > rhs.impl->order; return impl->order > rhs.impl->order;
} }
bool totallyLessOrEquals(const SortCursorWithCollation & rhs) const
{
if (impl->rows == 0 || rhs.impl->rows == 0)
return false;
/// Последняя строка данного курсора не больше первой строки другого.
return !greaterAt(rhs, impl->rows - 1, 0);
}
bool greater(const SortCursorWithCollation & rhs) const
{
return greaterAt(rhs, impl->pos, rhs.impl->pos);
}
bool operator< (const SortCursorWithCollation & rhs) const
{
return greater(rhs);
}
}; };
} }

View File

@ -60,4 +60,7 @@ template <> struct TypeName<Float32> { static std::string get() { return "Float
template <> struct TypeName<Float64> { static std::string get() { return "Float64"; } }; template <> struct TypeName<Float64> { static std::string get() { return "Float64"; } };
template <> struct TypeName<String> { static std::string get() { return "String"; } }; template <> struct TypeName<String> { static std::string get() { return "String"; } };
/// Этот тип не поддерживается СУБД, но используется в некоторых внутренних преобразованиях.
template <> struct TypeName<long double>{ static std::string get() { return "long double"; } };
} }

View File

@ -30,11 +30,6 @@ public:
children.push_back(input_); children.push_back(input_);
} }
AddingDefaultBlockInputStream(BlockInputStreamPtr input_, NamesAndTypesListPtr required_columns_, const Context & context_)
: AddingDefaultBlockInputStream{input_, required_columns_, ColumnDefaults{}, context_}
{
}
String getName() const override { return "AddingDefault"; } String getName() const override { return "AddingDefault"; }
String getID() const override String getID() const override
@ -65,7 +60,7 @@ protected:
private: private:
NamesAndTypesListPtr required_columns; NamesAndTypesListPtr required_columns;
const ColumnDefaults & column_defaults; const ColumnDefaults column_defaults;
Context context; Context context;
}; };

View File

@ -58,7 +58,7 @@ public:
private: private:
BlockOutputStreamPtr output; BlockOutputStreamPtr output;
NamesAndTypesListPtr required_columns; NamesAndTypesListPtr required_columns;
const ColumnDefaults & column_defaults; const ColumnDefaults column_defaults;
Context context; Context context;
bool only_explicit_column_defaults; bool only_explicit_column_defaults;
}; };

View File

@ -1,7 +1,11 @@
#pragma once #pragma once
#include <DB/Interpreters/Aggregator.h> #include <DB/Interpreters/Aggregator.h>
#include <DB/IO/ReadBufferFromFile.h>
#include <DB/IO/CompressedReadBuffer.h>
#include <DB/DataStreams/IProfilingBlockInputStream.h> #include <DB/DataStreams/IProfilingBlockInputStream.h>
#include <DB/DataStreams/NativeBlockInputStream.h>
#include <common/Revision.h>
namespace DB namespace DB
@ -22,12 +26,8 @@ public:
* Агрегатные функции ищутся везде в выражении. * Агрегатные функции ищутся везде в выражении.
* Столбцы, соответствующие keys и аргументам агрегатных функций, уже должны быть вычислены. * Столбцы, соответствующие keys и аргументам агрегатных функций, уже должны быть вычислены.
*/ */
AggregatingBlockInputStream(BlockInputStreamPtr input_, const Names & key_names, const AggregateDescriptions & aggregates, AggregatingBlockInputStream(BlockInputStreamPtr input_, const Aggregator::Params & params_, bool final_)
bool overflow_row_, bool final_, size_t max_rows_to_group_by_, OverflowMode group_by_overflow_mode_, : params(params_), aggregator(params), final(final_)
Compiler * compiler_, UInt32 min_count_to_compile_, size_t group_by_two_level_threshold_)
: aggregator(key_names, aggregates, overflow_row_, max_rows_to_group_by_, group_by_overflow_mode_,
compiler_, min_count_to_compile_, group_by_two_level_threshold_),
final(final_)
{ {
children.push_back(input_); children.push_back(input_);
} }
@ -44,12 +44,28 @@ public:
protected: protected:
Block readImpl() override; Block readImpl() override;
Aggregator::Params params;
Aggregator aggregator; Aggregator aggregator;
bool final; bool final;
bool executed = false; bool executed = false;
BlocksList blocks;
BlocksList::iterator it; /// Для чтения сброшенных во временный файл данных.
struct TemporaryFileStream
{
ReadBufferFromFile file_in;
CompressedReadBuffer compressed_in;
BlockInputStreamPtr block_in;
TemporaryFileStream(const std::string & path)
: file_in(path), compressed_in(file_in), block_in(new NativeBlockInputStream(compressed_in, Revision::get())) {}
};
std::vector<std::unique_ptr<TemporaryFileStream>> temporary_inputs;
/** Отсюда будем доставать готовые блоки после агрегации. */
std::unique_ptr<IBlockInputStream> impl;
Logger * log = &Logger::get("AggregatingBlockInputStream");
}; };
} }

View File

@ -22,6 +22,8 @@ public:
void flush() override { ostr.next(); } void flush() override { ostr.next(); }
String getContentType() const override { return "application/octet-stream"; }
protected: protected:
WriteBuffer & ostr; WriteBuffer & ostr;
const Block sample; const Block sample;

View File

@ -24,6 +24,8 @@ public:
void setTotals(const Block & totals) override; void setTotals(const Block & totals) override;
void setExtremes(const Block & extremes) override; void setExtremes(const Block & extremes) override;
String getContentType() const override { return row_output->getContentType(); }
private: private:
RowOutputStreamPtr row_output; RowOutputStreamPtr row_output;
bool first_row; bool first_row;

View File

@ -0,0 +1,49 @@
#pragma once
#include <DB/DataStreams/IProfilingBlockInputStream.h>
namespace DB
{
/** Поток блоков, из которого можно прочитать следующий блок из явно предоставленного списка.
* Также смотрите OneBlockInputStream.
*/
class BlocksListBlockInputStream : public IProfilingBlockInputStream
{
public:
/// Захватывает владение списком блоков.
BlocksListBlockInputStream(BlocksList && list_)
: list(std::move(list_)), it(list.begin()), end(list.end()) {}
/// Использует лежащий где-то ещё список блоков.
BlocksListBlockInputStream(BlocksList::iterator & begin_, BlocksList::iterator & end_)
: it(begin_), end(end_) {}
String getName() const override { return "BlocksList"; }
String getID() const override
{
std::stringstream res;
res << this;
return res.str();
}
protected:
Block readImpl() override
{
if (it == end)
return Block();
Block res = *it;
++it;
return res;
}
private:
BlocksList list;
BlocksList::iterator it;
const BlocksList::iterator end;
};
}

View File

@ -42,6 +42,10 @@ public:
virtual void setTotals(const Block & totals) {} virtual void setTotals(const Block & totals) {}
virtual void setExtremes(const Block & extremes) {} virtual void setExtremes(const Block & extremes) {}
/** Выставлять такой Content-Type при отдаче по HTTP.
*/
virtual String getContentType() const { return "text/plain; charset=UTF-8"; }
virtual ~IBlockOutputStream() {} virtual ~IBlockOutputStream() {}
/** Не давать изменить таблицу, пока жив поток блоков. /** Не давать изменить таблицу, пока жив поток блоков.

View File

@ -169,6 +169,9 @@ protected:
/// Информация о приблизительном общем количестве строк собрана в родительском источнике. /// Информация о приблизительном общем количестве строк собрана в родительском источнике.
bool collected_total_rows_approx = false; bool collected_total_rows_approx = false;
/// Превышено ограничение на количество строк/байт, и нужно прекратить выполнение на следующем вызове read, как будто поток иссяк.
bool limit_exceeded_need_break = false;
/// Ограничения и квоты. /// Ограничения и квоты.
LocalLimits limits; LocalLimits limits;

View File

@ -41,6 +41,9 @@ public:
virtual void setTotals(const Block & totals) {} virtual void setTotals(const Block & totals) {}
virtual void setExtremes(const Block & extremes) {} virtual void setExtremes(const Block & extremes) {}
/** Выставлять такой Content-Type при отдаче по HTTP. */
virtual String getContentType() const { return "text/plain; charset=UTF-8"; }
virtual ~IRowOutputStream() {} virtual ~IRowOutputStream() {}
}; };

View File

@ -42,6 +42,8 @@ public:
void setTotals(const Block & totals_) override { totals = totals_; } void setTotals(const Block & totals_) override { totals = totals_; }
void setExtremes(const Block & extremes_) override { extremes = extremes_; } void setExtremes(const Block & extremes_) override { extremes = extremes_; }
String getContentType() const override { return "application/json; charset=UTF-8"; }
protected: protected:
void writeRowsBeforeLimitAtLeast(); void writeRowsBeforeLimitAtLeast();

View File

@ -16,10 +16,8 @@ using Poco::SharedPtr;
class MergingAggregatedBlockInputStream : public IProfilingBlockInputStream class MergingAggregatedBlockInputStream : public IProfilingBlockInputStream
{ {
public: public:
MergingAggregatedBlockInputStream(BlockInputStreamPtr input_, const Names & keys_names_, MergingAggregatedBlockInputStream(BlockInputStreamPtr input_, const Aggregator::Params & params, bool final_, size_t max_threads_)
const AggregateDescriptions & aggregates_, bool overflow_row_, bool final_, size_t max_threads_) : aggregator(params), final(final_), max_threads(max_threads_)
: aggregator(keys_names_, aggregates_, overflow_row_, 0, OverflowMode::THROW, nullptr, 0, 0),
final(final_), max_threads(max_threads_)
{ {
children.push_back(input_); children.push_back(input_);
} }

View File

@ -1,7 +1,9 @@
#pragma once #pragma once
#include <common/threadpool.hpp>
#include <DB/Interpreters/Aggregator.h> #include <DB/Interpreters/Aggregator.h>
#include <DB/DataStreams/IProfilingBlockInputStream.h> #include <DB/DataStreams/IProfilingBlockInputStream.h>
#include <DB/Common/ConcurrentBoundedQueue.h>
namespace DB namespace DB
@ -19,18 +21,14 @@ namespace DB
* удалённых серверов делаются последовательно, при этом, чтение упирается в CPU. * удалённых серверов делаются последовательно, при этом, чтение упирается в CPU.
* Это несложно исправить. * Это несложно исправить.
* *
* Также, чтения и вычисления (слияние состояний) делаются по очереди.
* Есть возможность делать чтения асинхронно - при этом будет расходоваться в два раза больше памяти, но всё-равно немного.
* Это можно сделать с помощью UnionBlockInputStream.
*
* Можно держать в памяти не по одному блоку из каждого источника, а по несколько, и распараллелить мердж. * Можно держать в памяти не по одному блоку из каждого источника, а по несколько, и распараллелить мердж.
* При этом будет расходоваться кратно больше оперативки. * При этом будет расходоваться кратно больше оперативки.
*/ */
class MergingAggregatedMemoryEfficientBlockInputStream : public IProfilingBlockInputStream class MergingAggregatedMemoryEfficientBlockInputStream : public IProfilingBlockInputStream
{ {
public: public:
MergingAggregatedMemoryEfficientBlockInputStream(BlockInputStreams inputs_, const Names & keys_names_, MergingAggregatedMemoryEfficientBlockInputStream(
const AggregateDescriptions & aggregates_, bool overflow_row_, bool final_); BlockInputStreams inputs_, const Aggregator::Params & params, bool final_, size_t threads_);
String getName() const override { return "MergingAggregatedMemoryEfficient"; } String getName() const override { return "MergingAggregatedMemoryEfficient"; }
@ -42,6 +40,7 @@ protected:
private: private:
Aggregator aggregator; Aggregator aggregator;
bool final; bool final;
size_t threads;
bool started = false; bool started = false;
bool has_two_level = false; bool has_two_level = false;
@ -60,6 +59,37 @@ private:
}; };
std::vector<Input> inputs; std::vector<Input> inputs;
using BlocksToMerge = Poco::SharedPtr<BlocksList>;
/// Получить блоки, которые можно мерджить. Это позволяет мерджить их параллельно в отдельных потоках.
BlocksToMerge getNextBlocksToMerge();
/// Для параллельного мерджа.
struct OutputData
{
Block block;
std::exception_ptr exception;
OutputData() {}
OutputData(Block && block_) : block(std::move(block_)) {}
OutputData(std::exception_ptr && exception_) : exception(std::move(exception_)) {}
};
struct ParallelMergeData
{
boost::threadpool::pool pool;
std::mutex get_next_blocks_mutex;
ConcurrentBoundedQueue<OutputData> result_queue;
bool exhausted = false;
std::atomic<size_t> active_threads;
ParallelMergeData(size_t max_threads) : pool(max_threads), result_queue(max_threads), active_threads(max_threads) {}
};
std::unique_ptr<ParallelMergeData> parallel_merge_data;
void mergeThread(MemoryTracker * memory_tracker);
}; };
} }

View File

@ -136,7 +136,7 @@ private:
void initQueue(std::priority_queue<TSortCursor> & queue); void initQueue(std::priority_queue<TSortCursor> & queue);
template <typename TSortCursor> template <typename TSortCursor>
void merge(ColumnPlainPtrs & merged_columns, std::priority_queue<TSortCursor> & queue); void merge(Block & merged_block, ColumnPlainPtrs & merged_columns, std::priority_queue<TSortCursor> & queue);
Logger * log = &Logger::get("MergingSortedBlockInputStream"); Logger * log = &Logger::get("MergingSortedBlockInputStream");

View File

@ -30,6 +30,8 @@ public:
static void writeData(const IDataType & type, const ColumnPtr & column, WriteBuffer & ostr, size_t offset, size_t limit); static void writeData(const IDataType & type, const ColumnPtr & column, WriteBuffer & ostr, size_t offset, size_t limit);
String getContentType() const override { return "application/octet-stream"; }
private: private:
WriteBuffer & ostr; WriteBuffer & ostr;
UInt64 client_revision; UInt64 client_revision;

View File

@ -0,0 +1,30 @@
#pragma once
#include <DB/DataStreams/IBlockOutputStream.h>
namespace DB
{
/** Формат данных, предназначенный для упрощения реализации ODBC драйвера.
* ODBC драйвер предназначен для сборки под разные платформы без зависимостей от основного кода,
* поэтому формат сделан так, чтобы в нём можно было как можно проще его распарсить.
* Выводится заголовок с нужной информацией.
* Затем данные выводятся в порядке строк. Каждое значение выводится так: длина в формате VarUInt, затем данные в текстовом виде.
*/
class ODBCBlockOutputStream : public IBlockOutputStream
{
public:
ODBCBlockOutputStream(WriteBuffer & out_);
void write(const Block & block) override;
void flush() override { out.next(); }
String getContentType() const override { return "application/octet-stream"; }
private:
bool is_first = true;
WriteBuffer & out;
};
}

View File

@ -1,17 +1,13 @@
#pragma once #pragma once
#include <Poco/SharedPtr.h>
#include <DB/DataStreams/IProfilingBlockInputStream.h> #include <DB/DataStreams/IProfilingBlockInputStream.h>
namespace DB namespace DB
{ {
using Poco::SharedPtr;
/** Поток блоков, из которого можно прочитать один блок. /** Поток блоков, из которого можно прочитать один блок.
* Также смотрите BlocksListBlockInputStream.
*/ */
class OneBlockInputStream : public IProfilingBlockInputStream class OneBlockInputStream : public IProfilingBlockInputStream
{ {

View File

@ -1,8 +1,14 @@
#pragma once #pragma once
#include <DB/Interpreters/Aggregator.h> #include <DB/Interpreters/Aggregator.h>
#include <DB/IO/ReadBufferFromFile.h>
#include <DB/IO/CompressedReadBuffer.h>
#include <DB/DataStreams/IProfilingBlockInputStream.h> #include <DB/DataStreams/IProfilingBlockInputStream.h>
#include <DB/DataStreams/BlocksListBlockInputStream.h>
#include <DB/DataStreams/NativeBlockInputStream.h>
#include <DB/DataStreams/MergingAggregatedMemoryEfficientBlockInputStream.h>
#include <DB/DataStreams/ParallelInputsProcessor.h> #include <DB/DataStreams/ParallelInputsProcessor.h>
#include <common/Revision.h>
namespace DB namespace DB
@ -23,14 +29,10 @@ public:
*/ */
ParallelAggregatingBlockInputStream( ParallelAggregatingBlockInputStream(
BlockInputStreams inputs, BlockInputStreamPtr additional_input_at_end, BlockInputStreams inputs, BlockInputStreamPtr additional_input_at_end,
const Names & key_names, const AggregateDescriptions & aggregates, const Aggregator::Params & params_, bool final_, size_t max_threads_, size_t temporary_data_merge_threads_)
bool overflow_row_, bool final_, size_t max_threads_, : params(params_), aggregator(params),
size_t max_rows_to_group_by_, OverflowMode group_by_overflow_mode_, final(final_), max_threads(std::min(inputs.size(), max_threads_)), temporary_data_merge_threads(temporary_data_merge_threads_),
Compiler * compiler_, UInt32 min_count_to_compile_, size_t group_by_two_level_threshold_) keys_size(params.keys_size), aggregates_size(params.aggregates_size),
: aggregator(key_names, aggregates, overflow_row_, max_rows_to_group_by_, group_by_overflow_mode_,
compiler_, min_count_to_compile_, group_by_two_level_threshold_),
final(final_), max_threads(std::min(inputs.size(), max_threads_)),
keys_size(aggregator.getNumberOfKeys()), aggregates_size(aggregator.getNumberOfAggregates()),
handler(*this), processor(inputs, additional_input_at_end, max_threads, handler) handler(*this), processor(inputs, additional_input_at_end, max_threads, handler)
{ {
children = inputs; children = inputs;
@ -78,28 +80,59 @@ protected:
Aggregator::CancellationHook hook = [&]() { return this->isCancelled(); }; Aggregator::CancellationHook hook = [&]() { return this->isCancelled(); };
aggregator.setCancellationHook(hook); aggregator.setCancellationHook(hook);
AggregatedDataVariantsPtr data_variants = executeAndMerge(); execute();
if (isCancelled())
return {};
if (!aggregator.hasTemporaryFiles())
{
/** Если все частично-агрегированные данные в оперативке, то мерджим их параллельно, тоже в оперативке.
* NOTE Если израсходовано больше половины допустимой памяти, то мерджить следовало бы более экономно.
*/
AggregatedDataVariantsPtr data_variants = aggregator.merge(many_data, max_threads);
if (data_variants) if (data_variants)
blocks = aggregator.convertToBlocks(*data_variants, final, max_threads); impl.reset(new BlocksListBlockInputStream(
aggregator.convertToBlocks(*data_variants, final, max_threads)));
}
else
{
/** Если есть временные файлы с частично-агрегированными данными на диске,
* то читаем и мерджим их, расходуя минимальное количество памяти.
*/
it = blocks.begin(); ProfileEvents::increment(ProfileEvents::ExternalAggregationMerge);
const auto & files = aggregator.getTemporaryFiles();
BlockInputStreams input_streams;
for (const auto & file : files.files)
{
temporary_inputs.emplace_back(new TemporaryFileStream(file->path()));
input_streams.emplace_back(temporary_inputs.back()->block_in);
}
LOG_TRACE(log, "Will merge " << files.files.size() << " temporary files of size "
<< (files.sum_size_compressed / 1048576.0) << " MiB compressed, "
<< (files.sum_size_uncompressed / 1048576.0) << " MiB uncompressed.");
impl.reset(new MergingAggregatedMemoryEfficientBlockInputStream(input_streams, params, final, temporary_data_merge_threads));
}
} }
Block res; Block res;
if (isCancelled() || it == blocks.end()) if (isCancelled() || !impl)
return res; return res;
res = *it; return impl->read();
++it;
return res;
} }
private: private:
Aggregator::Params params;
Aggregator aggregator; Aggregator aggregator;
bool final; bool final;
size_t max_threads; size_t max_threads;
size_t temporary_data_merge_threads;
size_t keys_size; size_t keys_size;
size_t aggregates_size; size_t aggregates_size;
@ -112,8 +145,22 @@ private:
bool no_more_keys = false; bool no_more_keys = false;
bool executed = false; bool executed = false;
BlocksList blocks;
BlocksList::iterator it; /// Для чтения сброшенных во временный файл данных.
struct TemporaryFileStream
{
ReadBufferFromFile file_in;
CompressedReadBuffer compressed_in;
BlockInputStreamPtr block_in;
TemporaryFileStream(const std::string & path)
: file_in(path), compressed_in(file_in), block_in(new NativeBlockInputStream(compressed_in, Revision::get())) {}
};
std::vector<std::unique_ptr<TemporaryFileStream>> temporary_inputs;
/** Отсюда будем доставать готовые блоки после агрегации.
*/
std::unique_ptr<IBlockInputStream> impl;
Logger * log = &Logger::get("ParallelAggregatingBlockInputStream"); Logger * log = &Logger::get("ParallelAggregatingBlockInputStream");
@ -159,8 +206,31 @@ private:
parent.threads_data[thread_num].src_bytes += block.bytes(); parent.threads_data[thread_num].src_bytes += block.bytes();
} }
void onFinishThread(size_t thread_num)
{
if (parent.aggregator.hasTemporaryFiles())
{
/// Сбросим имеющиеся в оперативке данные тоже на диск. Так проще их потом объединять.
auto & data = *parent.many_data[thread_num];
size_t rows = data.sizeWithoutOverflowRow();
if (rows)
parent.aggregator.writeToTemporaryFile(data, rows);
}
}
void onFinish() void onFinish()
{ {
if (parent.aggregator.hasTemporaryFiles())
{
/// Может так получиться, что какие-то данные ещё не сброшены на диск,
/// потому что во время вызова onFinishThread ещё никакие данные не были сброшены на диск, а потом какие-то - были.
for (auto & data : parent.many_data)
{
size_t rows = data->sizeWithoutOverflowRow();
if (rows)
parent.aggregator.writeToTemporaryFile(*data, rows);
}
}
} }
void onException(std::exception_ptr & exception, size_t thread_num) void onException(std::exception_ptr & exception, size_t thread_num)
@ -176,7 +246,7 @@ private:
ParallelInputsProcessor<Handler> processor; ParallelInputsProcessor<Handler> processor;
AggregatedDataVariantsPtr executeAndMerge() void execute()
{ {
many_data.resize(max_threads); many_data.resize(max_threads);
exceptions.resize(max_threads); exceptions.resize(max_threads);
@ -197,7 +267,7 @@ private:
rethrowFirstException(exceptions); rethrowFirstException(exceptions);
if (isCancelled()) if (isCancelled())
return nullptr; return;
double elapsed_seconds = watch.elapsedSeconds(); double elapsed_seconds = watch.elapsedSeconds();
@ -220,11 +290,6 @@ private:
<< "Total aggregated. " << total_src_rows << " rows (from " << total_src_bytes / 1048576.0 << " MiB)" << "Total aggregated. " << total_src_rows << " rows (from " << total_src_bytes / 1048576.0 << " MiB)"
<< " in " << elapsed_seconds << " sec." << " in " << elapsed_seconds << " sec."
<< " (" << total_src_rows / elapsed_seconds << " rows/sec., " << total_src_bytes / elapsed_seconds / 1048576.0 << " MiB/sec.)"); << " (" << total_src_rows / elapsed_seconds << " rows/sec., " << total_src_bytes / elapsed_seconds / 1048576.0 << " MiB/sec.)");
if (isCancelled())
return nullptr;
return aggregator.merge(many_data, max_threads);
} }
}; };

View File

@ -43,6 +43,11 @@ struct ParallelInputsHandler
/// Обработка блока данных + дополнительных информаций. /// Обработка блока данных + дополнительных информаций.
void onBlock(Block & block, BlockExtraInfo & extra_info, size_t thread_num) {} void onBlock(Block & block, BlockExtraInfo & extra_info, size_t thread_num) {}
/// Вызывается для каждого потока, когда потоку стало больше нечего делать.
/// Из-за того, что иссякла часть источников, и сейчас источников осталось меньше, чем потоков.
/// Вызывается, если метод onException не кидает исключение; вызывается до метода onFinish.
void onFinishThread(size_t thread_num) {}
/// Блоки закончились. Из-за того, что все источники иссякли или из-за отмены работы. /// Блоки закончились. Из-за того, что все источники иссякли или из-за отмены работы.
/// Этот метод всегда вызывается ровно один раз, в конце работы, если метод onException не кидает исключение. /// Этот метод всегда вызывается ровно один раз, в конце работы, если метод onException не кидает исключение.
void onFinish() {} void onFinish() {}
@ -182,6 +187,8 @@ private:
handler.onException(exception, thread_num); handler.onException(exception, thread_num);
} }
handler.onFinishThread(thread_num);
/// Последний поток при выходе сообщает, что данных больше нет. /// Последний поток при выходе сообщает, что данных больше нет.
if (0 == --active_threads) if (0 == --active_threads)
{ {

View File

@ -3,67 +3,63 @@
#include <common/logger_useful.h> #include <common/logger_useful.h>
#include <DB/DataStreams/IProfilingBlockInputStream.h> #include <DB/DataStreams/IProfilingBlockInputStream.h>
#include <DB/DataStreams/OneBlockInputStream.h>
#include <DB/Common/VirtualColumnUtils.h>
#include <DB/Common/Throttler.h> #include <DB/Common/Throttler.h>
#include <DB/Interpreters/Context.h> #include <DB/Interpreters/Context.h>
#include <DB/Client/ConnectionPool.h> #include <DB/Client/ConnectionPool.h>
#include <DB/Client/ParallelReplicas.h> #include <DB/Client/MultiplexedConnections.h>
namespace DB namespace DB
{ {
/** Позволяет выполнить запрос (SELECT) на удалённых репликах одного шарда и получить результат. /** Позволяет выполнить запрос на удалённых репликах одного шарда и получить результат.
*/ */
class RemoteBlockInputStream : public IProfilingBlockInputStream class RemoteBlockInputStream : public IProfilingBlockInputStream
{ {
private:
void init(const Settings * settings_)
{
if (settings_)
{
send_settings = true;
settings = *settings_;
}
else
send_settings = false;
}
public: public:
/// Принимает готовое соединение. /// Принимает готовое соединение.
RemoteBlockInputStream(Connection & connection_, const String & query_, const Settings * settings_, ThrottlerPtr throttler_ = nullptr, RemoteBlockInputStream(Connection & connection_, const String & query_, const Settings * settings_,
const Tables & external_tables_ = Tables(), QueryProcessingStage::Enum stage_ = QueryProcessingStage::Complete, ThrottlerPtr throttler_ = nullptr, const Tables & external_tables_ = Tables(),
const Context & context = getDefaultContext()) QueryProcessingStage::Enum stage_ = QueryProcessingStage::Complete,
: connection(&connection_), query(query_), throttler(throttler_), external_tables(external_tables_), stage(stage_), context(context) const Context & context_ = getDefaultContext());
{
init(settings_);
}
/// Принимает готовое соединение. Захватывает владение соединением из пула. /// Принимает готовое соединение. Захватывает владение соединением из пула.
RemoteBlockInputStream(ConnectionPool::Entry & pool_entry_, const String & query_, const Settings * settings_, ThrottlerPtr throttler_ = nullptr, RemoteBlockInputStream(ConnectionPool::Entry & pool_entry_, const String & query_, const Settings * settings_,
const Tables & external_tables_ = Tables(), QueryProcessingStage::Enum stage_ = QueryProcessingStage::Complete, ThrottlerPtr throttler_ = nullptr, const Tables & external_tables_ = Tables(),
const Context & context = getDefaultContext()) QueryProcessingStage::Enum stage_ = QueryProcessingStage::Complete,
: pool_entry(pool_entry_), connection(&*pool_entry_), query(query_), throttler(throttler_), const Context & context_ = getDefaultContext());
external_tables(external_tables_), stage(stage_), context(context)
{
init(settings_);
}
/// Принимает пул, из которого нужно будет достать одно или несколько соединений. /// Принимает пул, из которого нужно будет достать одно или несколько соединений.
RemoteBlockInputStream(IConnectionPool * pool_, const String & query_, const Settings * settings_, ThrottlerPtr throttler_ = nullptr, RemoteBlockInputStream(IConnectionPool * pool_, const String & query_, const Settings * settings_,
const Tables & external_tables_ = Tables(), QueryProcessingStage::Enum stage_ = QueryProcessingStage::Complete, ThrottlerPtr throttler_ = nullptr, const Tables & external_tables_ = Tables(),
const Context & context = getDefaultContext()) QueryProcessingStage::Enum stage_ = QueryProcessingStage::Complete,
: pool(pool_), query(query_), throttler(throttler_), external_tables(external_tables_), stage(stage_), context(context) const Context & context_ = getDefaultContext());
{
init(settings_);
}
/// Принимает пулы - один для каждого шарда, из которых нужно будет достать одно или несколько соединений.
RemoteBlockInputStream(ConnectionPoolsPtr & pools_, const String & query_, const Settings * settings_,
ThrottlerPtr throttler_ = nullptr, const Tables & external_tables_ = Tables(),
QueryProcessingStage::Enum stage_ = QueryProcessingStage::Complete,
const Context & context_ = getDefaultContext());
~RemoteBlockInputStream() override;
/// Отправить запрос на все существующие реплики.
void doBroadcast();
/// Кроме блоков, получить информацию о блоках.
void appendExtraInfo();
/// Отправляет запрос (инициирует вычисления) раньше, чем read.
void readPrefix() override;
/** Отменяем умолчальное уведомление о прогрессе,
* так как колбэк прогресса вызывается самостоятельно.
*/
void progress(const Progress & value) override {}
void cancel() override;
String getName() const override { return "Remote"; } String getName() const override { return "Remote"; }
String getID() const override String getID() const override
{ {
std::stringstream res; std::stringstream res;
@ -71,249 +67,35 @@ public:
return res.str(); return res.str();
} }
/** Отменяем умолчальное уведомление о прогрессе,
* так как колбэк прогресса вызывается самостоятельно.
*/
void progress(const Progress & value) override {}
void cancel() override
{
bool old_val = false;
if (!is_cancelled.compare_exchange_strong(old_val, true, std::memory_order_seq_cst, std::memory_order_relaxed))
return;
{
std::lock_guard<std::mutex> lock(external_tables_mutex);
/// Останавливаем отправку внешних данных.
for (auto & vec : external_tables_data)
for (auto & elem : vec)
if (IProfilingBlockInputStream * stream = dynamic_cast<IProfilingBlockInputStream *>(elem.first.get()))
stream->cancel();
}
if (!isQueryPending() || hasThrownException())
return;
tryCancel("Cancelling query");
}
~RemoteBlockInputStream() override
{
/** Если прервались в середине цикла общения с репликами, то прервываем
* все соединения, затем читаем и пропускаем оставшиеся пакеты чтобы
* эти соединения не остались висеть в рассихронизированном состоянии.
*/
if (established || isQueryPending())
parallel_replicas->disconnect();
}
/// Отправляет запрос (инициирует вычисления) раньше, чем read.
void readPrefix() override
{
if (!sent_query)
sendQuery();
}
/// Отправить запрос на все существующие реплики.
void reachAllReplicas()
{
reach_all_replicas = true;
}
/// Кроме блоков, получить информацию о блоках.
void appendExtraInfo()
{
append_extra_info = true;
}
BlockExtraInfo getBlockExtraInfo() const override BlockExtraInfo getBlockExtraInfo() const override
{ {
return parallel_replicas->getBlockExtraInfo(); return multiplexed_connections->getBlockExtraInfo();
} }
protected: protected:
/// Отправить на удаленные серверы все временные таблицы. /// Отправить на удаленные серверы все временные таблицы.
void sendExternalTables() void sendExternalTables();
{
size_t count = parallel_replicas->size();
{ Block readImpl() override;
std::lock_guard<std::mutex> lock(external_tables_mutex);
external_tables_data.reserve(count); void readSuffixImpl() override;
for (size_t i = 0; i < count; ++i)
{
ExternalTablesData res;
for (const auto & table : external_tables)
{
StoragePtr cur = table.second;
QueryProcessingStage::Enum stage = QueryProcessingStage::Complete;
DB::BlockInputStreams input = cur->read(cur->getColumnNamesList(), ASTPtr(), context, settings,
stage, DEFAULT_BLOCK_SIZE, 1);
if (input.size() == 0)
res.push_back(std::make_pair(new OneBlockInputStream(cur->getSampleBlock()), table.first));
else
res.push_back(std::make_pair(input[0], table.first));
}
external_tables_data.push_back(std::move(res));
}
}
parallel_replicas->sendExternalTablesData(external_tables_data);
}
Block readImpl() override
{
if (!sent_query)
{
sendQuery();
if (settings.skip_unavailable_shards && 0 == parallel_replicas->size())
return {};
}
while (true)
{
if (isCancelled())
return Block();
Connection::Packet packet = parallel_replicas->receivePacket();
switch (packet.type)
{
case Protocol::Server::Data:
/// Если блок не пуст и не является заголовочным блоком
if (packet.block && packet.block.rows() > 0)
return packet.block;
break; /// Если блок пустой - получим другие пакеты до EndOfStream.
case Protocol::Server::Exception:
got_exception_from_replica = true;
packet.exception->rethrow();
break;
case Protocol::Server::EndOfStream:
if (!parallel_replicas->hasActiveReplicas())
{
finished = true;
return Block();
}
break;
case Protocol::Server::Progress:
/** Используем прогресс с удалённого сервера.
* В том числе, запишем его в ProcessList,
* и будем использовать его для проверки
* ограничений (например, минимальная скорость выполнения запроса)
* и квот (например, на количество строчек для чтения).
*/
progressImpl(packet.progress);
break;
case Protocol::Server::ProfileInfo:
info = packet.profile_info;
break;
case Protocol::Server::Totals:
totals = packet.block;
break;
case Protocol::Server::Extremes:
extremes = packet.block;
break;
default:
got_unknown_packet_from_replica = true;
throw Exception("Unknown packet from server", ErrorCodes::UNKNOWN_PACKET_FROM_SERVER);
}
}
}
void readSuffixImpl() override
{
/** Если одно из:
* - ничего не начинали делать;
* - получили все пакеты до EndOfStream;
* - получили с одной реплики эксепшен;
* - получили с одной реплики неизвестный пакет;
* - то больше читать ничего не нужно.
*/
if (!isQueryPending() || hasThrownException())
return;
/** Если ещё прочитали не все данные, но они больше не нужны.
* Это может быть из-за того, что данных достаточно (например, при использовании LIMIT).
*/
/// Отправим просьбу прервать выполнение запроса, если ещё не отправляли.
tryCancel("Cancelling query because enough data has been read");
/// Получим оставшиеся пакеты, чтобы не было рассинхронизации в соединениях с репликами.
Connection::Packet packet = parallel_replicas->drain();
switch (packet.type)
{
case Protocol::Server::EndOfStream:
finished = true;
break;
case Protocol::Server::Exception:
got_exception_from_replica = true;
packet.exception->rethrow();
break;
default:
got_unknown_packet_from_replica = true;
throw Exception("Unknown packet from server", ErrorCodes::UNKNOWN_PACKET_FROM_SERVER);
}
}
/// Создать объект для общения с репликами одного шарда, на которых должен выполниться запрос. /// Создать объект для общения с репликами одного шарда, на которых должен выполниться запрос.
void createParallelReplicas() void createMultiplexedConnections();
{
Settings * parallel_replicas_settings = send_settings ? &settings : nullptr;
if (connection != nullptr)
parallel_replicas = std::make_unique<ParallelReplicas>(connection, parallel_replicas_settings, throttler);
else
parallel_replicas = std::make_unique<ParallelReplicas>(pool, parallel_replicas_settings, throttler,
append_extra_info, reach_all_replicas);
}
/// Возвращает true, если запрос отправлен. /// Возвращает true, если запрос отправлен.
bool isQueryPending() const bool isQueryPending() const;
{
return sent_query && !finished;
}
/// Возвращает true, если исключение было выкинуто. /// Возвращает true, если исключение было выкинуто.
bool hasThrownException() const bool hasThrownException() const;
{
return got_exception_from_replica || got_unknown_packet_from_replica;
}
private: private:
void sendQuery() void init(const Settings * settings_);
{
createParallelReplicas();
if (settings.skip_unavailable_shards && 0 == parallel_replicas->size()) void sendQuery();
return;
established = true; /// Отправить запрос на отмену всех соединений к репликам, если такой запрос ещё не был отправлен.
void tryCancel(const char * reason);
parallel_replicas->sendQuery(query, "", stage, true);
established = false;
sent_query = true;
sendExternalTables();
}
/// ITable::read requires a Context, therefore we should create one if the user can't supply it /// ITable::read requires a Context, therefore we should create one if the user can't supply it
static Context & getDefaultContext() static Context & getDefaultContext()
@ -322,23 +104,18 @@ private:
return instance; return instance;
} }
/// Отправить запрос на отмену всех соединений к репликам, если такой запрос ещё не был отправлен.
void tryCancel(const char * reason)
{
bool old_val = false;
if (!was_cancelled.compare_exchange_strong(old_val, true, std::memory_order_seq_cst, std::memory_order_relaxed))
return;
LOG_TRACE(log, "(" << parallel_replicas->dumpAddresses() << ") " << reason);
parallel_replicas->sendCancel();
}
private: private:
IConnectionPool * pool = nullptr; /// Готовое соединение.
ConnectionPool::Entry pool_entry; ConnectionPool::Entry pool_entry;
Connection * connection = nullptr; Connection * connection = nullptr;
std::unique_ptr<ParallelReplicas> parallel_replicas;
/// Пул соединений одного шарда.
IConnectionPool * pool = nullptr;
/// Пулы соединений одного или нескольких шардов.
ConnectionPoolsPtr pools;
std::unique_ptr<MultiplexedConnections> multiplexed_connections;
const String query; const String query;
bool send_settings; bool send_settings;
@ -384,7 +161,7 @@ private:
std::atomic<bool> got_unknown_packet_from_replica { false }; std::atomic<bool> got_unknown_packet_from_replica { false };
bool append_extra_info = false; bool append_extra_info = false;
bool reach_all_replicas = false; bool do_broadcast = false;
Logger * log = &Logger::get("RemoteBlockInputStream"); Logger * log = &Logger::get("RemoteBlockInputStream");
}; };

View File

@ -129,154 +129,13 @@ private:
* все элементы - нулевые. * все элементы - нулевые.
*/ */
template <class TSortCursor> template <class TSortCursor>
bool mergeMaps(Row & row, TSortCursor & cursor) bool mergeMaps(Row & row, TSortCursor & cursor);
{
auto non_empty_map_present = false;
/// merge nested maps
for (const auto & map : maps_to_sum)
{
const auto val_count = map.val_col_nums.size();
/// fetch key array reference from accumulator-row
auto & key_array_lhs = row[map.key_col_num].get<Array>();
/// returns a Field for pos-th item of val_index-th value
const auto val_getter_lhs = [&] (const auto val_index, const auto pos) -> decltype(auto) {
return row[map.val_col_nums[val_index]].get<Array>()[pos];
};
/// we will be sorting key positions, not the entire rows, to minimize actions
std::vector<std::size_t> key_pos_lhs(ext::range_iterator<std::size_t>{0},
ext::range_iterator<std::size_t>{key_array_lhs.size()});
std::sort(std::begin(key_pos_lhs), std::end(key_pos_lhs), [&] (const auto pos1, const auto pos2) {
return key_array_lhs[pos1] < key_array_lhs[pos2];
});
/// copy key field from current row under cursor
const auto key_field_rhs = (*cursor->all_columns[map.key_col_num])[cursor->pos];
/// for each element of `map.val_col_nums` copy corresponding array under cursor into vector
const auto val_fields_rhs = ext::map<std::vector>(map.val_col_nums,
[&] (const auto col_num) -> decltype(auto) {
return (*cursor->all_columns[col_num])[cursor->pos];
});
/// fetch key array reference from row under cursor
const auto & key_array_rhs = key_field_rhs.get<Array>();
/// returns a Field for pos-th item of val_index-th value
const auto val_getter_rhs = [&] (const auto val_index, const auto pos) -> decltype(auto) {
return val_fields_rhs[val_index].get<Array>()[pos];
};
std::vector<std::size_t> key_pos_rhs(ext::range_iterator<std::size_t>{0},
ext::range_iterator<std::size_t>{key_array_rhs.size()});
std::sort(std::begin(key_pos_rhs), std::end(key_pos_rhs), [&] (const auto pos1, const auto pos2) {
return key_array_rhs[pos1] < key_array_rhs[pos2];
});
/// max size after merge estimation
const auto max_size = key_pos_lhs.size() + key_pos_rhs.size();
/// create arrays with a single element (it will be overwritten on first iteration)
Array key_array_result(1);
key_array_result.reserve(max_size);
std::vector<Array> val_arrays_result(val_count, Array(1));
for (auto & val_array_result : val_arrays_result)
val_array_result.reserve(max_size);
/// discard first element
auto discard_prev = true;
/// either insert or merge new element
const auto insert_or_sum = [&] (std::size_t & index, const std::vector<std::size_t> & key_pos,
const auto & key_array, auto && val_getter) {
const auto pos = key_pos[index++];
const auto & key = key_array[pos];
if (discard_prev)
{
discard_prev = false;
key_array_result.back() = key;
for (const auto val_index : ext::range(0, val_count))
val_arrays_result[val_index].back() = val_getter(val_index, pos);
}
else if (key_array_result.back() == key)
{
/// merge with same key
auto should_discard = true;
for (const auto val_index : ext::range(0, val_count))
if (apply_visitor(FieldVisitorSum{val_getter(val_index, pos)},
val_arrays_result[val_index].back()))
should_discard = false;
discard_prev = should_discard;
}
else
{
/// append new key
key_array_result.emplace_back(key);
for (const auto val_index : ext::range(0, val_count))
val_arrays_result[val_index].emplace_back(val_getter(val_index, pos));
}
};
std::size_t index_lhs = 0;
std::size_t index_rhs = 0;
/// perform 2-way merge
while (true)
if (index_lhs < key_pos_lhs.size() && index_rhs == key_pos_rhs.size())
insert_or_sum(index_lhs, key_pos_lhs, key_array_lhs, val_getter_lhs);
else if (index_lhs == key_pos_lhs.size() && index_rhs < key_pos_rhs.size())
insert_or_sum(index_rhs, key_pos_rhs, key_array_rhs, val_getter_rhs);
else if (index_lhs < key_pos_lhs.size() && index_rhs < key_pos_rhs.size())
if (key_array_lhs[key_pos_lhs[index_lhs]] < key_array_rhs[key_pos_rhs[index_rhs]])
insert_or_sum(index_lhs, key_pos_lhs, key_array_lhs, val_getter_lhs);
else
insert_or_sum(index_rhs, key_pos_rhs, key_array_rhs, val_getter_rhs);
else
break;
/// discard last row if necessary
if (discard_prev)
key_array_result.pop_back();
/// store results into accumulator-row
key_array_lhs = std::move(key_array_result);
for (const auto val_col_index : ext::range(0, val_count))
{
/// discard last row if necessary
if (discard_prev)
val_arrays_result[val_col_index].pop_back();
row[map.val_col_nums[val_col_index]].get<Array>() = std::move(val_arrays_result[val_col_index]);
}
if (!key_array_lhs.empty())
non_empty_map_present = true;
}
return non_empty_map_present;
}
/** Прибавить строчку под курсором к row. /** Прибавить строчку под курсором к row.
* Возвращает false, если результат получился нулевым. * Возвращает false, если результат получился нулевым.
*/ */
template <class TSortCursor> template <class TSortCursor>
bool addRow(Row & row, TSortCursor & cursor) bool addRow(Row & row, TSortCursor & cursor);
{
bool res = mergeMaps(row, cursor); /// Есть ли хотя бы одно ненулевое число или непустой массив
for (size_t i = 0, size = column_numbers_to_sum.size(); i < size; ++i)
{
size_t j = column_numbers_to_sum[i];
if (apply_visitor(FieldVisitorSum((*cursor->all_columns[j])[cursor->pos]), row[j]))
res = true;
}
return res;
}
}; };
} }

View File

@ -31,6 +31,9 @@ public:
void setTotals(const Block & totals_) override { totals = totals_; } void setTotals(const Block & totals_) override { totals = totals_; }
void setExtremes(const Block & extremes_) override { extremes = extremes_; } void setExtremes(const Block & extremes_) override { extremes = extremes_; }
/// https://www.iana.org/assignments/media-types/text/tab-separated-values
String getContentType() const override { return "text/tab-separated-values; charset=UTF-8"; }
protected: protected:
void writeTotals(); void writeTotals();
void writeExtremes(); void writeExtremes();

View File

@ -271,6 +271,10 @@ private:
parent.output_queue.push(Payload()); parent.output_queue.push(Payload());
} }
void onFinishThread(size_t thread_num)
{
}
void onException(std::exception_ptr & exception, size_t thread_num) void onException(std::exception_ptr & exception, size_t thread_num)
{ {
//std::cerr << "pushing exception\n"; //std::cerr << "pushing exception\n";

View File

@ -4,9 +4,13 @@
#include <DB/Dictionaries/IDictionarySource.h> #include <DB/Dictionaries/IDictionarySource.h>
#include <DB/Dictionaries/DictionaryStructure.h> #include <DB/Dictionaries/DictionaryStructure.h>
#include <DB/Common/HashTable/HashMap.h> #include <DB/Common/HashTable/HashMap.h>
#include <DB/Common/ArenaWithFreeLists.h>
#include <DB/Columns/ColumnString.h> #include <DB/Columns/ColumnString.h>
#include <DB/Common/HashTable/HashMap.h>
#include <ext/scope_guard.hpp> #include <ext/scope_guard.hpp>
#include <ext/bit_cast.hpp>
#include <ext/range.hpp>
#include <ext/size.hpp>
#include <ext/map.hpp>
#include <Poco/RWLock.h> #include <Poco/RWLock.h>
#include <cmath> #include <cmath>
#include <atomic> #include <atomic>
@ -15,6 +19,7 @@
#include <map> #include <map>
#include <tuple> #include <tuple>
namespace DB namespace DB
{ {
@ -48,7 +53,7 @@ public:
std::string getTypeName() const override { return "Cache"; } std::string getTypeName() const override { return "Cache"; }
std::size_t getBytesAllocated() const override { return bytes_allocated; } std::size_t getBytesAllocated() const override { return bytes_allocated + (string_arena ? string_arena->size() : 0); }
std::size_t getQueryCount() const override { return query_count.load(std::memory_order_relaxed); } std::size_t getQueryCount() const override { return query_count.load(std::memory_order_relaxed); }
@ -89,11 +94,13 @@ public:
void toParent(const PODArray<id_t> & ids, PODArray<id_t> & out) const override void toParent(const PODArray<id_t> & ids, PODArray<id_t> & out) const override
{ {
getItems<UInt64>(*hierarchical_attribute, ids, out); const auto null_value = std::get<UInt64>(hierarchical_attribute->null_values);
getItems<UInt64>(*hierarchical_attribute, ids, out, [&] (const std::size_t) { return null_value; });
} }
#define DECLARE_MULTIPLE_GETTER(TYPE)\ #define DECLARE(TYPE)\
void get##TYPE(const std::string & attribute_name, const PODArray<id_t> & ids, PODArray<TYPE> & out) const override\ void get##TYPE(const std::string & attribute_name, const PODArray<id_t> & ids, PODArray<TYPE> & out) const\
{\ {\
auto & attribute = getAttribute(attribute_name);\ auto & attribute = getAttribute(attribute_name);\
if (attribute.type != AttributeUnderlyingType::TYPE)\ if (attribute.type != AttributeUnderlyingType::TYPE)\
@ -102,20 +109,22 @@ public:
ErrorCodes::TYPE_MISMATCH\ ErrorCodes::TYPE_MISMATCH\
};\ };\
\ \
getItems<TYPE>(attribute, ids, out);\ const auto null_value = std::get<TYPE>(attribute.null_values);\
\
getItems<TYPE>(attribute, ids, out, [&] (const std::size_t) { return null_value; });\
} }
DECLARE_MULTIPLE_GETTER(UInt8) DECLARE(UInt8)
DECLARE_MULTIPLE_GETTER(UInt16) DECLARE(UInt16)
DECLARE_MULTIPLE_GETTER(UInt32) DECLARE(UInt32)
DECLARE_MULTIPLE_GETTER(UInt64) DECLARE(UInt64)
DECLARE_MULTIPLE_GETTER(Int8) DECLARE(Int8)
DECLARE_MULTIPLE_GETTER(Int16) DECLARE(Int16)
DECLARE_MULTIPLE_GETTER(Int32) DECLARE(Int32)
DECLARE_MULTIPLE_GETTER(Int64) DECLARE(Int64)
DECLARE_MULTIPLE_GETTER(Float32) DECLARE(Float32)
DECLARE_MULTIPLE_GETTER(Float64) DECLARE(Float64)
#undef DECLARE_MULTIPLE_GETTER #undef DECLARE
void getString(const std::string & attribute_name, const PODArray<id_t> & ids, ColumnString * out) const override void getString(const std::string & attribute_name, const PODArray<id_t> & ids, ColumnString * out) const
{ {
auto & attribute = getAttribute(attribute_name); auto & attribute = getAttribute(attribute_name);
if (attribute.type != AttributeUnderlyingType::String) if (attribute.type != AttributeUnderlyingType::String)
@ -124,34 +133,175 @@ public:
ErrorCodes::TYPE_MISMATCH ErrorCodes::TYPE_MISMATCH
}; };
getItems(attribute, ids, out); const auto null_value = StringRef{std::get<String>(attribute.null_values)};
getItems(attribute, ids, out, [&] (const std::size_t) { return null_value; });
}
#define DECLARE(TYPE)\
void get##TYPE(\
const std::string & attribute_name, const PODArray<id_t> & ids, const PODArray<TYPE> & def,\
PODArray<TYPE> & out) const\
{\
auto & attribute = getAttribute(attribute_name);\
if (attribute.type != AttributeUnderlyingType::TYPE)\
throw Exception{\
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),\
ErrorCodes::TYPE_MISMATCH\
};\
\
getItems<TYPE>(attribute, ids, out, [&] (const std::size_t row) { return def[row]; });\
}
DECLARE(UInt8)
DECLARE(UInt16)
DECLARE(UInt32)
DECLARE(UInt64)
DECLARE(Int8)
DECLARE(Int16)
DECLARE(Int32)
DECLARE(Int64)
DECLARE(Float32)
DECLARE(Float64)
#undef DECLARE
void getString(
const std::string & attribute_name, const PODArray<id_t> & ids, const ColumnString * const def,
ColumnString * const out) const
{
auto & attribute = getAttribute(attribute_name);
if (attribute.type != AttributeUnderlyingType::String)
throw Exception{
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH
};
getItems(attribute, ids, out, [&] (const std::size_t row) { return def->getDataAt(row); });
}
#define DECLARE(TYPE)\
void get##TYPE(\
const std::string & attribute_name, const PODArray<id_t> & ids, const TYPE def, PODArray<TYPE> & out) const\
{\
auto & attribute = getAttribute(attribute_name);\
if (attribute.type != AttributeUnderlyingType::TYPE)\
throw Exception{\
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),\
ErrorCodes::TYPE_MISMATCH\
};\
\
getItems<TYPE>(attribute, ids, out, [&] (const std::size_t) { return def; });\
}
DECLARE(UInt8)
DECLARE(UInt16)
DECLARE(UInt32)
DECLARE(UInt64)
DECLARE(Int8)
DECLARE(Int16)
DECLARE(Int32)
DECLARE(Int64)
DECLARE(Float32)
DECLARE(Float64)
#undef DECLARE
void getString(
const std::string & attribute_name, const PODArray<id_t> & ids, const String & def,
ColumnString * const out) const
{
auto & attribute = getAttribute(attribute_name);
if (attribute.type != AttributeUnderlyingType::String)
throw Exception{
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH
};
getItems(attribute, ids, out, [&] (const std::size_t) { return StringRef{def}; });
}
void has(const PODArray<id_t> & ids, PODArray<UInt8> & out) const override
{
/// Mapping: <id> -> { all indices `i` of `ids` such that `ids[i]` = <id> }
MapType<std::vector<std::size_t>> outdated_ids;
const auto rows = ext::size(ids);
{
const Poco::ScopedReadRWLock read_lock{rw_lock};
const auto now = std::chrono::system_clock::now();
/// fetch up-to-date values, decide which ones require update
for (const auto row : ext::range(0, rows))
{
const auto id = ids[row];
const auto cell_idx = getCellIdx(id);
const auto & cell = cells[cell_idx];
/** cell should be updated if either:
* 1. ids do not match,
* 2. cell has expired,
* 3. explicit defaults were specified and cell was set default. */
if (cell.id != id || cell.expiresAt() < now)
outdated_ids[id].push_back(row);
else
out[row] = !cell.isDefault();
}
}
query_count.fetch_add(rows, std::memory_order_relaxed);
hit_count.fetch_add(rows - outdated_ids.size(), std::memory_order_release);
if (outdated_ids.empty())
return;
std::vector<id_t> required_ids(outdated_ids.size());
std::transform(std::begin(outdated_ids), std::end(outdated_ids), std::begin(required_ids),
[] (auto & pair) { return pair.first; });
/// request new values
update(required_ids, [&] (const auto id, const auto) {
for (const auto row : outdated_ids[id])
out[row] = true;
}, [&] (const auto id, const auto) {
for (const auto row : outdated_ids[id])
out[row] = false;
});
} }
private: private:
template <typename Value> using MapType = HashMap<id_t, Value>;
template <typename Value> using ContainerType = Value[];
template <typename Value> using ContainerPtrType = std::unique_ptr<ContainerType<Value>>;
struct cell_metadata_t final struct cell_metadata_t final
{ {
using time_point_t = std::chrono::system_clock::time_point;
using time_point_rep_t = time_point_t::rep;
using time_point_urep_t = std::make_unsigned_t<time_point_rep_t>;
static constexpr std::uint64_t EXPIRES_AT_MASK = std::numeric_limits<time_point_rep_t>::max();
static constexpr std::uint64_t IS_DEFAULT_MASK = ~EXPIRES_AT_MASK;
std::uint64_t id; std::uint64_t id;
std::chrono::system_clock::time_point expires_at; /// Stores both expiration time and `is_default` flag in the most significant bit
time_point_urep_t data;
/// Sets expiration time, resets `is_default` flag to false
time_point_t expiresAt() const { return ext::safe_bit_cast<time_point_t>(data & EXPIRES_AT_MASK); }
void setExpiresAt(const time_point_t & t) { data = ext::safe_bit_cast<time_point_urep_t>(t); }
bool isDefault() const { return (data & IS_DEFAULT_MASK) == IS_DEFAULT_MASK; }
void setDefault() { data |= IS_DEFAULT_MASK; }
}; };
struct attribute_t final struct attribute_t final
{ {
AttributeUnderlyingType type; AttributeUnderlyingType type;
std::tuple<UInt8, UInt16, UInt32, UInt64, std::tuple<
UInt8, UInt16, UInt32, UInt64,
Int8, Int16, Int32, Int64, Int8, Int16, Int32, Int64,
Float32, Float64, Float32, Float64,
String> null_values; String> null_values;
std::tuple<std::unique_ptr<UInt8[]>, std::tuple<
std::unique_ptr<UInt16[]>, ContainerPtrType<UInt8>, ContainerPtrType<UInt16>, ContainerPtrType<UInt32>, ContainerPtrType<UInt64>,
std::unique_ptr<UInt32[]>, ContainerPtrType<Int8>, ContainerPtrType<Int16>, ContainerPtrType<Int32>, ContainerPtrType<Int64>,
std::unique_ptr<UInt64[]>, ContainerPtrType<Float32>, ContainerPtrType<Float64>,
std::unique_ptr<Int8[]>, ContainerPtrType<StringRef>> arrays;
std::unique_ptr<Int16[]>,
std::unique_ptr<Int32[]>,
std::unique_ptr<Int64[]>,
std::unique_ptr<Float32[]>,
std::unique_ptr<Float64[]>,
std::unique_ptr<StringRef[]>> arrays;
}; };
void createAttributes() void createAttributes()
@ -188,123 +338,129 @@ private:
{ {
case AttributeUnderlyingType::UInt8: case AttributeUnderlyingType::UInt8:
std::get<UInt8>(attr.null_values) = null_value.get<UInt64>(); std::get<UInt8>(attr.null_values) = null_value.get<UInt64>();
std::get<std::unique_ptr<UInt8[]>>(attr.arrays) = std::make_unique<UInt8[]>(size); std::get<ContainerPtrType<UInt8>>(attr.arrays) = std::make_unique<ContainerType<UInt8>>(size);
bytes_allocated += size * sizeof(UInt8); bytes_allocated += size * sizeof(UInt8);
break; break;
case AttributeUnderlyingType::UInt16: case AttributeUnderlyingType::UInt16:
std::get<UInt16>(attr.null_values) = null_value.get<UInt64>(); std::get<UInt16>(attr.null_values) = null_value.get<UInt64>();
std::get<std::unique_ptr<UInt16[]>>(attr.arrays) = std::make_unique<UInt16[]>(size); std::get<ContainerPtrType<UInt16>>(attr.arrays) = std::make_unique<ContainerType<UInt16>>(size);
bytes_allocated += size * sizeof(UInt16); bytes_allocated += size * sizeof(UInt16);
break; break;
case AttributeUnderlyingType::UInt32: case AttributeUnderlyingType::UInt32:
std::get<UInt32>(attr.null_values) = null_value.get<UInt64>(); std::get<UInt32>(attr.null_values) = null_value.get<UInt64>();
std::get<std::unique_ptr<UInt32[]>>(attr.arrays) = std::make_unique<UInt32[]>(size); std::get<ContainerPtrType<UInt32>>(attr.arrays) = std::make_unique<ContainerType<UInt32>>(size);
bytes_allocated += size * sizeof(UInt32); bytes_allocated += size * sizeof(UInt32);
break; break;
case AttributeUnderlyingType::UInt64: case AttributeUnderlyingType::UInt64:
std::get<UInt64>(attr.null_values) = null_value.get<UInt64>(); std::get<UInt64>(attr.null_values) = null_value.get<UInt64>();
std::get<std::unique_ptr<UInt64[]>>(attr.arrays) = std::make_unique<UInt64[]>(size); std::get<ContainerPtrType<UInt64>>(attr.arrays) = std::make_unique<ContainerType<UInt64>>(size);
bytes_allocated += size * sizeof(UInt64); bytes_allocated += size * sizeof(UInt64);
break; break;
case AttributeUnderlyingType::Int8: case AttributeUnderlyingType::Int8:
std::get<Int8>(attr.null_values) = null_value.get<Int64>(); std::get<Int8>(attr.null_values) = null_value.get<Int64>();
std::get<std::unique_ptr<Int8[]>>(attr.arrays) = std::make_unique<Int8[]>(size); std::get<ContainerPtrType<Int8>>(attr.arrays) = std::make_unique<ContainerType<Int8>>(size);
bytes_allocated += size * sizeof(Int8); bytes_allocated += size * sizeof(Int8);
break; break;
case AttributeUnderlyingType::Int16: case AttributeUnderlyingType::Int16:
std::get<Int16>(attr.null_values) = null_value.get<Int64>(); std::get<Int16>(attr.null_values) = null_value.get<Int64>();
std::get<std::unique_ptr<Int16[]>>(attr.arrays) = std::make_unique<Int16[]>(size); std::get<ContainerPtrType<Int16>>(attr.arrays) = std::make_unique<ContainerType<Int16>>(size);
bytes_allocated += size * sizeof(Int16); bytes_allocated += size * sizeof(Int16);
break; break;
case AttributeUnderlyingType::Int32: case AttributeUnderlyingType::Int32:
std::get<Int32>(attr.null_values) = null_value.get<Int64>(); std::get<Int32>(attr.null_values) = null_value.get<Int64>();
std::get<std::unique_ptr<Int32[]>>(attr.arrays) = std::make_unique<Int32[]>(size); std::get<ContainerPtrType<Int32>>(attr.arrays) = std::make_unique<ContainerType<Int32>>(size);
bytes_allocated += size * sizeof(Int32); bytes_allocated += size * sizeof(Int32);
break; break;
case AttributeUnderlyingType::Int64: case AttributeUnderlyingType::Int64:
std::get<Int64>(attr.null_values) = null_value.get<Int64>(); std::get<Int64>(attr.null_values) = null_value.get<Int64>();
std::get<std::unique_ptr<Int64[]>>(attr.arrays) = std::make_unique<Int64[]>(size); std::get<ContainerPtrType<Int64>>(attr.arrays) = std::make_unique<ContainerType<Int64>>(size);
bytes_allocated += size * sizeof(Int64); bytes_allocated += size * sizeof(Int64);
break; break;
case AttributeUnderlyingType::Float32: case AttributeUnderlyingType::Float32:
std::get<Float32>(attr.null_values) = null_value.get<Float64>(); std::get<Float32>(attr.null_values) = null_value.get<Float64>();
std::get<std::unique_ptr<Float32[]>>(attr.arrays) = std::make_unique<Float32[]>(size); std::get<ContainerPtrType<Float32>>(attr.arrays) = std::make_unique<ContainerType<Float32>>(size);
bytes_allocated += size * sizeof(Float32); bytes_allocated += size * sizeof(Float32);
break; break;
case AttributeUnderlyingType::Float64: case AttributeUnderlyingType::Float64:
std::get<Float64>(attr.null_values) = null_value.get<Float64>(); std::get<Float64>(attr.null_values) = null_value.get<Float64>();
std::get<std::unique_ptr<Float64[]>>(attr.arrays) = std::make_unique<Float64[]>(size); std::get<ContainerPtrType<Float64>>(attr.arrays) = std::make_unique<ContainerType<Float64>>(size);
bytes_allocated += size * sizeof(Float64); bytes_allocated += size * sizeof(Float64);
break; break;
case AttributeUnderlyingType::String: case AttributeUnderlyingType::String:
std::get<String>(attr.null_values) = null_value.get<String>(); std::get<String>(attr.null_values) = null_value.get<String>();
std::get<std::unique_ptr<StringRef[]>>(attr.arrays) = std::make_unique<StringRef[]>(size); std::get<ContainerPtrType<StringRef>>(attr.arrays) = std::make_unique<ContainerType<StringRef>>(size);
bytes_allocated += size * sizeof(StringRef); bytes_allocated += size * sizeof(StringRef);
if (!string_arena)
string_arena = std::make_unique<ArenaWithFreeLists>();
break; break;
} }
return attr; return attr;
} }
template <typename T> template <typename T, typename DefaultGetter>
void getItems(attribute_t & attribute, const PODArray<id_t> & ids, PODArray<T> & out) const void getItems(
attribute_t & attribute, const PODArray<id_t> & ids, PODArray<T> & out, DefaultGetter && get_default) const
{ {
HashMap<id_t, std::vector<std::size_t>> outdated_ids; /// Mapping: <id> -> { all indices `i` of `ids` such that `ids[i]` = <id> }
auto & attribute_array = std::get<std::unique_ptr<T[]>>(attribute.arrays); MapType<std::vector<std::size_t>> outdated_ids;
auto & attribute_array = std::get<ContainerPtrType<T>>(attribute.arrays);
const auto rows = ext::size(ids);
{ {
const Poco::ScopedReadRWLock read_lock{rw_lock}; const Poco::ScopedReadRWLock read_lock{rw_lock};
const auto now = std::chrono::system_clock::now(); const auto now = std::chrono::system_clock::now();
/// fetch up-to-date values, decide which ones require update /// fetch up-to-date values, decide which ones require update
for (const auto i : ext::range(0, ids.size())) for (const auto row : ext::range(0, rows))
{ {
const auto id = ids[i]; const auto id = ids[row];
if (id == 0)
{
out[i] = std::get<T>(attribute.null_values);
continue;
}
const auto cell_idx = getCellIdx(id); const auto cell_idx = getCellIdx(id);
const auto & cell = cells[cell_idx]; const auto & cell = cells[cell_idx];
if (cell.id != id || cell.expires_at < now) /** cell should be updated if either:
{ * 1. ids do not match,
out[i] = std::get<T>(attribute.null_values); * 2. cell has expired,
outdated_ids[id].push_back(i); * 3. explicit defaults were specified and cell was set default. */
} if (cell.id != id || cell.expiresAt() < now)
outdated_ids[id].push_back(row);
else else
out[i] = attribute_array[cell_idx]; out[row] = cell.isDefault() ? get_default(row) : attribute_array[cell_idx];
} }
} }
query_count.fetch_add(ids.size(), std::memory_order_relaxed); query_count.fetch_add(rows, std::memory_order_relaxed);
hit_count.fetch_add(ids.size() - outdated_ids.size(), std::memory_order_release); hit_count.fetch_add(rows - outdated_ids.size(), std::memory_order_release);
if (outdated_ids.empty()) if (outdated_ids.empty())
return; return;
/// request new values
std::vector<id_t> required_ids(outdated_ids.size()); std::vector<id_t> required_ids(outdated_ids.size());
std::transform(std::begin(outdated_ids), std::end(outdated_ids), std::begin(required_ids), std::transform(std::begin(outdated_ids), std::end(outdated_ids), std::begin(required_ids),
[] (auto & pair) { return pair.first; }); [] (auto & pair) { return pair.first; });
/// request new values
update(required_ids, [&] (const auto id, const auto cell_idx) { update(required_ids, [&] (const auto id, const auto cell_idx) {
const auto attribute_value = attribute_array[cell_idx]; const auto attribute_value = attribute_array[cell_idx];
/// set missing values to out for (const auto row : outdated_ids[id])
for (const auto out_idx : outdated_ids[id]) out[row] = attribute_value;
out[out_idx] = attribute_value; }, [&] (const auto id, const auto cell_idx) {
for (const auto row : outdated_ids[id])
out[row] = get_default(row);
}); });
} }
void getItems(attribute_t & attribute, const PODArray<id_t> & ids, ColumnString * out) const template <typename DefaultGetter>
void getItems(
attribute_t & attribute, const PODArray<id_t> & ids, ColumnString * out, DefaultGetter && get_default) const
{ {
/// save on some allocations const auto rows = ext::size(ids);
out->getOffsets().reserve(ids.size());
auto & attribute_array = std::get<std::unique_ptr<StringRef[]>>(attribute.arrays); /// save on some allocations
out->getOffsets().reserve(rows);
auto & attribute_array = std::get<ContainerPtrType<StringRef>>(attribute.arrays);
auto found_outdated_values = false; auto found_outdated_values = false;
@ -314,20 +470,20 @@ private:
const auto now = std::chrono::system_clock::now(); const auto now = std::chrono::system_clock::now();
/// fetch up-to-date values, discard on fail /// fetch up-to-date values, discard on fail
for (const auto i : ext::range(0, ids.size())) for (const auto row : ext::range(0, rows))
{ {
const auto id = ids[i]; const auto id = ids[row];
const auto cell_idx = getCellIdx(id); const auto cell_idx = getCellIdx(id);
const auto & cell = cells[cell_idx]; const auto & cell = cells[cell_idx];
if (cell.id != id || cell.expires_at < now) if (cell.id != id || cell.expiresAt() < now)
{ {
found_outdated_values = true; found_outdated_values = true;
break; break;
} }
else else
{ {
const auto string_ref = attribute_array[cell_idx]; const auto string_ref = cell.isDefault() ? get_default(row) : attribute_array[cell_idx];
out->insertData(string_ref.data, string_ref.size); out->insertData(string_ref.data, string_ref.size);
} }
} }
@ -336,44 +492,47 @@ private:
/// optimistic code completed successfully /// optimistic code completed successfully
if (!found_outdated_values) if (!found_outdated_values)
{ {
query_count.fetch_add(ids.size(), std::memory_order_relaxed); query_count.fetch_add(rows, std::memory_order_relaxed);
hit_count.fetch_add(ids.size(), std::memory_order_release); hit_count.fetch_add(rows, std::memory_order_release);
return; return;
} }
/// now onto the pessimistic one, discard possibly partial results from the optimistic path /// now onto the pessimistic one, discard possible partial results from the optimistic path
out->getChars().resize_assume_reserved(0); out->getChars().resize_assume_reserved(0);
out->getOffsets().resize_assume_reserved(0); out->getOffsets().resize_assume_reserved(0);
/// outdated ids joined number of times they've been requested /// Mapping: <id> -> { all indices `i` of `ids` such that `ids[i]` = <id> }
HashMap<id_t, std::size_t> outdated_ids; MapType<std::vector<std::size_t>> outdated_ids;
/// we are going to store every string separately /// we are going to store every string separately
HashMap<id_t, String> map; MapType<String> map;
std::size_t total_length = 0; std::size_t total_length = 0;
{ {
const Poco::ScopedReadRWLock read_lock{rw_lock}; const Poco::ScopedReadRWLock read_lock{rw_lock};
const auto now = std::chrono::system_clock::now(); const auto now = std::chrono::system_clock::now();
for (const auto i : ext::range(0, ids.size())) for (const auto row : ext::range(0, ids.size()))
{ {
const auto id = ids[i]; const auto id = ids[row];
const auto cell_idx = getCellIdx(id); const auto cell_idx = getCellIdx(id);
const auto & cell = cells[cell_idx]; const auto & cell = cells[cell_idx];
if (cell.id != id || cell.expires_at < now) if (cell.id != id || cell.expiresAt() < now)
outdated_ids[id] += 1; outdated_ids[id].push_back(row);
else else
{ {
const auto string_ref = attribute_array[cell_idx]; const auto string_ref = cell.isDefault() ? get_default(row) : attribute_array[cell_idx];
if (!cell.isDefault())
map[id] = String{string_ref}; map[id] = String{string_ref};
total_length += string_ref.size + 1; total_length += string_ref.size + 1;
} }
} }
} }
query_count.fetch_add(ids.size(), std::memory_order_relaxed); query_count.fetch_add(rows, std::memory_order_relaxed);
hit_count.fetch_add(ids.size() - outdated_ids.size(), std::memory_order_release); hit_count.fetch_add(rows - outdated_ids.size(), std::memory_order_release);
/// request new values /// request new values
if (!outdated_ids.empty()) if (!outdated_ids.empty())
@ -386,28 +545,35 @@ private:
const auto attribute_value = attribute_array[cell_idx]; const auto attribute_value = attribute_array[cell_idx];
map[id] = String{attribute_value}; map[id] = String{attribute_value};
total_length += (attribute_value.size + 1) * outdated_ids[id]; total_length += (attribute_value.size + 1) * outdated_ids[id].size();
}, [&] (const auto id, const auto cell_idx) {
for (const auto row : outdated_ids[id])
total_length += get_default(row).size + 1;
}); });
} }
out->getChars().reserve(total_length); out->getChars().reserve(total_length);
for (const auto id : ids) for (const auto row : ext::range(0, ext::size(ids)))
{ {
const auto id = ids[row];
const auto it = map.find(id); const auto it = map.find(id);
const auto string = it != map.end() ? it->second : std::get<String>(attribute.null_values);
out->insertData(string.data(), string.size()); const auto string_ref = it != std::end(map) ? StringRef{it->second} : get_default(row);
out->insertData(string_ref.data, string_ref.size);
} }
} }
template <typename F> template <typename PresentIdHandler, typename AbsentIdHandler>
void update(const std::vector<id_t> ids, F && on_cell_updated) const void update(
const std::vector<id_t> & requested_ids, PresentIdHandler && on_cell_updated,
AbsentIdHandler && on_id_not_found) const
{ {
auto stream = source_ptr->loadIds(ids); auto stream = source_ptr->loadIds(requested_ids);
stream->readPrefix(); stream->readPrefix();
HashMap<UInt64, UInt8> remaining_ids{ids.size()}; MapType<UInt8> remaining_ids{requested_ids.size()};
for (const auto id : ids) for (const auto id : requested_ids)
remaining_ids.insert({ id, 0 }); remaining_ids.insert({ id, 0 });
std::uniform_int_distribution<std::uint64_t> distribution{ std::uniform_int_distribution<std::uint64_t> distribution{
@ -429,9 +595,9 @@ private:
const auto & ids = id_column->getData(); const auto & ids = id_column->getData();
/// cache column pointers /// cache column pointers
std::vector<const IColumn *> column_ptrs(attributes.size()); const auto column_ptrs = ext::map<std::vector>(ext::range(0, attributes.size()), [&block] (const auto & i) {
for (const auto i : ext::range(0, attributes.size())) return block.getByPosition(i + 1).column.get();
column_ptrs[i] = block.getByPosition(i + 1).column.get(); });
for (const auto i : ext::range(0, ids.size())) for (const auto i : ext::range(0, ids.size()))
{ {
@ -453,17 +619,20 @@ private:
cell.id = id; cell.id = id;
if (dict_lifetime.min_sec != 0 && dict_lifetime.max_sec != 0) if (dict_lifetime.min_sec != 0 && dict_lifetime.max_sec != 0)
cell.expires_at = std::chrono::system_clock::now() + std::chrono::seconds{distribution(rnd_engine)}; cell.setExpiresAt(std::chrono::system_clock::now() + std::chrono::seconds{distribution(rnd_engine)});
else else
cell.expires_at = std::chrono::time_point<std::chrono::system_clock>::max(); cell.setExpiresAt(std::chrono::time_point<std::chrono::system_clock>::max());
/// inform caller
on_cell_updated(id, cell_idx); on_cell_updated(id, cell_idx);
/// mark corresponding id as found
remaining_ids[id] = 1; remaining_ids[id] = 1;
} }
} }
stream->readSuffix(); stream->readSuffix();
/// Check which ids have not been found and require setting null_value
for (const auto id_found_pair : remaining_ids) for (const auto id_found_pair : remaining_ids)
{ {
if (id_found_pair.second) if (id_found_pair.second)
@ -473,19 +642,24 @@ private:
const auto cell_idx = getCellIdx(id); const auto cell_idx = getCellIdx(id);
auto & cell = cells[cell_idx]; auto & cell = cells[cell_idx];
/// Set null_value for each attribute
for (auto & attribute : attributes) for (auto & attribute : attributes)
setDefaultAttributeValue(attribute, cell_idx); setDefaultAttributeValue(attribute, cell_idx);
/// Check if cell had not been occupied before and increment element counter if it hadn't
if (cell.id == 0 && cell_idx != zero_cell_idx) if (cell.id == 0 && cell_idx != zero_cell_idx)
element_count.fetch_add(1, std::memory_order_relaxed); element_count.fetch_add(1, std::memory_order_relaxed);
cell.id = id; cell.id = id;
if (dict_lifetime.min_sec != 0 && dict_lifetime.max_sec != 0) if (dict_lifetime.min_sec != 0 && dict_lifetime.max_sec != 0)
cell.expires_at = std::chrono::system_clock::now() + std::chrono::seconds{distribution(rnd_engine)}; cell.setExpiresAt(std::chrono::system_clock::now() + std::chrono::seconds{distribution(rnd_engine)});
else else
cell.expires_at = std::chrono::time_point<std::chrono::system_clock>::max(); cell.setExpiresAt(std::chrono::time_point<std::chrono::system_clock>::max());
on_cell_updated(id, cell_idx); cell.setDefault();
/// inform caller that the cell has not been found
on_id_not_found(id, cell_idx);
} }
} }
@ -500,28 +674,28 @@ private:
{ {
switch (attribute.type) switch (attribute.type)
{ {
case AttributeUnderlyingType::UInt8: std::get<std::unique_ptr<UInt8[]>>(attribute.arrays)[idx] = std::get<UInt8>(attribute.null_values); break; case AttributeUnderlyingType::UInt8: std::get<ContainerPtrType<UInt8>>(attribute.arrays)[idx] = std::get<UInt8>(attribute.null_values); break;
case AttributeUnderlyingType::UInt16: std::get<std::unique_ptr<UInt16[]>>(attribute.arrays)[idx] = std::get<UInt16>(attribute.null_values); break; case AttributeUnderlyingType::UInt16: std::get<ContainerPtrType<UInt16>>(attribute.arrays)[idx] = std::get<UInt16>(attribute.null_values); break;
case AttributeUnderlyingType::UInt32: std::get<std::unique_ptr<UInt32[]>>(attribute.arrays)[idx] = std::get<UInt32>(attribute.null_values); break; case AttributeUnderlyingType::UInt32: std::get<ContainerPtrType<UInt32>>(attribute.arrays)[idx] = std::get<UInt32>(attribute.null_values); break;
case AttributeUnderlyingType::UInt64: std::get<std::unique_ptr<UInt64[]>>(attribute.arrays)[idx] = std::get<UInt64>(attribute.null_values); break; case AttributeUnderlyingType::UInt64: std::get<ContainerPtrType<UInt64>>(attribute.arrays)[idx] = std::get<UInt64>(attribute.null_values); break;
case AttributeUnderlyingType::Int8: std::get<std::unique_ptr<Int8[]>>(attribute.arrays)[idx] = std::get<Int8>(attribute.null_values); break; case AttributeUnderlyingType::Int8: std::get<ContainerPtrType<Int8>>(attribute.arrays)[idx] = std::get<Int8>(attribute.null_values); break;
case AttributeUnderlyingType::Int16: std::get<std::unique_ptr<Int16[]>>(attribute.arrays)[idx] = std::get<Int16>(attribute.null_values); break; case AttributeUnderlyingType::Int16: std::get<ContainerPtrType<Int16>>(attribute.arrays)[idx] = std::get<Int16>(attribute.null_values); break;
case AttributeUnderlyingType::Int32: std::get<std::unique_ptr<Int32[]>>(attribute.arrays)[idx] = std::get<Int32>(attribute.null_values); break; case AttributeUnderlyingType::Int32: std::get<ContainerPtrType<Int32>>(attribute.arrays)[idx] = std::get<Int32>(attribute.null_values); break;
case AttributeUnderlyingType::Int64: std::get<std::unique_ptr<Int64[]>>(attribute.arrays)[idx] = std::get<Int64>(attribute.null_values); break; case AttributeUnderlyingType::Int64: std::get<ContainerPtrType<Int64>>(attribute.arrays)[idx] = std::get<Int64>(attribute.null_values); break;
case AttributeUnderlyingType::Float32: std::get<std::unique_ptr<Float32[]>>(attribute.arrays)[idx] = std::get<Float32>(attribute.null_values); break; case AttributeUnderlyingType::Float32: std::get<ContainerPtrType<Float32>>(attribute.arrays)[idx] = std::get<Float32>(attribute.null_values); break;
case AttributeUnderlyingType::Float64: std::get<std::unique_ptr<Float64[]>>(attribute.arrays)[idx] = std::get<Float64>(attribute.null_values); break; case AttributeUnderlyingType::Float64: std::get<ContainerPtrType<Float64>>(attribute.arrays)[idx] = std::get<Float64>(attribute.null_values); break;
case AttributeUnderlyingType::String: case AttributeUnderlyingType::String:
{ {
const auto & null_value_ref = std::get<String>(attribute.null_values); const auto & null_value_ref = std::get<String>(attribute.null_values);
auto & string_ref = std::get<std::unique_ptr<StringRef[]>>(attribute.arrays)[idx]; auto & string_ref = std::get<ContainerPtrType<StringRef>>(attribute.arrays)[idx];
if (string_ref.data == null_value_ref.data())
return;
if (string_ref.size != 0) if (string_ref.data != null_value_ref.data())
bytes_allocated -= string_ref.size + 1; {
const std::unique_ptr<const char[]> deleter{string_ref.data}; if (string_ref.data)
string_arena->free(string_ref.data, string_ref.size);
string_ref = StringRef{null_value_ref}; string_ref = StringRef{null_value_ref};
}
break; break;
} }
@ -532,36 +706,32 @@ private:
{ {
switch (attribute.type) switch (attribute.type)
{ {
case AttributeUnderlyingType::UInt8: std::get<std::unique_ptr<UInt8[]>>(attribute.arrays)[idx] = value.get<UInt64>(); break; case AttributeUnderlyingType::UInt8: std::get<ContainerPtrType<UInt8>>(attribute.arrays)[idx] = value.get<UInt64>(); break;
case AttributeUnderlyingType::UInt16: std::get<std::unique_ptr<UInt16[]>>(attribute.arrays)[idx] = value.get<UInt64>(); break; case AttributeUnderlyingType::UInt16: std::get<ContainerPtrType<UInt16>>(attribute.arrays)[idx] = value.get<UInt64>(); break;
case AttributeUnderlyingType::UInt32: std::get<std::unique_ptr<UInt32[]>>(attribute.arrays)[idx] = value.get<UInt64>(); break; case AttributeUnderlyingType::UInt32: std::get<ContainerPtrType<UInt32>>(attribute.arrays)[idx] = value.get<UInt64>(); break;
case AttributeUnderlyingType::UInt64: std::get<std::unique_ptr<UInt64[]>>(attribute.arrays)[idx] = value.get<UInt64>(); break; case AttributeUnderlyingType::UInt64: std::get<ContainerPtrType<UInt64>>(attribute.arrays)[idx] = value.get<UInt64>(); break;
case AttributeUnderlyingType::Int8: std::get<std::unique_ptr<Int8[]>>(attribute.arrays)[idx] = value.get<Int64>(); break; case AttributeUnderlyingType::Int8: std::get<ContainerPtrType<Int8>>(attribute.arrays)[idx] = value.get<Int64>(); break;
case AttributeUnderlyingType::Int16: std::get<std::unique_ptr<Int16[]>>(attribute.arrays)[idx] = value.get<Int64>(); break; case AttributeUnderlyingType::Int16: std::get<ContainerPtrType<Int16>>(attribute.arrays)[idx] = value.get<Int64>(); break;
case AttributeUnderlyingType::Int32: std::get<std::unique_ptr<Int32[]>>(attribute.arrays)[idx] = value.get<Int64>(); break; case AttributeUnderlyingType::Int32: std::get<ContainerPtrType<Int32>>(attribute.arrays)[idx] = value.get<Int64>(); break;
case AttributeUnderlyingType::Int64: std::get<std::unique_ptr<Int64[]>>(attribute.arrays)[idx] = value.get<Int64>(); break; case AttributeUnderlyingType::Int64: std::get<ContainerPtrType<Int64>>(attribute.arrays)[idx] = value.get<Int64>(); break;
case AttributeUnderlyingType::Float32: std::get<std::unique_ptr<Float32[]>>(attribute.arrays)[idx] = value.get<Float64>(); break; case AttributeUnderlyingType::Float32: std::get<ContainerPtrType<Float32>>(attribute.arrays)[idx] = value.get<Float64>(); break;
case AttributeUnderlyingType::Float64: std::get<std::unique_ptr<Float64[]>>(attribute.arrays)[idx] = value.get<Float64>(); break; case AttributeUnderlyingType::Float64: std::get<ContainerPtrType<Float64>>(attribute.arrays)[idx] = value.get<Float64>(); break;
case AttributeUnderlyingType::String: case AttributeUnderlyingType::String:
{ {
const auto & string = value.get<String>(); const auto & string = value.get<String>();
auto & string_ref = std::get<std::unique_ptr<StringRef[]>>(attribute.arrays)[idx]; auto & string_ref = std::get<ContainerPtrType<StringRef>>(attribute.arrays)[idx];
const auto & null_value_ref = std::get<String>(attribute.null_values); const auto & null_value_ref = std::get<String>(attribute.null_values);
if (string_ref.data != null_value_ref.data())
{ /// free memory unless it points to a null_value
if (string_ref.size != 0) if (string_ref.data && string_ref.data != null_value_ref.data())
bytes_allocated -= string_ref.size + 1; string_arena->free(string_ref.data, string_ref.size);
/// avoid explicit delete, let unique_ptr handle it
const std::unique_ptr<const char[]> deleter{string_ref.data};
}
const auto size = string.size(); const auto size = string.size();
if (size != 0) if (size != 0)
{ {
auto string_ptr = std::make_unique<char[]>(size + 1); auto string_ptr = string_arena->alloc(size + 1);
std::copy(string.data(), string.data() + size + 1, string_ptr.get()); std::copy(string.data(), string.data() + size + 1, string_ptr);
string_ref = StringRef{string_ptr.release(), size}; string_ref = StringRef{string_ptr, size};
bytes_allocated += size + 1;
} }
else else
string_ref = {}; string_ref = {};
@ -616,6 +786,7 @@ private:
mutable std::vector<attribute_t> attributes; mutable std::vector<attribute_t> attributes;
mutable std::vector<cell_metadata_t> cells; mutable std::vector<cell_metadata_t> cells;
attribute_t * hierarchical_attribute = nullptr; attribute_t * hierarchical_attribute = nullptr;
std::unique_ptr<ArenaWithFreeLists> string_arena;
mutable std::mt19937_64 rnd_engine{getSeed()}; mutable std::mt19937_64 rnd_engine{getSeed()};

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include <DB/Dictionaries/IDictionarySource.h> #include <DB/Dictionaries/IDictionarySource.h>
#include <DB/Dictionaries/DictionaryStructure.h>
#include <DB/Client/ConnectionPool.h> #include <DB/Client/ConnectionPool.h>
#include <DB/DataStreams/RemoteBlockInputStream.h> #include <DB/DataStreams/RemoteBlockInputStream.h>
#include <DB/Interpreters/executeQuery.h> #include <DB/Interpreters/executeQuery.h>
@ -40,7 +41,8 @@ public:
max_connections, host, port, db, user, password, max_connections, host, port, db, user, password,
"ClickHouseDictionarySource") "ClickHouseDictionarySource")
}, },
load_all_query{composeLoadAllQuery()} load_all_query{composeLoadAllQuery()},
key_tuple_definition{dict_struct.key ? composeKeyTupleDefinition() : std::string{}}
{} {}
/// copy-constructor is provided in order to support cloneability /// copy-constructor is provided in order to support cloneability
@ -69,11 +71,13 @@ public:
BlockInputStreamPtr loadIds(const std::vector<std::uint64_t> & ids) override BlockInputStreamPtr loadIds(const std::vector<std::uint64_t> & ids) override
{ {
const auto query = composeLoadIdsQuery(ids); return createStreamForSelectiveLoad(composeLoadIdsQuery(ids));
}
if (is_local) BlockInputStreamPtr loadKeys(
return executeQuery(query, context, true).in; const ConstColumnPlainPtrs & key_columns, const std::vector<std::size_t> & requested_rows) override
return new RemoteBlockInputStream{pool.get(), query, nullptr}; {
return createStreamForSelectiveLoad(composeLoadKeysQuery(key_columns, requested_rows));
} }
bool isModified() const override { return true; } bool isModified() const override { return true; }
@ -95,13 +99,15 @@ private:
WriteBufferFromString out{query}; WriteBufferFromString out{query};
writeString("SELECT ", out); writeString("SELECT ", out);
if (!dict_struct.id.expression.empty()) if (dict_struct.id)
{ {
writeParenthesisedString(dict_struct.id.expression, out); if (!dict_struct.id->expression.empty())
{
writeParenthesisedString(dict_struct.id->expression, out);
writeString(" AS ", out); writeString(" AS ", out);
} }
writeProbablyBackQuotedString(dict_struct.id.name, out); writeProbablyBackQuotedString(dict_struct.id->name, out);
if (dict_struct.range_min && dict_struct.range_max) if (dict_struct.range_min && dict_struct.range_max)
{ {
@ -125,6 +131,26 @@ private:
writeProbablyBackQuotedString(dict_struct.range_max->name, out); writeProbablyBackQuotedString(dict_struct.range_max->name, out);
} }
}
else if (dict_struct.key)
{
auto first = true;
for (const auto & key : *dict_struct.key)
{
if (!first)
writeString(", ", out);
first = false;
if (!key.expression.empty())
{
writeParenthesisedString(key.expression, out);
writeString(" AS ", out);
}
writeProbablyBackQuotedString(key.name, out);
}
}
for (const auto & attr : dict_struct.attributes) for (const auto & attr : dict_struct.attributes)
{ {
@ -161,19 +187,22 @@ private:
std::string composeLoadIdsQuery(const std::vector<std::uint64_t> ids) std::string composeLoadIdsQuery(const std::vector<std::uint64_t> ids)
{ {
if (!dict_struct.id)
throw Exception{"Simple key required for method", ErrorCodes::UNSUPPORTED_METHOD};
std::string query; std::string query;
{ {
WriteBufferFromString out{query}; WriteBufferFromString out{query};
writeString("SELECT ", out); writeString("SELECT ", out);
if (!dict_struct.id.expression.empty()) if (!dict_struct.id->expression.empty())
{ {
writeParenthesisedString(dict_struct.id.expression, out); writeParenthesisedString(dict_struct.id->expression, out);
writeString(" AS ", out); writeString(" AS ", out);
} }
writeProbablyBackQuotedString(dict_struct.id.name, out); writeProbablyBackQuotedString(dict_struct.id->name, out);
for (const auto & attr : dict_struct.attributes) for (const auto & attr : dict_struct.attributes)
{ {
@ -204,7 +233,7 @@ private:
writeString(" AND ", out); writeString(" AND ", out);
} }
writeProbablyBackQuotedString(dict_struct.id.name, out); writeProbablyBackQuotedString(dict_struct.id->name, out);
writeString(" IN (", out); writeString(" IN (", out);
auto first = true; auto first = true;
@ -223,6 +252,118 @@ private:
return query; return query;
} }
std::string composeLoadKeysQuery(
const ConstColumnPlainPtrs & key_columns, const std::vector<std::size_t> & requested_rows)
{
if (!dict_struct.key)
throw Exception{"Composite key required for method", ErrorCodes::UNSUPPORTED_METHOD};
std::string query;
{
WriteBufferFromString out{query};
writeString("SELECT ", out);
auto first = true;
for (const auto & key_or_attribute : boost::join(*dict_struct.key, dict_struct.attributes))
{
if (!first)
writeString(", ", out);
first = false;
if (!key_or_attribute.expression.empty())
{
writeParenthesisedString(key_or_attribute.expression, out);
writeString(" AS ", out);
}
writeProbablyBackQuotedString(key_or_attribute.name, out);
}
writeString(" FROM ", out);
if (!db.empty())
{
writeProbablyBackQuotedString(db, out);
writeChar('.', out);
}
writeProbablyBackQuotedString(table, out);
writeString(" WHERE ", out);
if (!where.empty())
{
writeString(where, out);
writeString(" AND ", out);
}
writeString(key_tuple_definition, out);
writeString(" IN (", out);
first = true;
for (const auto row : requested_rows)
{
if (!first)
writeString(", ", out);
first = false;
composeKeyTuple(key_columns, row, out);
}
writeString(");", out);
}
return query;
}
std::string composeKeyTupleDefinition() const
{
if (!dict_struct.key)
throw Exception{"Composite key required for method", ErrorCodes::UNSUPPORTED_METHOD};
std::string result{"("};
auto first = true;
for (const auto & key : *dict_struct.key)
{
if (!first)
result += ", ";
first = false;
result += key.name;
}
result += ")";
return result;
}
void composeKeyTuple(const ConstColumnPlainPtrs & key_columns, const std::size_t row, WriteBuffer & out) const
{
writeString("(", out);
const auto keys_size = key_columns.size();
auto first = true;
for (const auto i : ext::range(0, keys_size))
{
if (!first)
writeString(", ", out);
first = false;
const auto & value = (*key_columns[i])[row];
(*dict_struct.key)[i].type->serializeTextQuoted(value, out);
}
writeString(")", out);
}
BlockInputStreamPtr createStreamForSelectiveLoad(const std::string query)
{
if (is_local)
return executeQuery(query, context, true).in;
return new RemoteBlockInputStream{pool.get(), query, nullptr};
}
const DictionaryStructure dict_struct; const DictionaryStructure dict_struct;
const std::string host; const std::string host;
const UInt16 port; const UInt16 port;
@ -236,6 +377,7 @@ private:
const bool is_local; const bool is_local;
std::unique_ptr<ConnectionPool> pool; std::unique_ptr<ConnectionPool> pool;
const std::string load_all_query; const std::string load_all_query;
const std::string key_tuple_definition;
}; };
} }

View File

@ -0,0 +1,941 @@
#pragma once
#include <DB/Dictionaries/IDictionary.h>
#include <DB/Dictionaries/IDictionarySource.h>
#include <DB/Dictionaries/DictionaryStructure.h>
#include <DB/Common/Arena.h>
#include <DB/Common/ArenaWithFreeLists.h>
#include <DB/Common/SmallObjectPool.h>
#include <DB/Common/HashTable/HashMap.h>
#include <DB/Columns/ColumnString.h>
#include <DB/Core/StringRef.h>
#include <ext/enumerate.hpp>
#include <ext/scope_guard.hpp>
#include <ext/bit_cast.hpp>
#include <ext/range.hpp>
#include <ext/map.hpp>
#include <Poco/RWLock.h>
#include <cmath>
#include <atomic>
#include <chrono>
#include <vector>
#include <map>
#include <tuple>
namespace DB
{
class ComplexKeyCacheDictionary final : public IDictionaryBase
{
public:
ComplexKeyCacheDictionary(const std::string & name, const DictionaryStructure & dict_struct,
DictionarySourcePtr source_ptr, const DictionaryLifetime dict_lifetime,
const std::size_t size)
: name{name}, dict_struct(dict_struct), source_ptr{std::move(source_ptr)}, dict_lifetime(dict_lifetime),
size{round_up_to_power_of_two(size)}
{
if (!this->source_ptr->supportsSelectiveLoad())
throw Exception{
name + ": source cannot be used with ComplexKeyCacheDictionary",
ErrorCodes::UNSUPPORTED_METHOD
};
createAttributes();
}
ComplexKeyCacheDictionary(const ComplexKeyCacheDictionary & other)
: ComplexKeyCacheDictionary{other.name, other.dict_struct, other.source_ptr->clone(), other.dict_lifetime, other.size}
{}
std::string getKeyDescription() const { return key_description; };
std::exception_ptr getCreationException() const override { return {}; }
std::string getName() const override { return name; }
std::string getTypeName() const override { return "ComplexKeyCache"; }
std::size_t getBytesAllocated() const override
{
return bytes_allocated + (key_size_is_fixed ? fixed_size_keys_pool->size() : keys_pool->size()) +
(string_arena ? string_arena->size() : 0);
}
std::size_t getQueryCount() const override { return query_count.load(std::memory_order_relaxed); }
double getHitRate() const override
{
return static_cast<double>(hit_count.load(std::memory_order_acquire)) /
query_count.load(std::memory_order_relaxed);
}
std::size_t getElementCount() const override { return element_count.load(std::memory_order_relaxed); }
double getLoadFactor() const override
{
return static_cast<double>(element_count.load(std::memory_order_relaxed)) / size;
}
bool isCached() const override { return true; }
DictionaryPtr clone() const override { return std::make_unique<ComplexKeyCacheDictionary>(*this); }
const IDictionarySource * getSource() const override { return source_ptr.get(); }
const DictionaryLifetime & getLifetime() const override { return dict_lifetime; }
const DictionaryStructure & getStructure() const override { return dict_struct; }
std::chrono::time_point<std::chrono::system_clock> getCreationTime() const override
{
return creation_time;
}
bool isInjective(const std::string & attribute_name) const override
{
return dict_struct.attributes[&getAttribute(attribute_name) - attributes.data()].injective;
}
#define DECLARE(TYPE)\
void get##TYPE(\
const std::string & attribute_name, const ConstColumnPlainPtrs & key_columns, const DataTypes & key_types,\
PODArray<TYPE> & out) const\
{\
dict_struct.validateKeyTypes(key_types);\
\
auto & attribute = getAttribute(attribute_name);\
if (attribute.type != AttributeUnderlyingType::TYPE)\
throw Exception{\
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),\
ErrorCodes::TYPE_MISMATCH\
};\
\
const auto null_value = std::get<TYPE>(attribute.null_values);\
\
getItems<TYPE>(attribute, key_columns, out, [&] (const std::size_t) { return null_value; });\
}
DECLARE(UInt8)
DECLARE(UInt16)
DECLARE(UInt32)
DECLARE(UInt64)
DECLARE(Int8)
DECLARE(Int16)
DECLARE(Int32)
DECLARE(Int64)
DECLARE(Float32)
DECLARE(Float64)
#undef DECLARE
void getString(
const std::string & attribute_name, const ConstColumnPlainPtrs & key_columns, const DataTypes & key_types,
ColumnString * out) const
{
dict_struct.validateKeyTypes(key_types);
auto & attribute = getAttribute(attribute_name);
if (attribute.type != AttributeUnderlyingType::String)
throw Exception{
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH
};
const auto null_value = StringRef{std::get<String>(attribute.null_values)};
getItems(attribute, key_columns, out, [&] (const std::size_t) { return null_value; });
}
#define DECLARE(TYPE)\
void get##TYPE(\
const std::string & attribute_name, const ConstColumnPlainPtrs & key_columns, const DataTypes & key_types,\
const PODArray<TYPE> & def, PODArray<TYPE> & out) const\
{\
dict_struct.validateKeyTypes(key_types);\
\
auto & attribute = getAttribute(attribute_name);\
if (attribute.type != AttributeUnderlyingType::TYPE)\
throw Exception{\
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),\
ErrorCodes::TYPE_MISMATCH\
};\
\
getItems<TYPE>(attribute, key_columns, out, [&] (const std::size_t row) { return def[row]; });\
}
DECLARE(UInt8)
DECLARE(UInt16)
DECLARE(UInt32)
DECLARE(UInt64)
DECLARE(Int8)
DECLARE(Int16)
DECLARE(Int32)
DECLARE(Int64)
DECLARE(Float32)
DECLARE(Float64)
#undef DECLARE
void getString(
const std::string & attribute_name, const ConstColumnPlainPtrs & key_columns, const DataTypes & key_types,
const ColumnString * const def, ColumnString * const out) const
{
dict_struct.validateKeyTypes(key_types);
auto & attribute = getAttribute(attribute_name);
if (attribute.type != AttributeUnderlyingType::String)
throw Exception{
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH
};
getItems(attribute, key_columns, out, [&] (const std::size_t row) { return def->getDataAt(row); });
}
#define DECLARE(TYPE)\
void get##TYPE(\
const std::string & attribute_name, const ConstColumnPlainPtrs & key_columns, const DataTypes & key_types,\
const TYPE def, PODArray<TYPE> & out) const\
{\
dict_struct.validateKeyTypes(key_types);\
\
auto & attribute = getAttribute(attribute_name);\
if (attribute.type != AttributeUnderlyingType::TYPE)\
throw Exception{\
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),\
ErrorCodes::TYPE_MISMATCH\
};\
\
getItems<TYPE>(attribute, key_columns, out, [&] (const std::size_t) { return def; });\
}
DECLARE(UInt8)
DECLARE(UInt16)
DECLARE(UInt32)
DECLARE(UInt64)
DECLARE(Int8)
DECLARE(Int16)
DECLARE(Int32)
DECLARE(Int64)
DECLARE(Float32)
DECLARE(Float64)
#undef DECLARE
void getString(
const std::string & attribute_name, const ConstColumnPlainPtrs & key_columns, const DataTypes & key_types,
const String & def, ColumnString * const out) const
{
dict_struct.validateKeyTypes(key_types);
auto & attribute = getAttribute(attribute_name);
if (attribute.type != AttributeUnderlyingType::String)
throw Exception{
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH
};
getItems(attribute, key_columns, out, [&] (const std::size_t) { return StringRef{def}; });
}
void has(const ConstColumnPlainPtrs & key_columns, const DataTypes & key_types, PODArray<UInt8> & out) const
{
dict_struct.validateKeyTypes(key_types);
/// Mapping: <key> -> { all indices `i` of `key_columns` such that `key_columns[i]` = <key> }
MapType<std::vector<std::size_t>> outdated_keys;
const auto rows = key_columns.front()->size();
const auto keys_size = dict_struct.key->size();
StringRefs keys(keys_size);
Arena temporary_keys_pool;
PODArray<StringRef> keys_array(rows);
{
const Poco::ScopedReadRWLock read_lock{rw_lock};
const auto now = std::chrono::system_clock::now();
/// fetch up-to-date values, decide which ones require update
for (const auto row : ext::range(0, rows))
{
const auto key = placeKeysInPool(row, key_columns, keys, temporary_keys_pool);
keys_array[row] = key;
const auto hash = StringRefHash{}(key);
const auto cell_idx = hash & (size - 1);
const auto & cell = cells[cell_idx];
/** cell should be updated if either:
* 1. keys (or hash) do not match,
* 2. cell has expired,
* 3. explicit defaults were specified and cell was set default. */
if (cell.hash != hash || cell.key != key || cell.expiresAt() < now)
outdated_keys[key].push_back(row);
else
out[row] = !cell.isDefault();
}
}
query_count.fetch_add(rows, std::memory_order_relaxed);
hit_count.fetch_add(rows - outdated_keys.size(), std::memory_order_release);
if (outdated_keys.empty())
return;
std::vector<std::size_t> required_rows(outdated_keys.size());
std::transform(std::begin(outdated_keys), std::end(outdated_keys), std::begin(required_rows),
[] (auto & pair) { return pair.second.front(); });
/// request new values
update(key_columns, keys_array, required_rows, [&] (const auto key, const auto) {
for (const auto out_idx : outdated_keys[key])
out[out_idx] = true;
}, [&] (const auto key, const auto) {
for (const auto out_idx : outdated_keys[key])
out[out_idx] = false;
});
}
private:
template <typename Value> using MapType = HashMapWithSavedHash<StringRef, Value, StringRefHash>;
template <typename Value> using ContainerType = Value[];
template <typename Value> using ContainerPtrType = std::unique_ptr<ContainerType<Value>>;
struct cell_metadata_t final
{
using time_point_t = std::chrono::system_clock::time_point;
using time_point_rep_t = time_point_t::rep;
using time_point_urep_t = std::make_unsigned_t<time_point_rep_t>;
static constexpr std::uint64_t EXPIRES_AT_MASK = std::numeric_limits<time_point_rep_t>::max();
static constexpr std::uint64_t IS_DEFAULT_MASK = ~EXPIRES_AT_MASK;
StringRef key;
decltype(StringRefHash{}(key)) hash;
/// Stores both expiration time and `is_default` flag in the most significant bit
time_point_urep_t data;
/// Sets expiration time, resets `is_default` flag to false
time_point_t expiresAt() const { return ext::safe_bit_cast<time_point_t>(data & EXPIRES_AT_MASK); }
void setExpiresAt(const time_point_t & t) { data = ext::safe_bit_cast<time_point_urep_t>(t); }
bool isDefault() const { return (data & IS_DEFAULT_MASK) == IS_DEFAULT_MASK; }
void setDefault() { data |= IS_DEFAULT_MASK; }
};
struct attribute_t final
{
AttributeUnderlyingType type;
std::tuple<
UInt8, UInt16, UInt32, UInt64,
Int8, Int16, Int32, Int64,
Float32, Float64,
String> null_values;
std::tuple<
ContainerPtrType<UInt8>, ContainerPtrType<UInt16>, ContainerPtrType<UInt32>, ContainerPtrType<UInt64>,
ContainerPtrType<Int8>, ContainerPtrType<Int16>, ContainerPtrType<Int32>, ContainerPtrType<Int64>,
ContainerPtrType<Float32>, ContainerPtrType<Float64>,
ContainerPtrType<StringRef>> arrays;
};
void createAttributes()
{
const auto size = dict_struct.attributes.size();
attributes.reserve(size);
bytes_allocated += size * sizeof(cell_metadata_t);
bytes_allocated += size * sizeof(attributes.front());
for (const auto & attribute : dict_struct.attributes)
{
attribute_index_by_name.emplace(attribute.name, attributes.size());
attributes.push_back(createAttributeWithType(attribute.underlying_type, attribute.null_value));
if (attribute.hierarchical)
throw Exception{
name + ": hierarchical attributes not supported for dictionary of type " + getTypeName(),
ErrorCodes::TYPE_MISMATCH
};
}
}
attribute_t createAttributeWithType(const AttributeUnderlyingType type, const Field & null_value)
{
attribute_t attr{type};
switch (type)
{
case AttributeUnderlyingType::UInt8:
std::get<UInt8>(attr.null_values) = null_value.get<UInt64>();
std::get<ContainerPtrType<UInt8>>(attr.arrays) = std::make_unique<ContainerType<UInt8>>(size);
bytes_allocated += size * sizeof(UInt8);
break;
case AttributeUnderlyingType::UInt16:
std::get<UInt16>(attr.null_values) = null_value.get<UInt64>();
std::get<ContainerPtrType<UInt16>>(attr.arrays) = std::make_unique<ContainerType<UInt16>>(size);
bytes_allocated += size * sizeof(UInt16);
break;
case AttributeUnderlyingType::UInt32:
std::get<UInt32>(attr.null_values) = null_value.get<UInt64>();
std::get<ContainerPtrType<UInt32>>(attr.arrays) = std::make_unique<ContainerType<UInt32>>(size);
bytes_allocated += size * sizeof(UInt32);
break;
case AttributeUnderlyingType::UInt64:
std::get<UInt64>(attr.null_values) = null_value.get<UInt64>();
std::get<ContainerPtrType<UInt64>>(attr.arrays) = std::make_unique<ContainerType<UInt64>>(size);
bytes_allocated += size * sizeof(UInt64);
break;
case AttributeUnderlyingType::Int8:
std::get<Int8>(attr.null_values) = null_value.get<Int64>();
std::get<ContainerPtrType<Int8>>(attr.arrays) = std::make_unique<ContainerType<Int8>>(size);
bytes_allocated += size * sizeof(Int8);
break;
case AttributeUnderlyingType::Int16:
std::get<Int16>(attr.null_values) = null_value.get<Int64>();
std::get<ContainerPtrType<Int16>>(attr.arrays) = std::make_unique<ContainerType<Int16>>(size);
bytes_allocated += size * sizeof(Int16);
break;
case AttributeUnderlyingType::Int32:
std::get<Int32>(attr.null_values) = null_value.get<Int64>();
std::get<ContainerPtrType<Int32>>(attr.arrays) = std::make_unique<ContainerType<Int32>>(size);
bytes_allocated += size * sizeof(Int32);
break;
case AttributeUnderlyingType::Int64:
std::get<Int64>(attr.null_values) = null_value.get<Int64>();
std::get<ContainerPtrType<Int64>>(attr.arrays) = std::make_unique<ContainerType<Int64>>(size);
bytes_allocated += size * sizeof(Int64);
break;
case AttributeUnderlyingType::Float32:
std::get<Float32>(attr.null_values) = null_value.get<Float64>();
std::get<ContainerPtrType<Float32>>(attr.arrays) = std::make_unique<ContainerType<Float32>>(size);
bytes_allocated += size * sizeof(Float32);
break;
case AttributeUnderlyingType::Float64:
std::get<Float64>(attr.null_values) = null_value.get<Float64>();
std::get<ContainerPtrType<Float64>>(attr.arrays) = std::make_unique<ContainerType<Float64>>(size);
bytes_allocated += size * sizeof(Float64);
break;
case AttributeUnderlyingType::String:
std::get<String>(attr.null_values) = null_value.get<String>();
std::get<ContainerPtrType<StringRef>>(attr.arrays) = std::make_unique<ContainerType<StringRef>>(size);
bytes_allocated += size * sizeof(StringRef);
if (!string_arena)
string_arena = std::make_unique<ArenaWithFreeLists>();
break;
}
return attr;
}
template <typename T, typename DefaultGetter>
void getItems(
attribute_t & attribute, const ConstColumnPlainPtrs & key_columns, PODArray<T> & out,
DefaultGetter && get_default) const
{
/// Mapping: <key> -> { all indices `i` of `key_columns` such that `key_columns[i]` = <key> }
MapType<std::vector<std::size_t>> outdated_keys;
auto & attribute_array = std::get<ContainerPtrType<T>>(attribute.arrays);
const auto rows = key_columns.front()->size();
const auto keys_size = dict_struct.key->size();
StringRefs keys(keys_size);
Arena temporary_keys_pool;
PODArray<StringRef> keys_array(rows);
{
const Poco::ScopedReadRWLock read_lock{rw_lock};
const auto now = std::chrono::system_clock::now();
/// fetch up-to-date values, decide which ones require update
for (const auto row : ext::range(0, rows))
{
const auto key = placeKeysInPool(row, key_columns, keys, temporary_keys_pool);
keys_array[row] = key;
const auto hash = StringRefHash{}(key);
const auto cell_idx = hash & (size - 1);
const auto & cell = cells[cell_idx];
/** cell should be updated if either:
* 1. keys (or hash) do not match,
* 2. cell has expired,
* 3. explicit defaults were specified and cell was set default. */
if (cell.hash != hash || cell.key != key || cell.expiresAt() < now)
outdated_keys[key].push_back(row);
else
out[row] = cell.isDefault() ? get_default(row) : attribute_array[cell_idx];
}
}
query_count.fetch_add(rows, std::memory_order_relaxed);
hit_count.fetch_add(rows - outdated_keys.size(), std::memory_order_release);
if (outdated_keys.empty())
return;
std::vector<std::size_t> required_rows(outdated_keys.size());
std::transform(std::begin(outdated_keys), std::end(outdated_keys), std::begin(required_rows),
[] (auto & pair) { return pair.second.front(); });
/// request new values
update(key_columns, keys_array, required_rows, [&] (const auto key, const auto cell_idx) {
for (const auto row : outdated_keys[key])
out[row] = attribute_array[cell_idx];
}, [&] (const auto key, const auto cell_idx) {
for (const auto row : outdated_keys[key])
out[row] = get_default(row);
});
}
template <typename DefaultGetter>
void getItems(
attribute_t & attribute, const ConstColumnPlainPtrs & key_columns, ColumnString * out,
DefaultGetter && get_default) const
{
const auto rows = key_columns.front()->size();
/// save on some allocations
out->getOffsets().reserve(rows);
const auto keys_size = dict_struct.key->size();
StringRefs keys(keys_size);
Arena temporary_keys_pool;
auto & attribute_array = std::get<ContainerPtrType<StringRef>>(attribute.arrays);
auto found_outdated_values = false;
/// perform optimistic version, fallback to pessimistic if failed
{
const Poco::ScopedReadRWLock read_lock{rw_lock};
const auto now = std::chrono::system_clock::now();
/// fetch up-to-date values, discard on fail
for (const auto row : ext::range(0, rows))
{
const auto key = placeKeysInPool(row, key_columns, keys, temporary_keys_pool);
SCOPE_EXIT(temporary_keys_pool.rollback(key.size));
const auto hash = StringRefHash{}(key);
const auto cell_idx = hash & (size - 1);
const auto & cell = cells[cell_idx];
if (cell.hash != hash || cell.key != key || cell.expiresAt() < now)
{
found_outdated_values = true;
break;
}
else
{
const auto string_ref = cell.isDefault() ? get_default(row) : attribute_array[cell_idx];
out->insertData(string_ref.data, string_ref.size);
}
}
}
/// optimistic code completed successfully
if (!found_outdated_values)
{
query_count.fetch_add(rows, std::memory_order_relaxed);
hit_count.fetch_add(rows, std::memory_order_release);
return;
}
/// now onto the pessimistic one, discard possible partial results from the optimistic path
out->getChars().resize_assume_reserved(0);
out->getOffsets().resize_assume_reserved(0);
/// Mapping: <key> -> { all indices `i` of `key_columns` such that `key_columns[i]` = <key> }
MapType<std::vector<std::size_t>> outdated_keys;
/// we are going to store every string separately
MapType<String> map;
PODArray<StringRef> keys_array(rows);
std::size_t total_length = 0;
{
const Poco::ScopedReadRWLock read_lock{rw_lock};
const auto now = std::chrono::system_clock::now();
for (const auto row : ext::range(0, rows))
{
const auto key = placeKeysInPool(row, key_columns, keys, temporary_keys_pool);
keys_array[row] = key;
const auto hash = StringRefHash{}(key);
const auto cell_idx = hash & (size - 1);
const auto & cell = cells[cell_idx];
if (cell.hash != hash || cell.key != key || cell.expiresAt() < now)
outdated_keys[key].push_back(row);
else
{
const auto string_ref = cell.isDefault() ? get_default(row) : attribute_array[cell_idx];
if (!cell.isDefault())
map[key] = String{string_ref};
total_length += string_ref.size + 1;
}
}
}
query_count.fetch_add(rows, std::memory_order_relaxed);
hit_count.fetch_add(rows - outdated_keys.size(), std::memory_order_release);
/// request new values
if (!outdated_keys.empty())
{
std::vector<std::size_t> required_rows(outdated_keys.size());
std::transform(std::begin(outdated_keys), std::end(outdated_keys), std::begin(required_rows),
[] (auto & pair) { return pair.second.front(); });
update(key_columns, keys_array, required_rows, [&] (const auto key, const auto cell_idx) {
const auto attribute_value = attribute_array[cell_idx];
map[key] = String{attribute_value};
total_length += (attribute_value.size + 1) * outdated_keys[key].size();
}, [&] (const auto key, const auto cell_idx) {
for (const auto row : outdated_keys[key])
total_length += get_default(row).size + 1;
});
}
out->getChars().reserve(total_length);
for (const auto row : ext::range(0, ext::size(keys_array)))
{
const auto key = keys_array[row];
const auto it = map.find(key);
const auto string_ref = it != std::end(map) ? StringRef{it->second} : get_default(row);
out->insertData(string_ref.data, string_ref.size);
}
}
template <typename PresentKeyHandler, typename AbsentKeyHandler>
void update(
const ConstColumnPlainPtrs & in_key_columns, const PODArray<StringRef> & in_keys,
const std::vector<std::size_t> & in_requested_rows, PresentKeyHandler && on_cell_updated,
AbsentKeyHandler && on_key_not_found) const
{
auto stream = source_ptr->loadKeys(in_key_columns, in_requested_rows);
stream->readPrefix();
MapType<bool> remaining_keys{in_requested_rows.size()};
for (const auto row : in_requested_rows)
remaining_keys.insert({ in_keys[row], false });
std::uniform_int_distribution<std::uint64_t> distribution{
dict_lifetime.min_sec,
dict_lifetime.max_sec
};
const Poco::ScopedWriteRWLock write_lock{rw_lock};
const auto keys_size = dict_struct.key->size();
StringRefs keys(keys_size);
const auto attributes_size = attributes.size();
while (const auto block = stream->read())
{
/// cache column pointers
const auto key_columns = ext::map<ConstColumnPlainPtrs>(ext::range(0, keys_size),
[&] (const std::size_t attribute_idx) {
return block.getByPosition(attribute_idx).column.get();
});
const auto attribute_columns = ext::map<ConstColumnPlainPtrs>(ext::range(0, attributes_size),
[&] (const std::size_t attribute_idx) {
return block.getByPosition(keys_size + attribute_idx).column.get();
});
const auto rows = block.rowsInFirstColumn();
for (const auto row : ext::range(0, rows))
{
auto key = allocKey(row, key_columns, keys);
const auto hash = StringRefHash{}(key);
const auto cell_idx = hash & (size - 1);
auto & cell = cells[cell_idx];
for (const auto attribute_idx : ext::range(0, attributes.size()))
{
const auto & attribute_column = *attribute_columns[attribute_idx];
auto & attribute = attributes[attribute_idx];
setAttributeValue(attribute, cell_idx, attribute_column[row]);
}
/// if cell id is zero and zero does not map to this cell, then the cell is unused
if (cell.key == StringRef{} && cell_idx != zero_cell_idx)
element_count.fetch_add(1, std::memory_order_relaxed);
/// handle memory allocated for old key
if (key == cell.key)
{
freeKey(key);
key = cell.key;
}
else
{
/// new key is different from the old one
if (cell.key.data)
freeKey(cell.key);
cell.key = key;
}
cell.hash = hash;
if (dict_lifetime.min_sec != 0 && dict_lifetime.max_sec != 0)
cell.setExpiresAt(std::chrono::system_clock::now() + std::chrono::seconds{distribution(rnd_engine)});
else
cell.setExpiresAt(std::chrono::time_point<std::chrono::system_clock>::max());
/// inform caller
on_cell_updated(key, cell_idx);
/// mark corresponding id as found
remaining_keys[key] = true;
}
}
stream->readSuffix();
/// Check which ids have not been found and require setting null_value
for (const auto key_found_pair : remaining_keys)
{
if (key_found_pair.second)
continue;
auto key = key_found_pair.first;
const auto hash = StringRefHash{}(key);
const auto cell_idx = hash & (size - 1);
auto & cell = cells[cell_idx];
/// Set null_value for each attribute
for (auto & attribute : attributes)
setDefaultAttributeValue(attribute, cell_idx);
/// Check if cell had not been occupied before and increment element counter if it hadn't
if (cell.key == StringRef{} && cell_idx != zero_cell_idx)
element_count.fetch_add(1, std::memory_order_relaxed);
if (key == cell.key)
key = cell.key;
else
{
if (cell.key.data)
freeKey(cell.key);
/// copy key from temporary pool
key = copyKey(key);
cell.key = key;
}
cell.hash = hash;
if (dict_lifetime.min_sec != 0 && dict_lifetime.max_sec != 0)
cell.setExpiresAt(std::chrono::system_clock::now() + std::chrono::seconds{distribution(rnd_engine)});
else
cell.setExpiresAt(std::chrono::time_point<std::chrono::system_clock>::max());
cell.setDefault();
/// inform caller that the cell has not been found
on_key_not_found(key, cell_idx);
}
}
std::uint64_t getCellIdx(const StringRef key) const
{
const auto hash = StringRefHash{}(key);
const auto idx = hash & (size - 1);
return idx;
}
void setDefaultAttributeValue(attribute_t & attribute, const std::size_t idx) const
{
switch (attribute.type)
{
case AttributeUnderlyingType::UInt8: std::get<ContainerPtrType<UInt8>>(attribute.arrays)[idx] = std::get<UInt8>(attribute.null_values); break;
case AttributeUnderlyingType::UInt16: std::get<ContainerPtrType<UInt16>>(attribute.arrays)[idx] = std::get<UInt16>(attribute.null_values); break;
case AttributeUnderlyingType::UInt32: std::get<ContainerPtrType<UInt32>>(attribute.arrays)[idx] = std::get<UInt32>(attribute.null_values); break;
case AttributeUnderlyingType::UInt64: std::get<ContainerPtrType<UInt64>>(attribute.arrays)[idx] = std::get<UInt64>(attribute.null_values); break;
case AttributeUnderlyingType::Int8: std::get<ContainerPtrType<Int8>>(attribute.arrays)[idx] = std::get<Int8>(attribute.null_values); break;
case AttributeUnderlyingType::Int16: std::get<ContainerPtrType<Int16>>(attribute.arrays)[idx] = std::get<Int16>(attribute.null_values); break;
case AttributeUnderlyingType::Int32: std::get<ContainerPtrType<Int32>>(attribute.arrays)[idx] = std::get<Int32>(attribute.null_values); break;
case AttributeUnderlyingType::Int64: std::get<ContainerPtrType<Int64>>(attribute.arrays)[idx] = std::get<Int64>(attribute.null_values); break;
case AttributeUnderlyingType::Float32: std::get<ContainerPtrType<Float32>>(attribute.arrays)[idx] = std::get<Float32>(attribute.null_values); break;
case AttributeUnderlyingType::Float64: std::get<ContainerPtrType<Float64>>(attribute.arrays)[idx] = std::get<Float64>(attribute.null_values); break;
case AttributeUnderlyingType::String:
{
const auto & null_value_ref = std::get<String>(attribute.null_values);
auto & string_ref = std::get<ContainerPtrType<StringRef>>(attribute.arrays)[idx];
if (string_ref.data != null_value_ref.data())
{
if (string_ref.data)
string_arena->free(string_ref.data, string_ref.size);
string_ref = StringRef{null_value_ref};
}
break;
}
}
}
void setAttributeValue(attribute_t & attribute, const std::size_t idx, const Field & value) const
{
switch (attribute.type)
{
case AttributeUnderlyingType::UInt8: std::get<ContainerPtrType<UInt8>>(attribute.arrays)[idx] = value.get<UInt64>(); break;
case AttributeUnderlyingType::UInt16: std::get<ContainerPtrType<UInt16>>(attribute.arrays)[idx] = value.get<UInt64>(); break;
case AttributeUnderlyingType::UInt32: std::get<ContainerPtrType<UInt32>>(attribute.arrays)[idx] = value.get<UInt64>(); break;
case AttributeUnderlyingType::UInt64: std::get<ContainerPtrType<UInt64>>(attribute.arrays)[idx] = value.get<UInt64>(); break;
case AttributeUnderlyingType::Int8: std::get<ContainerPtrType<Int8>>(attribute.arrays)[idx] = value.get<Int64>(); break;
case AttributeUnderlyingType::Int16: std::get<ContainerPtrType<Int16>>(attribute.arrays)[idx] = value.get<Int64>(); break;
case AttributeUnderlyingType::Int32: std::get<ContainerPtrType<Int32>>(attribute.arrays)[idx] = value.get<Int64>(); break;
case AttributeUnderlyingType::Int64: std::get<ContainerPtrType<Int64>>(attribute.arrays)[idx] = value.get<Int64>(); break;
case AttributeUnderlyingType::Float32: std::get<ContainerPtrType<Float32>>(attribute.arrays)[idx] = value.get<Float64>(); break;
case AttributeUnderlyingType::Float64: std::get<ContainerPtrType<Float64>>(attribute.arrays)[idx] = value.get<Float64>(); break;
case AttributeUnderlyingType::String:
{
const auto & string = value.get<String>();
auto & string_ref = std::get<ContainerPtrType<StringRef>>(attribute.arrays)[idx];
const auto & null_value_ref = std::get<String>(attribute.null_values);
/// free memory unless it points to a null_value
if (string_ref.data && string_ref.data != null_value_ref.data())
string_arena->free(string_ref.data, string_ref.size);
const auto size = string.size();
if (size != 0)
{
auto string_ptr = string_arena->alloc(size + 1);
std::copy(string.data(), string.data() + size + 1, string_ptr);
string_ref = StringRef{string_ptr, size};
}
else
string_ref = {};
break;
}
}
}
attribute_t & getAttribute(const std::string & attribute_name) const
{
const auto it = attribute_index_by_name.find(attribute_name);
if (it == std::end(attribute_index_by_name))
throw Exception{
name + ": no such attribute '" + attribute_name + "'",
ErrorCodes::BAD_ARGUMENTS
};
return attributes[it->second];
}
StringRef allocKey(const std::size_t row, const ConstColumnPlainPtrs & key_columns, StringRefs & keys) const
{
if (key_size_is_fixed)
return placeKeysInFixedSizePool(row, key_columns);
return placeKeysInPool(row, key_columns, keys, *keys_pool);
}
void freeKey(const StringRef key) const
{
if (key_size_is_fixed)
fixed_size_keys_pool->free(key.data);
else
keys_pool->free(key.data, key.size);
}
static std::size_t round_up_to_power_of_two(std::size_t n)
{
--n;
n |= n >> 1;
n |= n >> 2;
n |= n >> 4;
n |= n >> 8;
n |= n >> 16;
n |= n >> 32;
++n;
return n;
}
static std::uint64_t getSeed()
{
timespec ts;
clock_gettime(CLOCK_MONOTONIC, &ts);
return ts.tv_nsec ^ getpid();
}
template <typename Arena>
static StringRef placeKeysInPool(
const std::size_t row, const ConstColumnPlainPtrs & key_columns, StringRefs & keys, Arena & pool)
{
const auto keys_size = key_columns.size();
size_t sum_keys_size{};
for (const auto i : ext::range(0, keys_size))
{
keys[i] = key_columns[i]->getDataAtWithTerminatingZero(row);
sum_keys_size += keys[i].size;
}
const auto res = pool.alloc(sum_keys_size);
auto place = res;
for (size_t j = 0; j < keys_size; ++j)
{
memcpy(place, keys[j].data, keys[j].size);
place += keys[j].size;
}
return { res, sum_keys_size };
}
StringRef placeKeysInFixedSizePool(
const std::size_t row, const ConstColumnPlainPtrs & key_columns) const
{
const auto res = fixed_size_keys_pool->alloc();
auto place = res;
for (const auto & key_column : key_columns)
{
const auto key = key_column->getDataAt(row);
memcpy(place, key.data, key.size);
place += key.size;
}
return { res, key_size };
}
StringRef copyKey(const StringRef key) const
{
const auto res = key_size_is_fixed ? fixed_size_keys_pool->alloc() : keys_pool->alloc(key.size);
memcpy(res, key.data, key.size);
return { res, key.size };
}
const std::string name;
const DictionaryStructure dict_struct;
const DictionarySourcePtr source_ptr;
const DictionaryLifetime dict_lifetime;
const std::string key_description{dict_struct.getKeyDescription()};
mutable Poco::RWLock rw_lock;
const std::size_t size;
const std::uint64_t zero_cell_idx{getCellIdx(StringRef{})};
std::map<std::string, std::size_t> attribute_index_by_name;
mutable std::vector<attribute_t> attributes;
mutable std::vector<cell_metadata_t> cells{size};
const bool key_size_is_fixed{dict_struct.isKeySizeFixed()};
std::size_t key_size{key_size_is_fixed ? dict_struct.getKeySize() : 0};
std::unique_ptr<ArenaWithFreeLists> keys_pool = key_size_is_fixed ? nullptr :
std::make_unique<ArenaWithFreeLists>();
std::unique_ptr<SmallObjectPool> fixed_size_keys_pool = key_size_is_fixed ?
std::make_unique<SmallObjectPool>(key_size) : nullptr;
std::unique_ptr<ArenaWithFreeLists> string_arena;
mutable std::mt19937_64 rnd_engine{getSeed()};
mutable std::size_t bytes_allocated = 0;
mutable std::atomic<std::size_t> element_count{0};
mutable std::atomic<std::size_t> hit_count{0};
mutable std::atomic<std::size_t> query_count{0};
const std::chrono::time_point<std::chrono::system_clock> creation_time = std::chrono::system_clock::now();
};
}

View File

@ -0,0 +1,565 @@
#pragma once
#include <DB/Dictionaries/IDictionary.h>
#include <DB/Dictionaries/IDictionarySource.h>
#include <DB/Dictionaries/DictionaryStructure.h>
#include <DB/Core/StringRef.h>
#include <DB/Common/HashTable/HashMap.h>
#include <DB/Columns/ColumnString.h>
#include <DB/Common/Arena.h>
#include <ext/range.hpp>
#include <atomic>
#include <memory>
#include <tuple>
namespace DB
{
class ComplexKeyHashedDictionary final : public IDictionaryBase
{
public:
ComplexKeyHashedDictionary(
const std::string & name, const DictionaryStructure & dict_struct, DictionarySourcePtr source_ptr,
const DictionaryLifetime dict_lifetime, bool require_nonempty)
: name{name}, dict_struct(dict_struct), source_ptr{std::move(source_ptr)}, dict_lifetime(dict_lifetime),
require_nonempty(require_nonempty)
{
createAttributes();
try
{
loadData();
calculateBytesAllocated();
}
catch (...)
{
creation_exception = std::current_exception();
}
creation_time = std::chrono::system_clock::now();
}
ComplexKeyHashedDictionary(const ComplexKeyHashedDictionary & other)
: ComplexKeyHashedDictionary{other.name, other.dict_struct, other.source_ptr->clone(), other.dict_lifetime, other.require_nonempty}
{}
std::string getKeyDescription() const { return key_description; };
std::exception_ptr getCreationException() const override { return creation_exception; }
std::string getName() const override { return name; }
std::string getTypeName() const override { return "ComplexKeyHashed"; }
std::size_t getBytesAllocated() const override { return bytes_allocated; }
std::size_t getQueryCount() const override { return query_count.load(std::memory_order_relaxed); }
double getHitRate() const override { return 1.0; }
std::size_t getElementCount() const override { return element_count; }
double getLoadFactor() const override { return static_cast<double>(element_count) / bucket_count; }
bool isCached() const override { return false; }
DictionaryPtr clone() const override { return std::make_unique<ComplexKeyHashedDictionary>(*this); }
const IDictionarySource * getSource() const override { return source_ptr.get(); }
const DictionaryLifetime & getLifetime() const override { return dict_lifetime; }
const DictionaryStructure & getStructure() const override { return dict_struct; }
std::chrono::time_point<std::chrono::system_clock> getCreationTime() const override
{
return creation_time;
}
bool isInjective(const std::string & attribute_name) const override
{
return dict_struct.attributes[&getAttribute(attribute_name) - attributes.data()].injective;
}
#define DECLARE(TYPE)\
void get##TYPE(\
const std::string & attribute_name, const ConstColumnPlainPtrs & key_columns, const DataTypes & key_types,\
PODArray<TYPE> & out) const\
{\
dict_struct.validateKeyTypes(key_types);\
\
const auto & attribute = getAttribute(attribute_name);\
if (attribute.type != AttributeUnderlyingType::TYPE)\
throw Exception{\
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),\
ErrorCodes::TYPE_MISMATCH\
};\
\
const auto null_value = std::get<TYPE>(attribute.null_values);\
\
getItems<TYPE>(attribute, key_columns,\
[&] (const std::size_t row, const auto value) { out[row] = value; },\
[&] (const std::size_t) { return null_value; });\
}
DECLARE(UInt8)
DECLARE(UInt16)
DECLARE(UInt32)
DECLARE(UInt64)
DECLARE(Int8)
DECLARE(Int16)
DECLARE(Int32)
DECLARE(Int64)
DECLARE(Float32)
DECLARE(Float64)
#undef DECLARE
void getString(
const std::string & attribute_name, const ConstColumnPlainPtrs & key_columns, const DataTypes & key_types,
ColumnString * out) const
{
dict_struct.validateKeyTypes(key_types);
const auto & attribute = getAttribute(attribute_name);
if (attribute.type != AttributeUnderlyingType::String)
throw Exception{
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH
};
const auto & null_value = StringRef{std::get<String>(attribute.null_values)};
getItems<StringRef>(attribute, key_columns,
[&] (const std::size_t row, const StringRef value) { out->insertData(value.data, value.size); },
[&] (const std::size_t) { return null_value; });
}
#define DECLARE(TYPE)\
void get##TYPE(\
const std::string & attribute_name, const ConstColumnPlainPtrs & key_columns, const DataTypes & key_types,\
const PODArray<TYPE> & def, PODArray<TYPE> & out) const\
{\
dict_struct.validateKeyTypes(key_types);\
\
const auto & attribute = getAttribute(attribute_name);\
if (attribute.type != AttributeUnderlyingType::TYPE)\
throw Exception{\
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),\
ErrorCodes::TYPE_MISMATCH\
};\
\
getItems<TYPE>(attribute, key_columns,\
[&] (const std::size_t row, const auto value) { out[row] = value; },\
[&] (const std::size_t row) { return def[row]; });\
}
DECLARE(UInt8)
DECLARE(UInt16)
DECLARE(UInt32)
DECLARE(UInt64)
DECLARE(Int8)
DECLARE(Int16)
DECLARE(Int32)
DECLARE(Int64)
DECLARE(Float32)
DECLARE(Float64)
#undef DECLARE
void getString(
const std::string & attribute_name, const ConstColumnPlainPtrs & key_columns, const DataTypes & key_types,
const ColumnString * const def, ColumnString * const out) const
{
dict_struct.validateKeyTypes(key_types);
const auto & attribute = getAttribute(attribute_name);
if (attribute.type != AttributeUnderlyingType::String)
throw Exception{
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH
};
getItems<StringRef>(attribute, key_columns,
[&] (const std::size_t row, const StringRef value) { out->insertData(value.data, value.size); },
[&] (const std::size_t row) { return def->getDataAt(row); });
}
#define DECLARE(TYPE)\
void get##TYPE(\
const std::string & attribute_name, const ConstColumnPlainPtrs & key_columns, const DataTypes & key_types,\
const TYPE def, PODArray<TYPE> & out) const\
{\
dict_struct.validateKeyTypes(key_types);\
\
const auto & attribute = getAttribute(attribute_name);\
if (attribute.type != AttributeUnderlyingType::TYPE)\
throw Exception{\
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),\
ErrorCodes::TYPE_MISMATCH\
};\
\
getItems<TYPE>(attribute, key_columns,\
[&] (const std::size_t row, const auto value) { out[row] = value; },\
[&] (const std::size_t) { return def; });\
}
DECLARE(UInt8)
DECLARE(UInt16)
DECLARE(UInt32)
DECLARE(UInt64)
DECLARE(Int8)
DECLARE(Int16)
DECLARE(Int32)
DECLARE(Int64)
DECLARE(Float32)
DECLARE(Float64)
#undef DECLARE
void getString(
const std::string & attribute_name, const ConstColumnPlainPtrs & key_columns, const DataTypes & key_types,
const String & def, ColumnString * const out) const
{
dict_struct.validateKeyTypes(key_types);
const auto & attribute = getAttribute(attribute_name);
if (attribute.type != AttributeUnderlyingType::String)
throw Exception{
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH
};
getItems<StringRef>(attribute, key_columns,
[&] (const std::size_t row, const StringRef value) { out->insertData(value.data, value.size); },
[&] (const std::size_t) { return StringRef{def}; });
}
void has(const ConstColumnPlainPtrs & key_columns, const DataTypes & key_types, PODArray<UInt8> & out) const
{
dict_struct.validateKeyTypes(key_types);
const auto & attribute = attributes.front();
switch (attribute.type)
{
case AttributeUnderlyingType::UInt8: has<UInt8>(attribute, key_columns, out); break;
case AttributeUnderlyingType::UInt16: has<UInt16>(attribute, key_columns, out); break;
case AttributeUnderlyingType::UInt32: has<UInt32>(attribute, key_columns, out); break;
case AttributeUnderlyingType::UInt64: has<UInt64>(attribute, key_columns, out); break;
case AttributeUnderlyingType::Int8: has<Int8>(attribute, key_columns, out); break;
case AttributeUnderlyingType::Int16: has<Int16>(attribute, key_columns, out); break;
case AttributeUnderlyingType::Int32: has<Int32>(attribute, key_columns, out); break;
case AttributeUnderlyingType::Int64: has<Int64>(attribute, key_columns, out); break;
case AttributeUnderlyingType::Float32: has<Float32>(attribute, key_columns, out); break;
case AttributeUnderlyingType::Float64: has<Float64>(attribute, key_columns, out); break;
case AttributeUnderlyingType::String: has<StringRef>(attribute, key_columns, out); break;
}
}
private:
template <typename Value> using ContainerType = HashMapWithSavedHash<StringRef, Value, StringRefHash>;
template <typename Value> using ContainerPtrType = std::unique_ptr<ContainerType<Value>>;
struct attribute_t final
{
AttributeUnderlyingType type;
std::tuple<
UInt8, UInt16, UInt32, UInt64,
Int8, Int16, Int32, Int64,
Float32, Float64,
String> null_values;
std::tuple<
ContainerPtrType<UInt8>, ContainerPtrType<UInt16>, ContainerPtrType<UInt32>, ContainerPtrType<UInt64>,
ContainerPtrType<Int8>, ContainerPtrType<Int16>, ContainerPtrType<Int32>, ContainerPtrType<Int64>,
ContainerPtrType<Float32>, ContainerPtrType<Float64>,
ContainerPtrType<StringRef>> maps;
std::unique_ptr<Arena> string_arena;
};
void createAttributes()
{
const auto size = dict_struct.attributes.size();
attributes.reserve(size);
for (const auto & attribute : dict_struct.attributes)
{
attribute_index_by_name.emplace(attribute.name, attributes.size());
attributes.push_back(createAttributeWithType(attribute.underlying_type, attribute.null_value));
if (attribute.hierarchical)
throw Exception{
name + ": hierarchical attributes not supported for dictionary of type " + getTypeName(),
ErrorCodes::TYPE_MISMATCH
};
}
}
void loadData()
{
auto stream = source_ptr->loadAll();
stream->readPrefix();
/// created upfront to avoid excess allocations
const auto keys_size = dict_struct.key->size();
StringRefs keys(keys_size);
const auto attributes_size = attributes.size();
while (const auto block = stream->read())
{
const auto rows = block.rowsInFirstColumn();
element_count += rows;
const auto key_column_ptrs = ext::map<ConstColumnPlainPtrs>(ext::range(0, keys_size),
[&] (const std::size_t attribute_idx) {
return block.getByPosition(attribute_idx).column.get();
});
const auto attribute_column_ptrs = ext::map<ConstColumnPlainPtrs>(ext::range(0, attributes_size),
[&] (const std::size_t attribute_idx) {
return block.getByPosition(keys_size + attribute_idx).column.get();
});
for (const auto row_idx : ext::range(0, rows))
{
/// calculate key once per row
const auto key = placeKeysInPool(row_idx, key_column_ptrs, keys, keys_pool);
auto should_rollback = false;
for (const auto attribute_idx : ext::range(0, attributes_size))
{
const auto & attribute_column = *attribute_column_ptrs[attribute_idx];
auto & attribute = attributes[attribute_idx];
const auto inserted = setAttributeValue(attribute, key, attribute_column[row_idx]);
if (!inserted)
should_rollback = true;
}
/// @note on multiple equal keys the mapped value for the first one is stored
if (should_rollback)
keys_pool.rollback(key.size);
}
}
stream->readSuffix();
if (require_nonempty && 0 == element_count)
throw Exception{
name + ": dictionary source is empty and 'require_nonempty' property is set.",
ErrorCodes::DICTIONARY_IS_EMPTY
};
}
template <typename T>
void addAttributeSize(const attribute_t & attribute)
{
const auto & map_ref = std::get<ContainerPtrType<T>>(attribute.maps);
bytes_allocated += sizeof(ContainerType<T>) + map_ref->getBufferSizeInBytes();
bucket_count = map_ref->getBufferSizeInCells();
}
void calculateBytesAllocated()
{
bytes_allocated += attributes.size() * sizeof(attributes.front());
for (const auto & attribute : attributes)
{
switch (attribute.type)
{
case AttributeUnderlyingType::UInt8: addAttributeSize<UInt8>(attribute); break;
case AttributeUnderlyingType::UInt16: addAttributeSize<UInt16>(attribute); break;
case AttributeUnderlyingType::UInt32: addAttributeSize<UInt32>(attribute); break;
case AttributeUnderlyingType::UInt64: addAttributeSize<UInt64>(attribute); break;
case AttributeUnderlyingType::Int8: addAttributeSize<Int8>(attribute); break;
case AttributeUnderlyingType::Int16: addAttributeSize<Int16>(attribute); break;
case AttributeUnderlyingType::Int32: addAttributeSize<Int32>(attribute); break;
case AttributeUnderlyingType::Int64: addAttributeSize<Int64>(attribute); break;
case AttributeUnderlyingType::Float32: addAttributeSize<Float32>(attribute); break;
case AttributeUnderlyingType::Float64: addAttributeSize<Float64>(attribute); break;
case AttributeUnderlyingType::String:
{
addAttributeSize<StringRef>(attribute);
bytes_allocated += sizeof(Arena) + attribute.string_arena->size();
break;
}
}
}
bytes_allocated += keys_pool.size();
}
template <typename T>
void createAttributeImpl(attribute_t & attribute, const Field & null_value)
{
std::get<T>(attribute.null_values) = null_value.get<typename NearestFieldType<T>::Type>();
std::get<ContainerPtrType<T>>(attribute.maps) = std::make_unique<ContainerType<T>>();
}
attribute_t createAttributeWithType(const AttributeUnderlyingType type, const Field & null_value)
{
attribute_t attr{type};
switch (type)
{
case AttributeUnderlyingType::UInt8: createAttributeImpl<UInt8>(attr, null_value); break;
case AttributeUnderlyingType::UInt16: createAttributeImpl<UInt16>(attr, null_value); break;
case AttributeUnderlyingType::UInt32: createAttributeImpl<UInt32>(attr, null_value); break;
case AttributeUnderlyingType::UInt64: createAttributeImpl<UInt64>(attr, null_value); break;
case AttributeUnderlyingType::Int8: createAttributeImpl<Int8>(attr, null_value); break;
case AttributeUnderlyingType::Int16: createAttributeImpl<Int16>(attr, null_value); break;
case AttributeUnderlyingType::Int32: createAttributeImpl<Int32>(attr, null_value); break;
case AttributeUnderlyingType::Int64: createAttributeImpl<Int64>(attr, null_value); break;
case AttributeUnderlyingType::Float32: createAttributeImpl<Float32>(attr, null_value); break;
case AttributeUnderlyingType::Float64: createAttributeImpl<Float64>(attr, null_value); break;
case AttributeUnderlyingType::String:
{
std::get<String>(attr.null_values) = null_value.get<String>();
std::get<ContainerPtrType<StringRef>>(attr.maps) = std::make_unique<ContainerType<StringRef>>();
attr.string_arena = std::make_unique<Arena>();
break;
}
}
return attr;
}
template <typename T, typename ValueSetter, typename DefaultGetter>
void getItems(
const attribute_t & attribute, const ConstColumnPlainPtrs & key_columns, ValueSetter && set_value,
DefaultGetter && get_default) const
{
const auto & attr = *std::get<ContainerPtrType<T>>(attribute.maps);
const auto keys_size = key_columns.size();
StringRefs keys(keys_size);
Arena temporary_keys_pool;
const auto rows = key_columns.front()->size();
for (const auto i : ext::range(0, rows))
{
/// copy key data to arena so it is contiguous and return StringRef to it
const auto key = placeKeysInPool(i, key_columns, keys, temporary_keys_pool);
const auto it = attr.find(key);
set_value(i, it != attr.end() ? it->second : get_default(i));
/// free memory allocated for the key
temporary_keys_pool.rollback(key.size);
}
query_count.fetch_add(rows, std::memory_order_relaxed);
}
template <typename T>
bool setAttributeValueImpl(attribute_t & attribute, const StringRef key, const T value)
{
auto & map = *std::get<ContainerPtrType<T>>(attribute.maps);
const auto pair = map.insert({ key, value });
return pair.second;
}
bool setAttributeValue(attribute_t & attribute, const StringRef key, const Field & value)
{
switch (attribute.type)
{
case AttributeUnderlyingType::UInt8: return setAttributeValueImpl<UInt8>(attribute, key, value.get<UInt64>());
case AttributeUnderlyingType::UInt16: return setAttributeValueImpl<UInt16>(attribute, key, value.get<UInt64>());
case AttributeUnderlyingType::UInt32: return setAttributeValueImpl<UInt32>(attribute, key, value.get<UInt64>());
case AttributeUnderlyingType::UInt64: return setAttributeValueImpl<UInt64>(attribute, key, value.get<UInt64>());
case AttributeUnderlyingType::Int8: return setAttributeValueImpl<Int8>(attribute, key, value.get<Int64>());
case AttributeUnderlyingType::Int16: return setAttributeValueImpl<Int16>(attribute, key, value.get<Int64>());
case AttributeUnderlyingType::Int32: return setAttributeValueImpl<Int32>(attribute, key, value.get<Int64>());
case AttributeUnderlyingType::Int64: return setAttributeValueImpl<Int64>(attribute, key, value.get<Int64>());
case AttributeUnderlyingType::Float32: return setAttributeValueImpl<Float32>(attribute, key, value.get<Float64>());
case AttributeUnderlyingType::Float64: return setAttributeValueImpl<Float64>(attribute, key, value.get<Float64>());
case AttributeUnderlyingType::String:
{
auto & map = *std::get<ContainerPtrType<StringRef>>(attribute.maps);
const auto & string = value.get<String>();
const auto string_in_arena = attribute.string_arena->insert(string.data(), string.size());
const auto pair = map.insert({ key, StringRef{string_in_arena, string.size()} });
return pair.second;
}
}
return {};
}
const attribute_t & getAttribute(const std::string & attribute_name) const
{
const auto it = attribute_index_by_name.find(attribute_name);
if (it == std::end(attribute_index_by_name))
throw Exception{
name + ": no such attribute '" + attribute_name + "'",
ErrorCodes::BAD_ARGUMENTS
};
return attributes[it->second];
}
static StringRef placeKeysInPool(
const std::size_t row, const ConstColumnPlainPtrs & key_columns, StringRefs & keys, Arena & pool)
{
const auto keys_size = key_columns.size();
size_t sum_keys_size{};
for (const auto i : ext::range(0, keys_size))
{
keys[i] = key_columns[i]->getDataAtWithTerminatingZero(row);
sum_keys_size += keys[i].size;
}
const auto res = pool.alloc(sum_keys_size);
auto place = res;
for (size_t j = 0; j < keys_size; ++j)
{
memcpy(place, keys[j].data, keys[j].size);
place += keys[j].size;
}
return { res, sum_keys_size };
}
template <typename T>
void has(const attribute_t & attribute, const ConstColumnPlainPtrs & key_columns, PODArray<UInt8> & out) const
{
const auto & attr = *std::get<ContainerPtrType<T>>(attribute.maps);
const auto keys_size = key_columns.size();
StringRefs keys(keys_size);
Arena temporary_keys_pool;
const auto rows = key_columns.front()->size();
for (const auto i : ext::range(0, rows))
{
/// copy key data to arena so it is contiguous and return StringRef to it
const auto key = placeKeysInPool(i, key_columns, keys, temporary_keys_pool);
const auto it = attr.find(key);
out[i] = it != attr.end();
/// free memory allocated for the key
temporary_keys_pool.rollback(key.size);
}
query_count.fetch_add(rows, std::memory_order_relaxed);
}
const std::string name;
const DictionaryStructure dict_struct;
const DictionarySourcePtr source_ptr;
const DictionaryLifetime dict_lifetime;
const bool require_nonempty;
const std::string key_description{dict_struct.getKeyDescription()};
std::map<std::string, std::size_t> attribute_index_by_name;
std::vector<attribute_t> attributes;
Arena keys_pool;
std::size_t bytes_allocated = 0;
std::size_t element_count = 0;
std::size_t bucket_count = 0;
mutable std::atomic<std::size_t> query_count{0};
std::chrono::time_point<std::chrono::system_clock> creation_time;
std::exception_ptr creation_exception;
};
}

View File

@ -20,14 +20,29 @@ namespace
Block createSampleBlock(const DictionaryStructure & dict_struct) Block createSampleBlock(const DictionaryStructure & dict_struct)
{ {
Block block{ Block block;
ColumnWithTypeAndName{new ColumnUInt64{1}, new DataTypeUInt64, dict_struct.id.name}
}; if (dict_struct.id)
block.insert(ColumnWithTypeAndName{
new ColumnUInt64{1}, new DataTypeUInt64, dict_struct.id->name
});
if (dict_struct.key)
{
for (const auto & attribute : *dict_struct.key)
{
auto column = attribute.type->createColumn();
column->insertDefault();
block.insert(ColumnWithTypeAndName{column, attribute.type, attribute.name});
}
}
if (dict_struct.range_min) if (dict_struct.range_min)
for (const auto & attribute : { dict_struct.range_min, dict_struct.range_max }) for (const auto & attribute : { dict_struct.range_min, dict_struct.range_max })
block.insert( block.insert(ColumnWithTypeAndName{
ColumnWithTypeAndName{new ColumnUInt16{1}, new DataTypeDate, attribute->name}); new ColumnUInt16{1}, new DataTypeDate, attribute->name
});
for (const auto & attribute : dict_struct.attributes) for (const auto & attribute : dict_struct.attributes)
{ {

View File

@ -6,6 +6,7 @@
#include <DB/IO/WriteBuffer.h> #include <DB/IO/WriteBuffer.h>
#include <DB/IO/WriteHelpers.h> #include <DB/IO/WriteHelpers.h>
#include <Poco/Util/AbstractConfiguration.h> #include <Poco/Util/AbstractConfiguration.h>
#include <ext/range.hpp>
#include <vector> #include <vector>
#include <string> #include <string>
#include <map> #include <map>
@ -136,20 +137,36 @@ struct DictionarySpecialAttribute final
/// Name of identifier plus list of attributes /// Name of identifier plus list of attributes
struct DictionaryStructure final struct DictionaryStructure final
{ {
DictionarySpecialAttribute id; std::experimental::optional<DictionarySpecialAttribute> id;
std::experimental::optional<std::vector<DictionaryAttribute>> key;
std::vector<DictionaryAttribute> attributes; std::vector<DictionaryAttribute> attributes;
std::experimental::optional<DictionarySpecialAttribute> range_min; std::experimental::optional<DictionarySpecialAttribute> range_min;
std::experimental::optional<DictionarySpecialAttribute> range_max; std::experimental::optional<DictionarySpecialAttribute> range_max;
bool has_expressions = false; bool has_expressions = false;
DictionaryStructure(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix) DictionaryStructure(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix)
: id{config, config_prefix + ".id"}
{ {
if (id.name.empty()) const auto has_id = config.has(config_prefix + ".id");
throw Exception{ const auto has_key = config.has(config_prefix + ".key");
"No 'id' specified for dictionary",
ErrorCodes::BAD_ARGUMENTS if (has_key && has_id)
}; throw Exception{"Only one of 'id' and 'key' should be specified", ErrorCodes::BAD_ARGUMENTS};
if (has_id)
id.emplace(config, config_prefix + ".id");
else if (has_key)
{
key.emplace(getAttributes(config, config_prefix + ".key", false, false));
if (key->empty())
throw Exception{"Empty 'key' supplied", ErrorCodes::BAD_ARGUMENTS};
}
else
throw Exception{"Dictionary structure should specify either 'id' or 'key'", ErrorCodes::BAD_ARGUMENTS};
if (id)
{
if (id->name.empty())
throw Exception{"'id' cannot be empty", ErrorCodes::BAD_ARGUMENTS};
if (config.has(config_prefix + ".range_min")) if (config.has(config_prefix + ".range_min"))
range_min.emplace(config, config_prefix + ".range_min"); range_min.emplace(config, config_prefix + ".range_min");
@ -157,14 +174,93 @@ struct DictionaryStructure final
if (config.has(config_prefix + ".range_max")) if (config.has(config_prefix + ".range_max"))
range_max.emplace(config, config_prefix + ".range_max"); range_max.emplace(config, config_prefix + ".range_max");
if (!id.expression.empty() || if (!id->expression.empty() ||
(range_min && !range_min->expression.empty()) || (range_max && !range_max->expression.empty())) (range_min && !range_min->expression.empty()) ||
(range_max && !range_max->expression.empty()))
has_expressions = true; has_expressions = true;
}
attributes = getAttributes(config, config_prefix);
if (attributes.empty())
throw Exception{"Dictionary has no attributes defined", ErrorCodes::BAD_ARGUMENTS};
}
void validateKeyTypes(const DataTypes & key_types) const
{
if (key_types.size() != key->size())
throw Exception{
"Key structure does not match, expected " + getKeyDescription(),
ErrorCodes::TYPE_MISMATCH
};
for (const auto i : ext::range(0, key_types.size()))
{
const auto & expected_type = (*key)[i].type->getName();
const auto & actual_type = key_types[i]->getName();
if (expected_type != actual_type)
throw Exception{
"Key type at position " + std::to_string(i) + " does not match, expected " + expected_type +
", found " + actual_type,
ErrorCodes::TYPE_MISMATCH
};
}
}
std::string getKeyDescription() const
{
if (id)
return "UInt64";
std::ostringstream out;
out << '(';
auto first = true;
for (const auto & key_i : *key)
{
if (!first)
out << ", ";
first = false;
out << key_i.type->getName();
}
out << ')';
return out.str();
}
bool isKeySizeFixed() const
{
if (!key)
return true;
for (const auto key_i : * key)
if (key_i.underlying_type == AttributeUnderlyingType::String)
return false;
return true;
}
std::size_t getKeySize() const
{
return std::accumulate(std::begin(*key), std::end(*key), std::size_t{},
[] (const auto running_size, const auto & key_i) {return running_size + key_i.type->getSizeOfField(); });
}
private:
std::vector<DictionaryAttribute> getAttributes(
const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix,
const bool hierarchy_allowed = true, const bool allow_null_values = true)
{
Poco::Util::AbstractConfiguration::Keys keys; Poco::Util::AbstractConfiguration::Keys keys;
config.keys(config_prefix, keys); config.keys(config_prefix, keys);
auto has_hierarchy = false; auto has_hierarchy = false;
std::vector<DictionaryAttribute> attributes;
for (const auto & key : keys) for (const auto & key : keys)
{ {
if (0 != strncmp(key.data(), "attribute", strlen("attribute"))) if (0 != strncmp(key.data(), "attribute", strlen("attribute")))
@ -181,8 +277,10 @@ struct DictionaryStructure final
if (!expression.empty()) if (!expression.empty())
has_expressions = true; has_expressions = true;
const auto null_value_string = config.getString(prefix + "null_value");
Field null_value; Field null_value;
if (allow_null_values)
{
const auto null_value_string = config.getString(prefix + "null_value");
try try
{ {
ReadBufferFromString null_value_buffer{null_value_string}; ReadBufferFromString null_value_buffer{null_value_string};
@ -195,6 +293,7 @@ struct DictionaryStructure final
ErrorCodes::BAD_ARGUMENTS ErrorCodes::BAD_ARGUMENTS
}; };
} }
}
const auto hierarchical = config.getBool(prefix + "hierarchical", false); const auto hierarchical = config.getBool(prefix + "hierarchical", false);
const auto injective = config.getBool(prefix + "injective", false); const auto injective = config.getBool(prefix + "injective", false);
@ -204,6 +303,12 @@ struct DictionaryStructure final
ErrorCodes::BAD_ARGUMENTS ErrorCodes::BAD_ARGUMENTS
}; };
if (has_hierarchy && !hierarchy_allowed)
throw Exception{
"Hierarchy not allowed in '" + prefix,
ErrorCodes::BAD_ARGUMENTS
};
if (has_hierarchy && hierarchical) if (has_hierarchy && hierarchical)
throw Exception{ throw Exception{
"Only one hierarchical attribute supported", "Only one hierarchical attribute supported",
@ -217,11 +322,7 @@ struct DictionaryStructure final
}); });
} }
if (attributes.empty()) return attributes;
throw Exception{
"Dictionary has no attributes defined",
ErrorCodes::BAD_ARGUMENTS
};
} }
}; };

View File

@ -41,10 +41,13 @@ public:
BlockInputStreamPtr loadIds(const std::vector<std::uint64_t> & ids) override BlockInputStreamPtr loadIds(const std::vector<std::uint64_t> & ids) override
{ {
throw Exception{ throw Exception{"Method unsupported", ErrorCodes::NOT_IMPLEMENTED};
"Method unsupported", }
ErrorCodes::NOT_IMPLEMENTED
}; BlockInputStreamPtr loadKeys(
const ConstColumnPlainPtrs & key_columns, const std::vector<std::size_t> & requested_rows) override
{
throw Exception{"Method unsupported", ErrorCodes::NOT_IMPLEMENTED};
} }
bool isModified() const override { return getLastModification() > last_modification; } bool isModified() const override { return getLastModification() > last_modification; }

View File

@ -6,10 +6,12 @@
#include <DB/Columns/ColumnString.h> #include <DB/Columns/ColumnString.h>
#include <DB/Common/Arena.h> #include <DB/Common/Arena.h>
#include <ext/range.hpp> #include <ext/range.hpp>
#include <ext/size.hpp>
#include <atomic> #include <atomic>
#include <vector> #include <vector>
#include <tuple> #include <tuple>
namespace DB namespace DB
{ {
@ -84,11 +86,15 @@ public:
void toParent(const PODArray<id_t> & ids, PODArray<id_t> & out) const override void toParent(const PODArray<id_t> & ids, PODArray<id_t> & out) const override
{ {
getItems<UInt64>(*hierarchical_attribute, ids, out); const auto null_value = std::get<UInt64>(hierarchical_attribute->null_values);
getItems<UInt64>(*hierarchical_attribute, ids,
[&] (const std::size_t row, const UInt64 value) { out[row] = value; },
[&] (const std::size_t) { return null_value; });
} }
#define DECLARE_MULTIPLE_GETTER(TYPE)\ #define DECLARE(TYPE)\
void get##TYPE(const std::string & attribute_name, const PODArray<id_t> & ids, PODArray<TYPE> & out) const override\ void get##TYPE(const std::string & attribute_name, const PODArray<id_t> & ids, PODArray<TYPE> & out) const\
{\ {\
const auto & attribute = getAttribute(attribute_name);\ const auto & attribute = getAttribute(attribute_name);\
if (attribute.type != AttributeUnderlyingType::TYPE)\ if (attribute.type != AttributeUnderlyingType::TYPE)\
@ -97,20 +103,24 @@ public:
ErrorCodes::TYPE_MISMATCH\ ErrorCodes::TYPE_MISMATCH\
};\ };\
\ \
getItems<TYPE>(attribute, ids, out);\ const auto null_value = std::get<TYPE>(attribute.null_values);\
\
getItems<TYPE>(attribute, ids,\
[&] (const std::size_t row, const auto value) { out[row] = value; },\
[&] (const std::size_t) { return null_value; });\
} }
DECLARE_MULTIPLE_GETTER(UInt8) DECLARE(UInt8)
DECLARE_MULTIPLE_GETTER(UInt16) DECLARE(UInt16)
DECLARE_MULTIPLE_GETTER(UInt32) DECLARE(UInt32)
DECLARE_MULTIPLE_GETTER(UInt64) DECLARE(UInt64)
DECLARE_MULTIPLE_GETTER(Int8) DECLARE(Int8)
DECLARE_MULTIPLE_GETTER(Int16) DECLARE(Int16)
DECLARE_MULTIPLE_GETTER(Int32) DECLARE(Int32)
DECLARE_MULTIPLE_GETTER(Int64) DECLARE(Int64)
DECLARE_MULTIPLE_GETTER(Float32) DECLARE(Float32)
DECLARE_MULTIPLE_GETTER(Float64) DECLARE(Float64)
#undef DECLARE_MULTIPLE_GETTER #undef DECLARE
void getString(const std::string & attribute_name, const PODArray<id_t> & ids, ColumnString * out) const override void getString(const std::string & attribute_name, const PODArray<id_t> & ids, ColumnString * out) const
{ {
const auto & attribute = getAttribute(attribute_name); const auto & attribute = getAttribute(attribute_name);
if (attribute.type != AttributeUnderlyingType::String) if (attribute.type != AttributeUnderlyingType::String)
@ -119,38 +129,136 @@ public:
ErrorCodes::TYPE_MISMATCH ErrorCodes::TYPE_MISMATCH
}; };
const auto & attr = *std::get<std::unique_ptr<PODArray<StringRef>>>(attribute.arrays); const auto & null_value = StringRef{std::get<String>(attribute.null_values)};
const auto & null_value = std::get<String>(attribute.null_values);
for (const auto i : ext::range(0, ids.size())) getItems<StringRef>(attribute, ids,
{ [&] (const std::size_t row, const StringRef value) { out->insertData(value.data, value.size); },
const auto id = ids[i]; [&] (const std::size_t) { return null_value; });
const auto string_ref = id < attr.size() ? attr[id] : StringRef{null_value};
out->insertData(string_ref.data, string_ref.size);
} }
query_count.fetch_add(ids.size(), std::memory_order_relaxed); #define DECLARE(TYPE)\
void get##TYPE(\
const std::string & attribute_name, const PODArray<id_t> & ids, const PODArray<TYPE> & def,\
PODArray<TYPE> & out) const\
{\
const auto & attribute = getAttribute(attribute_name);\
if (attribute.type != AttributeUnderlyingType::TYPE)\
throw Exception{\
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),\
ErrorCodes::TYPE_MISMATCH\
};\
\
getItems<TYPE>(attribute, ids,\
[&] (const std::size_t row, const auto value) { out[row] = value; },\
[&] (const std::size_t row) { return def[row]; });\
}
DECLARE(UInt8)
DECLARE(UInt16)
DECLARE(UInt32)
DECLARE(UInt64)
DECLARE(Int8)
DECLARE(Int16)
DECLARE(Int32)
DECLARE(Int64)
DECLARE(Float32)
DECLARE(Float64)
#undef DECLARE
void getString(
const std::string & attribute_name, const PODArray<id_t> & ids, const ColumnString * const def,
ColumnString * const out) const
{
const auto & attribute = getAttribute(attribute_name);
if (attribute.type != AttributeUnderlyingType::String)
throw Exception{
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH
};
getItems<StringRef>(attribute, ids,
[&] (const std::size_t row, const StringRef value) { out->insertData(value.data, value.size); },
[&] (const std::size_t row) { return def->getDataAt(row); });
}
#define DECLARE(TYPE)\
void get##TYPE(\
const std::string & attribute_name, const PODArray<id_t> & ids, const TYPE def,\
PODArray<TYPE> & out) const\
{\
const auto & attribute = getAttribute(attribute_name);\
if (attribute.type != AttributeUnderlyingType::TYPE)\
throw Exception{\
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),\
ErrorCodes::TYPE_MISMATCH\
};\
\
getItems<TYPE>(attribute, ids,\
[&] (const std::size_t row, const auto value) { out[row] = value; },\
[&] (const std::size_t) { return def; });\
}
DECLARE(UInt8)
DECLARE(UInt16)
DECLARE(UInt32)
DECLARE(UInt64)
DECLARE(Int8)
DECLARE(Int16)
DECLARE(Int32)
DECLARE(Int64)
DECLARE(Float32)
DECLARE(Float64)
#undef DECLARE
void getString(
const std::string & attribute_name, const PODArray<id_t> & ids, const String & def,
ColumnString * const out) const
{
const auto & attribute = getAttribute(attribute_name);
if (attribute.type != AttributeUnderlyingType::String)
throw Exception{
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH
};
getItems<StringRef>(attribute, ids,
[&] (const std::size_t row, const StringRef value) { out->insertData(value.data, value.size); },
[&] (const std::size_t) { return StringRef{def}; });
}
void has(const PODArray<id_t> & ids, PODArray<UInt8> & out) const override
{
const auto & attribute = attributes.front();
switch (attribute.type)
{
case AttributeUnderlyingType::UInt8: has<UInt8>(attribute, ids, out); break;
case AttributeUnderlyingType::UInt16: has<UInt16>(attribute, ids, out); break;
case AttributeUnderlyingType::UInt32: has<UInt32>(attribute, ids, out); break;
case AttributeUnderlyingType::UInt64: has<UInt64>(attribute, ids, out); break;
case AttributeUnderlyingType::Int8: has<Int8>(attribute, ids, out); break;
case AttributeUnderlyingType::Int16: has<Int16>(attribute, ids, out); break;
case AttributeUnderlyingType::Int32: has<Int32>(attribute, ids, out); break;
case AttributeUnderlyingType::Int64: has<Int64>(attribute, ids, out); break;
case AttributeUnderlyingType::Float32: has<Float32>(attribute, ids, out); break;
case AttributeUnderlyingType::Float64: has<Float64>(attribute, ids, out); break;
case AttributeUnderlyingType::String: has<String>(attribute, ids, out); break;
}
} }
private: private:
template <typename Value> using ContainerType = PODArray<Value>;
template <typename Value> using ContainerPtrType = std::unique_ptr<ContainerType<Value>>;
struct attribute_t final struct attribute_t final
{ {
AttributeUnderlyingType type; AttributeUnderlyingType type;
std::tuple<UInt8, UInt16, UInt32, UInt64, std::tuple<
UInt8, UInt16, UInt32, UInt64,
Int8, Int16, Int32, Int64, Int8, Int16, Int32, Int64,
Float32, Float64, Float32, Float64,
String> null_values; String> null_values;
std::tuple<std::unique_ptr<PODArray<UInt8>>, std::tuple<
std::unique_ptr<PODArray<UInt16>>, ContainerPtrType<UInt8>, ContainerPtrType<UInt16>, ContainerPtrType<UInt32>, ContainerPtrType<UInt64>,
std::unique_ptr<PODArray<UInt32>>, ContainerPtrType<Int8>, ContainerPtrType<Int16>, ContainerPtrType<Int32>, ContainerPtrType<Int64>,
std::unique_ptr<PODArray<UInt64>>, ContainerPtrType<Float32>, ContainerPtrType<Float64>,
std::unique_ptr<PODArray<Int8>>, ContainerPtrType<StringRef>> arrays;
std::unique_ptr<PODArray<Int16>>,
std::unique_ptr<PODArray<Int32>>,
std::unique_ptr<PODArray<Int64>>,
std::unique_ptr<PODArray<Float32>>,
std::unique_ptr<PODArray<Float64>>,
std::unique_ptr<PODArray<StringRef>>> arrays;
std::unique_ptr<Arena> string_arena; std::unique_ptr<Arena> string_arena;
}; };
@ -210,7 +318,7 @@ private:
template <typename T> template <typename T>
void addAttributeSize(const attribute_t & attribute) void addAttributeSize(const attribute_t & attribute)
{ {
const auto & array_ref = std::get<std::unique_ptr<PODArray<T>>>(attribute.arrays); const auto & array_ref = std::get<ContainerPtrType<T>>(attribute.arrays);
bytes_allocated += sizeof(PODArray<T>) + array_ref->storage_size(); bytes_allocated += sizeof(PODArray<T>) + array_ref->storage_size();
bucket_count = array_ref->capacity(); bucket_count = array_ref->capacity();
} }
@ -249,8 +357,8 @@ private:
{ {
const auto & null_value_ref = std::get<T>(attribute.null_values) = const auto & null_value_ref = std::get<T>(attribute.null_values) =
null_value.get<typename NearestFieldType<T>::Type>(); null_value.get<typename NearestFieldType<T>::Type>();
std::get<std::unique_ptr<PODArray<T>>>(attribute.arrays) = std::get<ContainerPtrType<T>>(attribute.arrays) =
std::make_unique<PODArray<T>>(initial_array_size, null_value_ref); std::make_unique<ContainerType<T>>(initial_array_size, null_value_ref);
} }
attribute_t createAttributeWithType(const AttributeUnderlyingType type, const Field & null_value) attribute_t createAttributeWithType(const AttributeUnderlyingType type, const Field & null_value)
@ -272,8 +380,8 @@ private:
case AttributeUnderlyingType::String: case AttributeUnderlyingType::String:
{ {
const auto & null_value_ref = std::get<String>(attr.null_values) = null_value.get<String>(); const auto & null_value_ref = std::get<String>(attr.null_values) = null_value.get<String>();
std::get<std::unique_ptr<PODArray<StringRef>>>(attr.arrays) = std::get<ContainerPtrType<StringRef>>(attr.arrays) =
std::make_unique<PODArray<StringRef>>(initial_array_size, null_value_ref); std::make_unique<ContainerType<StringRef>>(initial_array_size, StringRef{null_value_ref});
attr.string_arena = std::make_unique<Arena>(); attr.string_arena = std::make_unique<Arena>();
break; break;
} }
@ -282,25 +390,29 @@ private:
return attr; return attr;
} }
template <typename T> template <typename T, typename ValueSetter, typename DefaultGetter>
void getItems(const attribute_t & attribute, const PODArray<id_t> & ids, PODArray<T> & out) const void getItems(
const attribute_t & attribute, const PODArray<id_t> & ids, ValueSetter && set_value,
DefaultGetter && get_default) const
{ {
const auto & attr = *std::get<std::unique_ptr<PODArray<T>>>(attribute.arrays); const auto & attr = *std::get<ContainerPtrType<T>>(attribute.arrays);
const auto null_value = std::get<T>(attribute.null_values); const auto rows = ext::size(ids);
using null_value_type = std::conditional_t<std::is_same<T, StringRef>::value, String, T>;
const auto null_value = std::get<null_value_type>(attribute.null_values);
for (const auto i : ext::range(0, ids.size())) for (const auto row : ext::range(0, rows))
{ {
const auto id = ids[i]; const auto id = ids[row];
out[i] = id < attr.size() ? attr[id] : null_value; set_value(row, id < ext::size(attr) && attr[id] != null_value ? attr[id] : get_default(row));
} }
query_count.fetch_add(ids.size(), std::memory_order_relaxed); query_count.fetch_add(rows, std::memory_order_relaxed);
} }
template <typename T> template <typename T>
void setAttributeValueImpl(attribute_t & attribute, const id_t id, const T value) void setAttributeValueImpl(attribute_t & attribute, const id_t id, const T value)
{ {
auto & array = *std::get<std::unique_ptr<PODArray<T>>>(attribute.arrays); auto & array = *std::get<ContainerPtrType<T>>(attribute.arrays);
if (id >= array.size()) if (id >= array.size())
array.resize_fill(id + 1, std::get<T>(attribute.null_values)); array.resize_fill(id + 1, std::get<T>(attribute.null_values));
array[id] = value; array[id] = value;
@ -328,9 +440,9 @@ private:
case AttributeUnderlyingType::Float64: setAttributeValueImpl<Float64>(attribute, id, value.get<Float64>()); break; case AttributeUnderlyingType::Float64: setAttributeValueImpl<Float64>(attribute, id, value.get<Float64>()); break;
case AttributeUnderlyingType::String: case AttributeUnderlyingType::String:
{ {
auto & array = *std::get<std::unique_ptr<PODArray<StringRef>>>(attribute.arrays); auto & array = *std::get<ContainerPtrType<StringRef>>(attribute.arrays);
if (id >= array.size()) if (id >= array.size())
array.resize_fill(id + 1, std::get<String>(attribute.null_values)); array.resize_fill(id + 1, StringRef{std::get<String>(attribute.null_values)});
const auto & string = value.get<String>(); const auto & string = value.get<String>();
const auto string_in_arena = attribute.string_arena->insert(string.data(), string.size()); const auto string_in_arena = attribute.string_arena->insert(string.data(), string.size());
array[id] = StringRef{string_in_arena, string.size()}; array[id] = StringRef{string_in_arena, string.size()};
@ -351,6 +463,23 @@ private:
return attributes[it->second]; return attributes[it->second];
} }
template <typename T>
void has(const attribute_t & attribute, const PODArray<id_t> & ids, PODArray<UInt8> & out) const
{
using stored_type = std::conditional_t<std::is_same<T, String>::value, StringRef, T>;
const auto & attr = *std::get<ContainerPtrType<stored_type>>(attribute.arrays);
const auto & null_value = std::get<T>(attribute.null_values);
const auto rows = ext::size(ids);
for (const auto i : ext::range(0, rows))
{
const auto id = ids[i];
out[i] = id < ext::size(attr) && attr[id] != null_value;
}
query_count.fetch_add(rows, std::memory_order_relaxed);
}
const std::string name; const std::string name;
const DictionaryStructure dict_struct; const DictionaryStructure dict_struct;
const DictionarySourcePtr source_ptr; const DictionarySourcePtr source_ptr;

View File

@ -19,8 +19,7 @@ class HashedDictionary final : public IDictionary
public: public:
HashedDictionary(const std::string & name, const DictionaryStructure & dict_struct, HashedDictionary(const std::string & name, const DictionaryStructure & dict_struct,
DictionarySourcePtr source_ptr, const DictionaryLifetime dict_lifetime, bool require_nonempty) DictionarySourcePtr source_ptr, const DictionaryLifetime dict_lifetime, bool require_nonempty)
: name{name}, dict_struct(dict_struct), : name{name}, dict_struct(dict_struct), source_ptr{std::move(source_ptr)}, dict_lifetime(dict_lifetime),
source_ptr{std::move(source_ptr)}, dict_lifetime(dict_lifetime),
require_nonempty(require_nonempty) require_nonempty(require_nonempty)
{ {
createAttributes(); createAttributes();
@ -82,11 +81,15 @@ public:
void toParent(const PODArray<id_t> & ids, PODArray<id_t> & out) const override void toParent(const PODArray<id_t> & ids, PODArray<id_t> & out) const override
{ {
getItems<UInt64>(*hierarchical_attribute, ids, out); const auto null_value = std::get<UInt64>(hierarchical_attribute->null_values);
getItems<UInt64>(*hierarchical_attribute, ids,
[&] (const std::size_t row, const UInt64 value) { out[row] = value; },
[&] (const std::size_t) { return null_value; });
} }
#define DECLARE_MULTIPLE_GETTER(TYPE)\ #define DECLARE(TYPE)\
void get##TYPE(const std::string & attribute_name, const PODArray<id_t> & ids, PODArray<TYPE> & out) const override\ void get##TYPE(const std::string & attribute_name, const PODArray<id_t> & ids, PODArray<TYPE> & out) const\
{\ {\
const auto & attribute = getAttribute(attribute_name);\ const auto & attribute = getAttribute(attribute_name);\
if (attribute.type != AttributeUnderlyingType::TYPE)\ if (attribute.type != AttributeUnderlyingType::TYPE)\
@ -95,20 +98,24 @@ public:
ErrorCodes::TYPE_MISMATCH\ ErrorCodes::TYPE_MISMATCH\
};\ };\
\ \
getItems<TYPE>(attribute, ids, out);\ const auto null_value = std::get<TYPE>(attribute.null_values);\
\
getItems<TYPE>(attribute, ids,\
[&] (const std::size_t row, const auto value) { out[row] = value; },\
[&] (const std::size_t) { return null_value; });\
} }
DECLARE_MULTIPLE_GETTER(UInt8) DECLARE(UInt8)
DECLARE_MULTIPLE_GETTER(UInt16) DECLARE(UInt16)
DECLARE_MULTIPLE_GETTER(UInt32) DECLARE(UInt32)
DECLARE_MULTIPLE_GETTER(UInt64) DECLARE(UInt64)
DECLARE_MULTIPLE_GETTER(Int8) DECLARE(Int8)
DECLARE_MULTIPLE_GETTER(Int16) DECLARE(Int16)
DECLARE_MULTIPLE_GETTER(Int32) DECLARE(Int32)
DECLARE_MULTIPLE_GETTER(Int64) DECLARE(Int64)
DECLARE_MULTIPLE_GETTER(Float32) DECLARE(Float32)
DECLARE_MULTIPLE_GETTER(Float64) DECLARE(Float64)
#undef DECLARE_MULTIPLE_GETTER #undef DECLARE
void getString(const std::string & attribute_name, const PODArray<id_t> & ids, ColumnString * out) const override void getString(const std::string & attribute_name, const PODArray<id_t> & ids, ColumnString * out) const
{ {
const auto & attribute = getAttribute(attribute_name); const auto & attribute = getAttribute(attribute_name);
if (attribute.type != AttributeUnderlyingType::String) if (attribute.type != AttributeUnderlyingType::String)
@ -117,38 +124,135 @@ public:
ErrorCodes::TYPE_MISMATCH ErrorCodes::TYPE_MISMATCH
}; };
const auto & attr = *std::get<std::unique_ptr<HashMap<UInt64, StringRef>>>(attribute.maps); const auto & null_value = StringRef{std::get<String>(attribute.null_values)};
const auto & null_value = std::get<String>(attribute.null_values);
for (const auto i : ext::range(0, ids.size())) getItems<StringRef>(attribute, ids,
{ [&] (const std::size_t row, const StringRef value) { out->insertData(value.data, value.size); },
const auto it = attr.find(ids[i]); [&] (const std::size_t) { return null_value; });
const auto string_ref = it != attr.end() ? it->second : StringRef{null_value};
out->insertData(string_ref.data, string_ref.size);
} }
query_count.fetch_add(ids.size(), std::memory_order_relaxed); #define DECLARE(TYPE)\
void get##TYPE(\
const std::string & attribute_name, const PODArray<id_t> & ids, const PODArray<TYPE> & def,\
PODArray<TYPE> & out) const\
{\
const auto & attribute = getAttribute(attribute_name);\
if (attribute.type != AttributeUnderlyingType::TYPE)\
throw Exception{\
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),\
ErrorCodes::TYPE_MISMATCH\
};\
\
getItems<TYPE>(attribute, ids,\
[&] (const std::size_t row, const auto value) { out[row] = value; },\
[&] (const std::size_t row) { return def[row]; });\
}
DECLARE(UInt8)
DECLARE(UInt16)
DECLARE(UInt32)
DECLARE(UInt64)
DECLARE(Int8)
DECLARE(Int16)
DECLARE(Int32)
DECLARE(Int64)
DECLARE(Float32)
DECLARE(Float64)
#undef DECLARE
void getString(
const std::string & attribute_name, const PODArray<id_t> & ids, const ColumnString * const def,
ColumnString * const out) const
{
const auto & attribute = getAttribute(attribute_name);
if (attribute.type != AttributeUnderlyingType::String)
throw Exception{
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH
};
getItems<StringRef>(attribute, ids,
[&] (const std::size_t row, const StringRef value) { out->insertData(value.data, value.size); },
[&] (const std::size_t row) { return def->getDataAt(row); });
}
#define DECLARE(TYPE)\
void get##TYPE(\
const std::string & attribute_name, const PODArray<id_t> & ids, const TYPE & def, PODArray<TYPE> & out) const\
{\
const auto & attribute = getAttribute(attribute_name);\
if (attribute.type != AttributeUnderlyingType::TYPE)\
throw Exception{\
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),\
ErrorCodes::TYPE_MISMATCH\
};\
\
getItems<TYPE>(attribute, ids,\
[&] (const std::size_t row, const auto value) { out[row] = value; },\
[&] (const std::size_t) { return def; });\
}
DECLARE(UInt8)
DECLARE(UInt16)
DECLARE(UInt32)
DECLARE(UInt64)
DECLARE(Int8)
DECLARE(Int16)
DECLARE(Int32)
DECLARE(Int64)
DECLARE(Float32)
DECLARE(Float64)
#undef DECLARE
void getString(
const std::string & attribute_name, const PODArray<id_t> & ids, const String & def,
ColumnString * const out) const
{
const auto & attribute = getAttribute(attribute_name);
if (attribute.type != AttributeUnderlyingType::String)
throw Exception{
name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH
};
getItems<StringRef>(attribute, ids,
[&] (const std::size_t row, const StringRef value) { out->insertData(value.data, value.size); },
[&] (const std::size_t) { return StringRef{def}; });
}
void has(const PODArray<id_t> & ids, PODArray<UInt8> & out) const override
{
const auto & attribute = attributes.front();
switch (attribute.type)
{
case AttributeUnderlyingType::UInt8: has<UInt8>(attribute, ids, out); break;
case AttributeUnderlyingType::UInt16: has<UInt16>(attribute, ids, out); break;
case AttributeUnderlyingType::UInt32: has<UInt32>(attribute, ids, out); break;
case AttributeUnderlyingType::UInt64: has<UInt64>(attribute, ids, out); break;
case AttributeUnderlyingType::Int8: has<Int8>(attribute, ids, out); break;
case AttributeUnderlyingType::Int16: has<Int16>(attribute, ids, out); break;
case AttributeUnderlyingType::Int32: has<Int32>(attribute, ids, out); break;
case AttributeUnderlyingType::Int64: has<Int64>(attribute, ids, out); break;
case AttributeUnderlyingType::Float32: has<Float32>(attribute, ids, out); break;
case AttributeUnderlyingType::Float64: has<Float64>(attribute, ids, out); break;
case AttributeUnderlyingType::String: has<StringRef>(attribute, ids, out); break;
}
} }
private: private:
template <typename Value> using CollectionType = HashMap<UInt64, Value>;
template <typename Value> using CollectionPtrType = std::unique_ptr<CollectionType<Value>>;
struct attribute_t final struct attribute_t final
{ {
AttributeUnderlyingType type; AttributeUnderlyingType type;
std::tuple<UInt8, UInt16, UInt32, UInt64, std::tuple<
UInt8, UInt16, UInt32, UInt64,
Int8, Int16, Int32, Int64, Int8, Int16, Int32, Int64,
Float32, Float64, Float32, Float64,
String> null_values; String> null_values;
std::tuple<std::unique_ptr<HashMap<UInt64, UInt8>>, std::tuple<
std::unique_ptr<HashMap<UInt64, UInt16>>, CollectionPtrType<UInt8>, CollectionPtrType<UInt16>, CollectionPtrType<UInt32>, CollectionPtrType<UInt64>,
std::unique_ptr<HashMap<UInt64, UInt32>>, CollectionPtrType<Int8>, CollectionPtrType<Int16>, CollectionPtrType<Int32>, CollectionPtrType<Int64>,
std::unique_ptr<HashMap<UInt64, UInt64>>, CollectionPtrType<Float32>, CollectionPtrType<Float64>,
std::unique_ptr<HashMap<UInt64, Int8>>, CollectionPtrType<StringRef>> maps;
std::unique_ptr<HashMap<UInt64, Int16>>,
std::unique_ptr<HashMap<UInt64, Int32>>,
std::unique_ptr<HashMap<UInt64, Int64>>,
std::unique_ptr<HashMap<UInt64, Float32>>,
std::unique_ptr<HashMap<UInt64, Float64>>,
std::unique_ptr<HashMap<UInt64, StringRef>>> maps;
std::unique_ptr<Arena> string_arena; std::unique_ptr<Arena> string_arena;
}; };
@ -208,8 +312,8 @@ private:
template <typename T> template <typename T>
void addAttributeSize(const attribute_t & attribute) void addAttributeSize(const attribute_t & attribute)
{ {
const auto & map_ref = std::get<std::unique_ptr<HashMap<UInt64, T>>>(attribute.maps); const auto & map_ref = std::get<CollectionPtrType<T>>(attribute.maps);
bytes_allocated += sizeof(HashMap<UInt64, T>) + map_ref->getBufferSizeInBytes(); bytes_allocated += sizeof(CollectionType<T>) + map_ref->getBufferSizeInBytes();
bucket_count = map_ref->getBufferSizeInCells(); bucket_count = map_ref->getBufferSizeInCells();
} }
@ -246,7 +350,7 @@ private:
void createAttributeImpl(attribute_t & attribute, const Field & null_value) void createAttributeImpl(attribute_t & attribute, const Field & null_value)
{ {
std::get<T>(attribute.null_values) = null_value.get<typename NearestFieldType<T>::Type>(); std::get<T>(attribute.null_values) = null_value.get<typename NearestFieldType<T>::Type>();
std::get<std::unique_ptr<HashMap<UInt64, T>>>(attribute.maps) = std::make_unique<HashMap<UInt64, T>>(); std::get<CollectionPtrType<T>>(attribute.maps) = std::make_unique<CollectionType<T>>();
} }
attribute_t createAttributeWithType(const AttributeUnderlyingType type, const Field & null_value) attribute_t createAttributeWithType(const AttributeUnderlyingType type, const Field & null_value)
@ -268,8 +372,7 @@ private:
case AttributeUnderlyingType::String: case AttributeUnderlyingType::String:
{ {
std::get<String>(attr.null_values) = null_value.get<String>(); std::get<String>(attr.null_values) = null_value.get<String>();
std::get<std::unique_ptr<HashMap<UInt64, StringRef>>>(attr.maps) = std::get<CollectionPtrType<StringRef>>(attr.maps) = std::make_unique<CollectionType<StringRef>>();
std::make_unique<HashMap<UInt64, StringRef>>();
attr.string_arena = std::make_unique<Arena>(); attr.string_arena = std::make_unique<Arena>();
break; break;
} }
@ -278,25 +381,27 @@ private:
return attr; return attr;
} }
template <typename T> template <typename T, typename ValueSetter, typename DefaultGetter>
void getItems(const attribute_t & attribute, const PODArray<id_t> & ids, PODArray<T> & out) const void getItems(
const attribute_t & attribute, const PODArray<id_t> & ids, ValueSetter && set_value,
DefaultGetter && get_default) const
{ {
const auto & attr = *std::get<std::unique_ptr<HashMap<UInt64, T>>>(attribute.maps); const auto & attr = *std::get<CollectionPtrType<T>>(attribute.maps);
const auto null_value = std::get<T>(attribute.null_values); const auto rows = ext::size(ids);
for (const auto i : ext::range(0, ids.size())) for (const auto i : ext::range(0, rows))
{ {
const auto it = attr.find(ids[i]); const auto it = attr.find(ids[i]);
out[i] = it != attr.end() ? it->second : null_value; set_value(i, it != attr.end() ? it->second : get_default(i));
} }
query_count.fetch_add(ids.size(), std::memory_order_relaxed); query_count.fetch_add(rows, std::memory_order_relaxed);
} }
template <typename T> template <typename T>
void setAttributeValueImpl(attribute_t & attribute, const id_t id, const T value) void setAttributeValueImpl(attribute_t & attribute, const id_t id, const T value)
{ {
auto & map = *std::get<std::unique_ptr<HashMap<UInt64, T>>>(attribute.maps); auto & map = *std::get<CollectionPtrType<T>>(attribute.maps);
map.insert({ id, value }); map.insert({ id, value });
} }
@ -316,7 +421,7 @@ private:
case AttributeUnderlyingType::Float64: setAttributeValueImpl<Float64>(attribute, id, value.get<Float64>()); break; case AttributeUnderlyingType::Float64: setAttributeValueImpl<Float64>(attribute, id, value.get<Float64>()); break;
case AttributeUnderlyingType::String: case AttributeUnderlyingType::String:
{ {
auto & map = *std::get<std::unique_ptr<HashMap<UInt64, StringRef>>>(attribute.maps); auto & map = *std::get<CollectionPtrType<StringRef>>(attribute.maps);
const auto & string = value.get<String>(); const auto & string = value.get<String>();
const auto string_in_arena = attribute.string_arena->insert(string.data(), string.size()); const auto string_in_arena = attribute.string_arena->insert(string.data(), string.size());
map.insert({ id, StringRef{string_in_arena, string.size()} }); map.insert({ id, StringRef{string_in_arena, string.size()} });
@ -337,6 +442,18 @@ private:
return attributes[it->second]; return attributes[it->second];
} }
template <typename T>
void has(const attribute_t & attribute, const PODArray<id_t> & ids, PODArray<UInt8> & out) const
{
const auto & attr = *std::get<CollectionPtrType<T>>(attribute.maps);
const auto rows = ext::size(ids);
for (const auto i : ext::range(0, rows))
out[i] = attr.find(ids[i]) != std::end(attr);
query_count.fetch_add(rows, std::memory_order_relaxed);
}
const std::string name; const std::string name;
const DictionaryStructure dict_struct; const DictionaryStructure dict_struct;
const DictionarySourcePtr source_ptr; const DictionarySourcePtr source_ptr;

View File

@ -59,6 +59,10 @@ struct IDictionary : IDictionaryBase
{ {
virtual bool hasHierarchy() const = 0; virtual bool hasHierarchy() const = 0;
virtual void toParent(const PODArray<id_t> & ids, PODArray<id_t> & out) const = 0;
virtual void has(const PODArray<id_t> & ids, PODArray<UInt8> & out) const = 0;
/// do not call unless you ensure that hasHierarchy() returns true /// do not call unless you ensure that hasHierarchy() returns true
id_t toParent(id_t id) const id_t toParent(id_t id) const
{ {
@ -70,8 +74,6 @@ struct IDictionary : IDictionaryBase
return out.front(); return out.front();
} }
virtual void toParent(const PODArray<id_t> & ids, PODArray<id_t> & out) const = 0;
bool in(id_t child_id, const id_t ancestor_id) const bool in(id_t child_id, const id_t ancestor_id) const
{ {
while (child_id != 0 && child_id != ancestor_id) while (child_id != 0 && child_id != ancestor_id)
@ -79,19 +81,6 @@ struct IDictionary : IDictionaryBase
return child_id != 0; return child_id != 0;
} }
/// return mapped values for a collection of identifiers
virtual void getUInt8(const std::string & attr_name, const PODArray<id_t> & ids, PODArray<UInt8> & out) const = 0;
virtual void getUInt16(const std::string & attr_name, const PODArray<id_t> & ids, PODArray<UInt16> & out) const = 0;
virtual void getUInt32(const std::string & attr_name, const PODArray<id_t> & ids, PODArray<UInt32> & out) const = 0;
virtual void getUInt64(const std::string & attr_name, const PODArray<id_t> & ids, PODArray<UInt64> & out) const = 0;
virtual void getInt8(const std::string & attr_name, const PODArray<id_t> & ids, PODArray<Int8> & out) const = 0;
virtual void getInt16(const std::string & attr_name, const PODArray<id_t> & ids, PODArray<Int16> & out) const = 0;
virtual void getInt32(const std::string & attr_name, const PODArray<id_t> & ids, PODArray<Int32> & out) const = 0;
virtual void getInt64(const std::string & attr_name, const PODArray<id_t> & ids, PODArray<Int64> & out) const = 0;
virtual void getFloat32(const std::string & attr_name, const PODArray<id_t> & ids, PODArray<Float32> & out) const = 0;
virtual void getFloat64(const std::string & attr_name, const PODArray<id_t> & ids, PODArray<Float64> & out) const = 0;
virtual void getString(const std::string & attr_name, const PODArray<id_t> & ids, ColumnString * out) const = 0;
}; };
} }

View File

@ -27,6 +27,11 @@ public:
/// returns an input stream with the data for a collection of identifiers /// returns an input stream with the data for a collection of identifiers
virtual BlockInputStreamPtr loadIds(const std::vector<std::uint64_t> & ids) = 0; virtual BlockInputStreamPtr loadIds(const std::vector<std::uint64_t> & ids) = 0;
/** returns an input stream with the data for a collection of composite keys.
* `requested_rows` contains indices of all rows containing unique keys. */
virtual BlockInputStreamPtr loadKeys(
const ConstColumnPlainPtrs & key_columns, const std::vector<std::size_t> & requested_rows) = 0;
/// indicates whether the source has been modified since last load* operation /// indicates whether the source has been modified since last load* operation
virtual bool isModified() const = 0; virtual bool isModified() const = 0;

View File

@ -4,6 +4,9 @@
#include <DB/Dictionaries/MongoDBBlockInputStream.h> #include <DB/Dictionaries/MongoDBBlockInputStream.h>
#include <Poco/Util/AbstractConfiguration.h> #include <Poco/Util/AbstractConfiguration.h>
#include <mongo/client/dbclient.h> #include <mongo/client/dbclient.h>
#include <ext/collection_cast.hpp>
#include <ext/enumerate.hpp>
#include <ext/size.hpp>
namespace DB namespace DB
@ -12,6 +15,8 @@ namespace DB
/// Allows loading dictionaries from a MongoDB collection /// Allows loading dictionaries from a MongoDB collection
class MongoDBDictionarySource final : public IDictionarySource class MongoDBDictionarySource final : public IDictionarySource
{ {
static const auto max_block_size = 8192;
MongoDBDictionarySource( MongoDBDictionarySource(
const DictionaryStructure & dict_struct, const std::string & host, const std::string & port, const DictionaryStructure & dict_struct, const std::string & host, const std::string & port,
const std::string & user, const std::string & password, const std::string & user, const std::string & password,
@ -89,7 +94,7 @@ public:
{ {
return new MongoDBBlockInputStream{ return new MongoDBBlockInputStream{
connection.query(db + '.' + collection, {}, 0, 0, &fields_to_query), connection.query(db + '.' + collection, {}, 0, 0, &fields_to_query),
sample_block, 8192 sample_block, max_block_size
}; };
} }
@ -97,13 +102,62 @@ public:
BlockInputStreamPtr loadIds(const std::vector<std::uint64_t> & ids) override BlockInputStreamPtr loadIds(const std::vector<std::uint64_t> & ids) override
{ {
if (!dict_struct.id)
throw Exception{"'id' is required for selective loading", ErrorCodes::UNSUPPORTED_METHOD};
/// mongo::BSONObj has shitty design and does not use fixed width integral types /// mongo::BSONObj has shitty design and does not use fixed width integral types
const std::vector<long long int> iids{std::begin(ids), std::end(ids)}; const auto query = BSON(
const auto ids_enumeration = BSON(dict_struct.id.name << BSON("$in" << iids)); dict_struct.id->name << BSON("$in" << ext::collection_cast<std::vector<long long int>>(ids)));
return new MongoDBBlockInputStream{ return new MongoDBBlockInputStream{
connection.query(db + '.' + collection, ids_enumeration, 0, 0, &fields_to_query), connection.query(db + '.' + collection, query, 0, 0, &fields_to_query),
sample_block, 8192 sample_block, max_block_size
};
}
BlockInputStreamPtr loadKeys(
const ConstColumnPlainPtrs & key_columns, const std::vector<std::size_t> & requested_rows) override
{
if (!dict_struct.key)
throw Exception{"'key' is required for selective loading", ErrorCodes::UNSUPPORTED_METHOD};
std::string query_string;
{
WriteBufferFromString out{query_string};
writeString("{$or:[", out);
auto first = true;
for (const auto row : requested_rows)
{
if (!first)
writeChar(',', out);
first = false;
writeChar('{', out);
for (const auto idx_key : ext::enumerate(*dict_struct.key))
{
if (idx_key.first != 0)
writeChar(',', out);
writeString(idx_key.second.name, out);
writeChar(':', out);
idx_key.second.type->serializeTextQuoted((*key_columns[idx_key.first])[row], out);
}
writeChar('}', out);
}
writeString("]}", out);
}
return new MongoDBBlockInputStream{
connection.query(db + '.' + collection, query_string, 0, 0, &fields_to_query),
sample_block, max_block_size
}; };
} }

View File

@ -58,6 +58,15 @@ public:
return new MySQLBlockInputStream{pool.Get(), query, sample_block, max_block_size}; return new MySQLBlockInputStream{pool.Get(), query, sample_block, max_block_size};
} }
BlockInputStreamPtr loadKeys(
const ConstColumnPlainPtrs & key_columns, const std::vector<std::size_t> & requested_rows) override
{
/// Здесь не логгируем и не обновляем время модификации, так как запрос может быть большим, и часто задаваться.
const auto query = composeLoadKeysQuery(key_columns, requested_rows);
return new MySQLBlockInputStream{pool.Get(), query, sample_block, max_block_size};
}
bool isModified() const override bool isModified() const override
{ {
if (dont_check_update_time) if (dont_check_update_time)
@ -156,13 +165,15 @@ private:
WriteBufferFromString out{query}; WriteBufferFromString out{query};
writeString("SELECT ", out); writeString("SELECT ", out);
if (!dict_struct.id.expression.empty()) if (dict_struct.id)
{ {
writeParenthesisedString(dict_struct.id.expression, out); if (!dict_struct.id->expression.empty())
{
writeParenthesisedString(dict_struct.id->expression, out);
writeString(" AS ", out); writeString(" AS ", out);
} }
writeProbablyBackQuotedString(dict_struct.id.name, out); writeProbablyBackQuotedString(dict_struct.id->name, out);
if (dict_struct.range_min && dict_struct.range_max) if (dict_struct.range_min && dict_struct.range_max)
{ {
@ -186,6 +197,26 @@ private:
writeProbablyBackQuotedString(dict_struct.range_max->name, out); writeProbablyBackQuotedString(dict_struct.range_max->name, out);
} }
}
else if (dict_struct.key)
{
auto first = true;
for (const auto & key : *dict_struct.key)
{
if (!first)
writeString(", ", out);
first = false;
if (!key.expression.empty())
{
writeParenthesisedString(key.expression, out);
writeString(" AS ", out);
}
writeProbablyBackQuotedString(key.name, out);
}
}
for (const auto & attr : dict_struct.attributes) for (const auto & attr : dict_struct.attributes)
{ {
@ -222,19 +253,22 @@ private:
std::string composeLoadIdsQuery(const std::vector<std::uint64_t> & ids) std::string composeLoadIdsQuery(const std::vector<std::uint64_t> & ids)
{ {
if (!dict_struct.id)
throw Exception{"Simple key required for method", ErrorCodes::UNSUPPORTED_METHOD};
std::string query; std::string query;
{ {
WriteBufferFromString out{query}; WriteBufferFromString out{query};
writeString("SELECT ", out); writeString("SELECT ", out);
if (!dict_struct.id.expression.empty()) if (!dict_struct.id->expression.empty())
{ {
writeParenthesisedString(dict_struct.id.expression, out); writeParenthesisedString(dict_struct.id->expression, out);
writeString(" AS ", out); writeString(" AS ", out);
} }
writeProbablyBackQuotedString(dict_struct.id.name, out); writeProbablyBackQuotedString(dict_struct.id->name, out);
for (const auto & attr : dict_struct.attributes) for (const auto & attr : dict_struct.attributes)
{ {
@ -265,7 +299,7 @@ private:
writeString(" AND ", out); writeString(" AND ", out);
} }
writeProbablyBackQuotedString(dict_struct.id.name, out); writeProbablyBackQuotedString(dict_struct.id->name, out);
writeString(" IN (", out); writeString(" IN (", out);
auto first = true; auto first = true;
@ -284,6 +318,92 @@ private:
return query; return query;
} }
std::string composeLoadKeysQuery(
const ConstColumnPlainPtrs & key_columns, const std::vector<std::size_t> & requested_rows)
{
if (!dict_struct.key)
throw Exception{"Composite key required for method", ErrorCodes::UNSUPPORTED_METHOD};
std::string query;
{
WriteBufferFromString out{query};
writeString("SELECT ", out);
auto first = true;
for (const auto & key_or_attribute : boost::join(*dict_struct.key, dict_struct.attributes))
{
if (!first)
writeString(", ", out);
first = false;
if (!key_or_attribute.expression.empty())
{
writeParenthesisedString(key_or_attribute.expression, out);
writeString(" AS ", out);
}
writeProbablyBackQuotedString(key_or_attribute.name, out);
}
writeString(" FROM ", out);
if (!db.empty())
{
writeProbablyBackQuotedString(db, out);
writeChar('.', out);
}
writeProbablyBackQuotedString(table, out);
writeString(" WHERE ", out);
if (!where.empty())
{
writeString(where, out);
writeString(" AND ", out);
}
first = true;
for (const auto row : requested_rows)
{
if (!first)
writeString(" OR ", out);
first = false;
composeKeyCondition(key_columns, row, out);
}
writeString(";", out);
}
return query;
}
void composeKeyCondition(const ConstColumnPlainPtrs & key_columns, const std::size_t row, WriteBuffer & out) const
{
writeString("(", out);
const auto keys_size = key_columns.size();
auto first = true;
for (const auto i : ext::range(0, keys_size))
{
if (!first)
writeString(" AND ", out);
first = false;
const auto & key_description = (*dict_struct.key)[i];
const auto & value = (*key_columns[i])[row];
/// key_i=value_i
writeString(key_description.name, out);
writeString("=", out);
key_description.type->serializeTextQuoted(value, out);
}
writeString(")", out);
}
const DictionaryStructure dict_struct; const DictionaryStructure dict_struct;
const std::string db; const std::string db;
const std::string table; const std::string table;

View File

@ -117,7 +117,7 @@ public:
const auto val_it = std::find_if(std::begin(ranges_and_values), std::end(ranges_and_values), const auto val_it = std::find_if(std::begin(ranges_and_values), std::end(ranges_and_values),
[date] (const value_t<StringRef> & v) { return v.range.contains(date); }); [date] (const value_t<StringRef> & v) { return v.range.contains(date); });
const auto string_ref = val_it != std::end(ranges_and_values) ? val_it->value : null_value; const auto string_ref = val_it != std::end(ranges_and_values) ? val_it->value : StringRef{null_value};
out->insertData(string_ref.data, string_ref.size); out->insertData(string_ref.data, string_ref.size);
} }
else else

View File

@ -25,7 +25,8 @@ private:
public: public:
FunctionFactory(); FunctionFactory();
FunctionPtr get(const String & name, const Context & context) const; FunctionPtr get(const String & name, const Context & context) const; /// Кидает исключение, если не нашлось.
FunctionPtr tryGet(const String & name, const Context & context) const; /// Возвращает nullptr, если не нашлось.
template <typename F> void registerFunction() template <typename F> void registerFunction()
{ {

View File

@ -5,6 +5,7 @@
#include <DB/DataTypes/DataTypeDateTime.h> #include <DB/DataTypes/DataTypeDateTime.h>
#include <DB/Functions/IFunction.h> #include <DB/Functions/IFunction.h>
#include <DB/Functions/NumberTraits.h> #include <DB/Functions/NumberTraits.h>
#include <DB/Core/FieldVisitors.h>
namespace DB namespace DB
@ -713,6 +714,10 @@ public:
}; };
template <typename FunctionName>
struct FunctionUnaryArithmeticMonotonicity;
template <template <typename> class Op, typename Name> template <template <typename> class Op, typename Name>
class FunctionUnaryArithmetic : public IFunction class FunctionUnaryArithmetic : public IFunction
{ {
@ -815,6 +820,16 @@ public:
+ " of argument of function " + getName(), + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN); ErrorCodes::ILLEGAL_COLUMN);
} }
bool hasInformationAboutMonotonicity() const override
{
return FunctionUnaryArithmeticMonotonicity<Name>::has();
}
Monotonicity getMonotonicityForRange(const IDataType & type, const Field & left, const Field & right) const override
{
return FunctionUnaryArithmeticMonotonicity<Name>::get(left, right);
}
}; };
@ -854,6 +869,41 @@ typedef FunctionBinaryArithmetic<BitShiftRightImpl, NameBitShiftRight> Functi
typedef FunctionBinaryArithmetic<LeastImpl, NameLeast> FunctionLeast; typedef FunctionBinaryArithmetic<LeastImpl, NameLeast> FunctionLeast;
typedef FunctionBinaryArithmetic<GreatestImpl, NameGreatest> FunctionGreatest; typedef FunctionBinaryArithmetic<GreatestImpl, NameGreatest> FunctionGreatest;
/// Свойства монотонности для некоторых функций.
template <> struct FunctionUnaryArithmeticMonotonicity<NameNegate>
{
static bool has() { return true; }
static IFunction::Monotonicity get(const Field & left, const Field & right)
{
return { true, false };
}
};
template <> struct FunctionUnaryArithmeticMonotonicity<NameAbs>
{
static bool has() { return true; }
static IFunction::Monotonicity get(const Field & left, const Field & right)
{
Float64 left_float = left.isNull() ? -std::numeric_limits<Float64>::infinity() : apply_visitor(FieldVisitorConvertToNumber<Float64>(), left);
Float64 right_float = right.isNull() ? std::numeric_limits<Float64>::infinity() : apply_visitor(FieldVisitorConvertToNumber<Float64>(), right);
if ((left_float < 0 && right_float > 0) || (left_float > 0 && right_float < 0))
return {};
return { true, (left_float > 0) };
}
};
template <> struct FunctionUnaryArithmeticMonotonicity<NameBitNot>
{
static bool has() { return false; }
static IFunction::Monotonicity get(const Field & left, const Field & right)
{
return {};
}
};
/// Оптимизации для целочисленного деления на константу. /// Оптимизации для целочисленного деления на константу.

View File

@ -11,7 +11,9 @@
#include <DB/Columns/ColumnFixedString.h> #include <DB/Columns/ColumnFixedString.h>
#include <DB/Columns/ColumnConst.h> #include <DB/Columns/ColumnConst.h>
#include <DB/Functions/IFunction.h> #include <DB/Functions/IFunction.h>
#include <DB/Core/FieldVisitors.h>
#include <ext/range.hpp> #include <ext/range.hpp>
#include <type_traits>
namespace DB namespace DB
@ -996,7 +998,7 @@ struct ConvertImpl<DataTypeFixedString, DataTypeString, Name>
/// Предварительное объявление. /// Предварительное объявление.
struct NameToDate { static constexpr auto name = "toDate"; }; struct NameToDate { static constexpr auto name = "toDate"; };
template <typename ToDataType, typename Name> template <typename ToDataType, typename Name, typename Monotonic>
class FunctionConvert : public IFunction class FunctionConvert : public IFunction
{ {
public: public:
@ -1039,6 +1041,16 @@ public:
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
} }
bool hasInformationAboutMonotonicity() const override
{
return Monotonic::has();
}
Monotonicity getMonotonicityForRange(const IDataType & type, const Field & left, const Field & right) const override
{
return Monotonic::get(type, left, right);
}
private: private:
template<typename ToDataType2 = ToDataType, typename Name2 = Name> template<typename ToDataType2 = ToDataType, typename Name2 = Name>
DataTypePtr getReturnTypeImpl(const DataTypes & arguments, DataTypePtr getReturnTypeImpl(const DataTypes & arguments,
@ -1260,6 +1272,113 @@ private:
}; };
/// Монотонность.
struct PositiveMonotonicity
{
static bool has() { return true; }
static IFunction::Monotonicity get(const IDataType & type, const Field & left, const Field & right)
{
return { true };
}
};
template <typename T>
struct ToIntMonotonicity
{
static bool has() { return true; }
template <typename T2 = T>
static UInt64 divideByRangeOfType(typename std::enable_if_t<sizeof(T2) != sizeof(UInt64), UInt64> x) { return x >> (sizeof(T) * 8); };
template <typename T2 = T>
static UInt64 divideByRangeOfType(typename std::enable_if_t<sizeof(T2) == sizeof(UInt64), UInt64> x) { return 0; };
static IFunction::Monotonicity get(const IDataType & type, const Field & left, const Field & right)
{
size_t size_of_type = type.getSizeOfField();
/// Если тип расширяется, то функция монотонна.
if (sizeof(T) > size_of_type)
return { true };
/// Если тип совпадает - тоже.
if (typeid_cast<const typename DataTypeFromFieldType<T>::Type *>(&type))
return { true };
/// В других случаях, для неограниченного диапазона не знаем, будет ли функция монотонной.
if (left.isNull() || right.isNull())
return {};
/// Если преобразуем из float, то аргументы должны помещаться в тип результата.
if (typeid_cast<const DataTypeFloat32 *>(&type)
|| typeid_cast<const DataTypeFloat64 *>(&type))
{
Float64 left_float = left.get<Float64>();
Float64 right_float = right.get<Float64>();
if (left_float >= std::numeric_limits<T>::min() && left_float <= std::numeric_limits<T>::max()
&& right_float >= std::numeric_limits<T>::min() && right_float <= std::numeric_limits<T>::max())
return { true };
return {};
}
/// Если меняем знаковость типа или преобразуем из даты, даты-времени, то аргумент должен быть из одной половинки.
/// На всякий случай, в остальных случаях тоже будем этого требовать.
if ((left.get<Int64>() >= 0) != (right.get<Int64>() >= 0))
return {};
/// Если уменьшаем тип, то все биты кроме тех, которые в него помещаются, должны совпадать.
if (divideByRangeOfType(left.get<UInt64>()) != divideByRangeOfType(right.get<UInt64>()))
return {};
return { true };
}
};
/** Монотонность для функции toString определяем, в основном, для тестовых целей.
* Всерьёз вряд ли кто-нибудь рассчитывает на оптимизацию запросов с условиями toString(CounterID) = 34.
*/
struct ToStringMonotonicity
{
static bool has() { return true; }
static IFunction::Monotonicity get(const IDataType & type, const Field & left, const Field & right)
{
IFunction::Monotonicity positive = { .is_monotonic = true, .is_positive = true };
IFunction::Monotonicity not_monotonic;
/// Функция toString монотонна, если аргумент - Date или DateTime, или неотрицательные числа с одинаковым количеством знаков.
if (typeid_cast<const DataTypeDate *>(&type)
|| typeid_cast<const DataTypeDateTime *>(&type))
return positive;
if (left.isNull() || right.isNull())
return {};
if (left.getType() == Field::Types::UInt64
&& right.getType() == Field::Types::UInt64)
{
return (left.get<Int64>() == 0 && right.get<Int64>() == 0)
|| (floor(log10(left.get<UInt64>())) == floor(log10(right.get<UInt64>())))
? positive : not_monotonic;
}
if (left.getType() == Field::Types::Int64
&& right.getType() == Field::Types::Int64)
{
return (left.get<Int64>() == 0 && right.get<Int64>() == 0)
|| (left.get<Int64>() > 0 && right.get<Int64>() > 0 && floor(log10(left.get<Int64>())) == floor(log10(right.get<Int64>())))
? positive : not_monotonic;
}
return not_monotonic;
}
};
struct NameToUInt8 { static constexpr auto name = "toUInt8"; }; struct NameToUInt8 { static constexpr auto name = "toUInt8"; };
struct NameToUInt16 { static constexpr auto name = "toUInt16"; }; struct NameToUInt16 { static constexpr auto name = "toUInt16"; };
struct NameToUInt32 { static constexpr auto name = "toUInt32"; }; struct NameToUInt32 { static constexpr auto name = "toUInt32"; };
@ -1273,19 +1392,19 @@ struct NameToFloat64 { static constexpr auto name = "toFloat64"; };
struct NameToDateTime { static constexpr auto name = "toDateTime"; }; struct NameToDateTime { static constexpr auto name = "toDateTime"; };
struct NameToString { static constexpr auto name = "toString"; }; struct NameToString { static constexpr auto name = "toString"; };
typedef FunctionConvert<DataTypeUInt8, NameToUInt8> FunctionToUInt8; typedef FunctionConvert<DataTypeUInt8, NameToUInt8, ToIntMonotonicity<UInt8>> FunctionToUInt8;
typedef FunctionConvert<DataTypeUInt16, NameToUInt16> FunctionToUInt16; typedef FunctionConvert<DataTypeUInt16, NameToUInt16, ToIntMonotonicity<UInt16>> FunctionToUInt16;
typedef FunctionConvert<DataTypeUInt32, NameToUInt32> FunctionToUInt32; typedef FunctionConvert<DataTypeUInt32, NameToUInt32, ToIntMonotonicity<UInt32>> FunctionToUInt32;
typedef FunctionConvert<DataTypeUInt64, NameToUInt64> FunctionToUInt64; typedef FunctionConvert<DataTypeUInt64, NameToUInt64, ToIntMonotonicity<UInt64>> FunctionToUInt64;
typedef FunctionConvert<DataTypeInt8, NameToInt8> FunctionToInt8; typedef FunctionConvert<DataTypeInt8, NameToInt8, ToIntMonotonicity<Int8>> FunctionToInt8;
typedef FunctionConvert<DataTypeInt16, NameToInt16> FunctionToInt16; typedef FunctionConvert<DataTypeInt16, NameToInt16, ToIntMonotonicity<Int16>> FunctionToInt16;
typedef FunctionConvert<DataTypeInt32, NameToInt32> FunctionToInt32; typedef FunctionConvert<DataTypeInt32, NameToInt32, ToIntMonotonicity<Int32>> FunctionToInt32;
typedef FunctionConvert<DataTypeInt64, NameToInt64> FunctionToInt64; typedef FunctionConvert<DataTypeInt64, NameToInt64, ToIntMonotonicity<Int64>> FunctionToInt64;
typedef FunctionConvert<DataTypeFloat32, NameToFloat32> FunctionToFloat32; typedef FunctionConvert<DataTypeFloat32, NameToFloat32, PositiveMonotonicity> FunctionToFloat32;
typedef FunctionConvert<DataTypeFloat64, NameToFloat64> FunctionToFloat64; typedef FunctionConvert<DataTypeFloat64, NameToFloat64, PositiveMonotonicity> FunctionToFloat64;
typedef FunctionConvert<DataTypeDate, NameToDate> FunctionToDate; typedef FunctionConvert<DataTypeDate, NameToDate, ToIntMonotonicity<UInt16>> FunctionToDate;
typedef FunctionConvert<DataTypeDateTime, NameToDateTime> FunctionToDateTime; typedef FunctionConvert<DataTypeDateTime, NameToDateTime, ToIntMonotonicity<UInt32>> FunctionToDateTime;
typedef FunctionConvert<DataTypeString, NameToString> FunctionToString; typedef FunctionConvert<DataTypeString, NameToString, ToStringMonotonicity> FunctionToString;
typedef FunctionConvert<DataTypeInt32, NameToUnixTimestamp> FunctionToUnixTimestamp; typedef FunctionConvert<DataTypeInt32, NameToUnixTimestamp, ToIntMonotonicity<UInt32>> FunctionToUnixTimestamp;
} }

View File

@ -48,80 +48,94 @@ namespace DB
#define TIME_SLOT_SIZE 1800 #define TIME_SLOT_SIZE 1800
/** Всевозможные преобразования.
* Представляют собой две функции - от даты-с-временем (UInt32) и от даты (UInt16).
*
* Также для преобразования T определяется "фактор-преобразование" F.
* Это такое преобразование F, что его значение идентифицирует область монотонности T
* (при фиксированном значении F, преобразование T является монотонным).
*
* Или, образно, если T аналогично взятию остатка от деления, то F аналогично делению.
*
* Пример: для преобразования T "получить номер дня в месяце" (2015-02-03 -> 3),
* фактор-преобразованием F является "округлить до месяца" (2015-02-03 -> 2015-02-01).
*/
struct ToYearImpl /// Это фактор-преобразование будет говорить, что функция монотонна всюду.
struct ZeroTransform
{ {
static inline UInt16 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toYear(t); } static inline UInt16 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return 0; }
static inline UInt16 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toYear(DayNum_t(d)); } static inline UInt16 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return 0; }
}; };
struct ToMonthImpl struct ToDateImpl
{ {
static inline UInt8 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toMonth(t); } static inline UInt8 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
static inline UInt8 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toMonth(DayNum_t(d)); }
};
struct ToDayOfMonthImpl
{ {
static inline UInt8 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toDayOfMonth(t); } return remote_date_lut.toDate(t);
static inline UInt8 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toDayOfMonth(DayNum_t(d)); } }
};
struct ToDayOfWeekImpl
{
static inline UInt8 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toDayOfWeek(t); }
static inline UInt8 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toDayOfWeek(DayNum_t(d)); }
};
struct ToHourImpl
{
static inline UInt8 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toHourInaccurate(t); }
static inline UInt8 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) static inline UInt8 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{ {
throw Exception("Illegal type Date of argument for function toHour", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); return d;
} }
};
struct ToMinuteImpl using FactorTransform = ZeroTransform;
{
static inline UInt8 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toMinuteInaccurate(t); }
static inline UInt8 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
throw Exception("Illegal type Date of argument for function toMinute", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
};
struct ToSecondImpl
{
static inline UInt8 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toSecondInaccurate(t); }
static inline UInt8 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
throw Exception("Illegal type Date of argument for function toSecond", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
}; };
struct ToMondayImpl struct ToMondayImpl
{ {
static inline UInt16 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toFirstDayNumOfWeek(remote_date_lut.toDayNum(t)); } static inline UInt16 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
static inline UInt16 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toFirstDayNumOfWeek(DayNum_t(d)); } {
return remote_date_lut.toFirstDayNumOfWeek(remote_date_lut.toDayNum(t));
}
static inline UInt16 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toFirstDayNumOfWeek(DayNum_t(d));
}
using FactorTransform = ZeroTransform;
}; };
struct ToStartOfMonthImpl struct ToStartOfMonthImpl
{ {
static inline UInt16 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toFirstDayNumOfMonth(remote_date_lut.toDayNum(t)); } static inline UInt16 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
static inline UInt16 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toFirstDayNumOfMonth(DayNum_t(d)); } {
return remote_date_lut.toFirstDayNumOfMonth(remote_date_lut.toDayNum(t));
}
static inline UInt16 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toFirstDayNumOfMonth(DayNum_t(d));
}
using FactorTransform = ZeroTransform;
}; };
struct ToStartOfQuarterImpl struct ToStartOfQuarterImpl
{ {
static inline UInt16 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toFirstDayNumOfQuarter(remote_date_lut.toDayNum(t)); } static inline UInt16 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
static inline UInt16 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toFirstDayNumOfQuarter(DayNum_t(d)); } {
return remote_date_lut.toFirstDayNumOfQuarter(remote_date_lut.toDayNum(t));
}
static inline UInt16 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toFirstDayNumOfQuarter(DayNum_t(d));
}
using FactorTransform = ZeroTransform;
}; };
struct ToStartOfYearImpl struct ToStartOfYearImpl
{ {
static inline UInt16 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toFirstDayNumOfYear(remote_date_lut.toDayNum(t)); } static inline UInt16 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
static inline UInt16 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toFirstDayNumOfYear(DayNum_t(d)); } {
return remote_date_lut.toFirstDayNumOfYear(remote_date_lut.toDayNum(t));
}
static inline UInt16 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toFirstDayNumOfYear(DayNum_t(d));
}
using FactorTransform = ZeroTransform;
}; };
@ -148,87 +162,250 @@ struct ToTimeImpl
{ {
throw Exception("Illegal type Date of argument for function toTime", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); throw Exception("Illegal type Date of argument for function toTime", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
} }
using FactorTransform = ToDateImpl;
}; };
struct ToStartOfMinuteImpl struct ToStartOfMinuteImpl
{ {
static inline UInt32 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toStartOfMinuteInaccurate(t); } static inline UInt32 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toStartOfMinuteInaccurate(t);
}
static inline UInt32 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) static inline UInt32 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{ {
throw Exception("Illegal type Date of argument for function toStartOfMinute", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); throw Exception("Illegal type Date of argument for function toStartOfMinute", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
} }
using FactorTransform = ZeroTransform;
}; };
struct ToStartOfFiveMinuteImpl struct ToStartOfFiveMinuteImpl
{ {
static inline UInt32 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toStartOfFiveMinuteInaccurate(t); } static inline UInt32 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toStartOfFiveMinuteInaccurate(t);
}
static inline UInt32 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) static inline UInt32 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{ {
throw Exception("Illegal type Date of argument for function toStartOfFiveMinute", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); throw Exception("Illegal type Date of argument for function toStartOfFiveMinute", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
} }
using FactorTransform = ZeroTransform;
}; };
struct ToStartOfHourImpl struct ToStartOfHourImpl
{ {
static inline UInt32 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toStartOfHourInaccurate(t); } static inline UInt32 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toStartOfHourInaccurate(t);
}
static inline UInt32 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) static inline UInt32 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{ {
throw Exception("Illegal type Date of argument for function toStartOfHour", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); throw Exception("Illegal type Date of argument for function toStartOfHour", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
} }
using FactorTransform = ZeroTransform;
};
struct ToYearImpl
{
static inline UInt16 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toYear(t);
}
static inline UInt16 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toYear(DayNum_t(d));
}
using FactorTransform = ZeroTransform;
};
struct ToMonthImpl
{
static inline UInt8 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toMonth(t);
}
static inline UInt8 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toMonth(DayNum_t(d));
}
using FactorTransform = ToStartOfYearImpl;
};
struct ToDayOfMonthImpl
{
static inline UInt8 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toDayOfMonth(t);
}
static inline UInt8 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toDayOfMonth(DayNum_t(d));
}
using FactorTransform = ToStartOfMonthImpl;
};
struct ToDayOfWeekImpl
{
static inline UInt8 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toDayOfWeek(t);
}
static inline UInt8 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toDayOfWeek(DayNum_t(d));
}
using FactorTransform = ToMondayImpl;
};
struct ToHourImpl
{
static inline UInt8 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toHourInaccurate(t);
}
static inline UInt8 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
throw Exception("Illegal type Date of argument for function toHour", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
using FactorTransform = ToDateImpl;
};
struct ToMinuteImpl
{
static inline UInt8 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toMinuteInaccurate(t);
}
static inline UInt8 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
throw Exception("Illegal type Date of argument for function toMinute", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
using FactorTransform = ToStartOfHourImpl;
};
struct ToSecondImpl
{
static inline UInt8 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toSecondInaccurate(t);
}
static inline UInt8 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
throw Exception("Illegal type Date of argument for function toSecond", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
using FactorTransform = ToStartOfMinuteImpl;
}; };
struct ToRelativeYearNumImpl struct ToRelativeYearNumImpl
{ {
static inline UInt16 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toYear(t); } static inline UInt16 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
static inline UInt16 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toYear(DayNum_t(d)); } {
return remote_date_lut.toYear(t);
}
static inline UInt16 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toYear(DayNum_t(d));
}
using FactorTransform = ZeroTransform;
}; };
struct ToRelativeMonthNumImpl struct ToRelativeMonthNumImpl
{ {
static inline UInt16 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toRelativeMonthNum(t); } static inline UInt16 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
static inline UInt16 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toRelativeMonthNum(DayNum_t(d)); } {
return remote_date_lut.toRelativeMonthNum(t);
}
static inline UInt16 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toRelativeMonthNum(DayNum_t(d));
}
using FactorTransform = ZeroTransform;
}; };
struct ToRelativeWeekNumImpl struct ToRelativeWeekNumImpl
{ {
static inline UInt16 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toRelativeWeekNum(t); } static inline UInt16 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
static inline UInt16 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toRelativeWeekNum(DayNum_t(d)); } {
return remote_date_lut.toRelativeWeekNum(t);
}
static inline UInt16 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toRelativeWeekNum(DayNum_t(d));
}
using FactorTransform = ZeroTransform;
}; };
struct ToRelativeDayNumImpl struct ToRelativeDayNumImpl
{ {
static inline UInt16 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toDayNum(t); } static inline UInt16 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
static inline UInt16 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return static_cast<DayNum_t>(d); } {
return remote_date_lut.toDayNum(t);
}
static inline UInt16 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return static_cast<DayNum_t>(d);
}
using FactorTransform = ZeroTransform;
}; };
struct ToRelativeHourNumImpl struct ToRelativeHourNumImpl
{ {
static inline UInt32 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toRelativeHourNum(t); } static inline UInt32 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toRelativeHourNum(t);
}
static inline UInt32 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) static inline UInt32 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{ {
throw Exception("Illegal type Date of argument for function toRelativeHourNum", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); throw Exception("Illegal type Date of argument for function toRelativeHourNum", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
} }
using FactorTransform = ZeroTransform;
}; };
struct ToRelativeMinuteNumImpl struct ToRelativeMinuteNumImpl
{ {
static inline UInt32 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return remote_date_lut.toRelativeMinuteNum(t); } static inline UInt32 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return remote_date_lut.toRelativeMinuteNum(t);
}
static inline UInt32 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) static inline UInt32 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{ {
throw Exception("Illegal type Date of argument for function toRelativeMinuteNum", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); throw Exception("Illegal type Date of argument for function toRelativeMinuteNum", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
} }
using FactorTransform = ZeroTransform;
}; };
struct ToRelativeSecondNumImpl struct ToRelativeSecondNumImpl
{ {
static inline UInt32 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) { return t; } static inline UInt32 execute(UInt32 t, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{
return t;
}
static inline UInt32 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut) static inline UInt32 execute(UInt16 d, const DateLUTImpl & remote_date_lut, const DateLUTImpl & local_date_lut)
{ {
throw Exception("Illegal type Date of argument for function toRelativeSecondNum", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); throw Exception("Illegal type Date of argument for function toRelativeSecondNum", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
} }
using FactorTransform = ZeroTransform;
}; };
template<typename FromType, typename ToType, typename Transform> template<typename FromType, typename ToType, typename Transform>
struct Transformer struct Transformer
{ {
@ -441,6 +618,42 @@ public:
throw Exception("Illegal type " + block.getByPosition(arguments[0]).type->getName() + " of argument of function " + getName(), throw Exception("Illegal type " + block.getByPosition(arguments[0]).type->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
} }
bool hasInformationAboutMonotonicity() const override
{
return true;
}
Monotonicity getMonotonicityForRange(const IDataType & type, const Field & left, const Field & right) const override
{
IFunction::Monotonicity is_monotonic { true };
IFunction::Monotonicity is_not_monotonic;
if (std::is_same<typename Transform::FactorTransform, ZeroTransform>::value)
return is_monotonic;
/// Этот метод вызывается только если у функции один аргумент. Поэтому, нас пока не волнует не-локальная тайм-зона.
const DateLUTImpl & date_lut = DateLUT::instance();
if (left.isNull() || right.isNull())
return is_not_monotonic;
/// Функция монотонна на отрезке [left, right], если фактор-преобразование возвращает для них одинаковые значения.
if (typeid_cast<const DataTypeDate *>(&type))
{
return Transform::FactorTransform::execute(UInt16(left.get<UInt64>()), date_lut, date_lut)
== Transform::FactorTransform::execute(UInt16(right.get<UInt64>()), date_lut, date_lut)
? is_monotonic : is_not_monotonic;
}
else
{
return Transform::FactorTransform::execute(UInt32(left.get<UInt64>()), date_lut, date_lut)
== Transform::FactorTransform::execute(UInt32(right.get<UInt64>()), date_lut, date_lut)
? is_monotonic : is_not_monotonic;
}
}
}; };

View File

@ -5,6 +5,7 @@
#include <DB/DataTypes/DataTypeString.h> #include <DB/DataTypes/DataTypeString.h>
#include <DB/DataTypes/DataTypeDate.h> #include <DB/DataTypes/DataTypeDate.h>
#include <DB/DataTypes/DataTypeDateTime.h> #include <DB/DataTypes/DataTypeDateTime.h>
#include <DB/DataTypes/DataTypeTuple.h>
#include <DB/Columns/ColumnVector.h> #include <DB/Columns/ColumnVector.h>
#include <DB/Columns/ColumnArray.h> #include <DB/Columns/ColumnArray.h>
@ -18,6 +19,8 @@
#include <DB/Dictionaries/FlatDictionary.h> #include <DB/Dictionaries/FlatDictionary.h>
#include <DB/Dictionaries/HashedDictionary.h> #include <DB/Dictionaries/HashedDictionary.h>
#include <DB/Dictionaries/CacheDictionary.h> #include <DB/Dictionaries/CacheDictionary.h>
#include <DB/Dictionaries/ComplexKeyHashedDictionary.h>
#include <DB/Dictionaries/ComplexKeyCacheDictionary.h>
#include <DB/Dictionaries/RangeHashedDictionary.h> #include <DB/Dictionaries/RangeHashedDictionary.h>
#include <ext/range.hpp> #include <ext/range.hpp>
@ -739,6 +742,156 @@ public:
}; };
class FunctionDictHas final : public IFunction
{
public:
static constexpr auto name = "dictHas";
static IFunction * create(const Context & context)
{
return new FunctionDictHas{context.getExternalDictionaries()};
}
FunctionDictHas(const ExternalDictionaries & dictionaries) : dictionaries(dictionaries) {}
String getName() const override { return name; }
private:
DataTypePtr getReturnType(const DataTypes & arguments) const override
{
if (arguments.size() != 2)
throw Exception{
"Number of arguments for function " + getName() + " doesn't match: passed "
+ toString(arguments.size()) + ", should be 2.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH
};
if (!typeid_cast<const DataTypeString *>(arguments[0].get()))
throw Exception{
"Illegal type " + arguments[0]->getName() + " of first argument of function " + getName()
+ ", expected a string.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT
};
if (!typeid_cast<const DataTypeUInt64 *>(arguments[1].get()) &&
!typeid_cast<const DataTypeTuple *>(arguments[1].get()))
throw Exception{
"Illegal type " + arguments[1]->getName() + " of second argument of function " + getName()
+ ", must be UInt64 or tuple(...).",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT
};
return new DataTypeUInt8;
}
void execute(Block & block, const ColumnNumbers & arguments, const size_t result) override
{
const auto dict_name_col = typeid_cast<const ColumnConst<String> *>(block.getByPosition(arguments[0]).column.get());
if (!dict_name_col)
throw Exception{
"First argument of function " + getName() + " must be a constant string",
ErrorCodes::ILLEGAL_COLUMN
};
auto dict = dictionaries.getDictionary(dict_name_col->getData());
const auto dict_ptr = dict.get();
if (!executeDispatchSimple<FlatDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatchSimple<HashedDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatchSimple<CacheDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatchComplex<ComplexKeyHashedDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatchComplex<ComplexKeyCacheDictionary>(block, arguments, result, dict_ptr))
throw Exception{
"Unsupported dictionary type " + dict_ptr->getTypeName(),
ErrorCodes::UNKNOWN_TYPE
};
}
template <typename DictionaryType>
bool executeDispatchSimple(
Block & block, const ColumnNumbers & arguments, const size_t result, const IDictionaryBase * const dictionary)
{
const auto dict = typeid_cast<const DictionaryType *>(dictionary);
if (!dict)
return false;
if (arguments.size() != 2)
throw Exception{
"Function " + getName() + " for dictionary of type " + dict->getTypeName() +
" requires exactly 2 arguments",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH
};
const auto id_col_untyped = block.getByPosition(arguments[1]).column.get();
if (const auto id_col = typeid_cast<const ColumnVector<UInt64> *>(id_col_untyped))
{
const auto & ids = id_col->getData();
const auto out = new ColumnVector<UInt8>(ext::size(ids));
block.getByPosition(result).column = out;
dict->has(ids, out->getData());
}
else if (const auto id_col = typeid_cast<const ColumnConst<UInt64> *>(id_col_untyped))
{
const PODArray<UInt64> ids(1, id_col->getData());
PODArray<UInt8> out(1);
dict->has(ids, out);
block.getByPosition(result).column = new ColumnConst<UInt8>{id_col->size(), out.front()};
}
else
throw Exception{
"Second argument of function " + getName() + " must be UInt64",
ErrorCodes::ILLEGAL_COLUMN
};
return true;
}
template <typename DictionaryType>
bool executeDispatchComplex(
Block & block, const ColumnNumbers & arguments, const size_t result, const IDictionaryBase * const dictionary)
{
const auto dict = typeid_cast<const DictionaryType *>(dictionary);
if (!dict)
return false;
if (arguments.size() != 2)
throw Exception{
"Function " + getName() + " for dictionary of type " + dict->getTypeName() +
" requires exactly 2 arguments",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH
};
const auto key_col_with_type = block.getByPosition(arguments[1]);
if (const auto key_col = typeid_cast<const ColumnTuple *>(key_col_with_type.column.get()))
{
const auto key_columns = ext::map<ConstColumnPlainPtrs>(key_col->getColumns(), [] (const ColumnPtr & ptr) {
return ptr.get();
});
const auto & key_types = static_cast<const DataTypeTuple &>(*key_col_with_type.type).getElements();
const auto out = new ColumnVector<UInt8>(key_col->size());
block.getByPosition(result).column = out;
dict->has(key_columns, key_types, out->getData());
}
else
throw Exception{
"Second argument of function " + getName() + " must be " + dict->getKeyDescription(),
ErrorCodes::TYPE_MISMATCH
};
return true;
}
const ExternalDictionaries & dictionaries;
};
class FunctionDictGetString final : public IFunction class FunctionDictGetString final : public IFunction
{ {
public: public:
@ -781,11 +934,12 @@ private:
}; };
} }
if (!typeid_cast<const DataTypeUInt64 *>(arguments[2].get())) if (!typeid_cast<const DataTypeUInt64 *>(arguments[2].get()) &&
!typeid_cast<const DataTypeTuple *>(arguments[2].get()))
{ {
throw Exception{ throw Exception{
"Illegal type " + arguments[2]->getName() + " of third argument of function " + getName() "Illegal type " + arguments[2]->getName() + " of third argument of function " + getName()
+ ", must be UInt64.", + ", must be UInt64 or tuple(...).",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT
}; };
} }
@ -817,6 +971,8 @@ private:
if (!executeDispatch<FlatDictionary>(block, arguments, result, dict_ptr) && if (!executeDispatch<FlatDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatch<HashedDictionary>(block, arguments, result, dict_ptr) && !executeDispatch<HashedDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatch<CacheDictionary>(block, arguments, result, dict_ptr) && !executeDispatch<CacheDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatchComplex<ComplexKeyHashedDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatchComplex<ComplexKeyCacheDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatchRange<RangeHashedDictionary>(block, arguments, result, dict_ptr)) !executeDispatchRange<RangeHashedDictionary>(block, arguments, result, dict_ptr))
throw Exception{ throw Exception{
"Unsupported dictionary type " + dict_ptr->getTypeName(), "Unsupported dictionary type " + dict_ptr->getTypeName(),
@ -876,6 +1032,53 @@ private:
return true; return true;
} }
template <typename DictionaryType>
bool executeDispatchComplex(
Block & block, const ColumnNumbers & arguments, const size_t result, const IDictionaryBase * const dictionary)
{
const auto dict = typeid_cast<const DictionaryType *>(dictionary);
if (!dict)
return false;
if (arguments.size() != 3)
throw Exception{
"Function " + getName() + " for dictionary of type " + dict->getTypeName() +
" requires exactly 3 arguments",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH
};
const auto attr_name_col = typeid_cast<const ColumnConst<String> *>(block.getByPosition(arguments[1]).column.get());
if (!attr_name_col)
throw Exception{
"Second argument of function " + getName() + " must be a constant string",
ErrorCodes::ILLEGAL_COLUMN
};
const auto & attr_name = attr_name_col->getData();
const auto key_col_with_type = block.getByPosition(arguments[2]);
if (const auto key_col = typeid_cast<const ColumnTuple *>(key_col_with_type.column.get()))
{
const auto key_columns = ext::map<ConstColumnPlainPtrs>(key_col->getColumns(), [] (const ColumnPtr & ptr) {
return ptr.get();
});
const auto & key_types = static_cast<const DataTypeTuple &>(*key_col_with_type.type).getElements();
const auto out = new ColumnString;
block.getByPosition(result).column = out;
dict->getString(attr_name, key_columns, key_types, out);
}
else
throw Exception{
"Third argument of function " + getName() + " must be " + dict->getKeyDescription(),
ErrorCodes::TYPE_MISMATCH
};
return true;
}
template <typename DictionaryType> template <typename DictionaryType>
bool executeDispatchRange( bool executeDispatchRange(
Block & block, const ColumnNumbers & arguments, const size_t result, const IDictionaryBase * const dictionary) Block & block, const ColumnNumbers & arguments, const size_t result, const IDictionaryBase * const dictionary)
@ -983,6 +1186,253 @@ private:
}; };
class FunctionDictGetStringOrDefault final : public IFunction
{
public:
static constexpr auto name = "dictGetStringOrDefault";
static IFunction * create(const Context & context)
{
return new FunctionDictGetStringOrDefault{context.getExternalDictionaries()};
}
FunctionDictGetStringOrDefault(const ExternalDictionaries & dictionaries) : dictionaries(dictionaries) {}
String getName() const override { return name; }
private:
DataTypePtr getReturnType(const DataTypes & arguments) const override
{
if (arguments.size() != 4)
throw Exception{
"Number of arguments for function " + getName() + " doesn't match: passed " +
toString(arguments.size()) + ", should be 4.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH
};
if (!typeid_cast<const DataTypeString *>(arguments[0].get()))
throw Exception{
"Illegal type " + arguments[0]->getName() + " of first argument of function " + getName() +
", expected a string.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT
};
if (!typeid_cast<const DataTypeString *>(arguments[1].get()))
throw Exception{
"Illegal type " + arguments[1]->getName() + " of second argument of function " + getName() +
", expected a string.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT
};
if (!typeid_cast<const DataTypeUInt64 *>(arguments[2].get()) &&
!typeid_cast<const DataTypeTuple *>(arguments[2].get()))
{
throw Exception{
"Illegal type " + arguments[2]->getName() + " of third argument of function " + getName()
+ ", must be UInt64 or tuple(...).",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT
};
}
if (!typeid_cast<const DataTypeString *>(arguments[3].get()))
throw Exception{
"Illegal type " + arguments[3]->getName() + " of fourth argument of function " + getName() +
", must be String.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT
};
return new DataTypeString;
}
void execute(Block & block, const ColumnNumbers & arguments, const size_t result)
{
const auto dict_name_col = typeid_cast<const ColumnConst<String> *>(block.getByPosition(arguments[0]).column.get());
if (!dict_name_col)
throw Exception{
"First argument of function " + getName() + " must be a constant string",
ErrorCodes::ILLEGAL_COLUMN
};
auto dict = dictionaries.getDictionary(dict_name_col->getData());
const auto dict_ptr = dict.get();
if (!executeDispatch<FlatDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatch<HashedDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatch<CacheDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatchComplex<ComplexKeyHashedDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatchComplex<ComplexKeyCacheDictionary>(block, arguments, result, dict_ptr))
throw Exception{
"Unsupported dictionary type " + dict_ptr->getTypeName(),
ErrorCodes::UNKNOWN_TYPE
};
}
template <typename DictionaryType>
bool executeDispatch(
Block & block, const ColumnNumbers & arguments, const size_t result, const IDictionaryBase * const dictionary)
{
const auto dict = typeid_cast<const DictionaryType *>(dictionary);
if (!dict)
return false;
if (arguments.size() != 4)
throw Exception{
"Function " + getName() + " for dictionary of type " + dict->getTypeName() +
" requires exactly 4 arguments",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH
};
const auto attr_name_col = typeid_cast<const ColumnConst<String> *>(block.getByPosition(arguments[1]).column.get());
if (!attr_name_col)
throw Exception{
"Second argument of function " + getName() + " must be a constant string",
ErrorCodes::ILLEGAL_COLUMN
};
const auto & attr_name = attr_name_col->getData();
const auto id_col_untyped = block.getByPosition(arguments[2]).column.get();
if (const auto id_col = typeid_cast<const ColumnVector<UInt64> *>(id_col_untyped))
executeDispatch(block, arguments, result, dict, attr_name, id_col);
else if (const auto id_col = typeid_cast<const ColumnConst<UInt64> *>(id_col_untyped))
executeDispatch(block, arguments, result, dict, attr_name, id_col);
else
throw Exception{
"Third argument of function " + getName() + " must be UInt64",
ErrorCodes::ILLEGAL_COLUMN
};
return true;
}
template <typename DictionaryType>
void executeDispatch(
Block & block, const ColumnNumbers & arguments, const size_t result, const DictionaryType * const dictionary,
const std::string & attr_name, const ColumnVector<UInt64> * const id_col)
{
const auto default_col_untyped = block.getByPosition(arguments[3]).column.get();
if (const auto default_col = typeid_cast<const ColumnString *>(default_col_untyped))
{
/// vector ids, vector defaults
const auto out = new ColumnString;
block.getByPosition(result).column = out;
const auto & ids = id_col->getData();
dictionary->getString(attr_name, ids, default_col, out);
}
else if (const auto default_col = typeid_cast<const ColumnConst<String> *>(default_col_untyped))
{
/// vector ids, const defaults
const auto out = new ColumnString;
block.getByPosition(result).column = out;
const auto & ids = id_col->getData();
const auto & def = default_col->getData();
dictionary->getString(attr_name, ids, def, out);
}
else
throw Exception{
"Fourth argument of function " + getName() + " must be String",
ErrorCodes::ILLEGAL_COLUMN
};
}
template <typename DictionaryType>
void executeDispatch(
Block & block, const ColumnNumbers & arguments, const size_t result, const DictionaryType * const dictionary,
const std::string & attr_name, const ColumnConst<UInt64> * const id_col)
{
const auto default_col_untyped = block.getByPosition(arguments[3]).column.get();
if (const auto default_col = typeid_cast<const ColumnString *>(default_col_untyped))
{
/// const ids, vector defaults
/// @todo avoid materialization
const PODArray<UInt64> ids(id_col->size(), id_col->getData());
const auto out = new ColumnString;
block.getByPosition(result).column = out;
dictionary->getString(attr_name, ids, default_col, out);
}
else if (const auto default_col = typeid_cast<const ColumnConst<String> *>(default_col_untyped))
{
/// const ids, const defaults
const PODArray<UInt64> ids(1, id_col->getData());
auto out = std::make_unique<ColumnString>();
const auto & def = default_col->getData();
dictionary->getString(attr_name, ids, def, out.get());
block.getByPosition(result).column = new ColumnConst<String>{
id_col->size(), out->getDataAt(0).toString()
};
}
else
throw Exception{
"Fourth argument of function " + getName() + " must be String",
ErrorCodes::ILLEGAL_COLUMN
};
}
template <typename DictionaryType>
bool executeDispatchComplex(
Block & block, const ColumnNumbers & arguments, const size_t result, const IDictionaryBase * const dictionary)
{
const auto dict = typeid_cast<const DictionaryType *>(dictionary);
if (!dict)
return false;
if (arguments.size() != 4)
throw Exception{
"Function " + getName() + " for dictionary of type " + dict->getTypeName() +
" requires exactly 4 arguments",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH
};
const auto attr_name_col = typeid_cast<const ColumnConst<String> *>(block.getByPosition(arguments[1]).column.get());
if (!attr_name_col)
throw Exception{
"Second argument of function " + getName() + " must be a constant string",
ErrorCodes::ILLEGAL_COLUMN
};
const auto & attr_name = attr_name_col->getData();
const auto key_col_with_type = block.getByPosition(arguments[2]);
const auto & key_col = typeid_cast<const ColumnTuple &>(*key_col_with_type.column);
const auto key_columns = ext::map<ConstColumnPlainPtrs>(key_col.getColumns(), [] (const ColumnPtr & ptr) {
return ptr.get();
});
const auto & key_types = static_cast<const DataTypeTuple &>(*key_col_with_type.type).getElements();
const auto out = new ColumnString;
block.getByPosition(result).column = out;
const auto default_col_untyped = block.getByPosition(arguments[3]).column.get();
if (const auto default_col = typeid_cast<const ColumnString *>(default_col_untyped))
dict->getString(attr_name, key_columns, key_types, default_col, out);
else if (const auto default_col = typeid_cast<const ColumnConst<String> *>(default_col_untyped))
{
const auto & def = default_col->getData();
dict->getString(attr_name, key_columns, key_types, def, out);
}
else
throw Exception{
"Fourth argument of function " + getName() + " must be String",
ErrorCodes::ILLEGAL_COLUMN
};
return true;
}
const ExternalDictionaries & dictionaries;
};
template <typename DataType> struct DictGetTraits; template <typename DataType> struct DictGetTraits;
#define DECLARE_DICT_GET_TRAITS(TYPE, DATA_TYPE) \ #define DECLARE_DICT_GET_TRAITS(TYPE, DATA_TYPE) \
template <> struct DictGetTraits<DATA_TYPE>\ template <> struct DictGetTraits<DATA_TYPE>\
@ -995,12 +1445,33 @@ template <> struct DictGetTraits<DATA_TYPE>\
dict->get##TYPE(name, ids, out);\ dict->get##TYPE(name, ids, out);\
}\ }\
template <typename DictionaryType>\ template <typename DictionaryType>\
static void get(\
const DictionaryType * const dict, const std::string & name, const ConstColumnPlainPtrs & key_columns,\
const DataTypes & key_types, PODArray<TYPE> & out)\
{\
dict->get##TYPE(name, key_columns, key_types, out);\
}\
template <typename DictionaryType>\
static void get(\ static void get(\
const DictionaryType * const dict, const std::string & name, const PODArray<UInt64> & ids,\ const DictionaryType * const dict, const std::string & name, const PODArray<UInt64> & ids,\
const PODArray<UInt16> & dates, PODArray<TYPE> & out)\ const PODArray<UInt16> & dates, PODArray<TYPE> & out)\
{\ {\
dict->get##TYPE(name, ids, dates, out);\ dict->get##TYPE(name, ids, dates, out);\
}\ }\
template <typename DictionaryType, typename DefaultsType>\
static void getOrDefault(\
const DictionaryType * const dict, const std::string & name, const PODArray<UInt64> & ids,\
const DefaultsType & def, PODArray<TYPE> & out)\
{\
dict->get##TYPE(name, ids, def, out);\
}\
template <typename DictionaryType, typename DefaultsType>\
static void getOrDefault(\
const DictionaryType * const dict, const std::string & name, const ConstColumnPlainPtrs & key_columns,\
const DataTypes & key_types, const DefaultsType & def, PODArray<TYPE> & out)\
{\
dict->get##TYPE(name, key_columns, key_types, def, out);\
}\
}; };
DECLARE_DICT_GET_TRAITS(UInt8, DataTypeUInt8) DECLARE_DICT_GET_TRAITS(UInt8, DataTypeUInt8)
DECLARE_DICT_GET_TRAITS(UInt16, DataTypeUInt16) DECLARE_DICT_GET_TRAITS(UInt16, DataTypeUInt16)
@ -1061,11 +1532,12 @@ private:
}; };
} }
if (!typeid_cast<const DataTypeUInt64 *>(arguments[2].get())) if (!typeid_cast<const DataTypeUInt64 *>(arguments[2].get()) &&
!typeid_cast<const DataTypeTuple *>(arguments[2].get()))
{ {
throw Exception{ throw Exception{
"Illegal type " + arguments[2]->getName() + " of third argument of function " + getName() "Illegal type " + arguments[2]->getName() + " of third argument of function " + getName()
+ ", must be UInt64.", + ", must be UInt64 or tuple(...).",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT
}; };
} }
@ -1097,6 +1569,8 @@ private:
if (!executeDispatch<FlatDictionary>(block, arguments, result, dict_ptr) && if (!executeDispatch<FlatDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatch<HashedDictionary>(block, arguments, result, dict_ptr) && !executeDispatch<HashedDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatch<CacheDictionary>(block, arguments, result, dict_ptr) && !executeDispatch<CacheDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatchComplex<ComplexKeyHashedDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatchComplex<ComplexKeyCacheDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatchRange<RangeHashedDictionary>(block, arguments, result, dict_ptr)) !executeDispatchRange<RangeHashedDictionary>(block, arguments, result, dict_ptr))
throw Exception{ throw Exception{
"Unsupported dictionary type " + dict_ptr->getTypeName(), "Unsupported dictionary type " + dict_ptr->getTypeName(),
@ -1131,13 +1605,11 @@ private:
const auto id_col_untyped = block.getByPosition(arguments[2]).column.get(); const auto id_col_untyped = block.getByPosition(arguments[2]).column.get();
if (const auto id_col = typeid_cast<const ColumnVector<UInt64> *>(id_col_untyped)) if (const auto id_col = typeid_cast<const ColumnVector<UInt64> *>(id_col_untyped))
{ {
const auto out = new ColumnVector<Type>; const auto out = new ColumnVector<Type>(id_col->size());
block.getByPosition(result).column = out; block.getByPosition(result).column = out;
const auto & ids = id_col->getData(); const auto & ids = id_col->getData();
auto & data = out->getData(); auto & data = out->getData();
const auto size = ids.size();
data.resize(size);
DictGetTraits<DataType>::get(dict, attr_name, ids, data); DictGetTraits<DataType>::get(dict, attr_name, ids, data);
} }
@ -1160,6 +1632,55 @@ private:
return true; return true;
} }
template <typename DictionaryType>
bool executeDispatchComplex(
Block & block, const ColumnNumbers & arguments, const size_t result, const IDictionaryBase * const dictionary)
{
const auto dict = typeid_cast<const DictionaryType *>(dictionary);
if (!dict)
return false;
if (arguments.size() != 3)
throw Exception{
"Function " + getName() + " for dictionary of type " + dict->getTypeName() +
" requires exactly 3 arguments",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH
};
const auto attr_name_col = typeid_cast<const ColumnConst<String> *>(block.getByPosition(arguments[1]).column.get());
if (!attr_name_col)
throw Exception{
"Second argument of function " + getName() + " must be a constant string",
ErrorCodes::ILLEGAL_COLUMN
};
const auto & attr_name = attr_name_col->getData();
const auto key_col_with_type = block.getByPosition(arguments[2]);
if (const auto key_col = typeid_cast<const ColumnTuple *>(key_col_with_type.column.get()))
{
const auto key_columns = ext::map<ConstColumnPlainPtrs>(key_col->getColumns(), [] (const ColumnPtr & ptr) {
return ptr.get();
});
const auto & key_types = static_cast<const DataTypeTuple &>(*key_col_with_type.type).getElements();
const auto out = new ColumnVector<Type>(key_columns.front()->size());
block.getByPosition(result).column = out;
auto & data = out->getData();
DictGetTraits<DataType>::get(dict, attr_name, key_columns, key_types, data);
}
else
throw Exception{
"Third argument of function " + getName() + " must be " + dict->getKeyDescription(),
ErrorCodes::TYPE_MISMATCH
};
return true;
}
template <typename DictionaryType> template <typename DictionaryType>
bool executeDispatchRange( bool executeDispatchRange(
Block & block, const ColumnNumbers & arguments, const size_t result, const IDictionaryBase * const dictionary) Block & block, const ColumnNumbers & arguments, const size_t result, const IDictionaryBase * const dictionary)
@ -1295,6 +1816,292 @@ using FunctionDictGetDate = FunctionDictGet<DataTypeDate>;
using FunctionDictGetDateTime = FunctionDictGet<DataTypeDateTime>; using FunctionDictGetDateTime = FunctionDictGet<DataTypeDateTime>;
template <typename DataType>
class FunctionDictGetOrDefault final : public IFunction
{
using Type = typename DataType::FieldType;
public:
static const std::string name;
static IFunction * create(const Context & context)
{
return new FunctionDictGetOrDefault{context.getExternalDictionaries()};
}
FunctionDictGetOrDefault(const ExternalDictionaries & dictionaries) : dictionaries(dictionaries) {}
String getName() const override { return name; }
private:
DataTypePtr getReturnType(const DataTypes & arguments) const override
{
if (arguments.size() != 4)
throw Exception{
"Number of arguments for function " + getName() + " doesn't match: passed "
+ toString(arguments.size()) + ", should be 4.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH
};
if (!typeid_cast<const DataTypeString *>(arguments[0].get()))
{
throw Exception{
"Illegal type " + arguments[0]->getName() + " of first argument of function " + getName()
+ ", expected a string.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT
};
}
if (!typeid_cast<const DataTypeString *>(arguments[1].get()))
{
throw Exception{
"Illegal type " + arguments[1]->getName() + " of second argument of function " + getName()
+ ", expected a string.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT
};
}
if (!typeid_cast<const DataTypeUInt64 *>(arguments[2].get()) &&
!typeid_cast<const DataTypeTuple *>(arguments[2].get()))
{
throw Exception{
"Illegal type " + arguments[2]->getName() + " of third argument of function " + getName()
+ ", must be UInt64 or tuple(...).",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT
};
}
if (!typeid_cast<const DataType *>(arguments[3].get()))
{
throw Exception{
"Illegal type " + arguments[3]->getName() + " of fourth argument of function " + getName()
+ ", must be " + DataType{}.getName() + ".",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT
};
}
return new DataType;
}
void execute(Block & block, const ColumnNumbers & arguments, const size_t result) override
{
const auto dict_name_col = typeid_cast<const ColumnConst<String> *>(block.getByPosition(arguments[0]).column.get());
if (!dict_name_col)
throw Exception{
"First argument of function " + getName() + " must be a constant string",
ErrorCodes::ILLEGAL_COLUMN
};
auto dict = dictionaries.getDictionary(dict_name_col->getData());
const auto dict_ptr = dict.get();
if (!executeDispatch<FlatDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatch<HashedDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatch<CacheDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatchComplex<ComplexKeyHashedDictionary>(block, arguments, result, dict_ptr) &&
!executeDispatchComplex<ComplexKeyCacheDictionary>(block, arguments, result, dict_ptr))
throw Exception{
"Unsupported dictionary type " + dict_ptr->getTypeName(),
ErrorCodes::UNKNOWN_TYPE
};
}
template <typename DictionaryType>
bool executeDispatch(Block & block, const ColumnNumbers & arguments, const size_t result,
const IDictionaryBase * const dictionary)
{
const auto dict = typeid_cast<const DictionaryType *>(dictionary);
if (!dict)
return false;
if (arguments.size() != 4)
throw Exception{
"Function " + getName() + " for dictionary of type " + dict->getTypeName() +
" requires exactly 4 arguments.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH
};
const auto attr_name_col = typeid_cast<const ColumnConst<String> *>(block.getByPosition(arguments[1]).column.get());
if (!attr_name_col)
throw Exception{
"Second argument of function " + getName() + " must be a constant string",
ErrorCodes::ILLEGAL_COLUMN
};
const auto & attr_name = attr_name_col->getData();
const auto id_col_untyped = block.getByPosition(arguments[2]).column.get();
if (const auto id_col = typeid_cast<const ColumnVector<UInt64> *>(id_col_untyped))
executeDispatch(block, arguments, result, dict, attr_name, id_col);
else if (const auto id_col = typeid_cast<const ColumnConst<UInt64> *>(id_col_untyped))
executeDispatch(block, arguments, result, dict, attr_name, id_col);
else
throw Exception{
"Third argument of function " + getName() + " must be UInt64",
ErrorCodes::ILLEGAL_COLUMN
};
return true;
}
template <typename DictionaryType>
void executeDispatch(
Block & block, const ColumnNumbers & arguments, const size_t result, const DictionaryType * const dictionary,
const std::string & attr_name, const ColumnVector<UInt64> * const id_col)
{
const auto default_col_untyped = block.getByPosition(arguments[3]).column.get();
if (const auto default_col = typeid_cast<const ColumnVector<Type> *>(default_col_untyped))
{
/// vector ids, vector defaults
const auto out = new ColumnVector<Type>(id_col->size());
block.getByPosition(result).column = out;
const auto & ids = id_col->getData();
auto & data = out->getData();
const auto & defs = default_col->getData();
DictGetTraits<DataType>::getOrDefault(dictionary, attr_name, ids, defs, data);
}
else if (const auto default_col = typeid_cast<const ColumnConst<Type> *>(default_col_untyped))
{
/// vector ids, const defaults
const auto out = new ColumnVector<Type>(id_col->size());
block.getByPosition(result).column = out;
const auto & ids = id_col->getData();
auto & data = out->getData();
const auto def = default_col->getData();
DictGetTraits<DataType>::getOrDefault(dictionary, attr_name, ids, def, data);
}
else
throw Exception{
"Fourth argument of function " + getName() + " must be " + DataType{}.getName(),
ErrorCodes::ILLEGAL_COLUMN
};
}
template <typename DictionaryType>
void executeDispatch(
Block & block, const ColumnNumbers & arguments, const size_t result, const DictionaryType * const dictionary,
const std::string & attr_name, const ColumnConst<UInt64> * const id_col)
{
const auto default_col_untyped = block.getByPosition(arguments[3]).column.get();
if (const auto default_col = typeid_cast<const ColumnVector<Type> *>(default_col_untyped))
{
/// const ids, vector defaults
/// @todo avoid materialization
const PODArray<UInt64> ids(id_col->size(), id_col->getData());
const auto out = new ColumnVector<Type>(id_col->size());
block.getByPosition(result).column = out;
auto & data = out->getData();
const auto & defs = default_col->getData();
DictGetTraits<DataType>::getOrDefault(dictionary, attr_name, ids, defs, data);
}
else if (const auto default_col = typeid_cast<const ColumnConst<Type> *>(default_col_untyped))
{
/// const ids, const defaults
const PODArray<UInt64> ids(1, id_col->getData());
PODArray<Type> data(1);
const auto & def = default_col->getData();
DictGetTraits<DataType>::getOrDefault(dictionary, attr_name, ids, def, data);
block.getByPosition(result).column = new ColumnConst<Type>{id_col->size(), data.front()};
}
else
throw Exception{
"Fourth argument of function " + getName() + " must be " + DataType{}.getName(),
ErrorCodes::ILLEGAL_COLUMN
};
}
template <typename DictionaryType>
bool executeDispatchComplex(
Block & block, const ColumnNumbers & arguments, const size_t result, const IDictionaryBase * const dictionary)
{
const auto dict = typeid_cast<const DictionaryType *>(dictionary);
if (!dict)
return false;
if (arguments.size() != 4)
throw Exception{
"Function " + getName() + " for dictionary of type " + dict->getTypeName() +
" requires exactly 4 arguments",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH
};
const auto attr_name_col = typeid_cast<const ColumnConst<String> *>(block.getByPosition(arguments[1]).column.get());
if (!attr_name_col)
throw Exception{
"Second argument of function " + getName() + " must be a constant string",
ErrorCodes::ILLEGAL_COLUMN
};
const auto & attr_name = attr_name_col->getData();
const auto key_col_with_type = block.getByPosition(arguments[2]);
const auto & key_col = typeid_cast<const ColumnTuple &>(*key_col_with_type.column);
const auto key_columns = ext::map<ConstColumnPlainPtrs>(key_col.getColumns(), [] (const ColumnPtr & ptr) {
return ptr.get();
});
const auto & key_types = static_cast<const DataTypeTuple &>(*key_col_with_type.type).getElements();
/// @todo detect when all key columns are constant
const auto rows = key_col.size();
const auto out = new ColumnVector<Type>(rows);
block.getByPosition(result).column = out;
auto & data = out->getData();
const auto default_col_untyped = block.getByPosition(arguments[3]).column.get();
if (const auto default_col = typeid_cast<const ColumnVector<Type> *>(default_col_untyped))
{
/// const defaults
const auto & defs = default_col->getData();
DictGetTraits<DataType>::getOrDefault(dict, attr_name, key_columns, key_types, defs, data);
}
else if (const auto default_col = typeid_cast<const ColumnConst<Type> *>(default_col_untyped))
{
const auto def = default_col->getData();
DictGetTraits<DataType>::getOrDefault(dict, attr_name, key_columns, key_types, def, data);
}
else
throw Exception{
"Fourth argument of function " + getName() + " must be " + DataType{}.getName(),
ErrorCodes::ILLEGAL_COLUMN
};
return true;
}
const ExternalDictionaries & dictionaries;
};
template <typename DataType>
const std::string FunctionDictGetOrDefault<DataType>::name = "dictGet" + DataType{}.getName() + "OrDefault";
using FunctionDictGetUInt8OrDefault = FunctionDictGetOrDefault<DataTypeUInt8>;
using FunctionDictGetUInt16OrDefault = FunctionDictGetOrDefault<DataTypeUInt16>;
using FunctionDictGetUInt32OrDefault = FunctionDictGetOrDefault<DataTypeUInt32>;
using FunctionDictGetUInt64OrDefault = FunctionDictGetOrDefault<DataTypeUInt64>;
using FunctionDictGetInt8OrDefault = FunctionDictGetOrDefault<DataTypeInt8>;
using FunctionDictGetInt16OrDefault = FunctionDictGetOrDefault<DataTypeInt16>;
using FunctionDictGetInt32OrDefault = FunctionDictGetOrDefault<DataTypeInt32>;
using FunctionDictGetInt64OrDefault = FunctionDictGetOrDefault<DataTypeInt64>;
using FunctionDictGetFloat32OrDefault = FunctionDictGetOrDefault<DataTypeFloat32>;
using FunctionDictGetFloat64OrDefault = FunctionDictGetOrDefault<DataTypeFloat64>;
using FunctionDictGetDateOrDefault = FunctionDictGetOrDefault<DataTypeDate>;
using FunctionDictGetDateTimeOrDefault = FunctionDictGetOrDefault<DataTypeDateTime>;
class FunctionDictGetHierarchy final : public IFunction class FunctionDictGetHierarchy final : public IFunction
{ {
public: public:

Some files were not shown because too many files have changed in this diff Show More