Distinguish between regular ColumnFunction and short curcuit argument

This commit is contained in:
Pavel Kruglov 2021-04-29 22:59:30 +03:00
parent e792fa588f
commit 807c6afab1
3 changed files with 33 additions and 28 deletions

View File

@ -15,10 +15,10 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR; extern const int LOGICAL_ERROR;
} }
ColumnFunction::ColumnFunction(size_t size, FunctionBasePtr function_, const ColumnsWithTypeAndName & columns_to_capture, bool ignore_arguments_types) ColumnFunction::ColumnFunction(size_t size, FunctionBasePtr function_, const ColumnsWithTypeAndName & columns_to_capture, bool is_short_circuit_argument_)
: size_(size), function(function_) : size_(size), function(function_), is_short_circuit_argument(is_short_circuit_argument_)
{ {
appendArguments(columns_to_capture, ignore_arguments_types); appendArguments(columns_to_capture);
} }
MutableColumnPtr ColumnFunction::cloneResized(size_t size) const MutableColumnPtr ColumnFunction::cloneResized(size_t size) const
@ -27,7 +27,7 @@ MutableColumnPtr ColumnFunction::cloneResized(size_t size) const
for (auto & column : capture) for (auto & column : capture)
column.column = column.column->cloneResized(size); column.column = column.column->cloneResized(size);
return ColumnFunction::create(size, function, capture); return ColumnFunction::create(size, function, capture, is_short_circuit_argument);
} }
ColumnPtr ColumnFunction::replicate(const Offsets & offsets) const ColumnPtr ColumnFunction::replicate(const Offsets & offsets) const
@ -41,7 +41,7 @@ ColumnPtr ColumnFunction::replicate(const Offsets & offsets) const
column.column = column.column->replicate(offsets); column.column = column.column->replicate(offsets);
size_t replicated_size = 0 == size_ ? 0 : offsets.back(); size_t replicated_size = 0 == size_ ? 0 : offsets.back();
return ColumnFunction::create(replicated_size, function, capture); return ColumnFunction::create(replicated_size, function, capture, is_short_circuit_argument);
} }
ColumnPtr ColumnFunction::cut(size_t start, size_t length) const ColumnPtr ColumnFunction::cut(size_t start, size_t length) const
@ -50,7 +50,7 @@ ColumnPtr ColumnFunction::cut(size_t start, size_t length) const
for (auto & column : capture) for (auto & column : capture)
column.column = column.column->cut(start, length); column.column = column.column->cut(start, length);
return ColumnFunction::create(length, function, capture); return ColumnFunction::create(length, function, capture, is_short_circuit_argument);
} }
ColumnPtr ColumnFunction::filter(const Filter & filt, ssize_t result_size_hint, bool reverse) const ColumnPtr ColumnFunction::filter(const Filter & filt, ssize_t result_size_hint, bool reverse) const
@ -73,7 +73,7 @@ ColumnPtr ColumnFunction::filter(const Filter & filt, ssize_t result_size_hint,
else else
filtered_size = capture.front().column->size(); filtered_size = capture.front().column->size();
return ColumnFunction::create(filtered_size, function, capture); return ColumnFunction::create(filtered_size, function, capture, is_short_circuit_argument);
} }
void ColumnFunction::expand(const Filter & mask, bool reverse) void ColumnFunction::expand(const Filter & mask, bool reverse)
@ -102,7 +102,7 @@ ColumnPtr ColumnFunction::permute(const Permutation & perm, size_t limit) const
for (auto & column : capture) for (auto & column : capture)
column.column = column.column->permute(perm, limit); column.column = column.column->permute(perm, limit);
return ColumnFunction::create(limit, function, capture); return ColumnFunction::create(limit, function, capture, is_short_circuit_argument);
} }
ColumnPtr ColumnFunction::index(const IColumn & indexes, size_t limit) const ColumnPtr ColumnFunction::index(const IColumn & indexes, size_t limit) const
@ -111,7 +111,7 @@ ColumnPtr ColumnFunction::index(const IColumn & indexes, size_t limit) const
for (auto & column : capture) for (auto & column : capture)
column.column = column.column->index(indexes, limit); column.column = column.column->index(indexes, limit);
return ColumnFunction::create(limit, function, capture); return ColumnFunction::create(limit, function, capture, is_short_circuit_argument);
} }
std::vector<MutableColumnPtr> ColumnFunction::scatter(IColumn::ColumnIndex num_columns, std::vector<MutableColumnPtr> ColumnFunction::scatter(IColumn::ColumnIndex num_columns,
@ -140,7 +140,7 @@ std::vector<MutableColumnPtr> ColumnFunction::scatter(IColumn::ColumnIndex num_c
{ {
auto & capture = captures[part]; auto & capture = captures[part];
size_t capture_size = capture.empty() ? counts[part] : capture.front().column->size(); size_t capture_size = capture.empty() ? counts[part] : capture.front().column->size();
columns.emplace_back(ColumnFunction::create(capture_size, function, std::move(capture))); columns.emplace_back(ColumnFunction::create(capture_size, function, std::move(capture), is_short_circuit_argument));
} }
return columns; return columns;
@ -173,7 +173,7 @@ size_t ColumnFunction::allocatedBytes() const
return total_size; return total_size;
} }
void ColumnFunction::appendArguments(const ColumnsWithTypeAndName & columns, bool ignore_arguments_types) void ColumnFunction::appendArguments(const ColumnsWithTypeAndName & columns)
{ {
auto args = function->getArgumentTypes().size(); auto args = function->getArgumentTypes().size();
auto were_captured = captured_columns.size(); auto were_captured = captured_columns.size();
@ -186,15 +186,15 @@ void ColumnFunction::appendArguments(const ColumnsWithTypeAndName & columns, boo
+ ".", ErrorCodes::LOGICAL_ERROR); + ".", ErrorCodes::LOGICAL_ERROR);
for (const auto & column : columns) for (const auto & column : columns)
appendArgument(column, ignore_arguments_types); appendArgument(column);
} }
void ColumnFunction::appendArgument(const ColumnWithTypeAndName & column, bool ignore_argument_type) void ColumnFunction::appendArgument(const ColumnWithTypeAndName & column)
{ {
const auto & argumnet_types = function->getArgumentTypes(); const auto & argumnet_types = function->getArgumentTypes();
auto index = captured_columns.size(); auto index = captured_columns.size();
if (!ignore_argument_type && !column.type->equals(*argumnet_types[index])) if (!is_short_circuit_argument && !column.type->equals(*argumnet_types[index]))
throw Exception("Cannot capture column " + std::to_string(argumnet_types.size()) + throw Exception("Cannot capture column " + std::to_string(argumnet_types.size()) +
" because it has incompatible type: got " + column.type->getName() + " because it has incompatible type: got " + column.type->getName() +
", but " + argumnet_types[index]->getName() + " is expected.", ErrorCodes::LOGICAL_ERROR); ", but " + argumnet_types[index]->getName() + " is expected.", ErrorCodes::LOGICAL_ERROR);
@ -202,7 +202,7 @@ void ColumnFunction::appendArgument(const ColumnWithTypeAndName & column, bool i
captured_columns.push_back(column); captured_columns.push_back(column);
} }
ColumnWithTypeAndName ColumnFunction::reduce(bool reduce_arguments) const ColumnWithTypeAndName ColumnFunction::reduce() const
{ {
auto args = function->getArgumentTypes().size(); auto args = function->getArgumentTypes().size();
auto captured = captured_columns.size(); auto captured = captured_columns.size();
@ -212,14 +212,16 @@ ColumnWithTypeAndName ColumnFunction::reduce(bool reduce_arguments) const
"arguments but " + toString(captured) + " columns were captured.", ErrorCodes::LOGICAL_ERROR); "arguments but " + toString(captured) + " columns were captured.", ErrorCodes::LOGICAL_ERROR);
ColumnsWithTypeAndName columns = captured_columns; ColumnsWithTypeAndName columns = captured_columns;
if (function->isShortCircuit()) if (is_short_circuit_argument)
function->executeShortCircuitArguments(columns);
else if (reduce_arguments)
{ {
if (function->isShortCircuit())
function->executeShortCircuitArguments(columns);
const ColumnFunction * arg;
for (auto & col : columns) for (auto & col : columns)
{ {
if (const auto * column_function = typeid_cast<const ColumnFunction *>(col.column.get())) if ((arg = typeid_cast<const ColumnFunction *>(col.column.get())) && arg->isShortCircuitArgument())
col = column_function->reduce(true); col = arg->reduce();
} }
} }

View File

@ -25,7 +25,7 @@ class ColumnFunction final : public COWHelper<IColumn, ColumnFunction>
private: private:
friend class COWHelper<IColumn, ColumnFunction>; friend class COWHelper<IColumn, ColumnFunction>;
ColumnFunction(size_t size, FunctionBasePtr function_, const ColumnsWithTypeAndName & columns_to_capture, bool ignore_arguments_types = false); ColumnFunction(size_t size, FunctionBasePtr function_, const ColumnsWithTypeAndName & columns_to_capture, bool is_short_circuit_argument_ = false);
public: public:
const char * getFamilyName() const override { return "Function"; } const char * getFamilyName() const override { return "Function"; }
@ -51,8 +51,8 @@ public:
size_t byteSizeAt(size_t n) const override; size_t byteSizeAt(size_t n) const override;
size_t allocatedBytes() const override; size_t allocatedBytes() const override;
void appendArguments(const ColumnsWithTypeAndName & columns, bool ignore_arguments_types = false); void appendArguments(const ColumnsWithTypeAndName & columns);
ColumnWithTypeAndName reduce(bool reduce_arguments = false) const; ColumnWithTypeAndName reduce() const;
Field operator[](size_t) const override Field operator[](size_t) const override
{ {
@ -154,12 +154,15 @@ public:
throw Exception("Method gather is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); throw Exception("Method gather is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
} }
bool isShortCircuitArgument() const { return is_short_circuit_argument; }
private: private:
size_t size_; size_t size_;
FunctionBasePtr function; FunctionBasePtr function;
ColumnsWithTypeAndName captured_columns; ColumnsWithTypeAndName captured_columns;
bool is_short_circuit_argument;
void appendArgument(const ColumnWithTypeAndName & column, bool ignore_argument_type = false); void appendArgument(const ColumnWithTypeAndName & column);
}; };
} }

