Distinguish between regular ColumnFunction and short curcuit argument

This commit is contained in:
Pavel Kruglov 2021-04-29 22:59:30 +03:00
parent e792fa588f
commit 807c6afab1
3 changed files with 33 additions and 28 deletions

View File

@ -15,10 +15,10 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
}
ColumnFunction::ColumnFunction(size_t size, FunctionBasePtr function_, const ColumnsWithTypeAndName & columns_to_capture, bool ignore_arguments_types)
: size_(size), function(function_)
ColumnFunction::ColumnFunction(size_t size, FunctionBasePtr function_, const ColumnsWithTypeAndName & columns_to_capture, bool is_short_circuit_argument_)
: size_(size), function(function_), is_short_circuit_argument(is_short_circuit_argument_)
{
appendArguments(columns_to_capture, ignore_arguments_types);
appendArguments(columns_to_capture);
}
MutableColumnPtr ColumnFunction::cloneResized(size_t size) const
@ -27,7 +27,7 @@ MutableColumnPtr ColumnFunction::cloneResized(size_t size) const
for (auto & column : capture)
column.column = column.column->cloneResized(size);
return ColumnFunction::create(size, function, capture);
return ColumnFunction::create(size, function, capture, is_short_circuit_argument);
}
ColumnPtr ColumnFunction::replicate(const Offsets & offsets) const
@ -41,7 +41,7 @@ ColumnPtr ColumnFunction::replicate(const Offsets & offsets) const
column.column = column.column->replicate(offsets);
size_t replicated_size = 0 == size_ ? 0 : offsets.back();
return ColumnFunction::create(replicated_size, function, capture);
return ColumnFunction::create(replicated_size, function, capture, is_short_circuit_argument);
}
ColumnPtr ColumnFunction::cut(size_t start, size_t length) const
@ -50,7 +50,7 @@ ColumnPtr ColumnFunction::cut(size_t start, size_t length) const
for (auto & column : capture)
column.column = column.column->cut(start, length);
return ColumnFunction::create(length, function, capture);
return ColumnFunction::create(length, function, capture, is_short_circuit_argument);
}
ColumnPtr ColumnFunction::filter(const Filter & filt, ssize_t result_size_hint, bool reverse) const
@ -73,7 +73,7 @@ ColumnPtr ColumnFunction::filter(const Filter & filt, ssize_t result_size_hint,
else
filtered_size = capture.front().column->size();
return ColumnFunction::create(filtered_size, function, capture);
return ColumnFunction::create(filtered_size, function, capture, is_short_circuit_argument);
}
void ColumnFunction::expand(const Filter & mask, bool reverse)
@ -102,7 +102,7 @@ ColumnPtr ColumnFunction::permute(const Permutation & perm, size_t limit) const
for (auto & column : capture)
column.column = column.column->permute(perm, limit);
return ColumnFunction::create(limit, function, capture);
return ColumnFunction::create(limit, function, capture, is_short_circuit_argument);
}
ColumnPtr ColumnFunction::index(const IColumn & indexes, size_t limit) const
@ -111,7 +111,7 @@ ColumnPtr ColumnFunction::index(const IColumn & indexes, size_t limit) const
for (auto & column : capture)
column.column = column.column->index(indexes, limit);
return ColumnFunction::create(limit, function, capture);
return ColumnFunction::create(limit, function, capture, is_short_circuit_argument);
}
std::vector<MutableColumnPtr> ColumnFunction::scatter(IColumn::ColumnIndex num_columns,
@ -140,7 +140,7 @@ std::vector<MutableColumnPtr> ColumnFunction::scatter(IColumn::ColumnIndex num_c
{
auto & capture = captures[part];
size_t capture_size = capture.empty() ? counts[part] : capture.front().column->size();
columns.emplace_back(ColumnFunction::create(capture_size, function, std::move(capture)));
columns.emplace_back(ColumnFunction::create(capture_size, function, std::move(capture), is_short_circuit_argument));
}
return columns;
@ -173,7 +173,7 @@ size_t ColumnFunction::allocatedBytes() const
return total_size;
}
void ColumnFunction::appendArguments(const ColumnsWithTypeAndName & columns, bool ignore_arguments_types)
void ColumnFunction::appendArguments(const ColumnsWithTypeAndName & columns)
{
auto args = function->getArgumentTypes().size();
auto were_captured = captured_columns.size();
@ -186,15 +186,15 @@ void ColumnFunction::appendArguments(const ColumnsWithTypeAndName & columns, boo
+ ".", ErrorCodes::LOGICAL_ERROR);
for (const auto & column : columns)
appendArgument(column, ignore_arguments_types);
appendArgument(column);
}
void ColumnFunction::appendArgument(const ColumnWithTypeAndName & column, bool ignore_argument_type)
void ColumnFunction::appendArgument(const ColumnWithTypeAndName & column)
{
const auto & argumnet_types = function->getArgumentTypes();
auto index = captured_columns.size();
if (!ignore_argument_type && !column.type->equals(*argumnet_types[index]))
if (!is_short_circuit_argument && !column.type->equals(*argumnet_types[index]))
throw Exception("Cannot capture column " + std::to_string(argumnet_types.size()) +
" because it has incompatible type: got " + column.type->getName() +
", but " + argumnet_types[index]->getName() + " is expected.", ErrorCodes::LOGICAL_ERROR);
@ -202,7 +202,7 @@ void ColumnFunction::appendArgument(const ColumnWithTypeAndName & column, bool i
captured_columns.push_back(column);
}
ColumnWithTypeAndName ColumnFunction::reduce(bool reduce_arguments) const
ColumnWithTypeAndName ColumnFunction::reduce() const
{
auto args = function->getArgumentTypes().size();
auto captured = captured_columns.size();
@ -212,14 +212,16 @@ ColumnWithTypeAndName ColumnFunction::reduce(bool reduce_arguments) const
"arguments but " + toString(captured) + " columns were captured.", ErrorCodes::LOGICAL_ERROR);
ColumnsWithTypeAndName columns = captured_columns;
if (function->isShortCircuit())
function->executeShortCircuitArguments(columns);
else if (reduce_arguments)
if (is_short_circuit_argument)
{
if (function->isShortCircuit())
function->executeShortCircuitArguments(columns);
const ColumnFunction * arg;
for (auto & col : columns)
{
if (const auto * column_function = typeid_cast<const ColumnFunction *>(col.column.get()))
col = column_function->reduce(true);
if ((arg = typeid_cast<const ColumnFunction *>(col.column.get())) && arg->isShortCircuitArgument())
col = arg->reduce();
}
}

View File

@ -25,7 +25,7 @@ class ColumnFunction final : public COWHelper<IColumn, ColumnFunction>
private:
friend class COWHelper<IColumn, ColumnFunction>;
ColumnFunction(size_t size, FunctionBasePtr function_, const ColumnsWithTypeAndName & columns_to_capture, bool ignore_arguments_types = false);
ColumnFunction(size_t size, FunctionBasePtr function_, const ColumnsWithTypeAndName & columns_to_capture, bool is_short_circuit_argument_ = false);
public:
const char * getFamilyName() const override { return "Function"; }
@ -51,8 +51,8 @@ public:
size_t byteSizeAt(size_t n) const override;
size_t allocatedBytes() const override;
void appendArguments(const ColumnsWithTypeAndName & columns, bool ignore_arguments_types = false);
ColumnWithTypeAndName reduce(bool reduce_arguments = false) const;
void appendArguments(const ColumnsWithTypeAndName & columns);
ColumnWithTypeAndName reduce() const;
Field operator[](size_t) const override
{
@ -154,12 +154,15 @@ public:
throw Exception("Method gather is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
bool isShortCircuitArgument() const { return is_short_circuit_argument; }
private:
size_t size_;
FunctionBasePtr function;
ColumnsWithTypeAndName captured_columns;
bool is_short_circuit_argument;
void appendArgument(const ColumnWithTypeAndName & column, bool ignore_argument_type = false);
void appendArgument(const ColumnWithTypeAndName & column);
};
}

View File

@ -226,11 +226,11 @@ void disjunctionMasks(PaddedPODArray<UInt8> & mask1, const PaddedPODArray<UInt8>
void maskedExecute(ColumnWithTypeAndName & column, const PaddedPODArray<UInt8> & mask, bool reverse, const UInt8 * default_value_for_expanding_mask)
{
const auto * column_function = checkAndGetColumn<ColumnFunction>(*column.column);
if (!column_function)
if (!column_function || !column_function->isShortCircuitArgument())
return;
auto filtered = column_function->filter(mask, -1, reverse);
auto result = typeid_cast<const ColumnFunction *>(filtered.get())->reduce(true);
auto result = typeid_cast<const ColumnFunction *>(filtered.get())->reduce();
if (default_value_for_expanding_mask)
{
result.column = result.column->convertToFullColumnIfLowCardinality();
@ -245,10 +245,10 @@ void maskedExecute(ColumnWithTypeAndName & column, const PaddedPODArray<UInt8> &
void executeColumnIfNeeded(ColumnWithTypeAndName & column)
{
const auto * column_function = checkAndGetColumn<ColumnFunction>(*column.column);
if (!column_function)
if (!column_function || !column_function->isShortCircuitArgument())
return;
column = typeid_cast<const ColumnFunction *>(column_function)->reduce(true);
column = typeid_cast<const ColumnFunction *>(column_function)->reduce();
}
bool checkArgumentsForColumnFunction(const ColumnsWithTypeAndName & arguments)