Fix tests

This commit is contained in:
Pavel Kruglov 2021-04-22 18:14:58 +03:00 committed by Pavel Kruglov
parent 50d4192126
commit 775d190fb3
20 changed files with 400 additions and 174 deletions

View File

@ -5,6 +5,7 @@
#include <IO/WriteHelpers.h>
#include <Functions/IFunction.h>
#include <common/logger_useful.h>
namespace DB
{

View File

@ -3,7 +3,6 @@
#include <Columns/IColumn.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnArray.h>
#include <Core/Field.h>

View File

@ -533,4 +533,6 @@ bool isColumnConst(const IColumn & column);
/// True if column's an ColumnNullable instance. It's just a syntax sugar for type check.
bool isColumnNullable(const IColumn & column);
bool isColumnFunction(const IColumn & column);
}

View File

@ -0,0 +1,153 @@
#include <Columns/MaskOperations.h>
#include <Columns/ColumnFunction.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnNothing.h>
#include <Columns/ColumnLowCardinality.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
ColumnPtr expandColumnByMask(const ColumnPtr & column, const PaddedPODArray<UInt8>& mask, Field * field)
{
MutableColumnPtr res = column->cloneEmpty();
res->reserve(mask.size());
size_t index = 0;
for (size_t i = 0; i != mask.size(); ++i)
{
if (mask[i])
{
if (index >= column->size())
throw Exception("Too many bits in mask", ErrorCodes::LOGICAL_ERROR);
res->insert((*column)[index]);
++index;
}
else if (field)
res->insert(*field);
else
res->insertDefault();
}
if (index < column->size())
throw Exception("Too less bits in mask", ErrorCodes::LOGICAL_ERROR);
return res;
}
template <typename ValueType>
PaddedPODArray<UInt8> copyMaskImpl(const PaddedPODArray<ValueType>& mask, bool reverse, const PaddedPODArray<UInt8> * null_bytemap, UInt8 null_value)
{
PaddedPODArray<UInt8> res;
res.reserve(mask.size());
for (size_t i = 0; i != mask.size(); ++i)
{
if (null_bytemap && (*null_bytemap)[i])
res.push_back(reverse ? !null_value : null_value);
else
res.push_back(reverse ? !mask[i]: !!mask[i]);
}
return res;
}
template <typename ValueType>
bool tryGetMaskFromColumn(const ColumnPtr column, PaddedPODArray<UInt8> & res, bool reverse, const PaddedPODArray<UInt8> * null_bytemap, UInt8 null_value)
{
if (const auto * col = checkAndGetColumn<ColumnVector<ValueType>>(*column))
{
res = copyMaskImpl(col->getData(), reverse, null_bytemap, null_value);
return true;
}
return false;
}
PaddedPODArray<UInt8> reverseMask(const PaddedPODArray<UInt8> & mask)
{
return copyMaskImpl(mask, true, nullptr, 1);
}
PaddedPODArray<UInt8> getMaskFromColumn(const ColumnPtr & column, bool reverse, const PaddedPODArray<UInt8> * null_bytemap, UInt8 null_value)
{
if (const auto * col = checkAndGetColumn<ColumnConst>(*column))
return getMaskFromColumn(col->convertToFullColumn(), reverse, null_bytemap, null_value);
if (const auto * col = checkAndGetColumn<ColumnNothing>(*column))
return PaddedPODArray<UInt8>(col->size(), reverse ? !null_value : null_value);
if (const auto * col = checkAndGetColumn<ColumnNullable>(*column))
{
const PaddedPODArray<UInt8> & null_map = checkAndGetColumn<ColumnUInt8>(*col->getNullMapColumnPtr())->getData();
return getMaskFromColumn(col->getNestedColumnPtr(), reverse, &null_map, null_value);
}
if (const auto * col = checkAndGetColumn<ColumnLowCardinality>(*column))
return getMaskFromColumn(col->convertToFullColumn(), reverse, null_bytemap, null_value);
PaddedPODArray<UInt8> res;
if (!tryGetMaskFromColumn<Int8>(column, res, reverse, null_bytemap, null_value) &&
!tryGetMaskFromColumn<Int16>(column, res, reverse, null_bytemap, null_value) &&
!tryGetMaskFromColumn<Int32>(column, res, reverse, null_bytemap, null_value) &&
!tryGetMaskFromColumn<Int64>(column, res, reverse, null_bytemap, null_value) &&
!tryGetMaskFromColumn<UInt8>(column, res, reverse, null_bytemap, null_value) &&
!tryGetMaskFromColumn<UInt16>(column, res, reverse, null_bytemap, null_value) &&
!tryGetMaskFromColumn<UInt32>(column, res, reverse, null_bytemap, null_value) &&
!tryGetMaskFromColumn<UInt64>(column, res, reverse, null_bytemap, null_value) &&
!tryGetMaskFromColumn<Float32>(column, res, reverse, null_bytemap, null_value) &&
!tryGetMaskFromColumn<Float64>(column, res, reverse, null_bytemap, null_value))
throw Exception("Cannot convert column " + column.get()->getName() + " to mask", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res;
}
template <typename Op>
void binaryMasksOperationImpl(PaddedPODArray<UInt8> & mask1, const PaddedPODArray<UInt8> & mask2, Op operation)
{
if (mask1.size() != mask2.size())
throw Exception("Masks have different sizes", ErrorCodes::LOGICAL_ERROR);
for (size_t i = 0; i != mask1.size(); ++i)
mask1[i] = operation(mask1[i], mask2[i]);
}
void conjunctionMasks(PaddedPODArray<UInt8> & mask1, const PaddedPODArray<UInt8> & mask2)
{
binaryMasksOperationImpl(mask1, mask2, [](const auto & lhs, const auto & rhs){ return lhs & rhs; });
}
void disjunctionMasks(PaddedPODArray<UInt8> & mask1, const PaddedPODArray<UInt8> & mask2)
{
binaryMasksOperationImpl(mask1, mask2, [](const auto & lhs, const auto & rhs){ return lhs | rhs; });
}
void maskedExecute(ColumnWithTypeAndName & column, const PaddedPODArray<UInt8>& mask, Field * default_value)
{
const auto * column_function = checkAndGetColumn<ColumnFunction>(*column.column);
if (!column_function)
return;
auto filtered = column_function->filter(mask, -1);
auto result = typeid_cast<const ColumnFunction *>(filtered.get())->reduce(true);
result.column = expandColumnByMask(result.column, mask, default_value);
column = std::move(result);
}
void executeColumnIfNeeded(ColumnWithTypeAndName & column)
{
const auto * column_function = checkAndGetColumn<ColumnFunction>(*column.column);
if (!column_function)
return;
column = typeid_cast<const ColumnFunction *>(column_function)->reduce(true);
}
}

