devirtualize compareAt calls

This commit is contained in:
Albert Kidrachev 2020-06-01 15:10:32 +03:00
parent 43c0499e87
commit fe170508bd
21 changed files with 123 additions and 10 deletions

View File

@ -192,6 +192,11 @@ public:
return 0;
}
std::vector<UInt8> compareAt(const IColumn &, size_t, const std::vector<UInt8> &, int) const override
{
return std::vector<UInt8>(getData().size(), 0);
}
void getPermutation(bool reverse, size_t limit, int nan_direction_hint, Permutation & res) const override;
void updatePermutation(bool reverse, size_t limit, int, Permutation & res, EqualRanges & equal_range) const override;

View File

@ -309,6 +309,10 @@ int ColumnArray::compareAt(size_t n, size_t m, const IColumn & rhs_, int nan_dir
: 1);
}
std::vector<UInt8> ColumnArray::compareAt(const IColumn & rhs, size_t rhs_row_num, const std::vector<UInt8> & mask, int nan_direction_hint) const
{
return compareImpl<ColumnArray>(assert_cast<const ColumnArray &>(rhs), rhs_row_num, mask, nan_direction_hint);
}
namespace
{

View File

@ -72,6 +72,7 @@ public:
ColumnPtr index(const IColumn & indexes, size_t limit) const override;
template <typename Type> ColumnPtr indexImpl(const PaddedPODArray<Type> & indexes, size_t limit) const;
int compareAt(size_t n, size_t m, const IColumn & rhs_, int nan_direction_hint) const override;
std::vector<UInt8> compareAt(const IColumn & rhs, size_t rhs_row_num, const std::vector<UInt8> & mask, int nan_direction_hint) const override;
void getPermutation(bool reverse, size_t limit, int nan_direction_hint, Permutation & res) const override;
void updatePermutation(bool reverse, size_t limit, int nan_direction_hint, Permutation & res, EqualRanges & equal_range) const override;
void reserve(size_t n) override;

View File

@ -187,6 +187,11 @@ public:
return data->compareAt(0, 0, *assert_cast<const ColumnConst &>(rhs).data, nan_direction_hint);
}
std::vector<UInt8> compareAt(const IColumn & rhs, size_t rhs_row_num, const std::vector<UInt8> & mask, int nan_direction_hint) const override
{
return data->compareAt(rhs, rhs_row_num, mask, nan_direction_hint);
}
MutableColumns scatter(ColumnIndex num_columns, const Selector & selector) const override;
void gather(ColumnGathererStream &) override

View File

@ -39,6 +39,12 @@ int ColumnDecimal<T>::compareAt(size_t n, size_t m, const IColumn & rhs_, int) c
return decimalLess<T>(b, a, other.scale, scale) ? 1 : (decimalLess<T>(a, b, scale, other.scale) ? -1 : 0);
}
template <typename T>
std::vector<UInt8> ColumnDecimal<T>::compareAt(const IColumn & rhs, size_t rhs_row_num, const std::vector<UInt8> & mask, int nan_direction_hint) const
{
return compareImpl<ColumnDecimal>(static_cast<const Self &>(rhs), rhs_row_num, mask, nan_direction_hint);
}
template <typename T>
StringRef ColumnDecimal<T>::serializeValueIntoArena(size_t n, Arena & arena, char const *& begin) const
{

View File

@ -107,6 +107,7 @@ public:
void updateHashWithValue(size_t n, SipHash & hash) const override;
void updateWeakHash32(WeakHash32 & hash) const override;
int compareAt(size_t n, size_t m, const IColumn & rhs_, int nan_direction_hint) const override;
std::vector<UInt8> compareAt(const IColumn & rhs, size_t rhs_row_num, const std::vector<UInt8> & mask, int nan_direction_hint) const override;
void getPermutation(bool reverse, size_t limit, int nan_direction_hint, IColumn::Permutation & res) const override;
void updatePermutation(bool reverse, size_t limit, int, IColumn::Permutation & res, EqualRanges& equal_range) const override;

View File

@ -116,6 +116,11 @@ public:
return memcmpSmallAllowOverflow15(chars.data() + p1 * n, rhs.chars.data() + p2 * n, n);
}
std::vector<UInt8> compareAt(const IColumn & rhs, size_t rhs_row_num, const std::vector<UInt8> & mask, int nan_direction_hint) const override
{
return compareImpl<ColumnFixedString>(assert_cast<const ColumnFixedString &>(rhs), rhs_row_num, mask, nan_direction_hint);
}
void getPermutation(bool reverse, size_t limit, int nan_direction_hint, Permutation & res) const override;
void updatePermutation(bool reverse, size_t limit, int nan_direction_hint, Permutation & res, EqualRanges & equal_range) const override;

