Merge pull request #1879 from yandex/lambdas-without-prerequisites

Lambdas without prerequisites
This commit is contained in:
alexey-milovidov 2018-02-09 22:22:58 +03:00 committed by GitHub
commit 8fb9967903
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 1192 additions and 695 deletions

View File

@ -18,13 +18,13 @@
#include <Common/FieldVisitors.h>
#include <DataTypes/FieldToDataType.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeExpression.h>
#include <DataTypes/DataTypesNumber.h>
#include <AggregateFunctions/parseAggregateFunctionParameters.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <Functions/FunctionFactory.h>
#include <Functions/IFunction.h>
#include <Common/typeid_cast.h>
#include <DataTypes/DataTypeFunction.h>
namespace DB
@ -127,20 +127,21 @@ void processFunction(const String & column_name, ASTPtr & ast, TypeAndConstantIn
{
ASTFunction * function = static_cast<ASTFunction *>(ast.get());
/// Special case for lambda functions. Lambda function has special return type "Expression".
/// We first create info with Expression of unspecified arguments, and will specify them later.
/// Special case for lambda functions. Lambda function has special return type "Function".
/// We first create info with Function of unspecified arguments, and will specify them later.
if (function->name == "lambda")
{
size_t number_of_lambda_parameters = AnalyzeLambdas::extractLambdaParameters(function->arguments->children.at(0)).size();
TypeAndConstantInference::ExpressionInfo expression_info;
expression_info.node = ast;
expression_info.data_type = std::make_unique<DataTypeExpression>(DataTypes(number_of_lambda_parameters));
expression_info.data_type = std::make_unique<DataTypeFunction>(DataTypes(number_of_lambda_parameters));
info.emplace(column_name, std::move(expression_info));
return;
}
DataTypes argument_types;
ColumnsWithTypeAndName argument_columns;
if (function->arguments)
{
@ -151,6 +152,9 @@ void processFunction(const String & column_name, ASTPtr & ast, TypeAndConstantIn
throw Exception("Logical error: type of function argument was not inferred during depth-first search", ErrorCodes::LOGICAL_ERROR);
argument_types.emplace_back(it->second.data_type);
argument_columns.emplace_back(ColumnWithTypeAndName(nullptr, it->second.data_type, ""));
if (it->second.is_constant_expression)
argument_columns.back().column = it->second.data_type->createColumnConst(1, it->second.value);
}
}
@ -203,7 +207,7 @@ void processFunction(const String & column_name, ASTPtr & ast, TypeAndConstantIn
return;
}
const FunctionPtr & function_ptr = FunctionFactory::instance().get(function->name, context);
const auto & function_builder_ptr = FunctionFactory::instance().get(function->name, context);
/// (?) Replace function name to canonical one. Because same function could be referenced by different names.
// function->name = function_ptr->getName();
@ -228,10 +232,12 @@ void processFunction(const String & column_name, ASTPtr & ast, TypeAndConstantIn
}
}
auto function_ptr = function_builder_ptr->build(argument_columns);
TypeAndConstantInference::ExpressionInfo expression_info;
expression_info.node = ast;
expression_info.function = function_ptr;
expression_info.data_type = function_ptr->getReturnType(argument_types);
expression_info.data_type = function_ptr->getReturnType();
if (all_consts && function_ptr->isSuitableForConstantFolding())
{
@ -325,7 +331,7 @@ void processHigherOrderFunction(
{
ASTFunction * function = static_cast<ASTFunction *>(ast.get());
const FunctionPtr & function_ptr = FunctionFactory::instance().get(function->name, context);
const auto & function_builder_ptr = FunctionFactory::instance().get(function->name, context);
if (!function->arguments)
throw Exception("Unexpected AST for higher-order function", ErrorCodes::UNEXPECTED_AST_STRUCTURE);
@ -339,7 +345,7 @@ void processHigherOrderFunction(
types.emplace_back(child_info.data_type);
}
function_ptr->getLambdaArgumentTypes(types);
function_builder_ptr->getLambdaArgumentTypes(types);
/// For every lambda expression, dive into it.
@ -353,11 +359,11 @@ void processHigherOrderFunction(
const ASTFunction * lambda = typeid_cast<const ASTFunction *>(child.get());
if (lambda && lambda->name == "lambda")
{
const DataTypeExpression * lambda_type = typeid_cast<const DataTypeExpression *>(types[i].get());
const auto * lambda_type = typeid_cast<const DataTypeFunction *>(types[i].get());
if (!lambda_type)
throw Exception("Logical error: IFunction::getLambdaArgumentTypes returned data type for lambda expression,"
" that is not DataTypeExpression", ErrorCodes::LOGICAL_ERROR);
" that is not DataTypeFunction", ErrorCodes::LOGICAL_ERROR);
if (!lambda->arguments || lambda->arguments->children.size() != 2)
throw Exception("Lambda function must have exactly two arguments (sides of arrow)", ErrorCodes::BAD_LAMBDA);
@ -390,7 +396,7 @@ void processHigherOrderFunction(
/// Update Expression type (expression signature).
info.at(lambda->getColumnName()).data_type = std::make_shared<DataTypeExpression>(
info.at(lambda->getColumnName()).data_type = std::make_shared<DataTypeFunction>(
lambda_argument_types, info.at(lambda->arguments->children[1]->getColumnName()).data_type);
}
}

View File

@ -15,7 +15,7 @@ struct CollectAliases;
struct AnalyzeColumns;
struct AnalyzeLambdas;
struct ExecuteTableFunctions;
class IFunction;
class IFunctionBase;
class IAggregateFunction;
@ -46,7 +46,7 @@ struct TypeAndConstantInference
DataTypePtr data_type;
bool is_constant_expression = false;
Field value; /// Has meaning if is_constant_expression == true.
std::shared_ptr<IFunction> function;
std::shared_ptr<IFunctionBase> function;
std::shared_ptr<IAggregateFunction> aggregate_function;
};

View File

@ -69,9 +69,7 @@ MutableColumns ColumnConst::scatter(ColumnIndex num_columns, const Selector & se
throw Exception("Size of selector (" + toString(selector.size()) + ") doesn't match size of column (" + toString(s) + ")",
ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
std::vector<size_t> counts(num_columns);
for (auto idx : selector)
++counts[idx];
std::vector<size_t> counts = countColumnsSizeInSelector(num_columns, selector);
MutableColumns res(num_columns);
for (size_t i = 0; i < num_columns; ++i)

View File

@ -1,25 +0,0 @@
#include <Interpreters/ExpressionActions.h>
#include <Columns/ColumnExpression.h>
namespace DB
{
ColumnExpression::ColumnExpression(
size_t s_, const ExpressionActionsPtr & expression_, const NamesAndTypesList & arguments_, const DataTypePtr & return_type_, const String & return_name_)
: expression(expression_), arguments(arguments_), return_type(return_type_), return_name(return_name_)
{
s = s_;
}
MutableColumnPtr ColumnExpression::cloneDummy(size_t s_) const
{
return ColumnExpression::create(s_, expression, arguments, return_type, return_name);
}
const ExpressionActionsPtr & ColumnExpression::getExpression() const { return expression; }
const DataTypePtr & ColumnExpression::getReturnType() const { return return_type; }
const std::string & ColumnExpression::getReturnName() const { return return_name; }
const NamesAndTypesList & ColumnExpression::getArguments() const { return arguments; }
}

View File

@ -1,41 +0,0 @@
#pragma once
#include <Core/NamesAndTypes.h>
#include <Columns/IColumnDummy.h>
namespace DB
{
class ExpressionActions;
/** A column containing a lambda expression.
* Behaves like a constant-column. Contains an expression, but not input or output data.
*/
class ColumnExpression final : public COWPtrHelper<IColumnDummy, ColumnExpression>
{
private:
friend class COWPtrHelper<IColumnDummy, ColumnExpression>;
using ExpressionActionsPtr = std::shared_ptr<ExpressionActions>;
ColumnExpression(size_t s_, const ExpressionActionsPtr & expression_, const NamesAndTypesList & arguments_, const DataTypePtr & return_type_, const String & return_name_);
ColumnExpression(const ColumnExpression &) = default;
public:
const char * getFamilyName() const override { return "Expression"; }
MutableColumnPtr cloneDummy(size_t s_) const override;
const ExpressionActionsPtr & getExpression() const;
const DataTypePtr & getReturnType() const;
const std::string & getReturnName() const;
const NamesAndTypesList & getArguments() const;
private:
ExpressionActionsPtr expression;
NamesAndTypesList arguments;
DataTypePtr return_type;
std::string return_name;
};
}

View File

@ -0,0 +1,202 @@
#include <Interpreters/ExpressionActions.h>
#include <Columns/ColumnFunction.h>
#include <Columns/ColumnsCommon.h>
#include <Functions/IFunction.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
ColumnFunction::ColumnFunction(size_t size, FunctionBasePtr function, const ColumnsWithTypeAndName & columns_to_capture)
: size_(size), function(function)
{
appendArguments(columns_to_capture);
}
MutableColumnPtr ColumnFunction::cloneResized(size_t size) const
{
ColumnsWithTypeAndName capture = captured_columns;
for (auto & column : capture)
column.column = column.column->cloneResized(size);
return ColumnFunction::create(size, function, capture);
}
MutableColumnPtr ColumnFunction::replicate(const Offsets & offsets) const
{
if (size_ != offsets.size())
throw Exception("Size of offsets (" + toString(offsets.size()) + ") doesn't match size of column ("
+ toString(size_) + ")", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
ColumnsWithTypeAndName capture = captured_columns;
for (auto & column : capture)
column.column = column.column->replicate(offsets);
size_t replicated_size = 0 == size_ ? 0 : offsets.back();
return ColumnFunction::create(replicated_size, function, capture);
}
MutableColumnPtr ColumnFunction::cut(size_t start, size_t length) const
{
ColumnsWithTypeAndName capture = captured_columns;
for (auto & column : capture)
column.column = column.column->cut(start, length);
return ColumnFunction::create(length, function, capture);
}
MutableColumnPtr ColumnFunction::filter(const Filter & filt, ssize_t result_size_hint) const
{
if (size_ != filt.size())
throw Exception("Size of filter (" + toString(filt.size()) + ") doesn't match size of column ("
+ toString(size_) + ")", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
ColumnsWithTypeAndName capture = captured_columns;
for (auto & column : capture)
column.column = column.column->filter(filt, result_size_hint);
size_t filtered_size = 0;
if (capture.empty())
filtered_size = countBytesInFilter(filt);
else
filtered_size = capture.front().column->size();
return ColumnFunction::create(filtered_size, function, capture);
}
MutableColumnPtr ColumnFunction::permute(const Permutation & perm, size_t limit) const
{
if (limit == 0)
limit = size_;
else
limit = std::min(size_, limit);
if (perm.size() < limit)
throw Exception("Size of permutation (" + toString(perm.size()) + ") is less than required ("
+ toString(limit) + ")", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
ColumnsWithTypeAndName capture = captured_columns;
for (auto & column : capture)
column.column = column.column->permute(perm, limit);
return ColumnFunction::create(limit, function, capture);
}
std::vector<MutableColumnPtr> ColumnFunction::scatter(IColumn::ColumnIndex num_columns,
const IColumn::Selector & selector) const
{
if (size_ != selector.size())
throw Exception("Size of selector (" + toString(selector.size()) + ") doesn't match size of column ("
+ toString(size_) + ")", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
std::vector<size_t> counts;
if (captured_columns.empty())
counts = countColumnsSizeInSelector(num_columns, selector);
std::vector<ColumnsWithTypeAndName> captures(num_columns, captured_columns);
for (size_t capture = 0; capture < captured_columns.size(); ++capture)
{
auto parts = captured_columns[capture].column->scatter(num_columns, selector);
for (IColumn::ColumnIndex part = 0; part < num_columns; ++part)
captures[part][capture].column = std::move(parts[part]);
}
std::vector<MutableColumnPtr> columns;
columns.reserve(num_columns);
for (IColumn::ColumnIndex part = 0; part < num_columns; ++part)
{
auto & capture = captures[part];
size_t size__ = capture.empty() ? counts[part] : capture.front().column->size();
columns.emplace_back(ColumnFunction::create(size__, function, std::move(capture)));
}
return columns;
}
void ColumnFunction::insertDefault()
{
for (auto & column : captured_columns)
column.column->assumeMutable()->insertDefault();
++size_;
}
void ColumnFunction::popBack(size_t n)
{
for (auto & column : captured_columns)
column.column->assumeMutable()->popBack(n);
size_ -= n;
}
size_t ColumnFunction::byteSize() const
{
size_t total_size = 0;
for (auto & column : captured_columns)
total_size += column.column->byteSize();
return total_size;
}
size_t ColumnFunction::allocatedBytes() const
{
size_t total_size = 0;
for (auto & column : captured_columns)
total_size += column.column->allocatedBytes();
return total_size;
}
void ColumnFunction::appendArguments(const ColumnsWithTypeAndName & columns)
{
auto args = function->getArgumentTypes().size();
auto were_captured = captured_columns.size();
auto wanna_capture = columns.size();
if (were_captured + wanna_capture > args)
throw Exception("Cannot capture " + toString(wanna_capture) + " columns because function " + function->getName()
+ " has " + toString(args) + " arguments" +
(were_captured ? " and " + toString(were_captured) + " columns have already been captured" : "")
+ ".", ErrorCodes::LOGICAL_ERROR);
for (const auto & column : columns)
appendArgument(column);
}
void ColumnFunction::appendArgument(const ColumnWithTypeAndName & column)
{
const auto & argumnet_types = function->getArgumentTypes();
auto index = captured_columns.size();
if (!column.type->equals(*argumnet_types[index]))
throw Exception("Cannot capture column " + std::to_string(argumnet_types.size()) +
"because it has incompatible type: got " + column.type->getName() +
", but " + argumnet_types[index]->getName() + " is expected.", ErrorCodes::LOGICAL_ERROR);
captured_columns.push_back(column);
}
ColumnWithTypeAndName ColumnFunction::reduce() const
{
auto args = function->getArgumentTypes().size();
auto captured = captured_columns.size();
if (args != captured)
throw Exception("Cannot call function " + function->getName() + " because is has " + toString(args) +
"arguments but " + toString(captured) + " columns were captured.", ErrorCodes::LOGICAL_ERROR);
Block block(captured_columns);
block.insert({nullptr, function->getReturnType(), ""});
ColumnNumbers arguments(captured_columns.size());
for (size_t i = 0; i < captured_columns.size(); ++i)
arguments[i] = i;
function->execute(block, arguments, captured_columns.size());
return block.getByPosition(captured_columns.size());
}
}

View File

@ -0,0 +1,116 @@
#pragma once
#include <Core/NamesAndTypes.h>
#include <Columns/IColumn.h>
class IFunctionBase;
using FunctionBasePtr = std::shared_ptr<IFunctionBase>;
namespace DB
{
/** A column containing a lambda expression.
* Behaves like a constant-column. Contains an expression, but not input or output data.
*/
class ColumnFunction final : public COWPtrHelper<IColumn, ColumnFunction>
{
private:
friend class COWPtrHelper<IColumn, ColumnFunction>;
ColumnFunction(size_t size, FunctionBasePtr function, const ColumnsWithTypeAndName & columns_to_capture);
public:
const char * getFamilyName() const override { return "Function"; }
MutableColumnPtr cloneResized(size_t size) const override;
size_t size() const override { return size_; }
MutableColumnPtr cut(size_t start, size_t length) const override;
MutableColumnPtr replicate(const Offsets & offsets) const override;
MutableColumnPtr filter(const Filter & filt, ssize_t result_size_hint) const override;
MutableColumnPtr permute(const Permutation & perm, size_t limit) const override;
void insertDefault() override;
void popBack(size_t n) override;
std::vector<MutableColumnPtr> scatter(IColumn::ColumnIndex num_columns,
const IColumn::Selector & selector) const override;
void getExtremes(Field &, Field &) const override {}
size_t byteSize() const override;
size_t allocatedBytes() const override;
void appendArguments(const ColumnsWithTypeAndName & columns);
ColumnWithTypeAndName reduce() const;
Field operator[](size_t) const override
{
throw Exception("Cannot get value from " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
void get(size_t, Field &) const override
{
throw Exception("Cannot get value from " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
StringRef getDataAt(size_t) const override
{
throw Exception("Cannot get value from " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
void insert(const Field &) override
{
throw Exception("Cannot get insert into " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
void insertRangeFrom(const IColumn &, size_t, size_t) override
{
throw Exception("Cannot insert into " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
void insertData(const char *, size_t) override
{
throw Exception("Cannot insert into " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
StringRef serializeValueIntoArena(size_t, Arena &, char const *&) const override
{
throw Exception("Cannot serialize from " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
const char * deserializeAndInsertFromArena(const char *) override
{
throw Exception("Cannot deserialize to " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
void updateHashWithValue(size_t, SipHash &) const override
{
throw Exception("updateHashWithValue is not implemented for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
int compareAt(size_t, size_t, const IColumn &, int) const override
{
throw Exception("compareAt is not implemented for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
void getPermutation(bool, size_t, int, Permutation &) const override
{
throw Exception("getPermutation is not implemented for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
void gather(ColumnGathererStream &) override
{
throw Exception("Method gather is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
private:
size_t size_;
FunctionBasePtr function;
ColumnsWithTypeAndName captured_columns;
void appendArgument(const ColumnWithTypeAndName & column);
};
}

View File

@ -48,6 +48,14 @@ size_t countBytesInFilter(const IColumn::Filter & filt)
return count;
}
std::vector<size_t> countColumnsSizeInSelector(IColumn::ColumnIndex num_columns, const IColumn::Selector & selector)
{
std::vector<size_t> counts(num_columns);
for (auto idx : selector)
++counts[idx];
return counts;
}
/** clang 4 generates better code than gcc 6.
* And both gcc and clang could not vectorize trivial loop by bytes automatically.

View File

@ -11,6 +11,10 @@ namespace DB
/// Counts how many bytes of `filt` are greater than zero.
size_t countBytesInFilter(const IColumn::Filter & filt);
/// Returns vector with num_columns elements. vector[i] is the count of i values in selector.
/// Selector must contain values from 0 to num_columns - 1. NOTE: this is not checked.
std::vector<size_t> countColumnsSizeInSelector(IColumn::ColumnIndex num_columns, const IColumn::Selector & selector);
/// Returns true, if the memory contains only zeros.
bool memoryIsZero(const void * data, size_t size);

View File

@ -56,10 +56,10 @@ std::ostream & operator<<(std::ostream & stream, const TableStructureReadLock &)
return stream;
}
std::ostream & operator<<(std::ostream & stream, const IFunction & what)
std::ostream & operator<<(std::ostream & stream, const IFunctionBuilder & what)
{
stream << "IFunction(name = " << what.getName() << ", variadic = " << what.isVariadic() << ", args = " << what.getNumberOfArguments()
<< ")";
stream << "IFunction(name = " << what.getName() << ", variadic = " << what.isVariadic() << ", args = "
<< what.getNumberOfArguments() << ")";
return stream;
}

View File

@ -25,8 +25,8 @@ std::ostream & operator<<(std::ostream & stream, const IStorage & what);
class TableStructureReadLock;
std::ostream & operator<<(std::ostream & stream, const TableStructureReadLock & what);
class IFunction;
std::ostream & operator<<(std::ostream & stream, const IFunction & what);
class IFunctionBase;
std::ostream & operator<<(std::ostream & stream, const IFunctionBase & what);
class Block;
std::ostream & operator<<(std::ostream & stream, const Block & what);

View File

@ -7,8 +7,6 @@
namespace DB
{
class IFunction;
/// Implicitly converts string and numeric values to Enum, numeric types to other numeric types.
class CastTypeBlockInputStream : public IProfilingBlockInputStream
{

View File

@ -1,12 +1,12 @@
#include <DataTypes/DataTypeExpression.h>
#include <DataTypes/DataTypeFunction.h>
namespace DB
{
std::string DataTypeExpression::getName() const
std::string DataTypeFunction::getName() const
{
std::string res = "Expression(";
std::string res = "Function(";
if (argument_types.size() > 1)
res += "(";
for (size_t i = 0; i < argument_types.size(); ++i)
@ -24,7 +24,7 @@ std::string DataTypeExpression::getName() const
return res;
}
bool DataTypeExpression::equals(const IDataType & rhs) const
bool DataTypeFunction::equals(const IDataType & rhs) const
{
return typeid(rhs) == typeid(*this) && getName() == rhs.getName();
}

View File

@ -8,7 +8,7 @@ namespace DB
/** Special data type, representing lambda expression.
*/
class DataTypeExpression final : public IDataTypeDummy
class DataTypeFunction final : public IDataTypeDummy
{
private:
DataTypes argument_types;
@ -19,11 +19,11 @@ public:
bool isParametric() const override { return true; }
/// Some types could be still unknown.
DataTypeExpression(const DataTypes & argument_types_ = DataTypes(), const DataTypePtr & return_type_ = nullptr)
: argument_types(argument_types_), return_type(return_type_) {}
DataTypeFunction(const DataTypes & argument_types_ = DataTypes(), const DataTypePtr & return_type_ = nullptr)
: argument_types(argument_types_), return_type(return_type_) {}
std::string getName() const override;
const char * getFamilyName() const override { return "Expression"; }
const char * getFamilyName() const override { return "Function"; }
const DataTypes & getArgumentTypes() const
{

View File

@ -33,7 +33,7 @@ void FunctionFactory::registerFunction(const
}
FunctionPtr FunctionFactory::get(
FunctionBuilderPtr FunctionFactory::get(
const std::string & name,
const Context & context) const
{
@ -44,7 +44,7 @@ FunctionPtr FunctionFactory::get(
}
FunctionPtr FunctionFactory::tryGet(
FunctionBuilderPtr FunctionFactory::tryGet(
const std::string & name,
const Context & context) const
{

View File

@ -25,7 +25,7 @@ class FunctionFactory : public ext::singleton<FunctionFactory>
friend class StorageSystemFunctions;
public:
using Creator = std::function<FunctionPtr(const Context &)>;
using Creator = std::function<FunctionBuilderPtr(const Context &)>;
/// For compatibility with SQL, it's possible to specify that certain function name is case insensitive.
enum CaseSensitiveness
@ -34,30 +34,45 @@ public:
CaseInsensitive
};
/// Register a function by its name.
/// No locking, you must register all functions before usage of get.
void registerFunction(
const std::string & name,
Creator creator,
CaseSensitiveness case_sensitiveness = CaseSensitive);
template <typename Function>
void registerFunction(CaseSensitiveness case_sensitiveness = CaseSensitive)
{
registerFunction<Function>(Function::name, case_sensitiveness);
}
template <typename Function>
void registerFunction()
void registerFunction(const std::string & name, CaseSensitiveness case_sensitiveness = CaseSensitive)
{
registerFunction(Function::name, &Function::create);
if constexpr (std::is_base_of<IFunction, Function>::value)
registerFunction(name, &createDefaultFunction<Function>, case_sensitiveness);
else
registerFunction(name, &Function::create, case_sensitiveness);
}
/// Throws an exception if not found.
FunctionPtr get(const std::string & name, const Context & context) const;
FunctionBuilderPtr get(const std::string & name, const Context & context) const;
/// Returns nullptr if not found.
FunctionPtr tryGet(const std::string & name, const Context & context) const;
FunctionBuilderPtr tryGet(const std::string & name, const Context & context) const;
private:
using Functions = std::unordered_map<std::string, Creator>;
Functions functions;
Functions case_insensitive_functions;
template <typename Function>
static FunctionBuilderPtr createDefaultFunction(const Context & context)
{
return std::make_shared<DefaultFunctionBuilder>(Function::create(context));
}
/// Register a function by its name.
/// No locking, you must register all functions before usage of get.
void registerFunction(
const std::string & name,
Creator creator,
CaseSensitiveness case_sensitiveness = CaseSensitive);
};
}

View File

@ -778,7 +778,7 @@ private:
return false;
}
FunctionPtr getFunctionForIntervalArithmetic(const DataTypePtr & type0, const DataTypePtr & type1) const
FunctionBuilderPtr getFunctionForIntervalArithmetic(const DataTypePtr & type0, const DataTypePtr & type1) const
{
/// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
/// We construct another function (example: addMonths) and call it.
@ -830,7 +830,7 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
/// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
if (FunctionPtr function = getFunctionForIntervalArithmetic(arguments[0], arguments[1]))
if (auto function_builder = getFunctionForIntervalArithmetic(arguments[0], arguments[1]))
{
ColumnsWithTypeAndName new_arguments(2);
@ -844,10 +844,8 @@ public:
/// Change interval argument to its representation
new_arguments[1].type = std::make_shared<DataTypeNumber<DataTypeInterval::FieldType>>();
DataTypePtr res;
std::vector<ExpressionAction> unused_prerequisites;
function->getReturnTypeAndPrerequisites(new_arguments, res, unused_prerequisites);
return res;
auto function = function_builder->build(new_arguments);
return function->getReturnType();
}
DataTypePtr type_res;
@ -873,7 +871,7 @@ public:
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override
{
/// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
if (FunctionPtr function = getFunctionForIntervalArithmetic(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type))
if (auto function_builder = getFunctionForIntervalArithmetic(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type))
{
ColumnNumbers new_arguments = arguments;
@ -885,7 +883,11 @@ public:
Block new_block = block;
new_block.getByPosition(new_arguments[1]).type = std::make_shared<DataTypeNumber<DataTypeInterval::FieldType>>();
function->executeImpl(new_block, new_arguments, result);
ColumnsWithTypeAndName new_arguments_with_type_and_name =
{new_block.getByPosition(new_arguments[0]), new_block.getByPosition(new_arguments[1])};
auto function = function_builder->build(new_arguments_with_type_and_name);
function->execute(new_block, new_arguments, result);
block.getByPosition(result).column = new_block.getByPosition(result).column;
return;

View File

@ -2345,10 +2345,7 @@ String FunctionArrayReduce::getName() const
return name;
}
void FunctionArrayReduce::getReturnTypeAndPrerequisitesImpl(
const ColumnsWithTypeAndName & arguments,
DataTypePtr & out_return_type,
std::vector<ExpressionAction> & /*out_prerequisites*/)
DataTypePtr FunctionArrayReduce::getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const
{
/// The first argument is a constant string with the name of the aggregate function
/// (possibly with parameters in parentheses, for example: "quantile(0.99)").
@ -2390,7 +2387,7 @@ void FunctionArrayReduce::getReturnTypeAndPrerequisitesImpl(
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row);
}
out_return_type = aggregate_function->getReturnType();
return aggregate_function->getReturnType();
}

View File

@ -1402,14 +1402,13 @@ public:
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
void getReturnTypeAndPrerequisitesImpl(
const ColumnsWithTypeAndName & arguments,
DataTypePtr & out_return_type,
std::vector<ExpressionAction> & out_prerequisites) override;
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override;
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override;
private:
AggregateFunctionPtr aggregate_function;
/// lazy initialization in getReturnTypeImpl
/// TODO: init in FunctionBuilder
mutable AggregateFunctionPtr aggregate_function;
};

View File

@ -1084,7 +1084,11 @@ public:
{
size_t size = left_tuple->getElements().size();
for (size_t i = 0; i < size; ++i)
getReturnType({ left_tuple->getElements()[i], right_tuple->getElements()[i] });
{
ColumnsWithTypeAndName args = {{nullptr, left_tuple->getElements()[i], ""},
{nullptr, right_tuple->getElements()[i], ""}};
getReturnType(args);
}
}
return std::make_shared<DataTypeUInt8>();

View File

@ -21,9 +21,9 @@ void registerFunctionsConditional(FunctionFactory & factory)
factory.registerFunction<FunctionCaseWithExpression>();
/// These are obsolete function names.
factory.registerFunction("caseWithExpr", FunctionCaseWithExpression::create);
factory.registerFunction("caseWithoutExpr", FunctionMultiIf::create);
factory.registerFunction("caseWithoutExpression", FunctionMultiIf::create);
factory.registerFunction<FunctionCaseWithExpression>("caseWithExpr");
factory.registerFunction<FunctionMultiIf>("caseWithoutExpr");
factory.registerFunction<FunctionMultiIf>("caseWithoutExpression");
}
@ -251,15 +251,15 @@ DataTypePtr FunctionCaseWithExpression::getReturnTypeImpl(const DataTypes & args
/// get the return type of a transform function.
/// Get the return types of the arrays that we pass to the transform function.
DataTypes src_array_types;
DataTypes dst_array_types;
ColumnsWithTypeAndName src_array_types;
ColumnsWithTypeAndName dst_array_types;
for (size_t i = 1; i < (args.size() - 1); ++i)
{
if ((i % 2) != 0)
src_array_types.push_back(args[i]);
src_array_types.push_back({nullptr, args[i], {}});
else
dst_array_types.push_back(args[i]);
dst_array_types.push_back({nullptr, args[i], {}});
}
FunctionArray fun_array{context};
@ -269,7 +269,9 @@ DataTypePtr FunctionCaseWithExpression::getReturnTypeImpl(const DataTypes & args
/// Finally get the return type of the transform function.
FunctionTransform fun_transform;
return fun_transform.getReturnType({args.front(), src_array_type, dst_array_type, args.back()});
ColumnsWithTypeAndName transform_args = {{nullptr, args.front(), {}}, {nullptr, src_array_type, {}},
{nullptr, dst_array_type, {}}, {nullptr, args.back(), {}}};
return fun_transform.getReturnType(transform_args);
}
void FunctionCaseWithExpression::executeImpl(Block & block, const ColumnNumbers & args, size_t result)
@ -288,22 +290,22 @@ void FunctionCaseWithExpression::executeImpl(Block & block, const ColumnNumbers
/// Create the arrays required by the transform function.
ColumnNumbers src_array_args;
DataTypes src_array_types;
ColumnsWithTypeAndName src_array_types;
ColumnNumbers dst_array_args;
DataTypes dst_array_types;
ColumnsWithTypeAndName dst_array_types;
for (size_t i = 1; i < (args.size() - 1); ++i)
{
if ((i % 2) != 0)
{
src_array_args.push_back(args[i]);
src_array_types.push_back(block.getByPosition(args[i]).type);
src_array_types.push_back(block.getByPosition(args[i]));
}
else
{
dst_array_args.push_back(args[i]);
dst_array_types.push_back(block.getByPosition(args[i]).type);
dst_array_types.push_back(block.getByPosition(args[i]));
}
}

View File

@ -47,7 +47,7 @@ void registerFunctionsConversion(FunctionFactory & factory)
factory.registerFunction<FunctionToFixedString>();
factory.registerFunction<FunctionToUnixTimestamp>();
factory.registerFunction<FunctionCast>();
factory.registerFunction<FunctionBuilderCast>();
factory.registerFunction<FunctionToUInt8OrZero>();
factory.registerFunction<FunctionToUInt16OrZero>();

View File

@ -625,14 +625,11 @@ public:
size_t getNumberOfArguments() const override { return 0; }
bool isInjective(const Block &) override { return std::is_same_v<Name, NameToString>; }
void getReturnTypeAndPrerequisitesImpl(
const ColumnsWithTypeAndName & arguments,
DataTypePtr & out_return_type,
std::vector<ExpressionAction> &) override
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if constexpr (std::is_same_v<ToDataType, DataTypeInterval>)
{
out_return_type = std::make_shared<DataTypeInterval>(DataTypeInterval::Kind(Name::kind));
return std::make_shared<DataTypeInterval>(DataTypeInterval::Kind(Name::kind));
}
else
{
@ -660,9 +657,9 @@ public:
}
if (std::is_same_v<ToDataType, DataTypeDateTime>)
out_return_type = std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments(arguments, 1, 0));
return std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments(arguments, 1, 0));
else
out_return_type = std::make_shared<ToDataType>();
return std::make_shared<ToDataType>();
}
}
@ -845,9 +842,7 @@ public:
size_t getNumberOfArguments() const override { return 2; }
bool isInjective(const Block &) override { return true; }
void getReturnTypeAndPrerequisitesImpl(const ColumnsWithTypeAndName & arguments,
DataTypePtr & out_return_type,
std::vector<ExpressionAction> & /*out_prerequisites*/) override
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if (!arguments[1].type->isUnsignedInteger())
throw Exception("Second argument for function " + getName() + " must be unsigned integer", ErrorCodes::ILLEGAL_COLUMN);
@ -857,7 +852,7 @@ public:
throw Exception(getName() + " is only implemented for types String and FixedString", ErrorCodes::NOT_IMPLEMENTED);
const size_t n = arguments[1].column->getUInt(0);
out_return_type = std::make_shared<DataTypeFixedString>(n);
return std::make_shared<DataTypeFixedString>(n);
}
bool useDefaultImplementationForConstants() const override { return true; }
@ -1130,19 +1125,80 @@ using FunctionToFloat32OrNull = FunctionConvertOrNull<DataTypeFloat32, NameToFlo
using FunctionToFloat64OrNull = FunctionConvertOrNull<DataTypeFloat64, NameToFloat64OrNull>;
class FunctionCast final : public IFunction
class PreparedFunctionCast : public PreparedFunctionImpl
{
public:
FunctionCast(const Context & context) : context(context) {}
using WrapperType = std::function<void(Block &, const ColumnNumbers &, size_t)>;
explicit PreparedFunctionCast(WrapperType && wrapper_function, const char * name)
: wrapper_function(std::move(wrapper_function)), name(name) {}
String getName() const override { return name; }
protected:
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override
{
/// drop second argument, pass others
ColumnNumbers new_arguments{arguments.front()};
if (arguments.size() > 2)
new_arguments.insert(std::end(new_arguments), std::next(std::begin(arguments), 2), std::end(arguments));
wrapper_function(block, new_arguments, result);
}
bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; }
private:
using WrapperType = std::function<void(Block &, const ColumnNumbers &, size_t)>;
const Context & context;
WrapperType wrapper_function;
std::function<Monotonicity(const IDataType &, const Field &, const Field &)> monotonicity_for_range;
const char * name;
};
class FunctionCast final : public IFunctionBase
{
public:
using WrapperType = std::function<void(Block &, const ColumnNumbers &, size_t)>;
using MonotonicityForRange = std::function<Monotonicity(const IDataType &, const Field &, const Field &)>;
FunctionCast(const Context & context, const char * name, MonotonicityForRange && monotonicity_for_range
, const DataTypes & argument_types, const DataTypePtr & return_type)
: context(context), name(name), monotonicity_for_range(monotonicity_for_range)
, argument_types(argument_types), return_type(return_type)
{
}
const DataTypes & getArgumentTypes() const override { return argument_types; }
const DataTypePtr & getReturnType() const override { return return_type; }
PreparedFunctionPtr prepare(const Block & /*sample_block*/) const override
{
return std::make_shared<PreparedFunctionCast>(prepare(getArgumentTypes()[0], getReturnType().get()), name);
}
String getName() const override { return name; }
bool hasInformationAboutMonotonicity() const override
{
return static_cast<bool>(monotonicity_for_range);
}
Monotonicity getMonotonicityForRange(const IDataType & type, const Field & left, const Field & right) const override
{
return monotonicity_for_range(type, left, right);
}
private:
const Context & context;
const char * name;
MonotonicityForRange monotonicity_for_range;
DataTypes argument_types;
DataTypePtr return_type;
template <typename DataType>
WrapperType createWrapper(const DataTypePtr & from_type, const DataType * const)
WrapperType createWrapper(const DataTypePtr & from_type, const DataType * const) const
{
using FunctionType = typename FunctionTo<DataType>::Type;
@ -1150,9 +1206,7 @@ private:
/// Check conversion using underlying function
{
DataTypePtr unused_data_type;
std::vector<ExpressionAction> unused_prerequisites;
function->getReturnTypeAndPrerequisites({{ nullptr, from_type, "" }}, unused_data_type, unused_prerequisites);
function->getReturnType(ColumnsWithTypeAndName(1, { nullptr, from_type, "" }));
}
return [function] (Block & block, const ColumnNumbers & arguments, const size_t result)
@ -1174,7 +1228,7 @@ private:
};
}
WrapperType createArrayWrapper(const DataTypePtr & from_type_untyped, const DataTypeArray * to_type)
WrapperType createArrayWrapper(const DataTypePtr & from_type_untyped, const DataTypeArray * to_type) const
{
/// Conversion from String through parsing.
if (checkAndGetDataType<DataTypeString>(from_type_untyped.get()))
@ -1235,7 +1289,7 @@ private:
};
}
WrapperType createTupleWrapper(const DataTypePtr & from_type_untyped, const DataTypeTuple * to_type)
WrapperType createTupleWrapper(const DataTypePtr & from_type_untyped, const DataTypeTuple * to_type) const
{
/// Conversion from String through parsing.
if (checkAndGetDataType<DataTypeString>(from_type_untyped.get()))
@ -1304,7 +1358,7 @@ private:
}
template <typename FieldType>
WrapperType createEnumWrapper(const DataTypePtr & from_type, const DataTypeEnum<FieldType> * to_type)
WrapperType createEnumWrapper(const DataTypePtr & from_type, const DataTypeEnum<FieldType> * to_type) const
{
using EnumType = DataTypeEnum<FieldType>;
using Function = typename FunctionTo<EnumType>::Type;
@ -1324,9 +1378,7 @@ private:
/// Check conversion using underlying function
{
DataTypePtr unused_data_type;
std::vector<ExpressionAction> unused_prerequisites;
function->getReturnTypeAndPrerequisites({{ nullptr, from_type, "" }}, unused_data_type, unused_prerequisites);
function->getReturnType(ColumnsWithTypeAndName(1, { nullptr, from_type, "" }));
}
return [function] (Block & block, const ColumnNumbers & arguments, const size_t result)
@ -1342,7 +1394,7 @@ private:
}
template <typename EnumTypeFrom, typename EnumTypeTo>
void checkEnumToEnumConversion(const EnumTypeFrom * from_type, const EnumTypeTo * to_type)
void checkEnumToEnumConversion(const EnumTypeFrom * from_type, const EnumTypeTo * to_type) const
{
const auto & from_values = from_type->getValues();
const auto & to_values = to_type->getValues();
@ -1369,9 +1421,10 @@ private:
};
template <typename ColumnStringType, typename EnumType>
WrapperType createStringToEnumWrapper()
WrapperType createStringToEnumWrapper() const
{
return [] (Block & block, const ColumnNumbers & arguments, const size_t result)
const char * function_name = name;
return [function_name] (Block & block, const ColumnNumbers & arguments, const size_t result)
{
const auto first_col = block.getByPosition(arguments.front()).column.get();
@ -1393,13 +1446,12 @@ private:
}
else
throw Exception{
"Unexpected column " + first_col->getName() + " as first argument of function " +
name,
"Unexpected column " + first_col->getName() + " as first argument of function " + function_name,
ErrorCodes::LOGICAL_ERROR};
};
}
WrapperType createIdentityWrapper(const DataTypePtr &)
WrapperType createIdentityWrapper(const DataTypePtr &) const
{
return [] (Block & block, const ColumnNumbers & arguments, const size_t result)
{
@ -1407,7 +1459,7 @@ private:
};
}
WrapperType createNothingWrapper(const IDataType * to_type)
WrapperType createNothingWrapper(const IDataType * to_type) const
{
ColumnPtr res = to_type->createColumnConstWithDefaultValue(1);
return [res] (Block & block, const ColumnNumbers &, const size_t result)
@ -1424,7 +1476,7 @@ private:
bool result_is_nullable = false;
};
WrapperType prepare(const DataTypePtr & from_type, const IDataType * to_type)
WrapperType prepare(const DataTypePtr & from_type, const IDataType * to_type) const
{
/// Determine whether pre-processing and/or post-processing must take place during conversion.
@ -1519,7 +1571,7 @@ private:
return wrapper;
}
WrapperType prepareImpl(const DataTypePtr & from_type, const IDataType * to_type)
WrapperType prepareImpl(const DataTypePtr & from_type, const IDataType * to_type) const
{
if (from_type->equals(*to_type))
return createIdentityWrapper(from_type);
@ -1569,99 +1621,95 @@ private:
"Conversion from " + from_type->getName() + " to " + to_type->getName() + " is not supported",
ErrorCodes::CANNOT_CONVERT_TYPE};
}
};
class FunctionBuilderCast : public FunctionBuilderImpl
{
public:
using MonotonicityForRange = FunctionCast::MonotonicityForRange;
static constexpr auto name = "CAST";
static FunctionBuilderPtr create(const Context & context) { return std::make_shared<FunctionBuilderCast>(context); }
FunctionBuilderCast(const Context & context) : context(context) {}
String getName() const { return name; }
size_t getNumberOfArguments() const override { return 2; }
protected:
FunctionBasePtr buildImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type) const override
{
DataTypes data_types(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i)
data_types[i] = arguments[i].type;
auto monotonicity = getMonotonicityInformation(arguments.front().type, return_type.get());
return std::make_shared<FunctionCast>(context, name, std::move(monotonicity), data_types, return_type);
}
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
const auto type_col = checkAndGetColumnConst<ColumnString>(arguments.back().column.get());
if (!type_col)
throw Exception("Second argument to " + getName() + " must be a constant string describing type",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return DataTypeFactory::instance().get(type_col->getValue<String>());
}
bool useDefaultImplementationForNulls() const override { return false; }
private:
template <typename DataType>
static auto monotonicityForType(const DataType * const)
{
return FunctionTo<DataType>::Type::Monotonic::get;
}
void prepareMonotonicityInformation(const DataTypePtr & from_type, const IDataType * to_type)
MonotonicityForRange getMonotonicityInformation(const DataTypePtr & from_type, const IDataType * to_type) const
{
if (const auto type = checkAndGetDataType<DataTypeUInt8>(to_type))
monotonicity_for_range = monotonicityForType(type);
return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeUInt16>(to_type))
monotonicity_for_range = monotonicityForType(type);
return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeUInt32>(to_type))
monotonicity_for_range = monotonicityForType(type);
return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeUInt64>(to_type))
monotonicity_for_range = monotonicityForType(type);
return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeInt8>(to_type))
monotonicity_for_range = monotonicityForType(type);
return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeInt16>(to_type))
monotonicity_for_range = monotonicityForType(type);
return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeInt32>(to_type))
monotonicity_for_range = monotonicityForType(type);
return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeInt64>(to_type))
monotonicity_for_range = monotonicityForType(type);
return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeFloat32>(to_type))
monotonicity_for_range = monotonicityForType(type);
return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeFloat64>(to_type))
monotonicity_for_range = monotonicityForType(type);
return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeDate>(to_type))
monotonicity_for_range = monotonicityForType(type);
return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeDateTime>(to_type))
monotonicity_for_range = monotonicityForType(type);
return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeString>(to_type))
monotonicity_for_range = monotonicityForType(type);
return monotonicityForType(type);
else if (from_type->isEnum())
{
if (const auto type = checkAndGetDataType<DataTypeEnum8>(to_type))
monotonicity_for_range = monotonicityForType(type);
return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeEnum16>(to_type))
monotonicity_for_range = monotonicityForType(type);
return monotonicityForType(type);
}
/// other types like Null, FixedString, Array and Tuple have no monotonicity defined
return {};
}
public:
static constexpr auto name = "CAST";
static FunctionPtr create(const Context & context) { return std::make_shared<FunctionCast>(context); }
String getName() const override { return name; }
bool useDefaultImplementationForNulls() const override { return false; }
size_t getNumberOfArguments() const override { return 2; }
void getReturnTypeAndPrerequisitesImpl(
const ColumnsWithTypeAndName & arguments, DataTypePtr & out_return_type,
std::vector<ExpressionAction> & /*out_prerequisites*/) override
{
const auto type_col = checkAndGetColumnConst<ColumnString>(arguments.back().column.get());
if (!type_col)
throw Exception("Second argument to " + getName() + " must be a constant string describing type",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
out_return_type = DataTypeFactory::instance().get(type_col->getValue<String>());
const DataTypePtr & from_type = arguments.front().type;
wrapper_function = prepare(from_type, out_return_type.get());
prepareMonotonicityInformation(from_type, out_return_type.get());
}
bool useDefaultImplementationForConstants() const override { return true; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; }
void executeImpl(Block & block, const ColumnNumbers & arguments, const size_t result) override
{
/// drop second argument, pass others
ColumnNumbers new_arguments{arguments.front()};
if (arguments.size() > 2)
new_arguments.insert(std::end(new_arguments), std::next(std::begin(arguments), 2), std::end(arguments));
wrapper_function(block, new_arguments, result);
}
bool hasInformationAboutMonotonicity() const override
{
return static_cast<bool>(monotonicity_for_range);
}
Monotonicity getMonotonicityForRange(const IDataType & type, const Field & left, const Field & right) const override
{
return monotonicity_for_range(type, left, right);
}
const Context & context;
};
}

View File

@ -83,7 +83,7 @@ void registerFunctionsDateTime(FunctionFactory & factory)
factory.registerFunction<FunctionToRelativeMinuteNum>();
factory.registerFunction<FunctionToRelativeSecondNum>();
factory.registerFunction<FunctionToTime>();
factory.registerFunction(FunctionNow::name, FunctionNow::create, FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionNow>(FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionToday>();
factory.registerFunction<FunctionYesterday>();
factory.registerFunction<FunctionTimeSlot>();
@ -108,7 +108,7 @@ void registerFunctionsDateTime(FunctionFactory & factory)
factory.registerFunction<FunctionSubtractMonths>();
factory.registerFunction<FunctionSubtractYears>();
factory.registerFunction(FunctionDateDiff::name, FunctionDateDiff::create, FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionDateDiff>(FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionToTimeZone>();
}

View File

@ -633,10 +633,7 @@ public:
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
void getReturnTypeAndPrerequisitesImpl(
const ColumnsWithTypeAndName & arguments,
DataTypePtr & out_return_type,
std::vector<ExpressionAction> &) override
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if (arguments.size() == 1)
{
@ -663,9 +660,9 @@ public:
/// For DateTime, if time zone is specified, attach it to type.
if (std::is_same_v<ToDataType, DataTypeDateTime>)
out_return_type = std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments(arguments, 1, 0));
return std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments(arguments, 1, 0));
else
out_return_type = std::make_shared<ToDataType>();
return std::make_shared<ToDataType>();
}
bool useDefaultImplementationForConstants() const override { return true; }
@ -942,10 +939,7 @@ public:
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
void getReturnTypeAndPrerequisitesImpl(
const ColumnsWithTypeAndName & arguments,
DataTypePtr & out_return_type,
std::vector<ExpressionAction> &) override
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if (arguments.size() != 2 && arguments.size() != 3)
throw Exception("Number of arguments for function " + getName() + " doesn't match: passed "
@ -979,16 +973,16 @@ public:
if (checkDataType<DataTypeDate>(arguments[0].type.get()))
{
if (std::is_same_v<decltype(Transform::execute(DataTypeDate::FieldType(), 0, std::declval<DateLUTImpl>())), UInt16>)
out_return_type = std::make_shared<DataTypeDate>();
return std::make_shared<DataTypeDate>();
else
out_return_type = std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments(arguments, 2, 0));
return std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments(arguments, 2, 0));
}
else
{
if (std::is_same_v<decltype(Transform::execute(DataTypeDateTime::FieldType(), 0, std::declval<DateLUTImpl>())), UInt16>)
out_return_type = std::make_shared<DataTypeDate>();
return std::make_shared<DataTypeDate>();
else
out_return_type = std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments(arguments, 2, 0));
return std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments(arguments, 2, 0));
}
}
@ -1291,10 +1285,7 @@ public:
size_t getNumberOfArguments() const override { return 2; }
void getReturnTypeAndPrerequisitesImpl(
const ColumnsWithTypeAndName & arguments,
DataTypePtr & out_return_type,
std::vector<ExpressionAction> &) override
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if (arguments.size() != 2)
throw Exception("Number of arguments for function " + getName() + " doesn't match: passed "
@ -1307,7 +1298,7 @@ public:
". Should be DateTime", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
String time_zone_name = extractTimeZoneNameFromFunctionArguments(arguments, 1, 0);
out_return_type = std::make_shared<DataTypeDateTime>(time_zone_name);
return std::make_shared<DataTypeDateTime>(time_zone_name);
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override

View File

@ -1,13 +1,11 @@
#pragma once
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeExpression.h>
#include <DataTypes/DataTypesNumber.h>
#include <Interpreters/ExpressionActions.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnExpression.h>
#include <Common/typeid_cast.h>
#include <Functions/IFunction.h>
#include <Functions/FunctionsMiscellaneous.h>
@ -635,15 +633,16 @@ public:
nested_types[i] = array_type->getNestedType();
}
const DataTypeExpression * expression_type = checkAndGetDataType<DataTypeExpression>(&*arguments[0]);
if (!expression_type || expression_type->getArgumentTypes().size() != nested_types.size())
throw Exception("First argument for this overload of " + getName() + " must be an expression with "
const DataTypeFunction * function_type = checkAndGetDataType<DataTypeFunction>(&*arguments[0]);
if (!function_type || function_type->getArgumentTypes().size() != nested_types.size())
throw Exception("First argument for this overload of " + getName() + " must be a function with "
+ toString(nested_types.size()) + " arguments. Found "
+ arguments[0]->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
arguments[0] = std::make_shared<DataTypeExpression>(nested_types);
arguments[0] = std::make_shared<DataTypeFunction>(nested_types);
}
/*
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
size_t min_args = Impl::needExpression() ? 2 : 1;
@ -693,10 +692,9 @@ public:
return Impl::getReturnType(return_type, first_array_type->getNestedType());
}
}
*/
void getReturnTypeAndPrerequisitesImpl(const ColumnsWithTypeAndName & arguments,
DataTypePtr & out_return_type,
ExpressionActions::Actions & out_prerequisites) override
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
size_t min_args = Impl::needExpression() ? 2 : 1;
if (arguments.size() < min_args)
@ -707,7 +705,7 @@ public:
if (arguments.size() == 1)
{
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(&*arguments[0].type);
const auto array_type = checkAndGetDataType<DataTypeArray>(&*arguments[0].type);
if (!array_type)
throw Exception("The only argument for function " + getName() + " must be array. Found "
@ -719,7 +717,7 @@ public:
throw Exception("The only argument for function " + getName() + " must be array of UInt8. Found "
+ arguments[0].type->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
out_return_type = Impl::getReturnType(nested_type, nested_type);
return Impl::getReturnType(nested_type, nested_type);
}
else
{
@ -727,59 +725,31 @@ public:
throw Exception("Function " + getName() + " needs one array argument.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (!arguments[0].column)
throw Exception("Type of first argument for function " + getName() + " must be an expression.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const auto data_type_function = checkAndGetDataType<DataTypeFunction>(arguments[0].type.get());
const ColumnExpression * column_expression = typeid_cast<const ColumnExpression *>(arguments[0].column.get());
if (!column_expression)
throw Exception("Column of first argument for function " + getName() + " must be an expression.",
if (!data_type_function)
throw Exception("First argument for function " + getName() + " must be a function.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
/// The types of the remaining arguments are already checked in getLambdaArgumentTypes.
/// Let's add to the block all the columns mentioned in the expression, multiplied into an array parallel to the one being processed.
const ExpressionActions & expression = *column_expression->getExpression();
const NamesAndTypesList & required_columns = expression.getRequiredColumnsWithTypes();
const NamesAndTypesList expression_arguments = column_expression->getArguments();
NameSet argument_names;
for (const auto & expression_argument : expression_arguments)
argument_names.emplace(expression_argument.name);
for (const auto & required_column : required_columns)
{
if (argument_names.count(required_column.name))
continue;
Names replicate_arguments;
replicate_arguments.push_back(required_column.name);
replicate_arguments.push_back(arguments[1].name);
out_prerequisites.push_back(ExpressionAction::applyFunction(std::make_shared<FunctionReplicate>(), replicate_arguments));
}
DataTypePtr return_type = column_expression->getReturnType();
DataTypePtr return_type = data_type_function->getReturnType();
if (Impl::needBoolean() && !checkDataType<DataTypeUInt8>(&*return_type))
throw Exception("Expression for function " + getName() + " must return UInt8, found "
+ return_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const DataTypeArray * first_array_type = checkAndGetDataType<DataTypeArray>(&*arguments[1].type);
const auto first_array_type = checkAndGetDataType<DataTypeArray>(&*arguments[1].type);
out_return_type = Impl::getReturnType(return_type, first_array_type->getNestedType());
return Impl::getReturnType(return_type, first_array_type->getNestedType());
}
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override
{
executeImpl(block, arguments, {}, result);
}
void executeImpl(Block & block, const ColumnNumbers & arguments, const ColumnNumbers & prerequisites, size_t result) override
{
if (arguments.size() == 1)
{
ColumnPtr column_array_ptr = block.getByPosition(arguments[0]).column;
const ColumnArray * column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
const auto * column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
if (!column_array)
{
@ -797,28 +767,28 @@ public:
const auto & column_with_type_and_name = block.getByPosition(arguments[0]);
if (!column_with_type_and_name.column)
throw Exception("First argument for function " + getName() + " must be an expression.",
throw Exception("First argument for function " + getName() + " must be a function.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const ColumnExpression * column_expression = typeid_cast<const ColumnExpression *>(column_with_type_and_name.column.get());
const auto * column_function = typeid_cast<const ColumnFunction *>(column_with_type_and_name.column.get());
ColumnPtr offsets_column;
Block temp_block;
const ExpressionActions & expression = *column_expression->getExpression();
const NamesAndTypesList & expression_arguments = column_expression->getArguments();
NameSet argument_names;
ColumnPtr column_first_array_ptr;
const ColumnArray * column_first_array = nullptr;
/// Put the expression arguments in the block.
ColumnsWithTypeAndName arrays;
arrays.reserve(arguments.size() - 1);
size_t i = 0;
for (const auto expression_argument : expression_arguments)
for (size_t i = 1; i < arguments.size(); ++i)
{
ColumnPtr column_array_ptr = block.getByPosition(arguments[i + 1]).column;
const ColumnArray * column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
const auto & array_with_type_and_name = block.getByPosition(arguments[i]);
ColumnPtr column_array_ptr = array_with_type_and_name.column;
const auto * column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
const DataTypePtr & array_type_ptr = array_with_type_and_name.type;
const auto * array_type = checkAndGetDataType<DataTypeArray>(array_type_ptr.get());
if (!column_array)
{
@ -829,6 +799,9 @@ public:
column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
}
if (!array_type)
throw Exception("Expected array type, found " + array_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!offsets_column)
{
offsets_column = column_array->getOffsetsPtr();
@ -841,46 +814,23 @@ public:
throw Exception("Arrays passed to " + getName() + " must have equal size", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH);
}
if (i == 0)
if (i == 1)
{
column_first_array_ptr = column_array_ptr;
column_first_array = column_array;
}
temp_block.insert({
column_array->getDataPtr(),
expression_argument.type,
expression_argument.name});
argument_names.insert(expression_argument.name);
++i;
arrays.emplace_back(ColumnWithTypeAndName(column_array->getDataPtr(),
array_type->getNestedType(), array_with_type_and_name.name));
}
/// Put all the necessary columns multiplied by the sizes of arrays into the block.
auto replicated_column_function_ptr = column_function->replicate(column_first_array->getOffsets());
auto * replicated_column_function = typeid_cast<ColumnFunction *>(replicated_column_function_ptr.get());
replicated_column_function->appendArguments(arrays);
Names required_columns = expression.getRequiredColumns();
size_t prerequisite_index = 0;
for (size_t i = 0; i < required_columns.size(); ++i)
{
const String & name = required_columns[i];
if (argument_names.count(name))
continue;
ColumnWithTypeAndName replicated_column = block.getByPosition(prerequisites[prerequisite_index]);
replicated_column.name = name;
replicated_column.column = typeid_cast<const ColumnArray &>(*replicated_column.column).getDataPtr();
replicated_column.type = typeid_cast<const DataTypeArray &>(*replicated_column.type).getNestedType(),
temp_block.insert(std::move(replicated_column));
++prerequisite_index;
}
expression.execute(temp_block);
block.getByPosition(result).column = Impl::execute(*column_first_array, temp_block.getByName(column_expression->getReturnName()).column);
block.getByPosition(result).column = Impl::execute(*column_first_array,
replicated_column_function->reduce().column);
}
}
};

View File

@ -1643,8 +1643,7 @@ public:
return name;
}
void getReturnTypeAndPrerequisitesImpl(
const ColumnsWithTypeAndName & arguments, DataTypePtr & out_return_type, ExpressionActions::Actions & out_prerequisites) override;
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override;
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override;
@ -1678,8 +1677,7 @@ void FunctionVisibleWidth::executeImpl(Block & block, const ColumnNumbers & argu
}
void FunctionHasColumnInTable::getReturnTypeAndPrerequisitesImpl(
const ColumnsWithTypeAndName & arguments, DataTypePtr & out_return_type, ExpressionActions::Actions & /*out_prerequisites*/)
DataTypePtr FunctionHasColumnInTable::getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const
{
if (arguments.size() < 3 || arguments.size() > 6)
throw Exception{"Invalid number of arguments for function " + getName(),
@ -1697,7 +1695,7 @@ void FunctionHasColumnInTable::getReturnTypeAndPrerequisitesImpl(
}
}
out_return_type = std::make_shared<DataTypeUInt8>();
return std::make_shared<DataTypeUInt8>();
}

View File

@ -1,13 +1,16 @@
#pragma once
#include <Functions/IFunction.h>
#include <Interpreters/ExpressionActions.h>
#include <DataTypes/DataTypeFunction.h>
#include <IO/WriteBufferFromString.h>
#include <IO/Operators.h>
#include <Columns/ColumnFunction.h>
namespace DB
{
/** Creates an array, multiplying the column (the first argument) by the number of elements in the array (the second argument).
* Used only as prerequisites for higher-order functions.
*/
class FunctionReplicate : public IFunction
{
@ -32,4 +35,181 @@ public:
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override;
};
/// Executes expression. Uses for lambda functions implementation. Can't be created from factory.
class FunctionExpression : public IFunctionBase, public IPreparedFunction,
public std::enable_shared_from_this<FunctionExpression>
{
public:
FunctionExpression(const ExpressionActionsPtr & expression_actions,
const DataTypes & argument_types, const Names & argument_names,
const DataTypePtr & return_type, const std::string & return_name)
: expression_actions(expression_actions), argument_types(argument_types),
argument_names(argument_names), return_type(return_type), return_name(return_name)
{
}
String getName() const override { return "FunctionExpression"; }
const DataTypes & getArgumentTypes() const override { return argument_types; }
const DataTypePtr & getReturnType() const override { return return_type; }
PreparedFunctionPtr prepare(const Block &) const override
{
return std::const_pointer_cast<FunctionExpression>(shared_from_this());
}
void execute(Block & block, const ColumnNumbers & arguments, size_t result) override
{
Block expr_block;
for (size_t i = 0; i < arguments.size(); ++i)
{
const auto & argument = block.getByPosition(arguments[i]);
/// Replace column name with value from argument_names.
expr_block.insert({argument.column, argument.type, argument_names[i]});
}
expression_actions->execute(expr_block);
block.getByPosition(result).column = expr_block.getByName(return_name).column;
}
private:
ExpressionActionsPtr expression_actions;
DataTypes argument_types;
Names argument_names;
DataTypePtr return_type;
std::string return_name;
};
/// Captures columns which are used by lambda function but not in argument list.
/// Returns ColumnFunction with captured columns.
/// For lambda(x, x + y) x is in lambda_arguments, y is in captured arguments, expression_actions is 'x + y'.
/// execute(y) returns ColumnFunction(FunctionExpression(x + y), y) with type Function(x) -> function_return_type.
class FunctionCapture : public IFunctionBase, public IPreparedFunction, public FunctionBuilderImpl,
public std::enable_shared_from_this<FunctionCapture>
{
public:
FunctionCapture(const ExpressionActionsPtr & expression_actions, const Names & captured,
const NamesAndTypesList & lambda_arguments,
const DataTypePtr & function_return_type, const std::string & expression_return_name)
: expression_actions(expression_actions), captured_names(captured), lambda_arguments(lambda_arguments)
, function_return_type(function_return_type), expression_return_name(expression_return_name)
{
const auto & all_arguments = expression_actions->getRequiredColumnsWithTypes();
std::unordered_map<std::string, DataTypePtr> arguments_map;
for (const auto & arg : all_arguments)
arguments_map[arg.name] = arg.type;
auto collect = [&arguments_map](const Names & names)
{
DataTypes types;
types.reserve(names.size());
for (const auto & name : names)
{
auto it = arguments_map.find(name);
if (it == arguments_map.end())
throw Exception("Lambda captured argument " + name + " not found in required columns.",
ErrorCodes::LOGICAL_ERROR);
types.push_back(it->second);
arguments_map.erase(it);
}
return types;
};
captured_types = collect(captured_names);
DataTypes argument_types;
argument_types.reserve(lambda_arguments.size());
for (const auto & lambda_argument : lambda_arguments)
argument_types.push_back(lambda_argument.type);
return_type = std::make_shared<DataTypeFunction>(argument_types, function_return_type);
name = "Capture[" + toString(captured_types) + "](" + toString(argument_types) + ") -> "
+ function_return_type->getName();
}
String getName() const override { return name; }
const DataTypes & getArgumentTypes() const override { return captured_types; }
const DataTypePtr & getReturnType() const override { return return_type; }
PreparedFunctionPtr prepare(const Block &) const override
{
return std::const_pointer_cast<FunctionCapture>(shared_from_this());
}
void execute(Block & block, const ColumnNumbers & arguments, size_t result) override
{
ColumnsWithTypeAndName columns;
columns.reserve(arguments.size());
Names names;
DataTypes types;
names.reserve(captured_names.size() + lambda_arguments.size());
names.insert(names.end(), captured_names.begin(), captured_names.end());
types.reserve(captured_types.size() + lambda_arguments.size());
types.insert(types.end(), captured_types.begin(), captured_types.end());
for (const auto & lambda_argument : lambda_arguments)
{
names.push_back(lambda_argument.name);
types.push_back(lambda_argument.type);
}
for (const auto & argument : arguments)
columns.push_back(block.getByPosition(argument));
auto function = std::make_shared<FunctionExpression>(expression_actions, types, names,
function_return_type, expression_return_name);
auto size = block.rows();
block.getByPosition(result).column = ColumnFunction::create(size, std::move(function), columns);
}
size_t getNumberOfArguments() const override { return captured_types.size(); }
protected:
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const override { return return_type; }
bool useDefaultImplementationForNulls() const override { return false; }
FunctionBasePtr buildImpl(const ColumnsWithTypeAndName &, const DataTypePtr &) const override
{
return std::const_pointer_cast<FunctionCapture>(shared_from_this());
}
private:
std::string toString(const DataTypes & data_types) const
{
std::string result;
{
WriteBufferFromString buffer(result);
bool first = true;
for (const auto & type : data_types)
{
if (!first)
buffer << ", ";
first = false;
buffer << type->getName();
}
}
return result;
}
ExpressionActionsPtr expression_actions;
DataTypes captured_types;
Names captured_names;
NamesAndTypesList lambda_arguments;
DataTypePtr function_return_type;
DataTypePtr return_type;
std::string expression_return_name;
std::string name;
};
}

View File

@ -10,14 +10,14 @@ void registerFunctionsRound(FunctionFactory & factory)
factory.registerFunction<FunctionRoundDuration>();
factory.registerFunction<FunctionRoundAge>();
factory.registerFunction("round", FunctionRound::create, FunctionFactory::CaseInsensitive);
factory.registerFunction("floor", FunctionFloor::create, FunctionFactory::CaseInsensitive);
factory.registerFunction("ceil", FunctionCeil::create, FunctionFactory::CaseInsensitive);
factory.registerFunction("trunc", FunctionTrunc::create, FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionRound>("round", FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionFloor>("floor", FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionCeil>("ceil", FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionTrunc>("trunc", FunctionFactory::CaseInsensitive);
/// Compatibility aliases.
factory.registerFunction("ceiling", FunctionCeil::create, FunctionFactory::CaseInsensitive);
factory.registerFunction("truncate", FunctionTrunc::create, FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionCeil>("ceiling", FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionTrunc>("truncate", FunctionFactory::CaseInsensitive);
}
}

View File

@ -118,8 +118,7 @@ public:
return {1};
}
void getReturnTypeAndPrerequisitesImpl(
const ColumnsWithTypeAndName & arguments, DataTypePtr & out_return_type, ExpressionActions::Actions & /*out_prerequisites*/) override
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
size_t count_arrays = 0;
@ -135,10 +134,12 @@ public:
throw Exception("First argument for function " + getName() + " must be tuple or array of tuple.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
size_t index = getElementNum(arguments[1].column, *tuple);
out_return_type = tuple->getElements()[index];
DataTypePtr out_return_type = tuple->getElements()[index];
for (; count_arrays; --count_arrays)
out_return_type = std::make_shared<DataTypeArray>(out_return_type);
return out_return_type;
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override
@ -174,7 +175,7 @@ public:
}
private:
size_t getElementNum(const ColumnPtr & index_column, const DataTypeTuple & tuple)
size_t getElementNum(const ColumnPtr & index_column, const DataTypeTuple & tuple) const
{
if (auto index_col = checkAndGetColumnConst<ColumnUInt8>(index_column.get()))
{

View File

@ -15,7 +15,7 @@ namespace DB
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
namespace
@ -86,15 +86,15 @@ ColumnPtr wrapInNullable(const ColumnPtr & src, Block & block, const ColumnNumbe
}
struct NullPresense
struct NullPresence
{
bool has_nullable = false;
bool has_null_constant = false;
};
NullPresense getNullPresense(const Block & block, const ColumnNumbers & args)
NullPresence getNullPresense(const Block & block, const ColumnNumbers & args)
{
NullPresense res;
NullPresence res;
for (const auto & arg : args)
{
@ -109,9 +109,9 @@ NullPresense getNullPresense(const Block & block, const ColumnNumbers & args)
return res;
}
NullPresense getNullPresense(const ColumnsWithTypeAndName & args)
NullPresence getNullPresense(const ColumnsWithTypeAndName & args)
{
NullPresense res;
NullPresence res;
for (const auto & elem : args)
{
@ -124,43 +124,6 @@ NullPresense getNullPresense(const ColumnsWithTypeAndName & args)
return res;
}
NullPresense getNullPresense(const DataTypes & types)
{
NullPresense res;
for (const auto & type : types)
{
if (!res.has_nullable)
res.has_nullable = type->isNullable();
if (!res.has_null_constant)
res.has_null_constant = type->onlyNull();
}
return res;
}
/// Turn the specified set of data types into their respective nested data types.
DataTypes toNestedDataTypes(const DataTypes & args)
{
DataTypes new_args;
new_args.reserve(args.size());
for (const auto & arg : args)
{
if (arg->isNullable())
{
auto nullable_type = static_cast<const DataTypeNullable *>(arg.get());
const DataTypePtr & nested_type = nullable_type->getNestedType();
new_args.push_back(nested_type);
}
else
new_args.push_back(arg);
}
return new_args;
}
bool allArgumentsAreConstants(const Block & block, const ColumnNumbers & args)
{
for (auto arg : args)
@ -168,14 +131,14 @@ bool allArgumentsAreConstants(const Block & block, const ColumnNumbers & args)
return false;
return true;
}
}
bool defaultImplementationForConstantArguments(
IFunction & func, Block & block, const ColumnNumbers & args, size_t result)
bool PreparedFunctionImpl::defaultImplementationForConstantArguments(Block & block, const ColumnNumbers & args, size_t result)
{
if (args.empty() || !func.useDefaultImplementationForConstants() || !allArgumentsAreConstants(block, args))
if (args.empty() || !useDefaultImplementationForConstants() || !allArgumentsAreConstants(block, args))
return false;
ColumnNumbers arguments_to_remain_constants = func.getArgumentsThatAreAlwaysConstant();
ColumnNumbers arguments_to_remain_constants = getArgumentsThatAreAlwaysConstant();
Block temporary_block;
bool have_converted_columns = false;
@ -198,8 +161,8 @@ bool defaultImplementationForConstantArguments(
* not in "arguments_to_remain_constants" set. Otherwise we get infinite recursion.
*/
if (!have_converted_columns)
throw Exception("Number of arguments for function " + func.getName() + " doesn't match: the function requires more arguments",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
throw Exception("Number of arguments for function " + getName() + " doesn't match: the function requires more arguments",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
temporary_block.insert(block.getByPosition(result));
@ -207,31 +170,30 @@ bool defaultImplementationForConstantArguments(
for (size_t i = 0; i < arguments_size; ++i)
temporary_argument_numbers[i] = i;
func.execute(temporary_block, temporary_argument_numbers, arguments_size);
execute(temporary_block, temporary_argument_numbers, arguments_size);
block.getByPosition(result).column = ColumnConst::create(temporary_block.getByPosition(arguments_size).column, block.rows());
return true;
}
bool defaultImplementationForNulls(
IFunction & func, Block & block, const ColumnNumbers & args, size_t result)
bool PreparedFunctionImpl::defaultImplementationForNulls(Block & block, const ColumnNumbers & args, size_t result)
{
if (args.empty() || !func.useDefaultImplementationForNulls())
if (args.empty() || !useDefaultImplementationForNulls())
return false;
NullPresense null_presense = getNullPresense(block, args);
NullPresence null_presence = getNullPresense(block, args);
if (null_presense.has_null_constant)
if (null_presence.has_null_constant)
{
block.getByPosition(result).column = block.getByPosition(result).type->createColumnConst(block.rows(), Null());
return true;
}
if (null_presense.has_nullable)
if (null_presence.has_nullable)
{
Block temporary_block = createBlockWithNestedColumns(block, args, result);
func.execute(temporary_block, args, result);
execute(temporary_block, args, result);
block.getByPosition(result).column = wrapInNullable(temporary_block.getByPosition(result).column, block, args, result);
return true;
}
@ -239,10 +201,18 @@ bool defaultImplementationForNulls(
return false;
}
void PreparedFunctionImpl::execute(Block & block, const ColumnNumbers & args, size_t result)
{
if (defaultImplementationForConstantArguments(block, args, result))
return;
if (defaultImplementationForNulls(block, args, result))
return;
executeImpl(block, args, result);
}
void IFunction::checkNumberOfArguments(size_t number_of_arguments) const
void FunctionBuilderImpl::checkNumberOfArguments(size_t number_of_arguments) const
{
if (isVariadic())
return;
@ -251,90 +221,31 @@ void IFunction::checkNumberOfArguments(size_t number_of_arguments) const
if (number_of_arguments != expected_number_of_arguments)
throw Exception("Number of arguments for function " + getName() + " doesn't match: passed "
+ toString(number_of_arguments) + ", should be " + toString(expected_number_of_arguments),
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
+ toString(number_of_arguments) + ", should be " + toString(expected_number_of_arguments),
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
}
DataTypePtr IFunction::getReturnType(const DataTypes & arguments) const
DataTypePtr FunctionBuilderImpl::getReturnType(const ColumnsWithTypeAndName & arguments) const
{
checkNumberOfArguments(arguments.size());
if (!arguments.empty() && useDefaultImplementationForNulls())
{
NullPresense null_presense = getNullPresense(arguments);
NullPresence null_presense = getNullPresense(arguments);
if (null_presense.has_null_constant)
{
return makeNullable(std::make_shared<DataTypeNothing>());
}
if (null_presense.has_nullable)
{
return makeNullable(getReturnTypeImpl(toNestedDataTypes(arguments)));
Block nested_block = createBlockWithNestedColumns(Block(arguments), ext::collection_cast<ColumnNumbers>(ext::range(0, arguments.size())));
auto return_type = getReturnTypeImpl(ColumnsWithTypeAndName(nested_block.begin(), nested_block.end()));
return makeNullable(return_type);
}
}
return getReturnTypeImpl(arguments);
}
void IFunction::getReturnTypeAndPrerequisites(
const ColumnsWithTypeAndName & arguments,
DataTypePtr & out_return_type,
std::vector<ExpressionAction> & out_prerequisites)
{
checkNumberOfArguments(arguments.size());
if (!arguments.empty() && useDefaultImplementationForNulls())
{
NullPresense null_presense = getNullPresense(arguments);
if (null_presense.has_null_constant)
{
out_return_type = makeNullable(std::make_shared<DataTypeNothing>());
return;
}
if (null_presense.has_nullable)
{
Block nested_block = createBlockWithNestedColumns(Block(arguments), ext::collection_cast<ColumnNumbers>(ext::range(0, arguments.size())));
getReturnTypeAndPrerequisitesImpl(ColumnsWithTypeAndName(nested_block.begin(), nested_block.end()), out_return_type, out_prerequisites);
out_return_type = makeNullable(out_return_type);
return;
}
}
getReturnTypeAndPrerequisitesImpl(arguments, out_return_type, out_prerequisites);
}
void IFunction::getLambdaArgumentTypes(DataTypes & arguments) const
{
checkNumberOfArguments(arguments.size());
getLambdaArgumentTypesImpl(arguments);
}
void IFunction::execute(Block & block, const ColumnNumbers & args, size_t result)
{
if (defaultImplementationForConstantArguments(*this, block, args, result))
return;
if (defaultImplementationForNulls(*this, block, args, result))
return;
executeImpl(block, args, result);
}
void IFunction::execute(Block & block, const ColumnNumbers & args, const ColumnNumbers & prerequisites, size_t result)
{
if (!prerequisites.empty())
{
executeImpl(block, args, prerequisites, result);
return;
}
execute(block, args, result);
}
}

View File

@ -16,45 +16,79 @@ namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NOT_IMPLEMENTED;
extern const int LOGICAL_ERROR;
}
struct ExpressionAction;
/** Interface for normal functions.
* Normal functions are functions that do not change the number of rows in the table,
* and the result of which for each row does not depend on other rows.
*
* A function can take an arbitrary number of arguments; returns exactly one value.
* The type of the result depends on the type and number of arguments.
*
* The function is dispatched for the whole block. This allows you to perform all kinds of checks rarely,
* and do the main job as an efficient loop.
*
* The function is applied to one or more columns of the block, and writes its result,
* adding a new column to the block. The function does not modify its arguments.
*/
class IFunction
/// The simplest executable object.
/// Motivation:
/// * Prepare something heavy once before main execution loop instead of doing it for each block.
/// * Provide const interface for IFunctionBase (later).
class IPreparedFunction
{
public:
/** The successor of IFunction must implement:
* - getName
* - either getReturnType, or getReturnTypeAndPrerequisites
* - one of the overloads of `execute`.
*/
virtual ~IPreparedFunction() = default;
/// Get the main function name.
virtual String getName() const = 0;
/// Override and return true if function could take different number of arguments.
virtual bool isVariadic() const { return false; }
virtual void execute(Block & block, const ColumnNumbers & arguments, size_t result) = 0;
};
/// For non-variadic functions, return number of arguments; otherwise return zero (that should be ignored).
virtual size_t getNumberOfArguments() const = 0;
using PreparedFunctionPtr = std::shared_ptr<IPreparedFunction>;
/// Throw if number of arguments is incorrect. Default implementation will check only in non-variadic case.
/// It is called inside getReturnType.
virtual void checkNumberOfArguments(size_t number_of_arguments) const;
class PreparedFunctionImpl : public IPreparedFunction
{
public:
void execute(Block & block, const ColumnNumbers & arguments, size_t result) final;
protected:
virtual void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) = 0;
/** Default implementation in presence of Nullable arguments or NULL constants as arguments is the following:
* if some of arguments are NULL constants then return NULL constant,
* if some of arguments are Nullable, then execute function as usual for block,
* where Nullable columns are substituted with nested columns (they have arbitrary values in rows corresponding to NULL value)
* and wrap result in Nullable column where NULLs are in all rows where any of arguments are NULL.
*/
virtual bool useDefaultImplementationForNulls() const { return true; }
/** If the function have non-zero number of arguments,
* and if all arguments are constant, that we could automatically provide default implementation:
* arguments are converted to ordinary columns with single value, then function is executed as usual,
* and then the result is converted to constant column.
*/
virtual bool useDefaultImplementationForConstants() const { return false; }
/** Some arguments could remain constant during this implementation.
*/
virtual ColumnNumbers getArgumentsThatAreAlwaysConstant() const { return {}; }
private:
bool defaultImplementationForNulls(Block & block, const ColumnNumbers & args, size_t result);
bool defaultImplementationForConstantArguments(Block & block, const ColumnNumbers & args, size_t result);
};
/// Function with known arguments and return type.
class IFunctionBase
{
public:
virtual ~IFunctionBase() = default;
/// Get the main function name.
virtual String getName() const = 0;
virtual const DataTypes & getArgumentTypes() const = 0;
virtual const DataTypePtr & getReturnType() const = 0;
/// Do preparations and return executable.
/// sample_block should contain data types of arguments and values of constants, if relevant.
virtual PreparedFunctionPtr prepare(const Block & sample_block) const = 0;
/// TODO: make const
virtual void execute(Block & block, const ColumnNumbers & arguments, size_t result)
{
return prepare(block)->execute(block, arguments, result);
}
/** Should we evaluate this function while constant folding, if arguments are constants?
* Usually this is true. Notable counterexample is function 'sleep'.
@ -94,85 +128,6 @@ public:
*/
virtual bool isDeterministicInScopeOfQuery() { return true; }
/// Get the result type by argument type. If the function does not apply to these arguments, throw an exception.
/// Overloading for those who do not need prerequisites and values of constant arguments. Not called from outside.
DataTypePtr getReturnType(const DataTypes & arguments) const;
virtual DataTypePtr getReturnTypeImpl(const DataTypes & /*arguments*/) const
{
throw Exception("getReturnType is not implemented for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
/** Get the result type by argument types and constant argument values.
* If the function does not apply to these arguments, throw an exception.
* You can also return a description of the additional columns that are required to perform the function.
* For non-constant columns `arguments[i].column = nullptr`.
* Meaningful element types in out_prerequisites: APPLY_FUNCTION, ADD_COLUMN.
*/
void getReturnTypeAndPrerequisites(
const ColumnsWithTypeAndName & arguments,
DataTypePtr & out_return_type,
std::vector<ExpressionAction> & out_prerequisites);
virtual void getReturnTypeAndPrerequisitesImpl(
const ColumnsWithTypeAndName & arguments,
DataTypePtr & out_return_type,
std::vector<ExpressionAction> & /*out_prerequisites*/)
{
DataTypes types(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i)
types[i] = arguments[i].type;
out_return_type = getReturnTypeImpl(types);
}
/// For higher-order functions (functions, that have lambda expression as at least one argument).
/// You pass data types with empty DataTypeExpression for lambda arguments.
/// This function will replace it with DataTypeExpression containing actual types.
void getLambdaArgumentTypes(DataTypes & arguments) const;
virtual void getLambdaArgumentTypesImpl(DataTypes & /*arguments*/) const
{
throw Exception("Function " + getName() + " can't have lambda-expressions as arguments", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
/// Execute the function on the block. Note: can be called simultaneously from several threads, for one object.
/// Overloading for those who do not need `prerequisites`. Not called from outside.
void execute(Block & block, const ColumnNumbers & arguments, size_t result);
/// Execute the function above the block. Note: can be called simultaneously from several threads, for one object.
/// `prerequisites` go in the same order as `out_prerequisites` obtained from getReturnTypeAndPrerequisites.
void execute(Block & block, const ColumnNumbers & arguments, const ColumnNumbers & prerequisites, size_t result);
virtual void executeImpl(Block & /*block*/, const ColumnNumbers & /*arguments*/, size_t /*result*/)
{
throw Exception("executeImpl is not implemented for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
virtual void executeImpl(Block & block, const ColumnNumbers & arguments, const ColumnNumbers & /*prerequisites*/, size_t result)
{
executeImpl(block, arguments, result);
}
/** Default implementation in presense of Nullable arguments or NULL constants as arguments is the following:
* if some of arguments are NULL constants then return NULL constant,
* if some of arguments are Nullable, then execute function as usual for block,
* where Nullable columns are substituted with nested columns (they have arbitary values in rows corresponding to NULL value)
* and wrap result in Nullable column where NULLs are in all rows where any of arguments are NULL.
*/
virtual bool useDefaultImplementationForNulls() const { return true; }
/** If the function have non-zero number of arguments,
* and if all arguments are constant, that we could automatically provide default implementation:
* arguments are converted to ordinary columns with single value, then function is executed as usual,
* and then the result is converted to constant column.
*/
virtual bool useDefaultImplementationForConstants() const { return false; }
/** Some arguments could remain constant during this implementation.
*/
virtual ColumnNumbers getArgumentsThatAreAlwaysConstant() const { return {}; }
/** Lets you know if the function is monotonic in a range of values.
* This is used to work with the index in a sorted chunk of data.
* And allows to use the index not only when it is written, for example `date >= const`, but also, for example, `toMonth(date) >= 11`.
@ -188,7 +143,7 @@ public:
bool is_always_monotonic = false; /// Is true if function is monotonic on the whole input range I
Monotonicity(bool is_monotonic_ = false, bool is_positive_ = true, bool is_always_monotonic_ = false)
: is_monotonic(is_monotonic_), is_positive(is_positive_), is_always_monotonic(is_always_monotonic_) {}
: is_monotonic(is_monotonic_), is_positive(is_positive_), is_always_monotonic(is_always_monotonic_) {}
};
/** Get information about monotonicity on a range of values. Call only if hasInformationAboutMonotonicity.
@ -198,12 +153,221 @@ public:
{
throw Exception("Function " + getName() + " has no information about its monotonicity.", ErrorCodes::NOT_IMPLEMENTED);
}
virtual ~IFunction() {}
};
using FunctionBasePtr = std::shared_ptr<IFunctionBase>;
/// Creates IFunctionBase from argument types list.
class IFunctionBuilder
{
public:
virtual ~IFunctionBuilder() = default;
/// Get the main function name.
virtual String getName() const = 0;
/// Override and return true if function could take different number of arguments.
virtual bool isVariadic() const { return false; }
/// For non-variadic functions, return number of arguments; otherwise return zero (that should be ignored).
virtual size_t getNumberOfArguments() const = 0;
/// Throw if number of arguments is incorrect. Default implementation will check only in non-variadic case.
virtual void checkNumberOfArguments(size_t number_of_arguments) const = 0;
/// Check arguments and return IFunctionBase.
virtual FunctionBasePtr build(const ColumnsWithTypeAndName & arguments) const = 0;
/// For higher-order functions (functions, that have lambda expression as at least one argument).
/// You pass data types with empty DataTypeFunction for lambda arguments.
/// This function will replace it with DataTypeFunction containing actual types.
virtual void getLambdaArgumentTypes(DataTypes & arguments) const = 0;
};
using FunctionBuilderPtr = std::shared_ptr<IFunctionBuilder>;
class FunctionBuilderImpl : public IFunctionBuilder
{
public:
FunctionBasePtr build(const ColumnsWithTypeAndName & arguments) const final
{
return buildImpl(arguments, getReturnType(arguments));
}
/// Default implementation. Will check only in non-variadic case.
void checkNumberOfArguments(size_t number_of_arguments) const override;
DataTypePtr getReturnType(const ColumnsWithTypeAndName & arguments) const;
void getLambdaArgumentTypes(DataTypes & arguments) const override
{
checkNumberOfArguments(arguments.size());
getLambdaArgumentTypesImpl(arguments);
}
protected:
/// Get the result type by argument type. If the function does not apply to these arguments, throw an exception.
virtual DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const
{
DataTypes data_types(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i)
data_types[i] = arguments[i].type;
return getReturnTypeImpl(data_types);
}
virtual DataTypePtr getReturnTypeImpl(const DataTypes & /*arguments*/) const
{
throw Exception("getReturnType is not implemented for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
/** If useDefaultImplementationForNulls() is true, than change arguments for getReturnType() and buildImpl():
* if some of arguments are Nullable(Nothing) then don't call getReturnType(), call buildImpl() with return_type = Nullable(Nothing),
* if some of arguments are Nullable, then:
* - Nullable types are substituted with nested types for getReturnType() function
* - wrap getReturnType() result in Nullable type and pass to buildImpl
*
* Otherwise build returns buildImpl(arguments, getReturnType(arguments));
*/
virtual bool useDefaultImplementationForNulls() const { return true; }
virtual FunctionBasePtr buildImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type) const = 0;
virtual void getLambdaArgumentTypesImpl(DataTypes & /*arguments*/) const
{
throw Exception("Function " + getName() + " can't have lambda-expressions as arguments", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
};
/// Previous function interface.
class IFunction : public std::enable_shared_from_this<IFunction>,
public FunctionBuilderImpl, public IFunctionBase, public PreparedFunctionImpl
{
public:
String getName() const override = 0;
/// TODO: make const
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override = 0;
/// Override this functions to change default implementation behavior. See details in IMyFunction.
bool useDefaultImplementationForNulls() const override { return true; }
bool useDefaultImplementationForConstants() const override { return false; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {}; }
using PreparedFunctionImpl::execute;
using FunctionBuilderImpl::getReturnTypeImpl;
using FunctionBuilderImpl::getLambdaArgumentTypesImpl;
using FunctionBuilderImpl::getReturnType;
PreparedFunctionPtr prepare(const Block & /*sample_block*/) const final
{
throw Exception("prepare is not implemented for IFunction", ErrorCodes::NOT_IMPLEMENTED);
}
const DataTypes & getArgumentTypes() const final
{
throw Exception("getArgumentTypes is not implemented for IFunction", ErrorCodes::NOT_IMPLEMENTED);
}
const DataTypePtr & getReturnType() const override
{
throw Exception("getReturnType is not implemented for IFunction", ErrorCodes::NOT_IMPLEMENTED);
}
protected:
FunctionBasePtr buildImpl(const ColumnsWithTypeAndName & /*arguments*/, const DataTypePtr & /*return_type*/) const final
{
throw Exception("buildImpl is not implemented for IFunction", ErrorCodes::NOT_IMPLEMENTED);
}
};
/// Wrappers over IFunction.
class DefaultExecutable final : public PreparedFunctionImpl
{
public:
explicit DefaultExecutable(std::shared_ptr<IFunction> function) : function(std::move(function)) {}
String getName() const override { return function->getName(); }
protected:
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) final
{
return function->executeImpl(block, arguments, result);
}
bool useDefaultImplementationForNulls() const final { return function->useDefaultImplementationForNulls(); }
bool useDefaultImplementationForConstants() const final { return function->useDefaultImplementationForConstants(); }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const final { return function->getArgumentsThatAreAlwaysConstant(); }
private:
std::shared_ptr<IFunction> function;
};
class DefaultFunction final : public IFunctionBase
{
public:
DefaultFunction(std::shared_ptr<IFunction> function, DataTypes arguments, DataTypePtr return_type)
: function(std::move(function)), arguments(std::move(arguments)), return_type(std::move(return_type)) {}
String getName() const override { return function->getName(); }
const DataTypes & getArgumentTypes() const override { return arguments; }
const DataTypePtr & getReturnType() const override { return return_type; }
PreparedFunctionPtr prepare(const Block & /*sample_block*/) const override { return std::make_shared<DefaultExecutable>(function); }
bool isSuitableForConstantFolding() const override { return function->isSuitableForConstantFolding(); }
bool isInjective(const Block & sample_block) override { return function->isInjective(sample_block); }
bool isDeterministicInScopeOfQuery() override { return function->isDeterministicInScopeOfQuery(); }
bool hasInformationAboutMonotonicity() const override { return function->hasInformationAboutMonotonicity(); }
IFunctionBase::Monotonicity getMonotonicityForRange(const IDataType & type, const Field & left, const Field & right) const override
{
return function->getMonotonicityForRange(type, left, right);
}
private:
std::shared_ptr<IFunction> function;
DataTypes arguments;
DataTypePtr return_type;
};
class DefaultFunctionBuilder : public FunctionBuilderImpl
{
public:
explicit DefaultFunctionBuilder(std::shared_ptr<IFunction> function) : function(std::move(function)) {}
void checkNumberOfArguments(size_t number_of_arguments) const override
{
return function->checkNumberOfArguments(number_of_arguments);
}
String getName() const override { return function->getName(); };
bool isVariadic() const override { return function->isVariadic(); }
size_t getNumberOfArguments() const override { return function->getNumberOfArguments(); }
protected:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { return function->getReturnTypeImpl(arguments); }
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override { return function->getReturnTypeImpl(arguments); }
bool useDefaultImplementationForNulls() const override { return function->useDefaultImplementationForNulls(); }
FunctionBasePtr buildImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type) const override
{
DataTypes data_types(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i)
data_types[i] = arguments[i].type;
return std::make_shared<DefaultFunction>(function, data_types, return_type);
}
void getLambdaArgumentTypesImpl(DataTypes & arguments) const override { return function->getLambdaArgumentTypesImpl(arguments); }
private:
std::shared_ptr<IFunction> function;
};
using FunctionPtr = std::shared_ptr<IFunction>;
}

View File

@ -36,7 +36,6 @@ Names ExpressionAction::getNeededColumns() const
{
Names res = argument_names;
res.insert(res.end(), prerequisite_names.begin(), prerequisite_names.end());
res.insert(res.end(), array_joined_columns.begin(), array_joined_columns.end());
for (const auto & column : projection)
@ -49,7 +48,7 @@ Names ExpressionAction::getNeededColumns() const
}
ExpressionAction ExpressionAction::applyFunction(const FunctionPtr & function_,
ExpressionAction ExpressionAction::applyFunction(const FunctionBuilderPtr & function_,
const std::vector<std::string> & argument_names_,
std::string result_name_)
{
@ -68,7 +67,7 @@ ExpressionAction ExpressionAction::applyFunction(const FunctionPtr & function_,
ExpressionAction a;
a.type = APPLY_FUNCTION;
a.result_name = result_name_;
a.function = function_;
a.function_builder = function_;
a.argument_names = argument_names_;
return a;
}
@ -128,7 +127,7 @@ ExpressionAction ExpressionAction::arrayJoin(const NameSet & array_joined_column
a.array_join_is_left = array_join_is_left;
if (array_join_is_left)
a.function = FunctionFactory::instance().get("emptyArrayToSingle", context);
a.function_builder = FunctionFactory::instance().get("emptyArrayToSingle", context);
return a;
}
@ -160,13 +159,8 @@ ExpressionActions::Actions ExpressionAction::getPrerequisites(Block & sample_blo
arguments[i] = sample_block.getByName(argument_names[i]);
}
function->getReturnTypeAndPrerequisites(arguments, result_type, res);
for (size_t i = 0; i < res.size(); ++i)
{
if (res[i].result_name != "")
prerequisite_names.push_back(res[i].result_name);
}
function = function_builder->build(arguments);
result_type = function->getReturnType();
}
return res;
@ -201,15 +195,6 @@ void ExpressionAction::prepare(Block & sample_block)
all_const = false;
}
ColumnNumbers prerequisites(prerequisite_names.size());
for (size_t i = 0; i < prerequisite_names.size(); ++i)
{
prerequisites[i] = sample_block.getPositionByName(prerequisite_names[i]);
ColumnPtr col = sample_block.safeGetByPosition(prerequisites[i]).column;
if (!col || !col->isColumnConst())
all_const = false;
}
ColumnPtr new_column;
/// If all arguments are constants, and function is suitable to be executed in 'prepare' stage - execute function.
@ -222,7 +207,7 @@ void ExpressionAction::prepare(Block & sample_block)
new_column.type = result_type;
sample_block.insert(std::move(new_column));
function->execute(sample_block, arguments, prerequisites, result_position);
function->execute(sample_block, arguments, result_position);
/// If the result is not a constant, just in case, we will consider the result as unknown.
ColumnWithTypeAndName & col = sample_block.safeGetByPosition(result_position);
@ -343,19 +328,11 @@ void ExpressionAction::execute(Block & block) const
arguments[i] = block.getPositionByName(argument_names[i]);
}
ColumnNumbers prerequisites(prerequisite_names.size());
for (size_t i = 0; i < prerequisite_names.size(); ++i)
{
if (!block.has(prerequisite_names[i]))
throw Exception("Not found column: '" + prerequisite_names[i] + "'", ErrorCodes::NOT_FOUND_COLUMN_IN_BLOCK);
prerequisites[i] = block.getPositionByName(prerequisite_names[i]);
}
size_t num_columns_without_result = block.columns();
block.insert({ nullptr, result_type, result_name});
ProfileEvents::increment(ProfileEvents::FunctionExecute);
function->execute(block, arguments, prerequisites, num_columns_without_result);
function->execute(block, arguments, num_columns_without_result);
break;
}
@ -383,7 +360,7 @@ void ExpressionAction::execute(Block & block) const
Block tmp_block{src_col, {{}, src_col.type, {}}};
function->execute(tmp_block, {0}, 1);
function_builder->build({src_col})->execute(tmp_block, {0}, 1);
non_empty_array_columns[name] = tmp_block.safeGetByPosition(1).column;
}
@ -837,6 +814,7 @@ void ExpressionActions::finalize(const Names & output_columns)
action.type = ExpressionAction::ADD_COLUMN;
action.result_type = result.type;
action.added_column = result.column;
action.function_builder = nullptr;
action.function = nullptr;
action.argument_names.clear();
in.clear();
@ -889,9 +867,6 @@ void ExpressionActions::finalize(const Names & output_columns)
for (const auto & name : action.argument_names)
++columns_refcount[name];
for (const auto & name : action.prerequisite_names)
++columns_refcount[name];
for (const auto & name_alias : action.projection)
++columns_refcount[name_alias.first];
}
@ -920,9 +895,6 @@ void ExpressionActions::finalize(const Names & output_columns)
for (const auto & name : action.argument_names)
process(name);
for (const auto & name : action.prerequisite_names)
process(name);
/// For `projection`, there is no reduction in `refcount`, because the `project` action replaces the names of the columns, in effect, already deleting them under the old names.
}

View File

@ -22,8 +22,11 @@ using NamesWithAliases = std::vector<NameWithAlias>;
class Join;
class IFunction;
using FunctionPtr = std::shared_ptr<IFunction>;
class IFunctionBase;
using FunctionBasePtr = std::shared_ptr<IFunctionBase>;
class IFunctionBuilder;
using FunctionBuilderPtr = std::shared_ptr<IFunctionBuilder>;
class IDataType;
using DataTypePtr = std::shared_ptr<const IDataType>;
@ -68,9 +71,9 @@ public:
ColumnPtr added_column;
/// For APPLY_FUNCTION and LEFT ARRAY JOIN.
mutable FunctionPtr function; /// mutable - to allow execute.
FunctionBuilderPtr function_builder;
FunctionBasePtr function;
Names argument_names;
Names prerequisite_names;
/// For ARRAY_JOIN
NameSet array_joined_columns;
@ -85,7 +88,7 @@ public:
/// If result_name_ == "", as name "function_name(arguments separated by commas) is used".
static ExpressionAction applyFunction(
const FunctionPtr & function_, const std::vector<std::string> & argument_names_, std::string result_name_ = "");
const FunctionBuilderPtr & function_, const std::vector<std::string> & argument_names_, std::string result_name_ = "");
static ExpressionAction addColumn(const ColumnWithTypeAndName & added_column_);
static ExpressionAction removeColumn(const std::string & removed_name);

View File

@ -18,12 +18,10 @@
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeExpression.h>
#include <DataTypes/NestedUtils.h>
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnSet.h>
#include <Columns/ColumnExpression.h>
#include <Columns/ColumnConst.h>
#include <Interpreters/InterpreterSelectQuery.h>
@ -58,6 +56,8 @@
#include <ext/range.h>
#include <DataTypes/DataTypeFactory.h>
#include <DataTypes/DataTypeFunction.h>
#include <Functions/FunctionsMiscellaneous.h>
namespace DB
@ -1867,11 +1867,9 @@ struct ExpressionAnalyzer::ScopeStack
throw Exception("Unknown identifier: " + name, ErrorCodes::UNKNOWN_IDENTIFIER);
}
void addAction(const ExpressionAction & action, const Names & additional_required_columns = Names())
void addAction(const ExpressionAction & action)
{
size_t level = 0;
for (size_t i = 0; i < additional_required_columns.size(); ++i)
level = std::max(level, getColumnLevel(additional_required_columns[i]));
Names required = action.getNeededColumns();
for (size_t i = 0; i < required.size(); ++i)
level = std::max(level, getColumnLevel(required[i]));
@ -2104,7 +2102,7 @@ void ExpressionAnalyzer::getActionsImpl(const ASTPtr & ast, bool no_subqueries,
return;
}
const FunctionPtr & function = FunctionFactory::instance().get(node->name, context);
const FunctionBuilderPtr & function_builder = FunctionFactory::instance().get(node->name, context);
Names argument_names;
DataTypes argument_types;
@ -2128,7 +2126,7 @@ void ExpressionAnalyzer::getActionsImpl(const ASTPtr & ast, bool no_subqueries,
throw Exception("First argument of lambda must be a tuple", ErrorCodes::TYPE_MISMATCH);
has_lambda_arguments = true;
argument_types.emplace_back(std::make_shared<DataTypeExpression>(DataTypes(lambda_args_tuple->arguments->children.size())));
argument_types.emplace_back(std::make_shared<DataTypeFunction>(DataTypes(lambda_args_tuple->arguments->children.size())));
/// Select the name in the next cycle.
argument_names.emplace_back();
}
@ -2183,11 +2181,9 @@ void ExpressionAnalyzer::getActionsImpl(const ASTPtr & ast, bool no_subqueries,
if (only_consts && !arguments_present)
return;
Names additional_requirements;
if (has_lambda_arguments && !only_consts)
{
function->getLambdaArgumentTypes(argument_types);
function_builder->getLambdaArgumentTypes(argument_types);
/// Call recursively for lambda expressions.
for (size_t i = 0; i < node->arguments->children.size(); ++i)
@ -2197,7 +2193,7 @@ void ExpressionAnalyzer::getActionsImpl(const ASTPtr & ast, bool no_subqueries,
ASTFunction * lambda = typeid_cast<ASTFunction *>(child.get());
if (lambda && lambda->name == "lambda")
{
const DataTypeExpression * lambda_type = typeid_cast<const DataTypeExpression *>(argument_types[i].get());
const DataTypeFunction * lambda_type = typeid_cast<const DataTypeFunction *>(argument_types[i].get());
ASTFunction * lambda_args_tuple = typeid_cast<ASTFunction *>(lambda->arguments->children.at(0).get());
ASTs lambda_arg_asts = lambda_args_tuple->arguments->children;
NamesAndTypesList lambda_arguments;
@ -2220,22 +2216,23 @@ void ExpressionAnalyzer::getActionsImpl(const ASTPtr & ast, bool no_subqueries,
String result_name = lambda->arguments->children.at(1)->getColumnName();
lambda_actions->finalize(Names(1, result_name));
DataTypePtr result_type = lambda_actions->getSampleBlock().getByName(result_name).type;
argument_types[i] = std::make_shared<DataTypeExpression>(lambda_type->getArgumentTypes(), result_type);
Names captured = lambda_actions->getRequiredColumns();
for (size_t j = 0; j < captured.size(); ++j)
if (findColumn(captured[j], lambda_arguments) == lambda_arguments.end())
additional_requirements.push_back(captured[j]);
Names captured;
Names required = lambda_actions->getRequiredColumns();
for (size_t j = 0; j < required.size(); ++j)
if (findColumn(required[j], lambda_arguments) == lambda_arguments.end())
captured.push_back(required[j]);
/// We can not name `getColumnName()`,
/// because it does not uniquely define the expression (the types of arguments can be different).
argument_names[i] = getUniqueName(actions_stack.getSampleBlock(), "__lambda");
String lambda_name = getUniqueName(actions_stack.getSampleBlock(), "__lambda");
ColumnWithTypeAndName lambda_column;
lambda_column.column = ColumnExpression::create(1, lambda_actions, lambda_arguments, result_type, result_name);
lambda_column.type = argument_types[i];
lambda_column.name = argument_names[i];
actions_stack.addAction(ExpressionAction::addColumn(lambda_column));
auto function_capture = std::make_shared<FunctionCapture>(
lambda_actions, captured, lambda_arguments, result_type, result_name);
actions_stack.addAction(ExpressionAction::applyFunction(function_capture, captured, lambda_name));
argument_types[i] = std::make_shared<DataTypeFunction>(lambda_type->getArgumentTypes(), result_type);
argument_names[i] = lambda_name;
}
}
}
@ -2253,8 +2250,7 @@ void ExpressionAnalyzer::getActionsImpl(const ASTPtr & ast, bool no_subqueries,
}
if (arguments_present)
actions_stack.addAction(ExpressionAction::applyFunction(function, argument_names, node->getColumnName()),
additional_requirements);
actions_stack.addAction(ExpressionAction::applyFunction(function_builder, argument_names, node->getColumnName()));
}
}
else if (ASTLiteral * node = typeid_cast<ASTLiteral *>(ast.get()))

View File

@ -20,6 +20,9 @@ class FieldWithInfinity;
using SetElements = std::vector<std::vector<Field>>;
using SetElementsPtr = std::unique_ptr<SetElements>;
class IFunctionBase;
using FunctionBasePtr = std::shared_ptr<IFunctionBase>;
/** Data structure for implementation of IN expression.
*/
class Set
@ -175,7 +178,7 @@ public:
{
size_t tuple_index;
size_t pk_index;
std::vector<FunctionPtr> functions;
std::vector<FunctionBasePtr> functions;
DataTypePtr data_type;
};

View File

@ -24,16 +24,10 @@ ColumnPtr castColumn(const ColumnWithTypeAndName & arg, const DataTypePtr & type
}
};
FunctionPtr func_cast = FunctionFactory::instance().get("CAST", context);
FunctionBuilderPtr func_builder_cast = FunctionFactory::instance().get("CAST", context);
{
DataTypePtr unused_return_type;
ColumnsWithTypeAndName arguments{ temporary_block.getByPosition(0), temporary_block.getByPosition(1) };
std::vector<ExpressionAction> unused_prerequisites;
/// Prepares function to execution. TODO It is not obvious.
func_cast->getReturnTypeAndPrerequisites(arguments, unused_return_type, unused_prerequisites);
}
ColumnsWithTypeAndName arguments{ temporary_block.getByPosition(0), temporary_block.getByPosition(1) };
auto func_cast = func_builder_cast->build(arguments);
func_cast->execute(temporary_block, {0, 1}, 2);
return temporary_block.getByPosition(2).column;

View File

@ -950,7 +950,7 @@ void MergeTreeData::createConvertExpression(const DataPartPtr & part, const Name
out_expression->add(ExpressionAction::addColumn(
{ DataTypeString().createColumnConst(1, new_type_name), std::make_shared<DataTypeString>(), new_type_name_column }));
const FunctionPtr & function = FunctionFactory::instance().get("CAST", context);
const auto & function = FunctionFactory::instance().get("CAST", context);
out_expression->add(ExpressionAction::applyFunction(
function, Names{column.name, new_type_name_column}), out_names);

View File

@ -342,17 +342,15 @@ static bool getConstant(const ASTPtr & expr, Block & block_with_constants, Field
static void applyFunction(
FunctionPtr & func,
const FunctionBasePtr & func,
const DataTypePtr & arg_type, const Field & arg_value,
DataTypePtr & res_type, Field & res_value)
{
std::vector<ExpressionAction> unused_prerequisites;
ColumnsWithTypeAndName arguments{{ arg_type->createColumnConst(1, arg_value), arg_type, "x" }};
func->getReturnTypeAndPrerequisites(arguments, res_type, unused_prerequisites);
res_type = func->getReturnType();
Block block
{
arguments[0],
{ arg_type->createColumnConst(1, arg_value), arg_type, "x" },
{ nullptr, res_type, "y" }
};
@ -526,14 +524,14 @@ bool PKCondition::isPrimaryKeyPossiblyWrappedByMonotonicFunctions(
for (auto it = chain_not_tested_for_monotonicity.rbegin(); it != chain_not_tested_for_monotonicity.rend(); ++it)
{
FunctionPtr func = FunctionFactory::instance().tryGet((*it)->name, context);
auto func_builder = FunctionFactory::instance().tryGet((*it)->name, context);
ColumnsWithTypeAndName arguments{{ nullptr, primary_key_column_type, "" }};
auto func = func_builder->build(arguments);
if (!func || !func->hasInformationAboutMonotonicity())
return false;
std::vector<ExpressionAction> unused_prerequisites;
ColumnsWithTypeAndName arguments{{ nullptr, primary_key_column_type, "" }};
func->getReturnTypeAndPrerequisites(arguments, primary_key_column_type, unused_prerequisites);
primary_key_column_type = func->getReturnType();
out_functions_chain.push_back(func);
}

View File

@ -17,6 +17,9 @@
namespace DB
{
class IFunction;
using FunctionBasePtr = std::shared_ptr<IFunctionBase>;
/** Range with open or closed ends; possibly unbounded.
*/
struct Range
@ -296,7 +299,7 @@ public:
* If the primary key column is wrapped in functions that can be monotonous in some value ranges
* (for example: -toFloat64(toDayOfWeek(date))), then here the functions will be located: toDayOfWeek, toFloat64, negate.
*/
using MonotonicFunctionsChain = std::vector<FunctionPtr>;
using MonotonicFunctionsChain = std::vector<FunctionBasePtr>;
mutable MonotonicFunctionsChain monotonic_functions_chain; /// The function execution does not violate the constancy.
};