View File

@ -226,11 +226,11 @@ void disjunctionMasks(PaddedPODArray<UInt8> & mask1, const PaddedPODArray<UInt8>
void maskedExecute(ColumnWithTypeAndName & column, const PaddedPODArray<UInt8> & mask, bool reverse, const UInt8 * default_value_for_expanding_mask) void maskedExecute(ColumnWithTypeAndName & column, const PaddedPODArray<UInt8> & mask, bool reverse, const UInt8 * default_value_for_expanding_mask)
{ {
const auto * column_function = checkAndGetColumn<ColumnFunction>(*column.column); const auto * column_function = checkAndGetColumn<ColumnFunction>(*column.column);
if (!column_function) if (!column_function || !column_function->isShortCircuitArgument())
return; return;
auto filtered = column_function->filter(mask, -1, reverse); auto filtered = column_function->filter(mask, -1, reverse);
auto result = typeid_cast<const ColumnFunction *>(filtered.get())->reduce(true); auto result = typeid_cast<const ColumnFunction *>(filtered.get())->reduce();
if (default_value_for_expanding_mask) if (default_value_for_expanding_mask)
{ {
result.column = result.column->convertToFullColumnIfLowCardinality(); result.column = result.column->convertToFullColumnIfLowCardinality();
@ -245,10 +245,10 @@ void maskedExecute(ColumnWithTypeAndName & column, const PaddedPODArray<UInt8> &
void executeColumnIfNeeded(ColumnWithTypeAndName & column) void executeColumnIfNeeded(ColumnWithTypeAndName & column)
{ {
const auto * column_function = checkAndGetColumn<ColumnFunction>(*column.column); const auto * column_function = checkAndGetColumn<ColumnFunction>(*column.column);
if (!column_function) if (!column_function || !column_function->isShortCircuitArgument())
return; return;
column = typeid_cast<const ColumnFunction *>(column_function)->reduce(true); column = typeid_cast<const ColumnFunction *>(column_function)->reduce();
} }
bool checkArgumentsForColumnFunction(const ColumnsWithTypeAndName & arguments) bool checkArgumentsForColumnFunction(const ColumnsWithTypeAndName & arguments)