View File

@ -116,6 +116,11 @@ public:
throw Exception("compareAt is not implemented for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
std::vector<UInt8> compareAt(const IColumn &, size_t, const std::vector<UInt8> &, int) const override
{
throw Exception("compareAt(const IColumn &, size_t, const std::vector<UInt8> &, int) is not implemented for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
void getPermutation(bool, size_t, int, Permutation &) const override
{
throw Exception("getPermutation is not implemented for " + getName(), ErrorCodes::NOT_IMPLEMENTED);

View File

@ -279,6 +279,11 @@ int ColumnLowCardinality::compareAt(size_t n, size_t m, const IColumn & rhs, int
return getDictionary().compareAt(n_index, m_index, low_cardinality_column.getDictionary(), nan_direction_hint);
}
std::vector<UInt8> ColumnLowCardinality::compareAt(const IColumn & rhs, size_t rhs_row_num, const std::vector<UInt8> & mask, int nan_direction_hint) const
{
return compareImpl<ColumnLowCardinality>(assert_cast<const ColumnLowCardinality &>(rhs), rhs_row_num, mask, nan_direction_hint);
}
void ColumnLowCardinality::getPermutation(bool reverse, size_t limit, int nan_direction_hint, Permutation & res) const
{
if (limit == 0)

View File

@ -109,6 +109,8 @@ public:
int compareAt(size_t n, size_t m, const IColumn & rhs, int nan_direction_hint) const override;
std::vector<UInt8> compareAt(const IColumn & rhs, size_t rhs_row_num, const std::vector<UInt8> & mask, int nan_direction_hint) const override;
void getPermutation(bool reverse, size_t limit, int nan_direction_hint, Permutation & res) const override;
void updatePermutation(bool reverse, size_t limit, int, IColumn::Permutation & res, EqualRanges & equal_range) const override;

View File

@ -248,6 +248,11 @@ int ColumnNullable::compareAt(size_t n, size_t m, const IColumn & rhs_, int null
return getNestedColumn().compareAt(n, m, nested_rhs, null_direction_hint);
}
std::vector<UInt8> ColumnNullable::compareAt(const IColumn & rhs, size_t rhs_row_num, const std::vector<UInt8> & mask, int nan_direction_hint) const
{
return compareImpl<ColumnNullable>(assert_cast<const ColumnNullable &>(rhs), rhs_row_num, mask, nan_direction_hint);
}
void ColumnNullable::getPermutation(bool reverse, size_t limit, int null_direction_hint, Permutation & res) const
{
/// Cannot pass limit because of unknown amount of NULLs.

View File

@ -77,6 +77,7 @@ public:
ColumnPtr permute(const Permutation & perm, size_t limit) const override;
ColumnPtr index(const IColumn & indexes, size_t limit) const override;
int compareAt(size_t n, size_t m, const IColumn & rhs_, int null_direction_hint) const override;
std::vector<UInt8> compareAt(const IColumn & rhs, size_t rhs_row_num, const std::vector<UInt8> & mask, int nan_direction_hint) const override;
void getPermutation(bool reverse, size_t limit, int null_direction_hint, Permutation & res) const override;
void updatePermutation(bool reverse, size_t limit, int, Permutation & res, EqualRanges & equal_range) const override;
void reserve(size_t n) override;

View File

@ -220,6 +220,11 @@ public:
return memcmpSmallAllowOverflow15(chars.data() + offsetAt(n), sizeAt(n) - 1, rhs.chars.data() + rhs.offsetAt(m), rhs.sizeAt(m) - 1);
}
std::vector<UInt8> compareAt(const IColumn & rhs, size_t rhs_row_num, const std::vector<UInt8> & mask, int nan_direction_hint) const override
{
return compareImpl<ColumnString>(assert_cast<const ColumnString &>(rhs), rhs_row_num, mask, nan_direction_hint);
}
/// Variant of compareAt for string comparison with respect of collation.
int compareAtWithCollation(size_t n, size_t m, const IColumn & rhs_, const Collator & collator) const;

View File

@ -1,4 +1,5 @@
#include <Columns/ColumnTuple.h>
#include <Columns/IColumnImpl.h>
#include <DataStreams/ColumnGathererStream.h>
#include <IO/WriteBufferFromString.h>
#include <IO/Operators.h>
@ -278,6 +279,11 @@ int ColumnTuple::compareAt(size_t n, size_t m, const IColumn & rhs, int nan_dire
return 0;
}
std::vector<UInt8> ColumnTuple::compareAt(const IColumn & rhs, size_t rhs_row_num, const std::vector<UInt8> & mask, int nan_direction_hint) const
{
return compareImpl<ColumnTuple>(assert_cast<const ColumnTuple &>(rhs), rhs_row_num, mask, nan_direction_hint);
}
template <bool positive>
struct ColumnTuple::Less
{

View File

@ -70,6 +70,7 @@ public:
MutableColumns scatter(ColumnIndex num_columns, const Selector & selector) const override;
void gather(ColumnGathererStream & gatherer_stream) override;
int compareAt(size_t n, size_t m, const IColumn & rhs, int nan_direction_hint) const override;
std::vector<UInt8> compareAt(const IColumn & rhs, size_t rhs_row_num, const std::vector<UInt8> & mask, int nan_direction_hint) const override;
void getExtremes(Field & min, Field & max) const override;
void getPermutation(bool reverse, size_t limit, int nan_direction_hint, Permutation & res) const override;
void updatePermutation(bool reverse, size_t limit, int nan_direction_hint, IColumn::Permutation & res, EqualRanges & equal_range) const override;

View File

@ -77,6 +77,7 @@ public:
}
int compareAt(size_t n, size_t m, const IColumn & rhs, int nan_direction_hint) const override;
std::vector<UInt8> compareAt(const IColumn & rhs, size_t rhs_row_num, const std::vector<UInt8> & mask, int nan_direction_hint) const override;
void updatePermutation(bool reverse, size_t limit, int nan_direction_hint, IColumn::Permutation & res, EqualRanges & equal_range) const override;
void getExtremes(Field & min, Field & max) const override { column_holder->getExtremes(min, max); }
@ -375,6 +376,12 @@ int ColumnUnique<ColumnType>::compareAt(size_t n, size_t m, const IColumn & rhs,
return getNestedColumn()->compareAt(n, m, *column_unique.getNestedColumn(), nan_direction_hint);
}
template <typename ColumnType>
std::vector<UInt8> ColumnUnique<ColumnType>::compareAt(const IColumn & rhs, size_t rhs_row_num, const std::vector<UInt8> & mask, int nan_direction_hint) const
{
return compareImpl<ColumnUnique<ColumnType>>(static_cast<const ColumnUnique<ColumnType> &>(rhs), rhs_row_num, mask, nan_direction_hint);
}
template <typename ColumnType>
void ColumnUnique<ColumnType>::updatePermutation(bool reverse, size_t limit, int nan_direction_hint, IColumn::Permutation & res, EqualRanges & equal_range) const
{

View File

@ -276,6 +276,11 @@ public:
return typeid(rhs) == typeid(ColumnVector<T>);
}
std::vector<UInt8> compareAt(const IColumn & rhs, size_t rhs_row_num, const std::vector<UInt8> & mask, int nan_direction_hint) const override
{
return this->template compareImpl<Self>(static_cast<const Self &>(rhs), rhs_row_num, mask, nan_direction_hint);
}
/** More efficient methods of manipulation - to manipulate with data directly. */
Container & getData()
{

View File

@ -244,6 +244,10 @@ public:
*/
virtual int compareAt(size_t n, size_t m, const IColumn & rhs, int nan_direction_hint) const = 0;
virtual std::vector<UInt8> compareAt(const IColumn & rhs, size_t rhs_row_num, const std::vector<UInt8> & mask, int nan_direction_hint) const = 0;
/** Returns a permutation that sorts elements of this column,
* i.e. perm[i]-th element of source column should be i-th element of sorted column.
* reverse - reverse ordering (acsending).
@ -399,7 +403,6 @@ public:
virtual bool lowCardinality() const { return false; }
virtual ~IColumn() = default;
IColumn() = default;
IColumn(const IColumn &) = default;
@ -414,6 +417,9 @@ protected:
/// In derived classes (that use final keyword), implement scatter method as call to scatterImpl.
template <typename Derived>
std::vector<MutablePtr> scatterImpl(ColumnIndex num_columns, const Selector & selector) const;
template <typename Derived>
std::vector<UInt8> compareImpl(const Derived & rhs, size_t rhs_row_num, const std::vector<UInt8> & mask, int nan_direction_hint) const;
};
using ColumnPtr = IColumn::Ptr;

View File

@ -35,6 +35,7 @@ public:
size_t byteSize() const override { return 0; }
size_t allocatedBytes() const override { return 0; }
int compareAt(size_t, size_t, const IColumn &, int) const override { return 0; }
std::vector<UInt8> compareAt(const IColumn &, size_t, const std::vector<UInt8> &, int) const override { return std::vector<UInt8>(s, 0); }
Field operator[](size_t) const override { throw Exception("Cannot get value from " + getName(), ErrorCodes::NOT_IMPLEMENTED); }
void get(size_t, Field &) const override { throw Exception("Cannot get value from " + getName(), ErrorCodes::NOT_IMPLEMENTED); }

View File

@ -46,4 +46,18 @@ std::vector<IColumn::MutablePtr> IColumn::scatterImpl(ColumnIndex num_columns,
return columns;
}
template <typename Derived>
std::vector<UInt8> IColumn::compareImpl(const Derived & rhs, size_t rhs_row_num, const std::vector<UInt8> & mask, int nan_direction_hint) const
{
size_t rows_num = size();
std::vector<UInt8> results(rows_num, 0);
for (size_t i = 0; i < rows_num; ++i)
{
if (mask[i])
results[i] = compareAt(i, rhs_row_num, rhs, nan_direction_hint);
}
return results;
}
}

View File

@ -44,6 +44,34 @@ bool less(const ColumnRawPtrs & lhs, UInt64 lhs_row_num,
return false;
}
IColumn::Filter getFilterMask(const ColumnRawPtrs & lhs, const ColumnRawPtrs & rhs, size_t rhs_row_num, const SortDescription & description, size_t rows_num)
{
IColumn::Filter filter(rows_num, 1);
std::vector<UInt8> mask(rows_num, 1);
size_t size = description.size();
for (size_t i = 0; i < size; ++i)
{
std::vector<UInt8> compare_result = lhs[i]->compareAt(*rhs[i], rhs_row_num, mask, 1);
int direction = description[i].direction;
for (size_t j = 0; j < rows_num; ++j)
{
if (mask[j])
{
int res = direction * compare_result[j];
if (res)
{
filter[j] = (res >= 0);
mask[j] = 0;
}
}
}
}
return filter;
}
void PartialSortingTransform::transform(Chunk & chunk)
{
if (read_rows)
@ -60,18 +88,13 @@ void PartialSortingTransform::transform(Chunk & chunk)
*/
if (!threshold_block_columns.empty())
{
IColumn::Filter filter(rows_num, 1);
block_columns = extractColumns(block, description);
size_t filtered_count = 0;
for (UInt64 i = 0; i < rows_num; ++i)
{
if (less(threshold_block_columns, limit - 1, block_columns, i, description))
{
++filtered_count;
filter[i] = 0;
}
}
IColumn::Filter filter = getFilterMask(block_columns, threshold_block_columns, limit - 1, description, rows_num);
for (const auto & item : filter)
filtered_count += !item;
if (filtered_count)
{