Refactor, fix bugs, improve performance

This commit is contained in:
Pavel Kruglov 2021-04-27 15:49:58 +03:00
parent bd415b17d2
commit 400cad4d8b
51 changed files with 476 additions and 323 deletions

View File

@ -1,5 +1,6 @@
#include <Columns/ColumnAggregateFunction.h>
#include <Columns/ColumnsCommon.h>
#include <Columns/MaskOperations.h>
#include <Common/assert_cast.h>
#include <DataStreams/ColumnGathererStream.h>
#include <IO/WriteBufferFromArena.h>
@ -308,6 +309,10 @@ ColumnPtr ColumnAggregateFunction::filter(const Filter & filter, ssize_t result_
return res;
}
void ColumnAggregateFunction::expand(const Filter & mask, bool reverse)
{
expandDataByMask<char *>(data, mask, reverse, nullptr);
}
ColumnPtr ColumnAggregateFunction::permute(const Permutation & perm, size_t limit) const
{

View File

@ -175,7 +175,9 @@ public:
void popBack(size_t n) override;
ColumnPtr filter(const Filter & filter, ssize_t result_size_hint, bool reverse = false) const override;
ColumnPtr filter(const Filter & filter, ssize_t result_size_hint, bool reverse) const override;
void expand(const Filter & mask, bool reverse) override;
ColumnPtr permute(const Permutation & perm, size_t limit) const override;

View File

@ -551,6 +551,11 @@ ColumnPtr ColumnArray::filter(const Filter & filt, ssize_t result_size_hint, boo
return filterGeneric(filt, result_size_hint, reverse);
}
void ColumnArray::expand(const IColumn::Filter & mask, bool reverse)
{
expandOffsetsByMask(getOffsets(), mask, reverse);
}
template <typename T>
ColumnPtr ColumnArray::filterNumber(const Filter & filt, ssize_t result_size_hint, bool reverse) const
{

View File

@ -70,7 +70,8 @@ public:
void insertFrom(const IColumn & src_, size_t n) override;
void insertDefault() override;
void popBack(size_t n) override;
ColumnPtr filter(const Filter & filt, ssize_t result_size_hint, bool revers = false) const override;
ColumnPtr filter(const Filter & filt, ssize_t result_size_hint, bool revers) const override;
void expand(const Filter & mask, bool reverse) override;
ColumnPtr permute(const Permutation & perm, size_t limit) const override;
ColumnPtr index(const IColumn & indexes, size_t limit) const override;
template <typename Type> ColumnPtr indexImpl(const PaddedPODArray<Type> & indexes, size_t limit) const;

View File

@ -90,6 +90,7 @@ public:
void updateWeakHash32(WeakHash32 &) const override { throwMustBeDecompressed(); }
void updateHashFast(SipHash &) const override { throwMustBeDecompressed(); }
ColumnPtr filter(const Filter &, ssize_t, bool) const override { throwMustBeDecompressed(); }
void expand(const Filter &, bool) override { throwMustBeDecompressed(); }
ColumnPtr permute(const Permutation &, size_t) const override { throwMustBeDecompressed(); }
ColumnPtr index(const IColumn &, size_t) const override { throwMustBeDecompressed(); }
int compareAt(size_t, size_t, const IColumn &, int) const override { throwMustBeDecompressed(); }

View File

@ -65,6 +65,24 @@ ColumnPtr ColumnConst::filter(const Filter & filt, ssize_t /*result_size_hint*/,
return ColumnConst::create(data, new_size);
}
void ColumnConst::expand(const Filter & mask, bool reverse)
{
if (mask.size() < s)
throw Exception("Mask size should be no less than data size.", ErrorCodes::LOGICAL_ERROR);
size_t bytes_count = countBytesInFilter(mask);
if (reverse)
bytes_count = mask.size() - bytes_count;
if (bytes_count < s)
throw Exception("Not enough bytes in mask", ErrorCodes::LOGICAL_ERROR);
else if (bytes_count > s)
throw Exception("Too many bytes in mask", ErrorCodes::LOGICAL_ERROR);
s = mask.size();
}
ColumnPtr ColumnConst::replicate(const Offsets & offsets) const
{
if (s != offsets.size())

View File

@ -180,7 +180,9 @@ public:
data->updateHashFast(hash);
}
ColumnPtr filter(const Filter & filt, ssize_t result_size_hint, bool reverse = false) const override;
ColumnPtr filter(const Filter & filt, ssize_t result_size_hint, bool reverse) const override;
void expand(const Filter & mask, bool reverse) override;
ColumnPtr replicate(const Offsets & offsets) const override;
ColumnPtr permute(const Permutation & perm, size_t limit) const override;
ColumnPtr index(const IColumn & indexes, size_t limit) const override;

View File

@ -15,6 +15,7 @@
#include <Columns/ColumnsCommon.h>
#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnCompressed.h>
#include <Columns/MaskOperations.h>
#include <DataStreams/ColumnGathererStream.h>
@ -320,6 +321,12 @@ ColumnPtr ColumnDecimal<T>::filter(const IColumn::Filter & filt, ssize_t result_
return res;
}
template <typename T>
void ColumnDecimal<T>::expand(const IColumn::Filter & mask, bool reverse)
{
expandDataByMask<T>(data, mask, reverse, T());
}
template <typename T>
ColumnPtr ColumnDecimal<T>::index(const IColumn & indexes, size_t limit) const
{

View File

@ -150,7 +150,9 @@ public:
UInt64 get64(size_t n) const override;
bool isDefaultAt(size_t n) const override { return data[n].value == 0; }
ColumnPtr filter(const IColumn::Filter & filt, ssize_t result_size_hint, bool reverse = false) const override;
ColumnPtr filter(const IColumn::Filter & filt, ssize_t result_size_hint, bool reverse) const override;
void expand(const IColumn::Filter & mask, bool reverse) override;
ColumnPtr permute(const IColumn::Permutation & perm, size_t limit) const override;
ColumnPtr index(const IColumn & indexes, size_t limit) const override;

View File

@ -344,6 +344,31 @@ ColumnPtr ColumnFixedString::filter(const IColumn::Filter & filt, ssize_t result
return res;
}
void ColumnFixedString::expand(const IColumn::Filter & mask, bool reverse)
{
if (mask.size() < size())
throw Exception("Mask size should be no less than data size.", ErrorCodes::LOGICAL_ERROR);
int index = mask.size() - 1;
int from = size() - 1;
while (index >= 0)
{
if (mask[index] ^ reverse)
{
if (from < 0)
throw Exception("Too many bytes in mask", ErrorCodes::LOGICAL_ERROR);
memcpySmallAllowReadWriteOverflow15(&chars[from * n], &chars[index * n], n);
--from;
}
--index;
}
if (from != -1)
throw Exception("Not enough bytes in mask", ErrorCodes::LOGICAL_ERROR);
}
ColumnPtr ColumnFixedString::permute(const Permutation & perm, size_t limit) const
{
size_t col_size = size();

View File

@ -145,7 +145,9 @@ public:
void insertRangeFrom(const IColumn & src, size_t start, size_t length) override;
ColumnPtr filter(const IColumn::Filter & filt, ssize_t result_size_hint, bool reverse = false) const override;
ColumnPtr filter(const IColumn::Filter & filt, ssize_t result_size_hint, bool reverse) const override;
void expand(const IColumn::Filter & mask, bool reverse) override;
ColumnPtr permute(const Permutation & perm, size_t limit) const override;

View File

@ -77,6 +77,17 @@ ColumnPtr ColumnFunction::filter(const Filter & filt, ssize_t result_size_hint,
return ColumnFunction::create(filtered_size, function, capture);
}
void ColumnFunction::expand(const Filter & mask, bool reverse)
{
for (auto & column : captured_columns)
{
column.column = column.column->cloneResized(column.column->size());
column.column->assumeMutable()->expand(mask, reverse);
}
size_ = mask.size();
}
ColumnPtr ColumnFunction::permute(const Permutation & perm, size_t limit) const
{
if (limit == 0)
@ -194,8 +205,6 @@ void ColumnFunction::appendArgument(const ColumnWithTypeAndName & column)
ColumnWithTypeAndName ColumnFunction::reduce(bool reduce_arguments) const
{
// LOG_DEBUG(&Poco::Logger::get("ColumnFunction"), "Reduce function: {}", function->getName());
auto args = function->getArgumentTypes().size();
auto captured = captured_columns.size();
@ -203,22 +212,17 @@ ColumnWithTypeAndName ColumnFunction::reduce(bool reduce_arguments) const
throw Exception("Cannot call function " + function->getName() + " because is has " + toString(args) +
"arguments but " + toString(captured) + " columns were captured.", ErrorCodes::LOGICAL_ERROR);
ColumnsWithTypeAndName columns;
if (reduce_arguments)
ColumnsWithTypeAndName columns = captured_columns;
if (function->isShortCircuit())
function->executeShortCircuitArguments(columns);
else if (reduce_arguments)
{
columns.reserve(captured_columns.size());
for (const auto & col : captured_columns)
for (auto & col : columns)
{
// LOG_DEBUG(&Poco::Logger::get("ColumnFunction"), "Arg type: {}", col.type->getName());
if (const auto * column_function = typeid_cast<const ColumnFunction *>(col.column.get()))
columns.push_back(column_function->reduce(true));
else
columns.push_back(col);
col = column_function->reduce(true);
}
}
else
columns = captured_columns;
ColumnWithTypeAndName res{nullptr, function->getResultType(), ""};

View File

@ -37,7 +37,8 @@ public:
ColumnPtr cut(size_t start, size_t length) const override;
ColumnPtr replicate(const Offsets & offsets) const override;
ColumnPtr filter(const Filter & filt, ssize_t result_size_hint, bool reverse = false) const override;
ColumnPtr filter(const Filter & filt, ssize_t result_size_hint, bool reverse) const override;
void expand(const Filter & mask, bool reverse) override;
ColumnPtr permute(const Permutation & perm, size_t limit) const override;
ColumnPtr index(const IColumn & indexes, size_t limit) const override;
@ -157,6 +158,7 @@ private:
size_t size_;
FunctionBasePtr function;
ColumnsWithTypeAndName captured_columns;
bool is_short_circuit_argumentz;
void appendArgument(const ColumnWithTypeAndName & column);
};

View File

@ -105,11 +105,16 @@ public:
void updateHashFast(SipHash &) const override;
ColumnPtr filter(const Filter & filt, ssize_t result_size_hint, bool reverse = false) const override
ColumnPtr filter(const Filter & filt, ssize_t result_size_hint, bool reverse) const override
{
return ColumnLowCardinality::create(dictionary.getColumnUniquePtr(), getIndexes().filter(filt, result_size_hint, reverse));
}
void expand(const Filter & mask, bool reverse) override
{
idx.getPositionsPtr()->expand(mask, reverse);
}
ColumnPtr permute(const Permutation & perm, size_t limit) const override
{
return ColumnLowCardinality::create(dictionary.getColumnUniquePtr(), getIndexes().permute(perm, limit));

View File

@ -149,6 +149,11 @@ ColumnPtr ColumnMap::filter(const Filter & filt, ssize_t result_size_hint, bool
return ColumnMap::create(filtered);
}
void ColumnMap::expand(const IColumn::Filter & mask, bool reverse)
{
nested->expand(mask, reverse);
}
ColumnPtr ColumnMap::permute(const Permutation & perm, size_t limit) const
{
auto permuted = nested->permute(perm, limit);

View File

@ -63,7 +63,8 @@ public:
void updateWeakHash32(WeakHash32 & hash) const override;
void updateHashFast(SipHash & hash) const override;
void insertRangeFrom(const IColumn & src, size_t start, size_t length) override;
ColumnPtr filter(const Filter & filt, ssize_t result_size_hint, bool reverse = false) const override;
ColumnPtr filter(const Filter & filt, ssize_t result_size_hint, bool reverse) const override;
void expand(const Filter & mask, bool reverse) override;
ColumnPtr permute(const Permutation & perm, size_t limit) const override;
ColumnPtr index(const IColumn & indexes, size_t limit) const override;
ColumnPtr replicate(const Offsets & offsets) const override;

View File

@ -221,6 +221,12 @@ ColumnPtr ColumnNullable::filter(const Filter & filt, ssize_t result_size_hint,
return ColumnNullable::create(filtered_data, filtered_null_map);
}
void ColumnNullable::expand(const IColumn::Filter & mask, bool reverse)
{
nested_column->expand(mask, reverse);
null_map->expand(mask, reverse);
}
ColumnPtr ColumnNullable::permute(const Permutation & perm, size_t limit) const
{
ColumnPtr permuted_data = getNestedColumn().permute(perm, limit);

View File

@ -87,7 +87,8 @@ public:
}
void popBack(size_t n) override;
ColumnPtr filter(const Filter & filt, ssize_t result_size_hint, bool reverse = false) const override;
ColumnPtr filter(const Filter & filt, ssize_t result_size_hint, bool reverse) const override;
void expand(const Filter & mask, bool reverse) override;
ColumnPtr permute(const Permutation & perm, size_t limit) const override;
ColumnPtr index(const IColumn & indexes, size_t limit) const override;
int compareAt(size_t n, size_t m, const IColumn & rhs_, int null_direction_hint) const override;

View File

@ -157,6 +157,11 @@ ColumnPtr ColumnString::filter(const Filter & filt, ssize_t result_size_hint, bo
return res;
}
void ColumnString::expand(const IColumn::Filter & mask, bool reverse)
{
expandOffsetsByMask(offsets, mask, reverse);
}
ColumnPtr ColumnString::permute(const Permutation & perm, size_t limit) const
{

View File

@ -210,7 +210,9 @@ public:
void insertRangeFrom(const IColumn & src, size_t start, size_t length) override;
ColumnPtr filter(const Filter & filt, ssize_t result_size_hint, bool reverse = false) const override;
ColumnPtr filter(const Filter & filt, ssize_t result_size_hint, bool reverse) const override;
void expand(const Filter & mask, bool reverse) override;
ColumnPtr permute(const Permutation & perm, size_t limit) const override;

View File

@ -232,6 +232,12 @@ ColumnPtr ColumnTuple::filter(const Filter & filt, ssize_t result_size_hint, boo
return ColumnTuple::create(new_columns);
}
void ColumnTuple::expand(const Filter & mask, bool reverse)
{
for (auto & column : columns)
column->expand(mask, reverse);
}
ColumnPtr ColumnTuple::permute(const Permutation & perm, size_t limit) const
{
const size_t tuple_size = columns.size();

View File

@ -66,7 +66,8 @@ public:
void updateWeakHash32(WeakHash32 & hash) const override;
void updateHashFast(SipHash & hash) const override;
void insertRangeFrom(const IColumn & src, size_t start, size_t length) override;
ColumnPtr filter(const Filter & filt, ssize_t result_size_hint, bool reverse = false) const override;
ColumnPtr filter(const Filter & filt, ssize_t result_size_hint, bool reverse) const override;
void expand(const Filter & mask, bool reverse) override;
ColumnPtr permute(const Permutation & perm, size_t limit) const override;
ColumnPtr index(const IColumn & indexes, size_t limit) const override;
ColumnPtr replicate(const Offsets & offsets) const override;

View File

@ -3,6 +3,7 @@
#include <pdqsort.h>
#include <Columns/ColumnsCommon.h>
#include <Columns/ColumnCompressed.h>
#include <Columns/MaskOperations.h>
#include <DataStreams/ColumnGathererStream.h>
#include <IO/WriteHelpers.h>
#include <Common/Arena.h>
@ -408,6 +409,12 @@ ColumnPtr ColumnVector<T>::filter(const IColumn::Filter & filt, ssize_t result_s
return res;
}
template <typename T>
void ColumnVector<T>::expand(const IColumn::Filter & mask, bool reverse)
{
expandDataByMask<T>(data, mask, reverse, T());
}
template <typename T>
void ColumnVector<T>::applyZeroMap(const IColumn::Filter & filt, bool inverted)
{

View File

@ -239,6 +239,7 @@ public:
return data[n];
}
void get(size_t n, Field & res) const override
{
res = (*this)[n];
@ -282,7 +283,9 @@ public:
void insertRangeFrom(const IColumn & src, size_t start, size_t length) override;
ColumnPtr filter(const IColumn::Filter & filt, ssize_t result_size_hint, bool reverse = false) const override;
ColumnPtr filter(const IColumn::Filter & filt, ssize_t result_size_hint, bool reverse) const override;
void expand(const IColumn::Filter & mask, bool reverse) override;
ColumnPtr permute(const IColumn::Permutation & perm, size_t limit) const override;

View File

@ -41,6 +41,11 @@ void filterArraysImplOnlyData(
PaddedPODArray<T> & res_elems,
const IColumn::Filter & filt, ssize_t result_size_hint, bool reverse = false);
template <typename Container, typename T>
void expandDataByMask(Container & data, const PaddedPODArray<UInt8> & mask, bool reverse, T default_value);
void expandOffsetsByMask(PaddedPODArray<UInt64> & offsets, const PaddedPODArray<UInt8> & mask, bool reverse);
namespace detail
{
template <typename T>
@ -70,6 +75,7 @@ ColumnPtr selectIndexImpl(const Column & column, const IColumn & indexes, size_t
ErrorCodes::LOGICAL_ERROR);
}
#define INSTANTIATE_INDEX_IMPL(Column) \
template ColumnPtr Column::indexImpl<UInt8>(const PaddedPODArray<UInt8> & indexes, size_t limit) const; \
template ColumnPtr Column::indexImpl<UInt16>(const PaddedPODArray<UInt16> & indexes, size_t limit) const; \

View File

@ -236,6 +236,11 @@ public:
using Filter = PaddedPODArray<UInt8>;
virtual Ptr filter(const Filter & filt, ssize_t result_size_hint, bool reverse = false) const = 0;
virtual void expand(const Filter &, bool)
{
throw Exception("expand function is not implemented", ErrorCodes::NOT_IMPLEMENTED);
}
/// Permutes elements using specified permutation. Is used in sorting.
/// limit - if it isn't 0, puts only first limit elements in the result.
using Permutation = PaddedPODArray<size_t>;

View File

@ -106,6 +106,14 @@ public:
return cloneDummy(bytes);
}
void expand(const IColumn::Filter & mask, bool reverse) override
{
size_t bytes = countBytesInFilter(mask);
if (reverse)
bytes = mask.size() - bytes;
s = bytes;
}
ColumnPtr permute(const Permutation & perm, size_t limit) const override
{
if (s != perm.size())

View File

@ -139,6 +139,11 @@ public:
throw Exception("Method filter is not supported for ColumnUnique.", ErrorCodes::NOT_IMPLEMENTED);
}
void expand(const IColumn::Filter &, bool) override
{
throw Exception("Method expand is not supported for ColumnUnique.", ErrorCodes::NOT_IMPLEMENTED);
}
ColumnPtr permute(const IColumn::Permutation &, size_t) const override
{
throw Exception("Method permute is not supported for ColumnUnique.", ErrorCodes::NOT_IMPLEMENTED);

View File

@ -4,42 +4,142 @@
#include <Columns/ColumnConst.h>
#include <Columns/ColumnNothing.h>
#include <Columns/ColumnLowCardinality.h>
#include <common/logger_useful.h>
#include <Columns/ColumnsCommon.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int LOGICAL_ERROR;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
void expandColumnByMask(ColumnPtr & column, const PaddedPODArray<UInt8>& mask, Field * field, bool reverse)
template <typename T>
void expandDataByMask(PaddedPODArray<T> & data, const PaddedPODArray<UInt8> & mask, bool reverse, T default_value)
{
MutableColumnPtr res = column->cloneEmpty();
res->reserve(mask.size());
size_t index = 0;
for (size_t i = 0; i != mask.size(); ++i)
{
if (reverse ^ mask[i])
{
if (index >= column->size())
throw Exception("Too many bits in mask", ErrorCodes::LOGICAL_ERROR);
if (mask.size() < data.size())
throw Exception("Mask size should be no less than data size.", ErrorCodes::LOGICAL_ERROR);
res->insert((*column)[index]);
++index;
int from = data.size() - 1;
int index = mask.size() - 1;
data.resize(mask.size());
while (index >= 0)
{
if (mask[index] ^ reverse)
{
if (from < 0)
throw Exception("Too many bytes in mask", ErrorCodes::LOGICAL_ERROR);
data[index] = data[from];
--from;
}
else if (field)
res->insert(*field);
else
res->insertDefault();
data[index] = default_value;
--index;
}
if (index < column->size())
throw Exception("Too less bits in mask", ErrorCodes::LOGICAL_ERROR);
column = std::move(res);
if (from != -1)
throw Exception("Not enough bytes in mask", ErrorCodes::LOGICAL_ERROR);
}
/// Explicit instantiations - not to place the implementation of the function above in the header file.
#define INSTANTIATE(TYPE) \
template void expandDataByMask<TYPE>(PaddedPODArray<TYPE> &, const PaddedPODArray<UInt8> &, bool, TYPE);
INSTANTIATE(UInt8)
INSTANTIATE(UInt16)
INSTANTIATE(UInt32)
INSTANTIATE(UInt64)
INSTANTIATE(UInt128)
INSTANTIATE(UInt256)
INSTANTIATE(Int8)
INSTANTIATE(Int16)
INSTANTIATE(Int32)
INSTANTIATE(Int64)
INSTANTIATE(Int128)
INSTANTIATE(Int256)
INSTANTIATE(Float32)
INSTANTIATE(Float64)
INSTANTIATE(Decimal32)
INSTANTIATE(Decimal64)
INSTANTIATE(Decimal128)
INSTANTIATE(Decimal256)
INSTANTIATE(DateTime64)
INSTANTIATE(char *)
#undef INSTANTIATE
void expandOffsetsByMask(PaddedPODArray<UInt64> & offsets, const PaddedPODArray<UInt8> & mask, bool reverse)
{
if (mask.size() < offsets.size())
throw Exception("Mask size should be no less than data size.", ErrorCodes::LOGICAL_ERROR);
int index = mask.size() - 1;
int from = offsets.size() - 1;
offsets.resize(mask.size());
UInt64 prev_offset = offsets[from];
while (index >= 0)
{
if (mask[index] ^ reverse)
{
if (from < 0)
throw Exception("Too many bytes in mask", ErrorCodes::LOGICAL_ERROR);
offsets[index] = offsets[from];
--from;
prev_offset = offsets[from];
}
else
offsets[index] = prev_offset;
--index;
}
if (from != -1)
throw Exception("Not enough bytes in mask", ErrorCodes::LOGICAL_ERROR);
}
template <typename ValueType>
bool tryExpandMaskColumnByMask(const ColumnPtr & column, const PaddedPODArray<UInt8> & mask, bool reverse, UInt8 default_value_for_expanding_mask)
{
if (const auto * col = checkAndGetColumn<ColumnVector<ValueType>>(*column))
{
expandDataByMask<ValueType>(const_cast<ColumnVector<ValueType> *>(col)->getData(), mask, reverse, default_value_for_expanding_mask);
return true;
}
return false;
}
void expandMaskColumnByMask(const ColumnPtr & column, const PaddedPODArray<UInt8>& mask, bool reverse, UInt8 default_value_for_expanding_mask)
{
if (const auto * col = checkAndGetColumn<ColumnNullable>(column.get()))
{
expandMaskColumnByMask(col->getNullMapColumnPtr(), mask, reverse, 0);
expandMaskColumnByMask(col->getNestedColumnPtr(), mask, reverse, default_value_for_expanding_mask);
return;
}
if (!tryExpandMaskColumnByMask<Int8>(column, mask, reverse, default_value_for_expanding_mask) &&
!tryExpandMaskColumnByMask<Int16>(column, mask, reverse, default_value_for_expanding_mask) &&
!tryExpandMaskColumnByMask<Int32>(column, mask, reverse, default_value_for_expanding_mask) &&
!tryExpandMaskColumnByMask<Int64>(column, mask, reverse, default_value_for_expanding_mask) &&
!tryExpandMaskColumnByMask<UInt8>(column, mask, reverse, default_value_for_expanding_mask) &&
!tryExpandMaskColumnByMask<UInt16>(column, mask, reverse, default_value_for_expanding_mask) &&
!tryExpandMaskColumnByMask<UInt32>(column, mask, reverse, default_value_for_expanding_mask) &&
!tryExpandMaskColumnByMask<UInt64>(column, mask, reverse, default_value_for_expanding_mask) &&
!tryExpandMaskColumnByMask<Float32>(column, mask, reverse, default_value_for_expanding_mask) &&
!tryExpandMaskColumnByMask<Float64>(column, mask, reverse, default_value_for_expanding_mask))
throw Exception("Cannot convert column " + column.get()->getName() + " to mask", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
void expandColumnByMask(const ColumnPtr & column, const PaddedPODArray<UInt8>& mask, bool reverse)
{
column->assumeMutable()->expand(mask, reverse);
}
template <typename ValueType>
@ -125,7 +225,7 @@ void disjunctionMasks(PaddedPODArray<UInt8> & mask1, const PaddedPODArray<UInt8>
binaryMasksOperationImpl(mask1, mask2, [](const auto & lhs, const auto & rhs){ return lhs | rhs; });
}
void maskedExecute(ColumnWithTypeAndName & column, const PaddedPODArray<UInt8> & mask, Field * default_value, bool reverse)
void maskedExecute(ColumnWithTypeAndName & column, const PaddedPODArray<UInt8> & mask, bool reverse, const UInt8 * default_value_for_expanding_mask)
{
const auto * column_function = checkAndGetColumn<ColumnFunction>(*column.column);
if (!column_function)
@ -133,7 +233,13 @@ void maskedExecute(ColumnWithTypeAndName & column, const PaddedPODArray<UInt8> &
auto filtered = column_function->filter(mask, -1, reverse);
auto result = typeid_cast<const ColumnFunction *>(filtered.get())->reduce(true);
expandColumnByMask(result.column, mask, default_value, reverse);
if (default_value_for_expanding_mask)
{
result.column = result.column->convertToFullColumnIfLowCardinality();
expandMaskColumnByMask(result.column, mask, reverse, *default_value_for_expanding_mask);
}
else
expandColumnByMask(result.column, mask, reverse);
column = std::move(result);
}
@ -146,4 +252,14 @@ void executeColumnIfNeeded(ColumnWithTypeAndName & column)
column = typeid_cast<const ColumnFunction *>(column_function)->reduce(true);
}
bool checkArgumentsForColumnFunction(const ColumnsWithTypeAndName & arguments)
{
for (const auto & arg : arguments)
{
if (const auto * col = checkAndGetColumn<ColumnFunction>(*arg.column))
return true;
}
return false;
}
}

View File

@ -1,20 +1,32 @@
#pragma once
#include <Core/ColumnWithTypeAndName.h>
#include <Core/ColumnsWithTypeAndName.h>
#include <Core/Field.h>
#include <Common/PODArray.h>
namespace DB
{
template <typename T>
void expandDataByMask(PaddedPODArray<T> & data, const PaddedPODArray<UInt8> & mask, bool reverse, T default_value);
void expandOffsetsByMask(PaddedPODArray<UInt64> & offsets, const PaddedPODArray<UInt8> & mask, bool reverse);
void getMaskFromColumn(const ColumnPtr & column, PaddedPODArray<UInt8> & mask, bool reverse = false, const PaddedPODArray<UInt8> * null_bytemap = nullptr, UInt8 null_value = 1);
void conjunctionMasks(PaddedPODArray<UInt8> & mask1, const PaddedPODArray<UInt8> & mask2);
void disjunctionMasks(PaddedPODArray<UInt8> & mask1, const PaddedPODArray<UInt8> & mask2);
void maskedExecute(ColumnWithTypeAndName & column, const PaddedPODArray<UInt8> & mask, Field * default_value = nullptr, bool reverse = false);
void maskedExecute(ColumnWithTypeAndName & column, const PaddedPODArray<UInt8> & mask, bool reverse = false, const UInt8 * default_value_for_expanding_mask = nullptr);
void expandColumnByMask(const ColumnPtr & column, const PaddedPODArray<UInt8>& mask, bool reverse);
void expandMaskColumnByMask(const ColumnPtr & column, const PaddedPODArray<UInt8>& mask, bool reverse, UInt8 default_value = 0);
void executeColumnIfNeeded(ColumnWithTypeAndName & column);
bool checkArgumentsForColumnFunction(const ColumnsWithTypeAndName & arguments);
}

View File

@ -955,6 +955,8 @@ public:
size_t getNumberOfArguments() const override { return 2; }
bool isSuitableForShortCircuitArgumentsExecution() const override { return IsOperation<Op>::can_throw; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
return getReturnTypeImplStatic(arguments, context);

View File

@ -20,6 +20,8 @@ private:
size_t getNumberOfArguments() const override { return 0; }
bool isSuitableForShortCircuitArgumentsExecution() const override { return false; }
DataTypePtr getReturnTypeImpl(const DataTypes & /*arguments*/) const override
{
return std::make_shared<DataTypeFloat64>();

View File

@ -41,6 +41,8 @@ private:
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 1; }
bool isSuitableForShortCircuitArgumentsExecution() const override { return false; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
const auto & arg = arguments.front();

View File

@ -120,6 +120,7 @@ public:
size_t getNumberOfArguments() const override { return 1; }
bool isInjective(const ColumnsWithTypeAndName &) const override { return is_injective; }
bool isSuitableForShortCircuitArgumentsExecution() const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }

View File

@ -1071,6 +1071,8 @@ public:
size_t getNumberOfArguments() const override { return 2; }
bool isSuitableForShortCircuitArgumentsExecution() const override { return false; }
/// Get result types by argument types. If the function does not apply to these arguments, throw an exception.
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{

View File

@ -515,9 +515,12 @@ void FunctionAnyArityLogical<Impl, Name>::executeShortCircuitArguments(ColumnsWi
if (Name::name != NameAnd::name && Name::name != NameOr::name)
throw Exception("Function " + getName() + " doesn't support short circuit execution", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
Field default_value = Name::name == NameAnd::name ? 0 : 1;
if (!checkArgumentsForColumnFunction(arguments))
return;
bool reverse = Name::name != NameAnd::name;
UInt8 null_value = Name::name == NameAnd::name ? 1 : 0;
UInt8 value_for_mask_expanding = Name::name == NameAnd::name ? 0 : 1;
executeColumnIfNeeded(arguments[0]);
IColumn::Filter mask;
getMaskFromColumn(arguments[0].column, mask, reverse, nullptr, null_value);
@ -525,7 +528,8 @@ void FunctionAnyArityLogical<Impl, Name>::executeShortCircuitArguments(ColumnsWi
for (size_t i = 1; i < arguments.size(); ++i)
{
if (isColumnFunction(*arguments[i].column))
maskedExecute(arguments[i], mask, &default_value, false);
maskedExecute(arguments[i], mask, false, &value_for_mask_expanding);
getMaskFromColumn(arguments[i].column, mask, reverse, nullptr, null_value);
}
}

View File

@ -29,13 +29,14 @@
* Functions AND and OR provide their own special implementations for ternary logic
*/
namespace DB
{
struct NameAnd { static constexpr auto name = "and"; };
struct NameOr { static constexpr auto name = "or"; };
struct NameXor { static constexpr auto name = "xor"; };
struct NameNot { static constexpr auto name = "not"; };
namespace DB
{
namespace FunctionsLogicalDetail
{
namespace Ternary

View File

@ -60,6 +60,7 @@ public:
bool isDeterministic() const override { return false; }
bool isDeterministicInScopeOfQuery() const override { return false; }
bool useDefaultImplementationForNulls() const override { return false; }
bool isSuitableForShortCircuitArgumentsExecution() const override { return false; }
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }

View File

@ -213,6 +213,8 @@ public:
virtual bool isShortCircuit() const { return false; }
virtual bool isSuitableForShortCircuitArgumentsExecution() const { return true; }
virtual void executeShortCircuitArguments(ColumnsWithTypeAndName & /*arguments*/) const
{
throw Exception("Function " + getName() + " doesn't support short circuit execution", ErrorCodes::NOT_IMPLEMENTED);
@ -277,6 +279,8 @@ public:
virtual bool isShortCircuit() const { return false; }
virtual bool isSuitableForShortCircuitArgumentsExecution() const { return true; }
virtual void executeShortCircuitArguments(ColumnsWithTypeAndName & /*arguments*/) const
{
throw Exception("Function " + getName() + " doesn't support short circuit execution", ErrorCodes::NOT_IMPLEMENTED);

View File

@ -10,154 +10,7 @@ namespace DB
class FunctionToExecutableFunctionAdaptor final : public IExecutableFunction
{
public:
<<<<<<< HEAD
explicit FunctionToExecutableFunctionAdaptor(std::shared_ptr<IFunction> function_) : function(std::move(function_)) {}
=======
explicit ExecutableFunctionAdaptor(ExecutableFunctionImplPtr impl_) : impl(std::move(impl_)) {}
String getName() const final { return impl->getName(); }
ColumnPtr execute(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count, bool dry_run) const final;
void createLowCardinalityResultCache(size_t cache_size) override;
private:
ExecutableFunctionImplPtr impl;
/// Cache is created by function createLowCardinalityResultCache()
ExecutableFunctionLowCardinalityResultCachePtr low_cardinality_result_cache;
ColumnPtr defaultImplementationForConstantArguments(
const ColumnsWithTypeAndName & args, const DataTypePtr & result_type, size_t input_rows_count, bool dry_run) const;
ColumnPtr defaultImplementationForNulls(
const ColumnsWithTypeAndName & args, const DataTypePtr & result_type, size_t input_rows_count, bool dry_run) const;
ColumnPtr executeWithoutLowCardinalityColumns(
const ColumnsWithTypeAndName & args, const DataTypePtr & result_type, size_t input_rows_count, bool dry_run) const;
};
class FunctionBaseAdaptor final : public IFunctionBase
{
public:
explicit FunctionBaseAdaptor(FunctionBaseImplPtr impl_) : impl(std::move(impl_)) {}
String getName() const final { return impl->getName(); }
const DataTypes & getArgumentTypes() const final { return impl->getArgumentTypes(); }
const DataTypePtr & getResultType() const final { return impl->getResultType(); }
ExecutableFunctionPtr prepare(const ColumnsWithTypeAndName & arguments) const final
{
return std::make_shared<ExecutableFunctionAdaptor>(impl->prepare(arguments));
}
#if USE_EMBEDDED_COMPILER
bool isCompilable() const final { return impl->isCompilable(); }
llvm::Value * compile(llvm::IRBuilderBase & builder, Values values) const override
{
return impl->compile(builder, std::move(values));
}
#endif
bool isStateful() const final { return impl->isStateful(); }
bool isSuitableForConstantFolding() const final { return impl->isSuitableForConstantFolding(); }
ColumnPtr getResultIfAlwaysReturnsConstantAndHasArguments(const ColumnsWithTypeAndName & arguments) const final
{
return impl->getResultIfAlwaysReturnsConstantAndHasArguments(arguments);
}
bool isInjective(const ColumnsWithTypeAndName & sample_columns) const final { return impl->isInjective(sample_columns); }
bool isDeterministic() const final { return impl->isDeterministic(); }
bool isDeterministicInScopeOfQuery() const final { return impl->isDeterministicInScopeOfQuery(); }
bool isShortCircuit() const final { return impl->isShortCircuit(); }
void executeShortCircuitArguments(ColumnsWithTypeAndName & arguments) const override
{
impl->executeShortCircuitArguments(arguments);
}
bool hasInformationAboutMonotonicity() const final { return impl->hasInformationAboutMonotonicity(); }
Monotonicity getMonotonicityForRange(const IDataType & type, const Field & left, const Field & right) const final
{
return impl->getMonotonicityForRange(type, left, right);
}
const IFunctionBaseImpl * getImpl() const { return impl.get(); }
private:
FunctionBaseImplPtr impl;
};
class FunctionOverloadResolverAdaptor final : public IFunctionOverloadResolver
{
public:
explicit FunctionOverloadResolverAdaptor(FunctionOverloadResolverImplPtr impl_) : impl(std::move(impl_)) {}
String getName() const final { return impl->getName(); }
bool isDeterministic() const final { return impl->isDeterministic(); }
bool isDeterministicInScopeOfQuery() const final { return impl->isDeterministicInScopeOfQuery(); }
bool isInjective(const ColumnsWithTypeAndName & columns) const final { return impl->isInjective(columns); }
bool isStateful() const final { return impl->isStateful(); }
bool isVariadic() const final { return impl->isVariadic(); }
bool isShortCircuit() const final { return impl->isShortCircuit(); }
size_t getNumberOfArguments() const final { return impl->getNumberOfArguments(); }
void checkNumberOfArguments(size_t number_of_arguments) const final;
FunctionBaseImplPtr buildImpl(const ColumnsWithTypeAndName & arguments) const
{
return impl->build(arguments, getReturnType(arguments));
}
FunctionBasePtr build(const ColumnsWithTypeAndName & arguments) const final
{
return std::make_shared<FunctionBaseAdaptor>(buildImpl(arguments));
}
void getLambdaArgumentTypes(DataTypes & arguments) const final
{
checkNumberOfArguments(arguments.size());
impl->getLambdaArgumentTypes(arguments);
}
ColumnNumbers getArgumentsThatAreAlwaysConstant() const final { return impl->getArgumentsThatAreAlwaysConstant(); }
ColumnNumbers getArgumentsThatDontImplyNullableReturnType(size_t number_of_arguments) const final
{
return impl->getArgumentsThatDontImplyNullableReturnType(number_of_arguments);
}
using DefaultReturnTypeGetter = std::function<DataTypePtr(const ColumnsWithTypeAndName &)>;
static DataTypePtr getReturnTypeDefaultImplementationForNulls(const ColumnsWithTypeAndName & arguments, const DefaultReturnTypeGetter & getter);
private:
FunctionOverloadResolverImplPtr impl;
DataTypePtr getReturnTypeWithoutLowCardinality(const ColumnsWithTypeAndName & arguments) const;
DataTypePtr getReturnType(const ColumnsWithTypeAndName & arguments) const;
};
/// Following classes are implement IExecutableFunctionImpl, IFunctionBaseImpl and IFunctionOverloadResolverImpl via IFunction.
class DefaultExecutable final : public IExecutableFunctionImpl
{
public:
explicit DefaultExecutable(std::shared_ptr<IFunction> function_) : function(std::move(function_)) {}
>>>>>>> Fix tests
String getName() const override { return function->getName(); }
@ -229,11 +82,8 @@ public:
bool isShortCircuit() const override { return function->isShortCircuit(); }
<<<<<<< HEAD
bool isSuitableForShortCircuitArgumentsExecution() const override { return function->isSuitableForShortCircuitArgumentsExecution(); }
=======
>>>>>>> Fix tests
void executeShortCircuitArguments(ColumnsWithTypeAndName & args) const override
{
function->executeShortCircuitArguments(args);
@ -267,16 +117,7 @@ public:
bool isStateful() const override { return function->isStateful(); }
bool isVariadic() const override { return function->isVariadic(); }
bool isShortCircuit() const override { return function->isShortCircuit(); }
<<<<<<< HEAD
bool isSuitableForShortCircuitArgumentsExecution() const override { return function->isSuitableForShortCircuitArgumentsExecution(); }
=======
void executeShortCircuitArguments(ColumnsWithTypeAndName & arguments) const override
{
function->executeShortCircuitArguments(arguments);
}
>>>>>>> Fix tests
size_t getNumberOfArguments() const override { return function->getNumberOfArguments(); }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return function->getArgumentsThatAreAlwaysConstant(); }

View File

@ -61,6 +61,8 @@ struct IsOperation
plus || minus || multiply ||
div_floating || div_int || div_int_or_zero ||
least || greatest;
static constexpr bool can_throw = div_int || modulo;
};
}

View File

@ -922,13 +922,16 @@ public:
void executeShortCircuitArguments(ColumnsWithTypeAndName & arguments) const override
{
if (!checkArgumentsForColumnFunction(arguments))
return;
executeColumnIfNeeded(arguments[0]);
if (isColumnFunction(*arguments[1].column) || isColumnFunction(*arguments[2].column))
{
IColumn::Filter mask;
getMaskFromColumn(arguments[0].column, mask);
maskedExecute(arguments[1], mask);
maskedExecute(arguments[2], mask, nullptr, /*reverse=*/true);
maskedExecute(arguments[2], mask, /*reverse=*/true);
}
}

View File

@ -110,11 +110,14 @@ public:
void executeShortCircuitArguments(ColumnsWithTypeAndName & arguments) const override
{
if (!checkArgumentsForColumnFunction(arguments))
return;
executeColumnIfNeeded(arguments[0]);
IColumn::Filter current_mask;
IColumn::Filter mask_disjunctions = IColumn::Filter(arguments[0].column->size(), 0);
auto default_value = std::make_unique<Field>(0);
UInt8 default_value_for_mask_expanding = 0;
size_t i = 1;
while (i < arguments.size())
{
@ -124,11 +127,14 @@ public:
maskedExecute(arguments[i], current_mask);
++i;
if (i == arguments.size() - 1)
default_value = nullptr;
if (isColumnFunction(*arguments[i].column))
maskedExecute(arguments[i], mask_disjunctions, default_value.get(), true);
{
if (i < arguments.size() - 1)
maskedExecute(arguments[i], mask_disjunctions, true, &default_value_for_mask_expanding);
else
maskedExecute(arguments[i], mask_disjunctions, true);
}
++i;
}

View File

@ -52,6 +52,8 @@ public:
return return_type;
}
bool isSuitableForShortCircuitArgumentsExecution() const override { return false; }
ExecutableFunctionPtr prepare(const ColumnsWithTypeAndName &) const override
{
return std::make_unique<ExecutableFunctionRandomConstant<ToType, Name>>(value);
@ -79,6 +81,8 @@ public:
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
bool isSuitableForShortCircuitArgumentsExecution() const override { return false; }
static FunctionOverloadResolverPtr create(ContextPtr)
{
return std::make_unique<RandomConstantOverloadResolver<ToType, Name>>();

View File

@ -26,6 +26,10 @@ public:
bool useDefaultImplementationForNulls() const override { return false; }
bool isShortCircuit() const override { return true; }
void executeShortCircuitArguments(ColumnsWithTypeAndName & /*arguments*/) const override {}
size_t getNumberOfArguments() const override
{
return 1;

View File

@ -29,6 +29,7 @@ public:
size_t getNumberOfArguments() const override { return 1; }
bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
bool isSuitableForShortCircuitArgumentsExecution() const override { return false; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{

View File

@ -29,6 +29,7 @@ public:
}
bool useDefaultImplementationForNulls() const override { return false; }
bool isShortCircuit() const override { return true; }
bool useDefaultImplementationForLowCardinalityColumns() const override { return false; }

View File

@ -69,21 +69,84 @@ ExpressionActionsPtr ExpressionActions::clone() const
return std::make_shared<ExpressionActions>(*this);
}
bool ExpressionActions::rewriteShortCircuitArguments(const ActionsDAG::NodeRawConstPtrs & children, const std::unordered_map<const ActionsDAG::Node *, bool> & need_outside, bool force_rewrite)
{
bool have_rewritten_child = false;
for (const auto * child : children)
{
if (!need_outside.contains(child) || need_outside.at(child) || child->is_lazy_executed)
continue;
switch (child->type)
{
case ActionsDAG::ActionType::FUNCTION:
if (rewriteShortCircuitArguments(child->children, need_outside, force_rewrite) || child->function_base->isSuitableForShortCircuitArgumentsExecution() || force_rewrite)
{
const_cast<ActionsDAG::Node *>(child)->is_lazy_executed = true;
have_rewritten_child = true;
}
break;
case ActionsDAG::ActionType::ALIAS:
have_rewritten_child |= rewriteShortCircuitArguments(child->children, need_outside, force_rewrite);
break;
default:
break;
}
}
return have_rewritten_child;
}
void ExpressionActions::rewriteArgumentsForShortCircuitFunctions(
const std::list<ActionsDAG::Node> & nodes,
const std::vector<Data> & data,
const std::unordered_map<const ActionsDAG::Node *, size_t> & reverse_index)
{
for (const auto & node : nodes)
{
if (node.type == ActionsDAG::ActionType::FUNCTION && node.function_base->isShortCircuit())
{
std::unordered_map<const ActionsDAG::Node *, bool> need_outside;
std::deque<const ActionsDAG::Node *> queue;
for (const auto * child : node.children)
queue.push_back(child);
need_outside[&node] = false;
while (!queue.empty())
{
const ActionsDAG::Node * cur = queue.front();
queue.pop_front();
if (need_outside.contains(cur))
continue;
if (data[reverse_index.at(cur)].used_in_result)
need_outside[cur] = true;
else
{
bool is_need_outside = false;
for (const auto * parent : data[reverse_index.at(cur)].parents)
{
if (!need_outside.contains(parent) || need_outside[parent])
{
is_need_outside = true;
break;
}
}
need_outside[cur] = is_need_outside;
}
for (const auto * child : cur->children)
queue.push_back(child);
}
bool force_rewrite = (node.children.size() == 1);
rewriteShortCircuitArguments(node.children, need_outside, force_rewrite);
}
}
}
void ExpressionActions::linearizeActions()
{
/// This function does the topological sort on DAG and fills all the fields of ExpressionActions.
/// Algorithm traverses DAG starting from nodes without children.
/// For every node we support the number of created children, and if all children are created, put node into queue.
struct Data
{
const Node * node = nullptr;
size_t num_created_children = 0;
std::vector<const Node *> parents;
ssize_t position = -1;
size_t num_created_parents = 0;
bool used_in_result = false;
};
const auto & nodes = getNodes();
const auto & index = actions_dag->getIndex();
@ -119,6 +182,9 @@ void ExpressionActions::linearizeActions()
ready_nodes.emplace(&node);
}
if (settings.use_short_circuit_function_evaluation)
rewriteArgumentsForShortCircuitFunctions(nodes, data, reverse_index);
/// Every argument will have fixed position in columns list.
/// If argument is removed, it's position may be reused by other action.
std::stack<size_t> free_positions;
@ -246,18 +312,6 @@ std::string ExpressionActions::Action::toString() const
out << ")";
break;
case ActionsDAG::ActionType::COLUMN_FUNCTION:
out << "COLUMN FUNCTION " << (node->is_function_compiled ? "[compiled] " : "")
<< (node->function_base ? node->function_base->getName() : "(no function)") << "(";
for (size_t i = 0; i < node->children.size(); ++i)
{
if (i)
out << ", ";
out << node->children[i]->result_name << " " << arguments[i];
}
out << ")";
break;
case ActionsDAG::ActionType::ARRAY_JOIN:
out << "ARRAY JOIN " << node->children.front()->result_name << " " << arguments.front();
break;
@ -334,11 +388,11 @@ namespace
ColumnsWithTypeAndName & inputs;
ColumnsWithTypeAndName columns = {};
std::vector<ssize_t> inputs_pos = {};
size_t num_rows = 0;
size_t num_rows;
};
}
static void executeAction(const ExpressionActions::Action & action, ExecutionContext & execution_context, bool dry_run, bool use_short_circuit_function_evaluation)
static void executeAction(const ExpressionActions::Action & action, ExecutionContext & execution_context, bool dry_run)
{
auto & inputs = execution_context.inputs;
auto & columns = execution_context.columns;
@ -348,54 +402,6 @@ static void executeAction(const ExpressionActions::Action & action, ExecutionCon
{
case ActionsDAG::ActionType::FUNCTION:
{
// LOG_DEBUG(&Poco::Logger::get("ExpressionActions"), "Execute action FUNCTION: {}", action.node->function_base->getName());
auto & res_column = columns[action.result_position];
if (res_column.type || res_column.column)
throw Exception("Result column is not empty", ErrorCodes::LOGICAL_ERROR);
res_column.type = action.node->result_type;
res_column.name = action.node->result_name;
ColumnsWithTypeAndName arguments(action.arguments.size());
for (size_t i = 0; i < arguments.size(); ++i)
{
auto & column = columns[action.arguments[i].pos];
if (action.node->children[i]->type == ActionsDAG::ActionType::COLUMN_FUNCTION)
{
const ColumnFunction * column_function = typeid_cast<const ColumnFunction *>(column.column.get());
if (column_function && (!action.node->function_base->isShortCircuit() || action.arguments[i].needed_later))
column.column = column_function->reduce(true).column;
}
if (!action.arguments[i].needed_later)
arguments[i] = std::move(column);
else
arguments[i] = column;
}
ProfileEvents::increment(ProfileEvents::FunctionExecute);
if (action.node->is_function_compiled)
ProfileEvents::increment(ProfileEvents::CompiledFunctionExecute);
if (action.node->function_base->isShortCircuit() && use_short_circuit_function_evaluation)
{
// LOG_DEBUG(&Poco::Logger::get("ExpressionActions"), "Execute Short Circuit Arguments");
action.node->function_base->executeShortCircuitArguments(arguments);
}
// LOG_DEBUG(&Poco::Logger::get("ExpressionActions"), "Execute function");
res_column.column = action.node->function->execute(arguments, res_column.type, num_rows, dry_run);
break;
}
case ActionsDAG::ActionType::COLUMN_FUNCTION:
{
// LOG_DEBUG(&Poco::Logger::get("ExpressionActions"), "Execute action COLUMN FUNCTION: {}", action.node->function_base->getName());
auto & res_column = columns[action.result_position];
if (res_column.type || res_column.column)
throw Exception("Result column is not empty", ErrorCodes::LOGICAL_ERROR);
@ -412,18 +418,19 @@ static void executeAction(const ExpressionActions::Action & action, ExecutionCon
arguments[i] = columns[action.arguments[i].pos];
}
if (use_short_circuit_function_evaluation)
if (action.node->is_lazy_executed)
res_column.column = ColumnFunction::create(num_rows, action.node->function_base, std::move(arguments));
else
{
// LOG_DEBUG(&Poco::Logger::get("ExpressionActions"), "Execute function");
ProfileEvents::increment(ProfileEvents::FunctionExecute);
if (action.node->is_function_compiled)
ProfileEvents::increment(ProfileEvents::CompiledFunctionExecute);
if (action.node->function_base->isShortCircuit())
action.node->function_base->executeShortCircuitArguments(arguments);
res_column.column = action.node->function->execute(arguments, res_column.type, num_rows, dry_run);
}
break;
}
@ -462,8 +469,6 @@ static void executeAction(const ExpressionActions::Action & action, ExecutionCon
case ActionsDAG::ActionType::COLUMN:
{
// LOG_DEBUG(&Poco::Logger::get("ExpressionActions"), "Execute action COLUMN: {}", action.node->result_name);
auto & res_column = columns[action.result_position];
res_column.column = action.node->column->cloneResized(num_rows);
res_column.type = action.node->result_type;
@ -473,21 +478,11 @@ static void executeAction(const ExpressionActions::Action & action, ExecutionCon
case ActionsDAG::ActionType::ALIAS:
{
// LOG_DEBUG(&Poco::Logger::get("ExpressionActions"), "Execute action ALIAS: {}", action.node->result_name);
const auto & arg = action.arguments.front();
if (action.result_position != arg.pos)
{
auto & column = columns[arg.pos];
if (action.node->children.back()->type == ActionsDAG::ActionType::COLUMN_FUNCTION)
{
const ColumnFunction * column_function = typeid_cast<const ColumnFunction *>(column.column.get());
if (column_function)
column.column = column_function->reduce(true).column;
}
columns[action.result_position].column = column.column;
columns[action.result_position].type = column.type;
columns[action.result_position].column = columns[arg.pos].column;
columns[action.result_position].type = columns[arg.pos].type;
if (!arg.needed_later)
columns[arg.pos] = {};
@ -500,8 +495,6 @@ static void executeAction(const ExpressionActions::Action & action, ExecutionCon
case ActionsDAG::ActionType::INPUT:
{
// LOG_DEBUG(&Poco::Logger::get("ExpressionActions"), "Execute action INPUT: {}", action.node->result_name);
auto pos = execution_context.inputs_pos[action.arguments.front().pos];
if (pos < 0)
{
@ -513,15 +506,7 @@ static void executeAction(const ExpressionActions::Action & action, ExecutionCon
action.node->result_name);
}
else
{
auto & column = inputs[pos];
const ColumnFunction * column_function = typeid_cast<const ColumnFunction *>(column.column.get());
if (column_function && column.type->getTypeId() != TypeIndex::Function)
column.column = column_function->reduce(true).column;
columns[action.result_position] = std::move(column);
}
columns[action.result_position] = std::move(inputs[pos]);
break;
}
@ -561,7 +546,7 @@ void ExpressionActions::execute(Block & block, size_t & num_rows, bool dry_run)
{
try
{
executeAction(action, execution_context, dry_run, settings.use_short_circuit_function_evaluation);
executeAction(action, execution_context, dry_run);
checkLimits(execution_context.columns);
//std::cerr << "Action: " << action.toString() << std::endl;

View File

@ -68,7 +68,6 @@ public:
using NameToInputMap = std::unordered_map<std::string_view, std::list<size_t>>;
private:
ActionsDAGPtr actions_dag;
Actions actions;
size_t num_columns = 0;
@ -120,9 +119,27 @@ public:
ExpressionActionsPtr clone() const;
private:
struct Data
{
const Node * node = nullptr;
size_t num_created_children = 0;
std::vector<const Node *> parents;
ssize_t position = -1;
size_t num_created_parents = 0;
bool used_in_result = false;
};
void checkLimits(const ColumnsWithTypeAndName & columns) const;
void linearizeActions();
bool rewriteShortCircuitArguments(
const ActionsDAG::NodeRawConstPtrs & children, const std::unordered_map<const ActionsDAG::Node *, bool> & need_outside, bool force_rewrite);
void rewriteArgumentsForShortCircuitFunctions(
const std::list<ActionsDAG::Node> & nodes,
const std::vector<Data> & data,
const std::unordered_map<const ActionsDAG::Node *, size_t> & reverse_index);
};

View File

@ -25,7 +25,7 @@ struct ExpressionActionsSettings
CompileExpressions compile_expressions = CompileExpressions::no;
bool use_short_circuit_function_evaluation = true;
bool use_short_circuit_function_evaluation = false;
static ExpressionActionsSettings fromSettings(const Settings & from, CompileExpressions compile_expressions = CompileExpressions::no);
static ExpressionActionsSettings fromContext(ContextPtr from, CompileExpressions compile_expressions = CompileExpressions::no);

View File

@ -88,7 +88,7 @@ bool allowEarlyConstantFolding(const ActionsDAG & actions, const Settings & sett
for (const auto & node : actions.getNodes())
{
if ((node.type == ActionsDAG::ActionType::FUNCTION || node.type == ActionsDAG::ActionType::COLUMN_FUNCTION) && node.function_base)
if ((node.type == ActionsDAG::ActionType::FUNCTION) && node.function_base)
{
if (!node.function_base->isSuitableForConstantFolding())
return false;