View File

@ -15,6 +15,8 @@ void conjunctionMasks(PaddedPODArray<UInt8> & mask1, const PaddedPODArray<UInt8>
void disjunctionMasks(PaddedPODArray<UInt8> & mask1, const PaddedPODArray<UInt8> & mask2);
ColumnWithTypeAndName maskedExecute(const ColumnWithTypeAndName & column, const PaddedPODArray<UInt8> & mask, Field * default_value = nullptr);
void maskedExecute(ColumnWithTypeAndName & column, const PaddedPODArray<UInt8> & mask, Field * default_value = nullptr);
void executeColumnIfNeeded(ColumnWithTypeAndName & column);
}

View File

@ -35,6 +35,7 @@ SRCS(
ColumnsCommon.cpp
FilterDescription.cpp
IColumn.cpp
MaskOperations.cpp
getLeastSuperColumn.cpp
)

View File

@ -1,116 +0,0 @@
#include <Common/MasksOperation.h>
#include <Columns/ColumnFunction.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnNothing.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
ColumnPtr expandColumnByMask(const ColumnPtr & column, const PaddedPODArray<UInt8>& mask, Field * field)
{
MutableColumnPtr res = column->cloneEmpty();
res->reserve(mask.size());
size_t index = 0;
for (size_t i = 0; i != mask.size(); ++i)
{
if (mask[i])
{
if (index >= column->size())
throw Exception("Too many bits in mask", ErrorCodes::LOGICAL_ERROR);
res->insert((*column)[index]);
++index;
}
else if (field)
res->insert(*field);
else
res->insertDefault();
}
if (index < column->size())
throw Exception("Too less bits in mask", ErrorCodes::LOGICAL_ERROR);
return res;
}
PaddedPODArray<UInt8> copyMaskImpl(const PaddedPODArray<UInt8>& mask, bool reverse, const PaddedPODArray<UInt8> * null_bytemap, UInt8 null_value)
{
PaddedPODArray<UInt8> res;
res.reserve(mask.size());
for (size_t i = 0; i != mask.size(); ++i)
{
if (null_bytemap && (*null_bytemap)[i])
res.push_back(null_value);
else
res.push_back(reverse ? !mask[i] : mask[i]);
}
return res;
}
PaddedPODArray<UInt8> reverseMask(const PaddedPODArray<UInt8> & mask)
{
return copyMaskImpl(mask, true, nullptr, 1);
}
PaddedPODArray<UInt8> getMaskFromColumn(const ColumnPtr & column, bool reverse, const PaddedPODArray<UInt8> * null_bytemap, UInt8 null_value)
{
if (const auto * col = typeid_cast<const ColumnConst *>(column.get()))
return getMaskFromColumn(col->convertToFullColumn(), reverse, null_bytemap, null_value);
if (const auto * col = typeid_cast<const ColumnNothing *>(column.get()))
return PaddedPODArray<UInt8>(col->size(), null_value);
if (const auto * col = typeid_cast<const ColumnNullable *>(column.get()))
{
const PaddedPODArray<UInt8> & null_map = typeid_cast<const ColumnUInt8 *>(col->getNullMapColumnPtr().get())->getData();
return getMaskFromColumn(col->getNestedColumnPtr(), reverse, &null_map, null_value);
}
if (const auto * col = typeid_cast<const ColumnUInt8 *>(column.get()))
return copyMaskImpl(col->getData(), reverse, null_bytemap, null_value);
throw Exception("Cannot convert column " + column.get()->getName() + " to mask", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
template <typename Op>
void binaryMasksOperationImpl(PaddedPODArray<UInt8> & mask1, const PaddedPODArray<UInt8> & mask2, Op operation)
{
if (mask1.size() != mask2.size())
throw Exception("Masks have different sizes", ErrorCodes::LOGICAL_ERROR);
for (size_t i = 0; i != mask1.size(); ++i)
mask1[i] = operation(mask1[i], mask2[i]);
}
void conjunctionMasks(PaddedPODArray<UInt8> & mask1, const PaddedPODArray<UInt8> & mask2)
{
binaryMasksOperationImpl(mask1, mask2, [](const auto & lhs, const auto & rhs){ return lhs & rhs; });
}
void disjunctionMasks(PaddedPODArray<UInt8> & mask1, const PaddedPODArray<UInt8> & mask2)
{
binaryMasksOperationImpl(mask1, mask2, [](const auto & lhs, const auto & rhs){ return lhs | rhs; });
}
ColumnWithTypeAndName maskedExecute(const ColumnWithTypeAndName & column, const PaddedPODArray<UInt8>& mask, Field * default_value)
{
const auto * column_function = typeid_cast<const ColumnFunction *>(column.column.get());
if (!column_function)
return column;
auto filtered = column_function->filter(mask, -1);
auto result = typeid_cast<const ColumnFunction *>(filtered.get())->reduce(true);
result.column = expandColumnByMask(result.column, mask, default_value);
return result;
}
}

View File

@ -55,7 +55,6 @@ SRCS(
IntervalKind.cpp
JSONBuilder.cpp
Macros.cpp
MasksOperation.cpp
MemoryStatisticsOS.cpp
MemoryTracker.cpp
OpenSSLHelpers.cpp

View File

@ -90,11 +90,17 @@ FunctionBasePtr JoinGetOverloadResolver<or_null>::buildImpl(const ColumnsWithTyp
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
auto [storage_join, attr_name] = getJoin(arguments, getContext());
DataTypes data_types(arguments.size() - 2);
for (size_t i = 2; i < arguments.size(); ++i)
data_types[i - 2] = arguments[i].type;
DataTypes argument_types(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i)
{
if (i >= 2)
data_types[i - 2] = arguments[i].type;
argument_types[i] = arguments[i].type;
}
auto return_type = storage_join->joinGetCheckAndGetReturnType(data_types, attr_name, or_null);
auto table_lock = storage_join->lockForShare(getContext()->getInitialQueryId(), getContext()->getSettingsRef().lock_acquire_timeout);
return std::make_unique<FunctionJoinGet<or_null>>(table_lock, storage_join, attr_name, data_types, return_type);
return std::make_unique<FunctionJoinGet<or_null>>(table_lock, storage_join, attr_name, argument_types, return_type);
}
void registerFunctionJoinGet(FunctionFactory & factory)

View File

@ -14,8 +14,8 @@
#include <algorithm>
#include <Common/MasksOperation.h>
#include <Columns/MaskOperations.h>
#include <common/logger_useful.h>
namespace DB
{
@ -475,7 +475,7 @@ static ColumnPtr basicExecuteImpl(ColumnRawPtrs arguments, size_t input_rows_cou
}
template <typename Impl, typename Name>
DataTypePtr FunctionAnyArityLogical<Impl, Name>::getReturnTypeImpl(const DataTypes & arguments) const
DataTypePtr FunctionAnyArityLogical<Impl, Name>::getReturnTypeImpl(const DataTypes & arguments) const
{
if (arguments.size() < 2)
throw Exception("Number of arguments for function \"" + getName() + "\" should be at least 2: passed "
@ -510,31 +510,30 @@ DataTypePtr FunctionAnyArityLogical<Impl, Name>::getReturnTypeImpl(const DataTyp
}
template <typename Impl, typename Name>
ColumnsWithTypeAndName FunctionAnyArityLogical<Impl, Name>::checkForLazyArgumentsExecution(const ColumnsWithTypeAndName & args) const
void FunctionAnyArityLogical<Impl, Name>::executeShortCircuitArguments(ColumnsWithTypeAndName & arguments) const
{
if (Name::name != NameAnd::name && Name::name != NameOr::name)
return args;
throw Exception("Function " + getName() + " doesn't support short circuit execution", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
ColumnsWithTypeAndName executed_arguments;
Field default_value = Name::name == NameAnd::name ? 0 : 1;
bool reverse = Name::name == NameAnd::name ? false : true;
bool reverse = Name::name != NameAnd::name;
UInt8 null_value = Name::name == NameAnd::name ? 1 : 0;
executed_arguments.push_back(args[0]);
for (size_t i = 1; i < args.size(); ++i)
{
const IColumn::Filter & mask = getMaskFromColumn(executed_arguments[i - 1].column, reverse, nullptr, null_value);
auto column = maskedExecute(args[i], mask, &default_value);
executed_arguments.push_back(std::move(column));
}
executeColumnIfNeeded(arguments[0]);
return executed_arguments;
for (size_t i = 1; i < arguments.size(); ++i)
{
if (isColumnFunction(*arguments[i].column))
{
IColumn::Filter mask = getMaskFromColumn(arguments[i - 1].column, reverse, nullptr, null_value);
maskedExecute(arguments[i], mask, &default_value);
}
}
}
template <typename Impl, typename Name>
ColumnPtr FunctionAnyArityLogical<Impl, Name>::executeImpl(
const ColumnsWithTypeAndName & args, const DataTypePtr & result_type, size_t input_rows_count) const
const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const
{
ColumnsWithTypeAndName arguments = checkForLazyArgumentsExecution(args);
ColumnRawPtrs args_in;
for (const auto & arg_index : arguments)
args_in.push_back(arg_index.column.get());

View File

@ -154,6 +154,7 @@ public:
bool isVariadic() const override { return true; }
bool isShortCircuit() const override { return name == NameAnd::name || name == NameOr::name; }
void executeShortCircuitArguments(ColumnsWithTypeAndName & arguments) const override;
size_t getNumberOfArguments() const override { return 0; }
bool useDefaultImplementationForNulls() const override { return !Impl::specialImplementationForNulls(); }
@ -204,9 +205,6 @@ public:
return phi;
}
#endif
private:
ColumnsWithTypeAndName checkForLazyArgumentsExecution(const ColumnsWithTypeAndName & args) const;
};

View File

@ -213,6 +213,11 @@ public:
virtual bool isShortCircuit() const { return false; }
virtual void executeShortCircuitArguments(ColumnsWithTypeAndName & /*arguments*/) const
{
throw Exception("Function " + getName() + " doesn't support short circuit execution", ErrorCodes::NOT_IMPLEMENTED);
}
/// The property of monotonicity for a certain range.
struct Monotonicity
{
@ -272,6 +277,11 @@ public:
virtual bool isShortCircuit() const { return false; }
virtual void executeShortCircuitArguments(ColumnsWithTypeAndName & /*arguments*/) const
{
throw Exception("Function " + getName() + " doesn't support short circuit execution", ErrorCodes::NOT_IMPLEMENTED);
}
/// For non-variadic functions, return number of arguments; otherwise return zero (that should be ignored).
/// For higher-order functions (functions, that have lambda expression as at least one argument).
/// You pass data types with empty DataTypeFunction for lambda arguments.

View File

@ -10,7 +10,154 @@ namespace DB
class FunctionToExecutableFunctionAdaptor final : public IExecutableFunction
{
public:
<<<<<<< HEAD
explicit FunctionToExecutableFunctionAdaptor(std::shared_ptr<IFunction> function_) : function(std::move(function_)) {}
=======
explicit ExecutableFunctionAdaptor(ExecutableFunctionImplPtr impl_) : impl(std::move(impl_)) {}
String getName() const final { return impl->getName(); }
ColumnPtr execute(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count, bool dry_run) const final;
void createLowCardinalityResultCache(size_t cache_size) override;
private:
ExecutableFunctionImplPtr impl;
/// Cache is created by function createLowCardinalityResultCache()
ExecutableFunctionLowCardinalityResultCachePtr low_cardinality_result_cache;
ColumnPtr defaultImplementationForConstantArguments(
const ColumnsWithTypeAndName & args, const DataTypePtr & result_type, size_t input_rows_count, bool dry_run) const;
ColumnPtr defaultImplementationForNulls(
const ColumnsWithTypeAndName & args, const DataTypePtr & result_type, size_t input_rows_count, bool dry_run) const;
ColumnPtr executeWithoutLowCardinalityColumns(
const ColumnsWithTypeAndName & args, const DataTypePtr & result_type, size_t input_rows_count, bool dry_run) const;
};
class FunctionBaseAdaptor final : public IFunctionBase
{
public:
explicit FunctionBaseAdaptor(FunctionBaseImplPtr impl_) : impl(std::move(impl_)) {}
String getName() const final { return impl->getName(); }
const DataTypes & getArgumentTypes() const final { return impl->getArgumentTypes(); }
const DataTypePtr & getResultType() const final { return impl->getResultType(); }
ExecutableFunctionPtr prepare(const ColumnsWithTypeAndName & arguments) const final
{
return std::make_shared<ExecutableFunctionAdaptor>(impl->prepare(arguments));
}
#if USE_EMBEDDED_COMPILER
bool isCompilable() const final { return impl->isCompilable(); }
llvm::Value * compile(llvm::IRBuilderBase & builder, Values values) const override
{
return impl->compile(builder, std::move(values));
}
#endif
bool isStateful() const final { return impl->isStateful(); }
bool isSuitableForConstantFolding() const final { return impl->isSuitableForConstantFolding(); }
ColumnPtr getResultIfAlwaysReturnsConstantAndHasArguments(const ColumnsWithTypeAndName & arguments) const final
{
return impl->getResultIfAlwaysReturnsConstantAndHasArguments(arguments);
}
bool isInjective(const ColumnsWithTypeAndName & sample_columns) const final { return impl->isInjective(sample_columns); }
bool isDeterministic() const final { return impl->isDeterministic(); }
bool isDeterministicInScopeOfQuery() const final { return impl->isDeterministicInScopeOfQuery(); }
bool isShortCircuit() const final { return impl->isShortCircuit(); }
void executeShortCircuitArguments(ColumnsWithTypeAndName & arguments) const override
{
impl->executeShortCircuitArguments(arguments);
}
bool hasInformationAboutMonotonicity() const final { return impl->hasInformationAboutMonotonicity(); }
Monotonicity getMonotonicityForRange(const IDataType & type, const Field & left, const Field & right) const final
{
return impl->getMonotonicityForRange(type, left, right);
}
const IFunctionBaseImpl * getImpl() const { return impl.get(); }
private:
FunctionBaseImplPtr impl;
};
class FunctionOverloadResolverAdaptor final : public IFunctionOverloadResolver
{
public:
explicit FunctionOverloadResolverAdaptor(FunctionOverloadResolverImplPtr impl_) : impl(std::move(impl_)) {}
String getName() const final { return impl->getName(); }
bool isDeterministic() const final { return impl->isDeterministic(); }
bool isDeterministicInScopeOfQuery() const final { return impl->isDeterministicInScopeOfQuery(); }
bool isInjective(const ColumnsWithTypeAndName & columns) const final { return impl->isInjective(columns); }
bool isStateful() const final { return impl->isStateful(); }
bool isVariadic() const final { return impl->isVariadic(); }
bool isShortCircuit() const final { return impl->isShortCircuit(); }
size_t getNumberOfArguments() const final { return impl->getNumberOfArguments(); }
void checkNumberOfArguments(size_t number_of_arguments) const final;
FunctionBaseImplPtr buildImpl(const ColumnsWithTypeAndName & arguments) const
{
return impl->build(arguments, getReturnType(arguments));
}
FunctionBasePtr build(const ColumnsWithTypeAndName & arguments) const final
{
return std::make_shared<FunctionBaseAdaptor>(buildImpl(arguments));
}
void getLambdaArgumentTypes(DataTypes & arguments) const final
{
checkNumberOfArguments(arguments.size());
impl->getLambdaArgumentTypes(arguments);
}
ColumnNumbers getArgumentsThatAreAlwaysConstant() const final { return impl->getArgumentsThatAreAlwaysConstant(); }
ColumnNumbers getArgumentsThatDontImplyNullableReturnType(size_t number_of_arguments) const final
{
return impl->getArgumentsThatDontImplyNullableReturnType(number_of_arguments);
}
using DefaultReturnTypeGetter = std::function<DataTypePtr(const ColumnsWithTypeAndName &)>;
static DataTypePtr getReturnTypeDefaultImplementationForNulls(const ColumnsWithTypeAndName & arguments, const DefaultReturnTypeGetter & getter);
private:
FunctionOverloadResolverImplPtr impl;
DataTypePtr getReturnTypeWithoutLowCardinality(const ColumnsWithTypeAndName & arguments) const;
DataTypePtr getReturnType(const ColumnsWithTypeAndName & arguments) const;
};
/// Following classes are implement IExecutableFunctionImpl, IFunctionBaseImpl and IFunctionOverloadResolverImpl via IFunction.
class DefaultExecutable final : public IExecutableFunctionImpl
{
public:
explicit DefaultExecutable(std::shared_ptr<IFunction> function_) : function(std::move(function_)) {}
>>>>>>> Fix tests
String getName() const override { return function->getName(); }
@ -82,8 +229,11 @@ public:
bool isShortCircuit() const override { return function->isShortCircuit(); }
<<<<<<< HEAD
bool isSuitableForShortCircuitArgumentsExecution() const override { return function->isSuitableForShortCircuitArgumentsExecution(); }
=======
>>>>>>> Fix tests
void executeShortCircuitArguments(ColumnsWithTypeAndName & args) const override
{
function->executeShortCircuitArguments(args);
@ -117,7 +267,16 @@ public:
bool isStateful() const override { return function->isStateful(); }
bool isVariadic() const override { return function->isVariadic(); }
bool isShortCircuit() const override { return function->isShortCircuit(); }
<<<<<<< HEAD
bool isSuitableForShortCircuitArgumentsExecution() const override { return function->isSuitableForShortCircuitArgumentsExecution(); }
=======
void executeShortCircuitArguments(ColumnsWithTypeAndName & arguments) const override
{
function->executeShortCircuitArguments(arguments);
}
>>>>>>> Fix tests
size_t getNumberOfArguments() const override { return function->getNumberOfArguments(); }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return function->getArgumentsThatAreAlwaysConstant(); }

View File

@ -26,7 +26,7 @@
#include <Functions/FunctionFactory.h>
#include <Interpreters/castColumn.h>
#include <Common/MasksOperation.h>
#include <Columns/MaskOperations.h>
namespace DB
@ -920,19 +920,19 @@ public:
return getLeastSupertype({arguments[1], arguments[2]});
}
ColumnsWithTypeAndName checkForLazyArgumentsExecution(const ColumnsWithTypeAndName & args) const
void executeShortCircuitArguments(ColumnsWithTypeAndName & arguments) const override
{
ColumnsWithTypeAndName executed_arguments;
IColumn::Filter mask = getMaskFromColumn(args[0].column);
executed_arguments.push_back(args[0]);
executed_arguments.push_back(maskedExecute(args[1], mask));
executed_arguments.push_back(maskedExecute(args[2], reverseMask(mask)));
return executed_arguments;
executeColumnIfNeeded(arguments[0]);
if (isColumnFunction(*arguments[1].column) || isColumnFunction(*arguments[2].column))
{
IColumn::Filter mask = getMaskFromColumn(arguments[0].column);
maskedExecute(arguments[1], mask);
maskedExecute(arguments[2], reverseMask(mask));
}
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & args, const DataTypePtr & result_type, size_t input_rows_count) const override
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override
{
ColumnsWithTypeAndName arguments = checkForLazyArgumentsExecution(args);
ColumnPtr res;
if ( (res = executeForConstAndNullableCondition(arguments, result_type, input_rows_count))
|| (res = executeForNullThenElse(arguments, result_type, input_rows_count))

View File

@ -8,7 +8,7 @@
#include <Common/assert_cast.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/getLeastSupertype.h>
#include <Common/MasksOperation.h>
#include <Columns/MaskOperations.h>
namespace DB
@ -108,30 +108,31 @@ public:
return getLeastSupertype(types_of_branches);
}
ColumnsWithTypeAndName checkForLazyArgumentsExecution(const ColumnsWithTypeAndName & args) const
void executeShortCircuitArguments(ColumnsWithTypeAndName & arguments) const override
{
ColumnsWithTypeAndName arguments;
arguments.push_back(args[0]);
IColumn::Filter mask = getMaskFromColumn(args[0].column);
executeColumnIfNeeded(arguments[0]);
IColumn::Filter mask = getMaskFromColumn(arguments[0].column);
Field default_value = 0;
size_t i = 1;
while (i < args.size())
while (i < arguments.size())
{
IColumn::Filter cond_mask = getMaskFromColumn(arguments[i - 1].column);
arguments.push_back(maskedExecute(args[i], cond_mask));
++i;
if (isColumnFunction(*arguments[i].column))
{
IColumn::Filter cond_mask = getMaskFromColumn(arguments[i - 1].column);
maskedExecute(arguments[i], cond_mask);
}
arguments.push_back(maskedExecute(args[i], reverseMask(mask), &default_value));
if (i != args.size() - 1)
disjunctionMasks(mask, getMaskFromColumn(arguments.back().column));
++i;
if (isColumnFunction(*arguments[i].column))
maskedExecute(arguments[i], reverseMask(mask), &default_value);
if (i != arguments.size() - 1)
disjunctionMasks(mask, getMaskFromColumn(arguments[i].column));
++i;
}
return arguments;
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override
ColumnPtr executeImpl(const ColumnsWithTypeAndName & args, const DataTypePtr & result_type, size_t input_rows_count) const override
{
ColumnsWithTypeAndName args = checkForLazyArgumentsExecution(arguments);
/** We will gather values from columns in branches to result column,
* depending on values of conditions.
*/

View File

@ -1578,6 +1578,9 @@ ActionsDAG::SplitResult ActionsDAG::splitActionsForFilter(const std::string & co
"Index for ActionsDAG does not contain filter column name {}. DAG:\n{}",
column_name, dumpDAG());
if (node->type == ActionType::COLUMN_FUNCTION)
const_cast<Node *>(node)->type = ActionType::FUNCTION;
std::unordered_set<const Node *> split_nodes = {node};
auto res = split(split_nodes);
res.second->project_input = project_input;

View File

@ -338,7 +338,7 @@ namespace
};
}
static void executeAction(const ExpressionActions::Action & action, ExecutionContext & execution_context, bool dry_run)
static void executeAction(const ExpressionActions::Action & action, ExecutionContext & execution_context, bool dry_run, bool use_short_circuit_function_evaluation)
{
auto & inputs = execution_context.inputs;
auto & columns = execution_context.columns;
@ -359,6 +359,7 @@ static void executeAction(const ExpressionActions::Action & action, ExecutionCon
for (size_t i = 0; i < arguments.size(); ++i)
{
auto & column = columns[action.arguments[i].pos];
if (action.node->children[i]->type == ActionsDAG::ActionType::COLUMN_FUNCTION)
{
const ColumnFunction * column_function = typeid_cast<const ColumnFunction *>(column.column.get());
@ -376,6 +377,8 @@ static void executeAction(const ExpressionActions::Action & action, ExecutionCon
if (action.node->is_function_compiled)
ProfileEvents::increment(ProfileEvents::CompiledFunctionExecute);
if (action.node->function_base->isShortCircuit() && use_short_circuit_function_evaluation)
action.node->function_base->executeShortCircuitArguments(arguments);
res_column.column = action.node->function->execute(arguments, res_column.type, num_rows, dry_run);
break;
}
@ -484,8 +487,12 @@ static void executeAction(const ExpressionActions::Action & action, ExecutionCon
else
{
auto & column = inputs[pos];
if (const auto * col = typeid_cast<const ColumnFunction *>(inputs[pos].column.get()))
column.column = col->reduce(true).column;
if (!action.node->children.empty() && action.node->children.back()->type == ActionsDAG::ActionType::COLUMN_FUNCTION)
{
const ColumnFunction * column_function = typeid_cast<const ColumnFunction *>(column.column.get());
if (column_function)
column.column = column_function->reduce(true).column;
}
columns[action.result_position] = std::move(column);
}
@ -528,7 +535,7 @@ void ExpressionActions::execute(Block & block, size_t & num_rows, bool dry_run)
{
try
{
executeAction(action, execution_context, dry_run);
executeAction(action, execution_context, dry_run, settings.use_short_circuit_function_evaluation);
checkLimits(execution_context.columns);
//std::cerr << "Action: " << action.toString() << std::endl;

View File

@ -14,6 +14,7 @@ ExpressionActionsSettings ExpressionActionsSettings::fromSettings(const Settings
settings.max_temporary_columns = from.max_temporary_columns;
settings.max_temporary_non_const_columns = from.max_temporary_non_const_columns;
settings.compile_expressions = compile_expressions;
settings.use_short_circuit_function_evaluation = from.use_short_circuit_function_evaluation;
return settings;
}

View File

@ -25,6 +25,8 @@ struct ExpressionActionsSettings
CompileExpressions compile_expressions = CompileExpressions::no;
bool use_short_circuit_function_evaluation = true;
static ExpressionActionsSettings fromSettings(const Settings & from, CompileExpressions compile_expressions = CompileExpressions::no);
static ExpressionActionsSettings fromContext(ContextPtr from, CompileExpressions compile_expressions = CompileExpressions::no);
};

View File

@ -88,7 +88,7 @@ bool allowEarlyConstantFolding(const ActionsDAG & actions, const Settings & sett
for (const auto & node : actions.getNodes())
{
if (node.type == ActionsDAG::ActionType::FUNCTION && node.function_base)
if ((node.type == ActionsDAG::ActionType::FUNCTION || node.type == ActionsDAG::ActionType::COLUMN_FUNCTION) && node.function_base)
{
if (!node.function_base->isSuitableForConstantFolding())
return false;
@ -1578,8 +1578,7 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
optimize_read_in_order =
settings.optimize_read_in_order
&& storage
&& query.orderBy()
&& storage && query.orderBy()
&& !query_analyzer.hasAggregation()
&& !query_analyzer.hasWindow()
&& !query.final()