Merge pull request #1879 from yandex/lambdas-without-prerequisites

Lambdas without prerequisites
This commit is contained in:
alexey-milovidov 2018-02-09 22:22:58 +03:00 committed by GitHub
commit 8fb9967903
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 1192 additions and 695 deletions

View File

@ -18,13 +18,13 @@
#include <Common/FieldVisitors.h> #include <Common/FieldVisitors.h>
#include <DataTypes/FieldToDataType.h> #include <DataTypes/FieldToDataType.h>
#include <DataTypes/DataTypeTuple.h> #include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeExpression.h>
#include <DataTypes/DataTypesNumber.h> #include <DataTypes/DataTypesNumber.h>
#include <AggregateFunctions/parseAggregateFunctionParameters.h> #include <AggregateFunctions/parseAggregateFunctionParameters.h>
#include <AggregateFunctions/AggregateFunctionFactory.h> #include <AggregateFunctions/AggregateFunctionFactory.h>
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Functions/IFunction.h> #include <Functions/IFunction.h>
#include <Common/typeid_cast.h> #include <Common/typeid_cast.h>
#include <DataTypes/DataTypeFunction.h>
namespace DB namespace DB
@ -127,20 +127,21 @@ void processFunction(const String & column_name, ASTPtr & ast, TypeAndConstantIn
{ {
ASTFunction * function = static_cast<ASTFunction *>(ast.get()); ASTFunction * function = static_cast<ASTFunction *>(ast.get());
/// Special case for lambda functions. Lambda function has special return type "Expression". /// Special case for lambda functions. Lambda function has special return type "Function".
/// We first create info with Expression of unspecified arguments, and will specify them later. /// We first create info with Function of unspecified arguments, and will specify them later.
if (function->name == "lambda") if (function->name == "lambda")
{ {
size_t number_of_lambda_parameters = AnalyzeLambdas::extractLambdaParameters(function->arguments->children.at(0)).size(); size_t number_of_lambda_parameters = AnalyzeLambdas::extractLambdaParameters(function->arguments->children.at(0)).size();
TypeAndConstantInference::ExpressionInfo expression_info; TypeAndConstantInference::ExpressionInfo expression_info;
expression_info.node = ast; expression_info.node = ast;
expression_info.data_type = std::make_unique<DataTypeExpression>(DataTypes(number_of_lambda_parameters)); expression_info.data_type = std::make_unique<DataTypeFunction>(DataTypes(number_of_lambda_parameters));
info.emplace(column_name, std::move(expression_info)); info.emplace(column_name, std::move(expression_info));
return; return;
} }
DataTypes argument_types; DataTypes argument_types;
ColumnsWithTypeAndName argument_columns;
if (function->arguments) if (function->arguments)
{ {
@ -151,6 +152,9 @@ void processFunction(const String & column_name, ASTPtr & ast, TypeAndConstantIn
throw Exception("Logical error: type of function argument was not inferred during depth-first search", ErrorCodes::LOGICAL_ERROR); throw Exception("Logical error: type of function argument was not inferred during depth-first search", ErrorCodes::LOGICAL_ERROR);
argument_types.emplace_back(it->second.data_type); argument_types.emplace_back(it->second.data_type);
argument_columns.emplace_back(ColumnWithTypeAndName(nullptr, it->second.data_type, ""));
if (it->second.is_constant_expression)
argument_columns.back().column = it->second.data_type->createColumnConst(1, it->second.value);
} }
} }
@ -203,7 +207,7 @@ void processFunction(const String & column_name, ASTPtr & ast, TypeAndConstantIn
return; return;
} }
const FunctionPtr & function_ptr = FunctionFactory::instance().get(function->name, context); const auto & function_builder_ptr = FunctionFactory::instance().get(function->name, context);
/// (?) Replace function name to canonical one. Because same function could be referenced by different names. /// (?) Replace function name to canonical one. Because same function could be referenced by different names.
// function->name = function_ptr->getName(); // function->name = function_ptr->getName();
@ -228,10 +232,12 @@ void processFunction(const String & column_name, ASTPtr & ast, TypeAndConstantIn
} }
} }
auto function_ptr = function_builder_ptr->build(argument_columns);
TypeAndConstantInference::ExpressionInfo expression_info; TypeAndConstantInference::ExpressionInfo expression_info;
expression_info.node = ast; expression_info.node = ast;
expression_info.function = function_ptr; expression_info.function = function_ptr;
expression_info.data_type = function_ptr->getReturnType(argument_types); expression_info.data_type = function_ptr->getReturnType();
if (all_consts && function_ptr->isSuitableForConstantFolding()) if (all_consts && function_ptr->isSuitableForConstantFolding())
{ {
@ -325,7 +331,7 @@ void processHigherOrderFunction(
{ {
ASTFunction * function = static_cast<ASTFunction *>(ast.get()); ASTFunction * function = static_cast<ASTFunction *>(ast.get());
const FunctionPtr & function_ptr = FunctionFactory::instance().get(function->name, context); const auto & function_builder_ptr = FunctionFactory::instance().get(function->name, context);
if (!function->arguments) if (!function->arguments)
throw Exception("Unexpected AST for higher-order function", ErrorCodes::UNEXPECTED_AST_STRUCTURE); throw Exception("Unexpected AST for higher-order function", ErrorCodes::UNEXPECTED_AST_STRUCTURE);
@ -339,7 +345,7 @@ void processHigherOrderFunction(
types.emplace_back(child_info.data_type); types.emplace_back(child_info.data_type);
} }
function_ptr->getLambdaArgumentTypes(types); function_builder_ptr->getLambdaArgumentTypes(types);
/// For every lambda expression, dive into it. /// For every lambda expression, dive into it.
@ -353,11 +359,11 @@ void processHigherOrderFunction(
const ASTFunction * lambda = typeid_cast<const ASTFunction *>(child.get()); const ASTFunction * lambda = typeid_cast<const ASTFunction *>(child.get());
if (lambda && lambda->name == "lambda") if (lambda && lambda->name == "lambda")
{ {
const DataTypeExpression * lambda_type = typeid_cast<const DataTypeExpression *>(types[i].get()); const auto * lambda_type = typeid_cast<const DataTypeFunction *>(types[i].get());
if (!lambda_type) if (!lambda_type)
throw Exception("Logical error: IFunction::getLambdaArgumentTypes returned data type for lambda expression," throw Exception("Logical error: IFunction::getLambdaArgumentTypes returned data type for lambda expression,"
" that is not DataTypeExpression", ErrorCodes::LOGICAL_ERROR); " that is not DataTypeFunction", ErrorCodes::LOGICAL_ERROR);
if (!lambda->arguments || lambda->arguments->children.size() != 2) if (!lambda->arguments || lambda->arguments->children.size() != 2)
throw Exception("Lambda function must have exactly two arguments (sides of arrow)", ErrorCodes::BAD_LAMBDA); throw Exception("Lambda function must have exactly two arguments (sides of arrow)", ErrorCodes::BAD_LAMBDA);
@ -390,7 +396,7 @@ void processHigherOrderFunction(
/// Update Expression type (expression signature). /// Update Expression type (expression signature).
info.at(lambda->getColumnName()).data_type = std::make_shared<DataTypeExpression>( info.at(lambda->getColumnName()).data_type = std::make_shared<DataTypeFunction>(
lambda_argument_types, info.at(lambda->arguments->children[1]->getColumnName()).data_type); lambda_argument_types, info.at(lambda->arguments->children[1]->getColumnName()).data_type);
} }
} }

View File

@ -15,7 +15,7 @@ struct CollectAliases;
struct AnalyzeColumns; struct AnalyzeColumns;
struct AnalyzeLambdas; struct AnalyzeLambdas;
struct ExecuteTableFunctions; struct ExecuteTableFunctions;
class IFunction; class IFunctionBase;
class IAggregateFunction; class IAggregateFunction;
@ -46,7 +46,7 @@ struct TypeAndConstantInference
DataTypePtr data_type; DataTypePtr data_type;
bool is_constant_expression = false; bool is_constant_expression = false;
Field value; /// Has meaning if is_constant_expression == true. Field value; /// Has meaning if is_constant_expression == true.
std::shared_ptr<IFunction> function; std::shared_ptr<IFunctionBase> function;
std::shared_ptr<IAggregateFunction> aggregate_function; std::shared_ptr<IAggregateFunction> aggregate_function;
}; };

View File

@ -69,9 +69,7 @@ MutableColumns ColumnConst::scatter(ColumnIndex num_columns, const Selector & se
throw Exception("Size of selector (" + toString(selector.size()) + ") doesn't match size of column (" + toString(s) + ")", throw Exception("Size of selector (" + toString(selector.size()) + ") doesn't match size of column (" + toString(s) + ")",
ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH); ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
std::vector<size_t> counts(num_columns); std::vector<size_t> counts = countColumnsSizeInSelector(num_columns, selector);
for (auto idx : selector)
++counts[idx];
MutableColumns res(num_columns); MutableColumns res(num_columns);
for (size_t i = 0; i < num_columns; ++i) for (size_t i = 0; i < num_columns; ++i)

View File

@ -1,25 +0,0 @@
#include <Interpreters/ExpressionActions.h>
#include <Columns/ColumnExpression.h>
namespace DB
{
ColumnExpression::ColumnExpression(
size_t s_, const ExpressionActionsPtr & expression_, const NamesAndTypesList & arguments_, const DataTypePtr & return_type_, const String & return_name_)
: expression(expression_), arguments(arguments_), return_type(return_type_), return_name(return_name_)
{
s = s_;
}
MutableColumnPtr ColumnExpression::cloneDummy(size_t s_) const
{
return ColumnExpression::create(s_, expression, arguments, return_type, return_name);
}
const ExpressionActionsPtr & ColumnExpression::getExpression() const { return expression; }
const DataTypePtr & ColumnExpression::getReturnType() const { return return_type; }
const std::string & ColumnExpression::getReturnName() const { return return_name; }
const NamesAndTypesList & ColumnExpression::getArguments() const { return arguments; }
}

View File

@ -1,41 +0,0 @@
#pragma once
#include <Core/NamesAndTypes.h>
#include <Columns/IColumnDummy.h>
namespace DB
{
class ExpressionActions;
/** A column containing a lambda expression.
* Behaves like a constant-column. Contains an expression, but not input or output data.
*/
class ColumnExpression final : public COWPtrHelper<IColumnDummy, ColumnExpression>
{
private:
friend class COWPtrHelper<IColumnDummy, ColumnExpression>;
using ExpressionActionsPtr = std::shared_ptr<ExpressionActions>;
ColumnExpression(size_t s_, const ExpressionActionsPtr & expression_, const NamesAndTypesList & arguments_, const DataTypePtr & return_type_, const String & return_name_);
ColumnExpression(const ColumnExpression &) = default;
public:
const char * getFamilyName() const override { return "Expression"; }
MutableColumnPtr cloneDummy(size_t s_) const override;
const ExpressionActionsPtr & getExpression() const;
const DataTypePtr & getReturnType() const;
const std::string & getReturnName() const;
const NamesAndTypesList & getArguments() const;
private:
ExpressionActionsPtr expression;
NamesAndTypesList arguments;
DataTypePtr return_type;
std::string return_name;
};
}

View File

@ -0,0 +1,202 @@
#include <Interpreters/ExpressionActions.h>
#include <Columns/ColumnFunction.h>
#include <Columns/ColumnsCommon.h>
#include <Functions/IFunction.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
ColumnFunction::ColumnFunction(size_t size, FunctionBasePtr function, const ColumnsWithTypeAndName & columns_to_capture)
: size_(size), function(function)
{
appendArguments(columns_to_capture);
}
MutableColumnPtr ColumnFunction::cloneResized(size_t size) const
{
ColumnsWithTypeAndName capture = captured_columns;
for (auto & column : capture)
column.column = column.column->cloneResized(size);
return ColumnFunction::create(size, function, capture);
}
MutableColumnPtr ColumnFunction::replicate(const Offsets & offsets) const
{
if (size_ != offsets.size())
throw Exception("Size of offsets (" + toString(offsets.size()) + ") doesn't match size of column ("
+ toString(size_) + ")", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
ColumnsWithTypeAndName capture = captured_columns;
for (auto & column : capture)
column.column = column.column->replicate(offsets);
size_t replicated_size = 0 == size_ ? 0 : offsets.back();
return ColumnFunction::create(replicated_size, function, capture);
}
MutableColumnPtr ColumnFunction::cut(size_t start, size_t length) const
{
ColumnsWithTypeAndName capture = captured_columns;
for (auto & column : capture)
column.column = column.column->cut(start, length);
return ColumnFunction::create(length, function, capture);
}
MutableColumnPtr ColumnFunction::filter(const Filter & filt, ssize_t result_size_hint) const
{
if (size_ != filt.size())
throw Exception("Size of filter (" + toString(filt.size()) + ") doesn't match size of column ("
+ toString(size_) + ")", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
ColumnsWithTypeAndName capture = captured_columns;
for (auto & column : capture)
column.column = column.column->filter(filt, result_size_hint);
size_t filtered_size = 0;
if (capture.empty())
filtered_size = countBytesInFilter(filt);
else
filtered_size = capture.front().column->size();
return ColumnFunction::create(filtered_size, function, capture);
}
MutableColumnPtr ColumnFunction::permute(const Permutation & perm, size_t limit) const
{
if (limit == 0)
limit = size_;
else
limit = std::min(size_, limit);
if (perm.size() < limit)
throw Exception("Size of permutation (" + toString(perm.size()) + ") is less than required ("
+ toString(limit) + ")", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
ColumnsWithTypeAndName capture = captured_columns;
for (auto & column : capture)
column.column = column.column->permute(perm, limit);
return ColumnFunction::create(limit, function, capture);
}
std::vector<MutableColumnPtr> ColumnFunction::scatter(IColumn::ColumnIndex num_columns,
const IColumn::Selector & selector) const
{
if (size_ != selector.size())
throw Exception("Size of selector (" + toString(selector.size()) + ") doesn't match size of column ("
+ toString(size_) + ")", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
std::vector<size_t> counts;
if (captured_columns.empty())
counts = countColumnsSizeInSelector(num_columns, selector);
std::vector<ColumnsWithTypeAndName> captures(num_columns, captured_columns);
for (size_t capture = 0; capture < captured_columns.size(); ++capture)
{
auto parts = captured_columns[capture].column->scatter(num_columns, selector);
for (IColumn::ColumnIndex part = 0; part < num_columns; ++part)
captures[part][capture].column = std::move(parts[part]);
}
std::vector<MutableColumnPtr> columns;
columns.reserve(num_columns);
for (IColumn::ColumnIndex part = 0; part < num_columns; ++part)
{
auto & capture = captures[part];
size_t size__ = capture.empty() ? counts[part] : capture.front().column->size();
columns.emplace_back(ColumnFunction::create(size__, function, std::move(capture)));
}
return columns;
}
void ColumnFunction::insertDefault()
{
for (auto & column : captured_columns)
column.column->assumeMutable()->insertDefault();
++size_;
}
void ColumnFunction::popBack(size_t n)
{
for (auto & column : captured_columns)
column.column->assumeMutable()->popBack(n);
size_ -= n;
}
size_t ColumnFunction::byteSize() const
{
size_t total_size = 0;
for (auto & column : captured_columns)
total_size += column.column->byteSize();
return total_size;
}
size_t ColumnFunction::allocatedBytes() const
{
size_t total_size = 0;
for (auto & column : captured_columns)
total_size += column.column->allocatedBytes();
return total_size;
}
void ColumnFunction::appendArguments(const ColumnsWithTypeAndName & columns)
{
auto args = function->getArgumentTypes().size();
auto were_captured = captured_columns.size();
auto wanna_capture = columns.size();
if (were_captured + wanna_capture > args)
throw Exception("Cannot capture " + toString(wanna_capture) + " columns because function " + function->getName()
+ " has " + toString(args) + " arguments" +
(were_captured ? " and " + toString(were_captured) + " columns have already been captured" : "")
+ ".", ErrorCodes::LOGICAL_ERROR);
for (const auto & column : columns)
appendArgument(column);
}
void ColumnFunction::appendArgument(const ColumnWithTypeAndName & column)
{
const auto & argumnet_types = function->getArgumentTypes();
auto index = captured_columns.size();
if (!column.type->equals(*argumnet_types[index]))
throw Exception("Cannot capture column " + std::to_string(argumnet_types.size()) +
"because it has incompatible type: got " + column.type->getName() +
", but " + argumnet_types[index]->getName() + " is expected.", ErrorCodes::LOGICAL_ERROR);
captured_columns.push_back(column);
}
ColumnWithTypeAndName ColumnFunction::reduce() const
{
auto args = function->getArgumentTypes().size();
auto captured = captured_columns.size();
if (args != captured)
throw Exception("Cannot call function " + function->getName() + " because is has " + toString(args) +
"arguments but " + toString(captured) + " columns were captured.", ErrorCodes::LOGICAL_ERROR);
Block block(captured_columns);
block.insert({nullptr, function->getReturnType(), ""});
ColumnNumbers arguments(captured_columns.size());
for (size_t i = 0; i < captured_columns.size(); ++i)
arguments[i] = i;
function->execute(block, arguments, captured_columns.size());
return block.getByPosition(captured_columns.size());
}
}

View File

@ -0,0 +1,116 @@
#pragma once
#include <Core/NamesAndTypes.h>
#include <Columns/IColumn.h>
class IFunctionBase;
using FunctionBasePtr = std::shared_ptr<IFunctionBase>;
namespace DB
{
/** A column containing a lambda expression.
* Behaves like a constant-column. Contains an expression, but not input or output data.
*/
class ColumnFunction final : public COWPtrHelper<IColumn, ColumnFunction>
{
private:
friend class COWPtrHelper<IColumn, ColumnFunction>;
ColumnFunction(size_t size, FunctionBasePtr function, const ColumnsWithTypeAndName & columns_to_capture);
public:
const char * getFamilyName() const override { return "Function"; }
MutableColumnPtr cloneResized(size_t size) const override;
size_t size() const override { return size_; }
MutableColumnPtr cut(size_t start, size_t length) const override;
MutableColumnPtr replicate(const Offsets & offsets) const override;
MutableColumnPtr filter(const Filter & filt, ssize_t result_size_hint) const override;
MutableColumnPtr permute(const Permutation & perm, size_t limit) const override;
void insertDefault() override;
void popBack(size_t n) override;
std::vector<MutableColumnPtr> scatter(IColumn::ColumnIndex num_columns,
const IColumn::Selector & selector) const override;
void getExtremes(Field &, Field &) const override {}
size_t byteSize() const override;
size_t allocatedBytes() const override;
void appendArguments(const ColumnsWithTypeAndName & columns);
ColumnWithTypeAndName reduce() const;
Field operator[](size_t) const override
{
throw Exception("Cannot get value from " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
void get(size_t, Field &) const override
{
throw Exception("Cannot get value from " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
StringRef getDataAt(size_t) const override
{
throw Exception("Cannot get value from " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
void insert(const Field &) override
{
throw Exception("Cannot get insert into " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
void insertRangeFrom(const IColumn &, size_t, size_t) override
{
throw Exception("Cannot insert into " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
void insertData(const char *, size_t) override
{
throw Exception("Cannot insert into " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
StringRef serializeValueIntoArena(size_t, Arena &, char const *&) const override
{
throw Exception("Cannot serialize from " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
const char * deserializeAndInsertFromArena(const char *) override
{
throw Exception("Cannot deserialize to " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
void updateHashWithValue(size_t, SipHash &) const override
{
throw Exception("updateHashWithValue is not implemented for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
int compareAt(size_t, size_t, const IColumn &, int) const override
{
throw Exception("compareAt is not implemented for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
void getPermutation(bool, size_t, int, Permutation &) const override
{
throw Exception("getPermutation is not implemented for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
void gather(ColumnGathererStream &) override
{
throw Exception("Method gather is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
private:
size_t size_;
FunctionBasePtr function;
ColumnsWithTypeAndName captured_columns;
void appendArgument(const ColumnWithTypeAndName & column);
};
}

View File

@ -48,6 +48,14 @@ size_t countBytesInFilter(const IColumn::Filter & filt)
return count; return count;
} }
std::vector<size_t> countColumnsSizeInSelector(IColumn::ColumnIndex num_columns, const IColumn::Selector & selector)
{
std::vector<size_t> counts(num_columns);
for (auto idx : selector)
++counts[idx];
return counts;
}
/** clang 4 generates better code than gcc 6. /** clang 4 generates better code than gcc 6.
* And both gcc and clang could not vectorize trivial loop by bytes automatically. * And both gcc and clang could not vectorize trivial loop by bytes automatically.

View File

@ -11,6 +11,10 @@ namespace DB
/// Counts how many bytes of `filt` are greater than zero. /// Counts how many bytes of `filt` are greater than zero.
size_t countBytesInFilter(const IColumn::Filter & filt); size_t countBytesInFilter(const IColumn::Filter & filt);
/// Returns vector with num_columns elements. vector[i] is the count of i values in selector.
/// Selector must contain values from 0 to num_columns - 1. NOTE: this is not checked.
std::vector<size_t> countColumnsSizeInSelector(IColumn::ColumnIndex num_columns, const IColumn::Selector & selector);
/// Returns true, if the memory contains only zeros. /// Returns true, if the memory contains only zeros.
bool memoryIsZero(const void * data, size_t size); bool memoryIsZero(const void * data, size_t size);

View File

@ -56,10 +56,10 @@ std::ostream & operator<<(std::ostream & stream, const TableStructureReadLock &)
return stream; return stream;
} }
std::ostream & operator<<(std::ostream & stream, const IFunction & what) std::ostream & operator<<(std::ostream & stream, const IFunctionBuilder & what)
{ {
stream << "IFunction(name = " << what.getName() << ", variadic = " << what.isVariadic() << ", args = " << what.getNumberOfArguments() stream << "IFunction(name = " << what.getName() << ", variadic = " << what.isVariadic() << ", args = "
<< ")"; << what.getNumberOfArguments() << ")";
return stream; return stream;
} }

View File

@ -25,8 +25,8 @@ std::ostream & operator<<(std::ostream & stream, const IStorage & what);
class TableStructureReadLock; class TableStructureReadLock;
std::ostream & operator<<(std::ostream & stream, const TableStructureReadLock & what); std::ostream & operator<<(std::ostream & stream, const TableStructureReadLock & what);
class IFunction; class IFunctionBase;
std::ostream & operator<<(std::ostream & stream, const IFunction & what); std::ostream & operator<<(std::ostream & stream, const IFunctionBase & what);
class Block; class Block;
std::ostream & operator<<(std::ostream & stream, const Block & what); std::ostream & operator<<(std::ostream & stream, const Block & what);

View File

@ -7,8 +7,6 @@
namespace DB namespace DB
{ {
class IFunction;
/// Implicitly converts string and numeric values to Enum, numeric types to other numeric types. /// Implicitly converts string and numeric values to Enum, numeric types to other numeric types.
class CastTypeBlockInputStream : public IProfilingBlockInputStream class CastTypeBlockInputStream : public IProfilingBlockInputStream
{ {

View File

@ -1,12 +1,12 @@
#include <DataTypes/DataTypeExpression.h> #include <DataTypes/DataTypeFunction.h>
namespace DB namespace DB
{ {
std::string DataTypeExpression::getName() const std::string DataTypeFunction::getName() const
{ {
std::string res = "Expression("; std::string res = "Function(";
if (argument_types.size() > 1) if (argument_types.size() > 1)
res += "("; res += "(";
for (size_t i = 0; i < argument_types.size(); ++i) for (size_t i = 0; i < argument_types.size(); ++i)
@ -24,7 +24,7 @@ std::string DataTypeExpression::getName() const
return res; return res;
} }
bool DataTypeExpression::equals(const IDataType & rhs) const bool DataTypeFunction::equals(const IDataType & rhs) const
{ {
return typeid(rhs) == typeid(*this) && getName() == rhs.getName(); return typeid(rhs) == typeid(*this) && getName() == rhs.getName();
} }

View File

@ -8,7 +8,7 @@ namespace DB
/** Special data type, representing lambda expression. /** Special data type, representing lambda expression.
*/ */
class DataTypeExpression final : public IDataTypeDummy class DataTypeFunction final : public IDataTypeDummy
{ {
private: private:
DataTypes argument_types; DataTypes argument_types;
@ -19,11 +19,11 @@ public:
bool isParametric() const override { return true; } bool isParametric() const override { return true; }
/// Some types could be still unknown. /// Some types could be still unknown.
DataTypeExpression(const DataTypes & argument_types_ = DataTypes(), const DataTypePtr & return_type_ = nullptr) DataTypeFunction(const DataTypes & argument_types_ = DataTypes(), const DataTypePtr & return_type_ = nullptr)
: argument_types(argument_types_), return_type(return_type_) {} : argument_types(argument_types_), return_type(return_type_) {}
std::string getName() const override; std::string getName() const override;
const char * getFamilyName() const override { return "Expression"; } const char * getFamilyName() const override { return "Function"; }
const DataTypes & getArgumentTypes() const const DataTypes & getArgumentTypes() const
{ {

View File

@ -33,7 +33,7 @@ void FunctionFactory::registerFunction(const
} }
FunctionPtr FunctionFactory::get( FunctionBuilderPtr FunctionFactory::get(
const std::string & name, const std::string & name,
const Context & context) const const Context & context) const
{ {
@ -44,7 +44,7 @@ FunctionPtr FunctionFactory::get(
} }
FunctionPtr FunctionFactory::tryGet( FunctionBuilderPtr FunctionFactory::tryGet(
const std::string & name, const std::string & name,
const Context & context) const const Context & context) const
{ {

View File

@ -25,7 +25,7 @@ class FunctionFactory : public ext::singleton<FunctionFactory>
friend class StorageSystemFunctions; friend class StorageSystemFunctions;
public: public:
using Creator = std::function<FunctionPtr(const Context &)>; using Creator = std::function<FunctionBuilderPtr(const Context &)>;
/// For compatibility with SQL, it's possible to specify that certain function name is case insensitive. /// For compatibility with SQL, it's possible to specify that certain function name is case insensitive.
enum CaseSensitiveness enum CaseSensitiveness
@ -34,30 +34,45 @@ public:
CaseInsensitive CaseInsensitive
}; };
/// Register a function by its name. template <typename Function>
/// No locking, you must register all functions before usage of get. void registerFunction(CaseSensitiveness case_sensitiveness = CaseSensitive)
void registerFunction( {
const std::string & name, registerFunction<Function>(Function::name, case_sensitiveness);
Creator creator, }
CaseSensitiveness case_sensitiveness = CaseSensitive);
template <typename Function> template <typename Function>
void registerFunction() void registerFunction(const std::string & name, CaseSensitiveness case_sensitiveness = CaseSensitive)
{ {
registerFunction(Function::name, &Function::create); if constexpr (std::is_base_of<IFunction, Function>::value)
registerFunction(name, &createDefaultFunction<Function>, case_sensitiveness);
else
registerFunction(name, &Function::create, case_sensitiveness);
} }
/// Throws an exception if not found. /// Throws an exception if not found.
FunctionPtr get(const std::string & name, const Context & context) const; FunctionBuilderPtr get(const std::string & name, const Context & context) const;
/// Returns nullptr if not found. /// Returns nullptr if not found.
FunctionPtr tryGet(const std::string & name, const Context & context) const; FunctionBuilderPtr tryGet(const std::string & name, const Context & context) const;
private: private:
using Functions = std::unordered_map<std::string, Creator>; using Functions = std::unordered_map<std::string, Creator>;
Functions functions; Functions functions;
Functions case_insensitive_functions; Functions case_insensitive_functions;
template <typename Function>
static FunctionBuilderPtr createDefaultFunction(const Context & context)
{
return std::make_shared<DefaultFunctionBuilder>(Function::create(context));
}
/// Register a function by its name.
/// No locking, you must register all functions before usage of get.
void registerFunction(
const std::string & name,
Creator creator,
CaseSensitiveness case_sensitiveness = CaseSensitive);
}; };
} }

View File

@ -778,7 +778,7 @@ private:
return false; return false;
} }
FunctionPtr getFunctionForIntervalArithmetic(const DataTypePtr & type0, const DataTypePtr & type1) const FunctionBuilderPtr getFunctionForIntervalArithmetic(const DataTypePtr & type0, const DataTypePtr & type1) const
{ {
/// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval. /// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
/// We construct another function (example: addMonths) and call it. /// We construct another function (example: addMonths) and call it.
@ -830,7 +830,7 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{ {
/// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval. /// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
if (FunctionPtr function = getFunctionForIntervalArithmetic(arguments[0], arguments[1])) if (auto function_builder = getFunctionForIntervalArithmetic(arguments[0], arguments[1]))
{ {
ColumnsWithTypeAndName new_arguments(2); ColumnsWithTypeAndName new_arguments(2);
@ -844,10 +844,8 @@ public:
/// Change interval argument to its representation /// Change interval argument to its representation
new_arguments[1].type = std::make_shared<DataTypeNumber<DataTypeInterval::FieldType>>(); new_arguments[1].type = std::make_shared<DataTypeNumber<DataTypeInterval::FieldType>>();
DataTypePtr res; auto function = function_builder->build(new_arguments);
std::vector<ExpressionAction> unused_prerequisites; return function->getReturnType();
function->getReturnTypeAndPrerequisites(new_arguments, res, unused_prerequisites);
return res;
} }
DataTypePtr type_res; DataTypePtr type_res;
@ -873,7 +871,7 @@ public:
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override
{ {
/// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval. /// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
if (FunctionPtr function = getFunctionForIntervalArithmetic(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type)) if (auto function_builder = getFunctionForIntervalArithmetic(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type))
{ {
ColumnNumbers new_arguments = arguments; ColumnNumbers new_arguments = arguments;
@ -885,7 +883,11 @@ public:
Block new_block = block; Block new_block = block;
new_block.getByPosition(new_arguments[1]).type = std::make_shared<DataTypeNumber<DataTypeInterval::FieldType>>(); new_block.getByPosition(new_arguments[1]).type = std::make_shared<DataTypeNumber<DataTypeInterval::FieldType>>();
function->executeImpl(new_block, new_arguments, result); ColumnsWithTypeAndName new_arguments_with_type_and_name =
{new_block.getByPosition(new_arguments[0]), new_block.getByPosition(new_arguments[1])};
auto function = function_builder->build(new_arguments_with_type_and_name);
function->execute(new_block, new_arguments, result);
block.getByPosition(result).column = new_block.getByPosition(result).column; block.getByPosition(result).column = new_block.getByPosition(result).column;
return; return;

View File

@ -2345,10 +2345,7 @@ String FunctionArrayReduce::getName() const
return name; return name;
} }
void FunctionArrayReduce::getReturnTypeAndPrerequisitesImpl( DataTypePtr FunctionArrayReduce::getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const
const ColumnsWithTypeAndName & arguments,
DataTypePtr & out_return_type,
std::vector<ExpressionAction> & /*out_prerequisites*/)
{ {
/// The first argument is a constant string with the name of the aggregate function /// The first argument is a constant string with the name of the aggregate function
/// (possibly with parameters in parentheses, for example: "quantile(0.99)"). /// (possibly with parameters in parentheses, for example: "quantile(0.99)").
@ -2390,7 +2387,7 @@ void FunctionArrayReduce::getReturnTypeAndPrerequisitesImpl(
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row); aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row);
} }
out_return_type = aggregate_function->getReturnType(); return aggregate_function->getReturnType();
} }

View File

@ -1402,14 +1402,13 @@ public:
bool isVariadic() const override { return true; } bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; } size_t getNumberOfArguments() const override { return 0; }
void getReturnTypeAndPrerequisitesImpl( DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override;
const ColumnsWithTypeAndName & arguments,
DataTypePtr & out_return_type,
std::vector<ExpressionAction> & out_prerequisites) override;
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override; void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override;
private: private:
AggregateFunctionPtr aggregate_function; /// lazy initialization in getReturnTypeImpl
/// TODO: init in FunctionBuilder
mutable AggregateFunctionPtr aggregate_function;
}; };

View File

@ -1084,7 +1084,11 @@ public:
{ {
size_t size = left_tuple->getElements().size(); size_t size = left_tuple->getElements().size();
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
getReturnType({ left_tuple->getElements()[i], right_tuple->getElements()[i] }); {
ColumnsWithTypeAndName args = {{nullptr, left_tuple->getElements()[i], ""},
{nullptr, right_tuple->getElements()[i], ""}};
getReturnType(args);
}
} }
return std::make_shared<DataTypeUInt8>(); return std::make_shared<DataTypeUInt8>();

View File

@ -21,9 +21,9 @@ void registerFunctionsConditional(FunctionFactory & factory)
factory.registerFunction<FunctionCaseWithExpression>(); factory.registerFunction<FunctionCaseWithExpression>();
/// These are obsolete function names. /// These are obsolete function names.
factory.registerFunction("caseWithExpr", FunctionCaseWithExpression::create); factory.registerFunction<FunctionCaseWithExpression>("caseWithExpr");
factory.registerFunction("caseWithoutExpr", FunctionMultiIf::create); factory.registerFunction<FunctionMultiIf>("caseWithoutExpr");
factory.registerFunction("caseWithoutExpression", FunctionMultiIf::create); factory.registerFunction<FunctionMultiIf>("caseWithoutExpression");
} }
@ -251,15 +251,15 @@ DataTypePtr FunctionCaseWithExpression::getReturnTypeImpl(const DataTypes & args
/// get the return type of a transform function. /// get the return type of a transform function.
/// Get the return types of the arrays that we pass to the transform function. /// Get the return types of the arrays that we pass to the transform function.
DataTypes src_array_types; ColumnsWithTypeAndName src_array_types;
DataTypes dst_array_types; ColumnsWithTypeAndName dst_array_types;
for (size_t i = 1; i < (args.size() - 1); ++i) for (size_t i = 1; i < (args.size() - 1); ++i)
{ {
if ((i % 2) != 0) if ((i % 2) != 0)
src_array_types.push_back(args[i]); src_array_types.push_back({nullptr, args[i], {}});
else else
dst_array_types.push_back(args[i]); dst_array_types.push_back({nullptr, args[i], {}});
} }
FunctionArray fun_array{context}; FunctionArray fun_array{context};
@ -269,7 +269,9 @@ DataTypePtr FunctionCaseWithExpression::getReturnTypeImpl(const DataTypes & args
/// Finally get the return type of the transform function. /// Finally get the return type of the transform function.
FunctionTransform fun_transform; FunctionTransform fun_transform;
return fun_transform.getReturnType({args.front(), src_array_type, dst_array_type, args.back()}); ColumnsWithTypeAndName transform_args = {{nullptr, args.front(), {}}, {nullptr, src_array_type, {}},
{nullptr, dst_array_type, {}}, {nullptr, args.back(), {}}};
return fun_transform.getReturnType(transform_args);
} }
void FunctionCaseWithExpression::executeImpl(Block & block, const ColumnNumbers & args, size_t result) void FunctionCaseWithExpression::executeImpl(Block & block, const ColumnNumbers & args, size_t result)
@ -288,22 +290,22 @@ void FunctionCaseWithExpression::executeImpl(Block & block, const ColumnNumbers
/// Create the arrays required by the transform function. /// Create the arrays required by the transform function.
ColumnNumbers src_array_args; ColumnNumbers src_array_args;
DataTypes src_array_types; ColumnsWithTypeAndName src_array_types;
ColumnNumbers dst_array_args; ColumnNumbers dst_array_args;
DataTypes dst_array_types; ColumnsWithTypeAndName dst_array_types;
for (size_t i = 1; i < (args.size() - 1); ++i) for (size_t i = 1; i < (args.size() - 1); ++i)
{ {
if ((i % 2) != 0) if ((i % 2) != 0)
{ {
src_array_args.push_back(args[i]); src_array_args.push_back(args[i]);
src_array_types.push_back(block.getByPosition(args[i]).type); src_array_types.push_back(block.getByPosition(args[i]));
} }
else else
{ {
dst_array_args.push_back(args[i]); dst_array_args.push_back(args[i]);
dst_array_types.push_back(block.getByPosition(args[i]).type); dst_array_types.push_back(block.getByPosition(args[i]));
} }
} }

View File

@ -47,7 +47,7 @@ void registerFunctionsConversion(FunctionFactory & factory)
factory.registerFunction<FunctionToFixedString>(); factory.registerFunction<FunctionToFixedString>();
factory.registerFunction<FunctionToUnixTimestamp>(); factory.registerFunction<FunctionToUnixTimestamp>();
factory.registerFunction<FunctionCast>(); factory.registerFunction<FunctionBuilderCast>();
factory.registerFunction<FunctionToUInt8OrZero>(); factory.registerFunction<FunctionToUInt8OrZero>();
factory.registerFunction<FunctionToUInt16OrZero>(); factory.registerFunction<FunctionToUInt16OrZero>();

View File

@ -625,14 +625,11 @@ public:
size_t getNumberOfArguments() const override { return 0; } size_t getNumberOfArguments() const override { return 0; }
bool isInjective(const Block &) override { return std::is_same_v<Name, NameToString>; } bool isInjective(const Block &) override { return std::is_same_v<Name, NameToString>; }
void getReturnTypeAndPrerequisitesImpl( DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
const ColumnsWithTypeAndName & arguments,
DataTypePtr & out_return_type,
std::vector<ExpressionAction> &) override
{ {
if constexpr (std::is_same_v<ToDataType, DataTypeInterval>) if constexpr (std::is_same_v<ToDataType, DataTypeInterval>)
{ {
out_return_type = std::make_shared<DataTypeInterval>(DataTypeInterval::Kind(Name::kind)); return std::make_shared<DataTypeInterval>(DataTypeInterval::Kind(Name::kind));
} }
else else
{ {
@ -660,9 +657,9 @@ public:
} }
if (std::is_same_v<ToDataType, DataTypeDateTime>) if (std::is_same_v<ToDataType, DataTypeDateTime>)
out_return_type = std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments(arguments, 1, 0)); return std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments(arguments, 1, 0));
else else
out_return_type = std::make_shared<ToDataType>(); return std::make_shared<ToDataType>();
} }
} }
@ -845,9 +842,7 @@ public:
size_t getNumberOfArguments() const override { return 2; } size_t getNumberOfArguments() const override { return 2; }
bool isInjective(const Block &) override { return true; } bool isInjective(const Block &) override { return true; }
void getReturnTypeAndPrerequisitesImpl(const ColumnsWithTypeAndName & arguments, DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
DataTypePtr & out_return_type,
std::vector<ExpressionAction> & /*out_prerequisites*/) override
{ {
if (!arguments[1].type->isUnsignedInteger()) if (!arguments[1].type->isUnsignedInteger())
throw Exception("Second argument for function " + getName() + " must be unsigned integer", ErrorCodes::ILLEGAL_COLUMN); throw Exception("Second argument for function " + getName() + " must be unsigned integer", ErrorCodes::ILLEGAL_COLUMN);
@ -857,7 +852,7 @@ public:
throw Exception(getName() + " is only implemented for types String and FixedString", ErrorCodes::NOT_IMPLEMENTED); throw Exception(getName() + " is only implemented for types String and FixedString", ErrorCodes::NOT_IMPLEMENTED);
const size_t n = arguments[1].column->getUInt(0); const size_t n = arguments[1].column->getUInt(0);
out_return_type = std::make_shared<DataTypeFixedString>(n); return std::make_shared<DataTypeFixedString>(n);
} }
bool useDefaultImplementationForConstants() const override { return true; } bool useDefaultImplementationForConstants() const override { return true; }
@ -1130,19 +1125,80 @@ using FunctionToFloat32OrNull = FunctionConvertOrNull<DataTypeFloat32, NameToFlo
using FunctionToFloat64OrNull = FunctionConvertOrNull<DataTypeFloat64, NameToFloat64OrNull>; using FunctionToFloat64OrNull = FunctionConvertOrNull<DataTypeFloat64, NameToFloat64OrNull>;
class FunctionCast final : public IFunction class PreparedFunctionCast : public PreparedFunctionImpl
{ {
public: public:
FunctionCast(const Context & context) : context(context) {} using WrapperType = std::function<void(Block &, const ColumnNumbers &, size_t)>;
explicit PreparedFunctionCast(WrapperType && wrapper_function, const char * name)
: wrapper_function(std::move(wrapper_function)), name(name) {}
String getName() const override { return name; }
protected:
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override
{
/// drop second argument, pass others
ColumnNumbers new_arguments{arguments.front()};
if (arguments.size() > 2)
new_arguments.insert(std::end(new_arguments), std::next(std::begin(arguments), 2), std::end(arguments));
wrapper_function(block, new_arguments, result);
}
bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; }
private: private:
using WrapperType = std::function<void(Block &, const ColumnNumbers &, size_t)>;
const Context & context;
WrapperType wrapper_function; WrapperType wrapper_function;
std::function<Monotonicity(const IDataType &, const Field &, const Field &)> monotonicity_for_range; const char * name;
};
class FunctionCast final : public IFunctionBase
{
public:
using WrapperType = std::function<void(Block &, const ColumnNumbers &, size_t)>;
using MonotonicityForRange = std::function<Monotonicity(const IDataType &, const Field &, const Field &)>;
FunctionCast(const Context & context, const char * name, MonotonicityForRange && monotonicity_for_range
, const DataTypes & argument_types, const DataTypePtr & return_type)
: context(context), name(name), monotonicity_for_range(monotonicity_for_range)
, argument_types(argument_types), return_type(return_type)
{
}
const DataTypes & getArgumentTypes() const override { return argument_types; }
const DataTypePtr & getReturnType() const override { return return_type; }
PreparedFunctionPtr prepare(const Block & /*sample_block*/) const override
{
return std::make_shared<PreparedFunctionCast>(prepare(getArgumentTypes()[0], getReturnType().get()), name);
}
String getName() const override { return name; }
bool hasInformationAboutMonotonicity() const override
{
return static_cast<bool>(monotonicity_for_range);
}
Monotonicity getMonotonicityForRange(const IDataType & type, const Field & left, const Field & right) const override
{
return monotonicity_for_range(type, left, right);
}
private:
const Context & context;
const char * name;
MonotonicityForRange monotonicity_for_range;
DataTypes argument_types;
DataTypePtr return_type;
template <typename DataType> template <typename DataType>
WrapperType createWrapper(const DataTypePtr & from_type, const DataType * const) WrapperType createWrapper(const DataTypePtr & from_type, const DataType * const) const
{ {
using FunctionType = typename FunctionTo<DataType>::Type; using FunctionType = typename FunctionTo<DataType>::Type;
@ -1150,9 +1206,7 @@ private:
/// Check conversion using underlying function /// Check conversion using underlying function
{ {
DataTypePtr unused_data_type; function->getReturnType(ColumnsWithTypeAndName(1, { nullptr, from_type, "" }));
std::vector<ExpressionAction> unused_prerequisites;
function->getReturnTypeAndPrerequisites({{ nullptr, from_type, "" }}, unused_data_type, unused_prerequisites);
} }
return [function] (Block & block, const ColumnNumbers & arguments, const size_t result) return [function] (Block & block, const ColumnNumbers & arguments, const size_t result)
@ -1174,7 +1228,7 @@ private:
}; };
} }
WrapperType createArrayWrapper(const DataTypePtr & from_type_untyped, const DataTypeArray * to_type) WrapperType createArrayWrapper(const DataTypePtr & from_type_untyped, const DataTypeArray * to_type) const
{ {
/// Conversion from String through parsing. /// Conversion from String through parsing.
if (checkAndGetDataType<DataTypeString>(from_type_untyped.get())) if (checkAndGetDataType<DataTypeString>(from_type_untyped.get()))
@ -1235,7 +1289,7 @@ private:
}; };
} }
WrapperType createTupleWrapper(const DataTypePtr & from_type_untyped, const DataTypeTuple * to_type) WrapperType createTupleWrapper(const DataTypePtr & from_type_untyped, const DataTypeTuple * to_type) const
{ {
/// Conversion from String through parsing. /// Conversion from String through parsing.
if (checkAndGetDataType<DataTypeString>(from_type_untyped.get())) if (checkAndGetDataType<DataTypeString>(from_type_untyped.get()))
@ -1304,7 +1358,7 @@ private:
} }
template <typename FieldType> template <typename FieldType>
WrapperType createEnumWrapper(const DataTypePtr & from_type, const DataTypeEnum<FieldType> * to_type) WrapperType createEnumWrapper(const DataTypePtr & from_type, const DataTypeEnum<FieldType> * to_type) const
{ {
using EnumType = DataTypeEnum<FieldType>; using EnumType = DataTypeEnum<FieldType>;
using Function = typename FunctionTo<EnumType>::Type; using Function = typename FunctionTo<EnumType>::Type;
@ -1324,9 +1378,7 @@ private:
/// Check conversion using underlying function /// Check conversion using underlying function
{ {
DataTypePtr unused_data_type; function->getReturnType(ColumnsWithTypeAndName(1, { nullptr, from_type, "" }));
std::vector<ExpressionAction> unused_prerequisites;
function->getReturnTypeAndPrerequisites({{ nullptr, from_type, "" }}, unused_data_type, unused_prerequisites);
} }
return [function] (Block & block, const ColumnNumbers & arguments, const size_t result) return [function] (Block & block, const ColumnNumbers & arguments, const size_t result)
@ -1342,7 +1394,7 @@ private:
} }
template <typename EnumTypeFrom, typename EnumTypeTo> template <typename EnumTypeFrom, typename EnumTypeTo>
void checkEnumToEnumConversion(const EnumTypeFrom * from_type, const EnumTypeTo * to_type) void checkEnumToEnumConversion(const EnumTypeFrom * from_type, const EnumTypeTo * to_type) const
{ {
const auto & from_values = from_type->getValues(); const auto & from_values = from_type->getValues();
const auto & to_values = to_type->getValues(); const auto & to_values = to_type->getValues();
@ -1369,9 +1421,10 @@ private:
}; };
template <typename ColumnStringType, typename EnumType> template <typename ColumnStringType, typename EnumType>
WrapperType createStringToEnumWrapper() WrapperType createStringToEnumWrapper() const
{ {
return [] (Block & block, const ColumnNumbers & arguments, const size_t result) const char * function_name = name;
return [function_name] (Block & block, const ColumnNumbers & arguments, const size_t result)
{ {
const auto first_col = block.getByPosition(arguments.front()).column.get(); const auto first_col = block.getByPosition(arguments.front()).column.get();
@ -1393,13 +1446,12 @@ private:
} }
else else
throw Exception{ throw Exception{
"Unexpected column " + first_col->getName() + " as first argument of function " + "Unexpected column " + first_col->getName() + " as first argument of function " + function_name,
name,
ErrorCodes::LOGICAL_ERROR}; ErrorCodes::LOGICAL_ERROR};
}; };
} }
WrapperType createIdentityWrapper(const DataTypePtr &) WrapperType createIdentityWrapper(const DataTypePtr &) const
{ {
return [] (Block & block, const ColumnNumbers & arguments, const size_t result) return [] (Block & block, const ColumnNumbers & arguments, const size_t result)
{ {
@ -1407,7 +1459,7 @@ private:
}; };
} }
WrapperType createNothingWrapper(const IDataType * to_type) WrapperType createNothingWrapper(const IDataType * to_type) const
{ {
ColumnPtr res = to_type->createColumnConstWithDefaultValue(1); ColumnPtr res = to_type->createColumnConstWithDefaultValue(1);
return [res] (Block & block, const ColumnNumbers &, const size_t result) return [res] (Block & block, const ColumnNumbers &, const size_t result)
@ -1424,7 +1476,7 @@ private:
bool result_is_nullable = false; bool result_is_nullable = false;
}; };
WrapperType prepare(const DataTypePtr & from_type, const IDataType * to_type) WrapperType prepare(const DataTypePtr & from_type, const IDataType * to_type) const
{ {
/// Determine whether pre-processing and/or post-processing must take place during conversion. /// Determine whether pre-processing and/or post-processing must take place during conversion.
@ -1519,7 +1571,7 @@ private:
return wrapper; return wrapper;
} }
WrapperType prepareImpl(const DataTypePtr & from_type, const IDataType * to_type) WrapperType prepareImpl(const DataTypePtr & from_type, const IDataType * to_type) const
{ {
if (from_type->equals(*to_type)) if (from_type->equals(*to_type))
return createIdentityWrapper(from_type); return createIdentityWrapper(from_type);
@ -1569,99 +1621,95 @@ private:
"Conversion from " + from_type->getName() + " to " + to_type->getName() + " is not supported", "Conversion from " + from_type->getName() + " to " + to_type->getName() + " is not supported",
ErrorCodes::CANNOT_CONVERT_TYPE}; ErrorCodes::CANNOT_CONVERT_TYPE};
} }
};
class FunctionBuilderCast : public FunctionBuilderImpl
{
public:
using MonotonicityForRange = FunctionCast::MonotonicityForRange;
static constexpr auto name = "CAST";
static FunctionBuilderPtr create(const Context & context) { return std::make_shared<FunctionBuilderCast>(context); }
FunctionBuilderCast(const Context & context) : context(context) {}
String getName() const { return name; }
size_t getNumberOfArguments() const override { return 2; }
protected:
FunctionBasePtr buildImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type) const override
{
DataTypes data_types(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i)
data_types[i] = arguments[i].type;
auto monotonicity = getMonotonicityInformation(arguments.front().type, return_type.get());
return std::make_shared<FunctionCast>(context, name, std::move(monotonicity), data_types, return_type);
}
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
const auto type_col = checkAndGetColumnConst<ColumnString>(arguments.back().column.get());
if (!type_col)
throw Exception("Second argument to " + getName() + " must be a constant string describing type",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return DataTypeFactory::instance().get(type_col->getValue<String>());
}
bool useDefaultImplementationForNulls() const override { return false; }
private:
template <typename DataType> template <typename DataType>
static auto monotonicityForType(const DataType * const) static auto monotonicityForType(const DataType * const)
{ {
return FunctionTo<DataType>::Type::Monotonic::get; return FunctionTo<DataType>::Type::Monotonic::get;
} }
void prepareMonotonicityInformation(const DataTypePtr & from_type, const IDataType * to_type) MonotonicityForRange getMonotonicityInformation(const DataTypePtr & from_type, const IDataType * to_type) const
{ {
if (const auto type = checkAndGetDataType<DataTypeUInt8>(to_type)) if (const auto type = checkAndGetDataType<DataTypeUInt8>(to_type))
monotonicity_for_range = monotonicityForType(type); return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeUInt16>(to_type)) else if (const auto type = checkAndGetDataType<DataTypeUInt16>(to_type))
monotonicity_for_range = monotonicityForType(type); return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeUInt32>(to_type)) else if (const auto type = checkAndGetDataType<DataTypeUInt32>(to_type))
monotonicity_for_range = monotonicityForType(type); return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeUInt64>(to_type)) else if (const auto type = checkAndGetDataType<DataTypeUInt64>(to_type))
monotonicity_for_range = monotonicityForType(type); return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeInt8>(to_type)) else if (const auto type = checkAndGetDataType<DataTypeInt8>(to_type))
monotonicity_for_range = monotonicityForType(type); return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeInt16>(to_type)) else if (const auto type = checkAndGetDataType<DataTypeInt16>(to_type))
monotonicity_for_range = monotonicityForType(type); return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeInt32>(to_type)) else if (const auto type = checkAndGetDataType<DataTypeInt32>(to_type))
monotonicity_for_range = monotonicityForType(type); return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeInt64>(to_type)) else if (const auto type = checkAndGetDataType<DataTypeInt64>(to_type))
monotonicity_for_range = monotonicityForType(type); return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeFloat32>(to_type)) else if (const auto type = checkAndGetDataType<DataTypeFloat32>(to_type))
monotonicity_for_range = monotonicityForType(type); return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeFloat64>(to_type)) else if (const auto type = checkAndGetDataType<DataTypeFloat64>(to_type))
monotonicity_for_range = monotonicityForType(type); return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeDate>(to_type)) else if (const auto type = checkAndGetDataType<DataTypeDate>(to_type))
monotonicity_for_range = monotonicityForType(type); return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeDateTime>(to_type)) else if (const auto type = checkAndGetDataType<DataTypeDateTime>(to_type))
monotonicity_for_range = monotonicityForType(type); return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeString>(to_type)) else if (const auto type = checkAndGetDataType<DataTypeString>(to_type))
monotonicity_for_range = monotonicityForType(type); return monotonicityForType(type);
else if (from_type->isEnum()) else if (from_type->isEnum())
{ {
if (const auto type = checkAndGetDataType<DataTypeEnum8>(to_type)) if (const auto type = checkAndGetDataType<DataTypeEnum8>(to_type))
monotonicity_for_range = monotonicityForType(type); return monotonicityForType(type);
else if (const auto type = checkAndGetDataType<DataTypeEnum16>(to_type)) else if (const auto type = checkAndGetDataType<DataTypeEnum16>(to_type))
monotonicity_for_range = monotonicityForType(type); return monotonicityForType(type);
} }
/// other types like Null, FixedString, Array and Tuple have no monotonicity defined /// other types like Null, FixedString, Array and Tuple have no monotonicity defined
return {};
} }
public: const Context & context;
static constexpr auto name = "CAST";
static FunctionPtr create(const Context & context) { return std::make_shared<FunctionCast>(context); }
String getName() const override { return name; }
bool useDefaultImplementationForNulls() const override { return false; }
size_t getNumberOfArguments() const override { return 2; }
void getReturnTypeAndPrerequisitesImpl(
const ColumnsWithTypeAndName & arguments, DataTypePtr & out_return_type,
std::vector<ExpressionAction> & /*out_prerequisites*/) override
{
const auto type_col = checkAndGetColumnConst<ColumnString>(arguments.back().column.get());
if (!type_col)
throw Exception("Second argument to " + getName() + " must be a constant string describing type",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
out_return_type = DataTypeFactory::instance().get(type_col->getValue<String>());
const DataTypePtr & from_type = arguments.front().type;
wrapper_function = prepare(from_type, out_return_type.get());
prepareMonotonicityInformation(from_type, out_return_type.get());
}
bool useDefaultImplementationForConstants() const override { return true; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; }
void executeImpl(Block & block, const ColumnNumbers & arguments, const size_t result) override
{
/// drop second argument, pass others
ColumnNumbers new_arguments{arguments.front()};
if (arguments.size() > 2)
new_arguments.insert(std::end(new_arguments), std::next(std::begin(arguments), 2), std::end(arguments));
wrapper_function(block, new_arguments, result);
}
bool hasInformationAboutMonotonicity() const override
{
return static_cast<bool>(monotonicity_for_range);
}
Monotonicity getMonotonicityForRange(const IDataType & type, const Field & left, const Field & right) const override
{
return monotonicity_for_range(type, left, right);
}
}; };
} }

View File

@ -83,7 +83,7 @@ void registerFunctionsDateTime(FunctionFactory & factory)
factory.registerFunction<FunctionToRelativeMinuteNum>(); factory.registerFunction<FunctionToRelativeMinuteNum>();
factory.registerFunction<FunctionToRelativeSecondNum>(); factory.registerFunction<FunctionToRelativeSecondNum>();
factory.registerFunction<FunctionToTime>(); factory.registerFunction<FunctionToTime>();
factory.registerFunction(FunctionNow::name, FunctionNow::create, FunctionFactory::CaseInsensitive); factory.registerFunction<FunctionNow>(FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionToday>(); factory.registerFunction<FunctionToday>();
factory.registerFunction<FunctionYesterday>(); factory.registerFunction<FunctionYesterday>();
factory.registerFunction<FunctionTimeSlot>(); factory.registerFunction<FunctionTimeSlot>();
@ -108,7 +108,7 @@ void registerFunctionsDateTime(FunctionFactory & factory)
factory.registerFunction<FunctionSubtractMonths>(); factory.registerFunction<FunctionSubtractMonths>();
factory.registerFunction<FunctionSubtractYears>(); factory.registerFunction<FunctionSubtractYears>();
factory.registerFunction(FunctionDateDiff::name, FunctionDateDiff::create, FunctionFactory::CaseInsensitive); factory.registerFunction<FunctionDateDiff>(FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionToTimeZone>(); factory.registerFunction<FunctionToTimeZone>();
} }

View File

@ -633,10 +633,7 @@ public:
bool isVariadic() const override { return true; } bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; } size_t getNumberOfArguments() const override { return 0; }
void getReturnTypeAndPrerequisitesImpl( DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
const ColumnsWithTypeAndName & arguments,
DataTypePtr & out_return_type,
std::vector<ExpressionAction> &) override
{ {
if (arguments.size() == 1) if (arguments.size() == 1)
{ {
@ -663,9 +660,9 @@ public:
/// For DateTime, if time zone is specified, attach it to type. /// For DateTime, if time zone is specified, attach it to type.
if (std::is_same_v<ToDataType, DataTypeDateTime>) if (std::is_same_v<ToDataType, DataTypeDateTime>)
out_return_type = std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments(arguments, 1, 0)); return std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments(arguments, 1, 0));
else else
out_return_type = std::make_shared<ToDataType>(); return std::make_shared<ToDataType>();
} }
bool useDefaultImplementationForConstants() const override { return true; } bool useDefaultImplementationForConstants() const override { return true; }
@ -942,10 +939,7 @@ public:
bool isVariadic() const override { return true; } bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; } size_t getNumberOfArguments() const override { return 0; }
void getReturnTypeAndPrerequisitesImpl( DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
const ColumnsWithTypeAndName & arguments,
DataTypePtr & out_return_type,
std::vector<ExpressionAction> &) override
{ {
if (arguments.size() != 2 && arguments.size() != 3) if (arguments.size() != 2 && arguments.size() != 3)
throw Exception("Number of arguments for function " + getName() + " doesn't match: passed " throw Exception("Number of arguments for function " + getName() + " doesn't match: passed "
@ -979,16 +973,16 @@ public:
if (checkDataType<DataTypeDate>(arguments[0].type.get())) if (checkDataType<DataTypeDate>(arguments[0].type.get()))
{ {
if (std::is_same_v<decltype(Transform::execute(DataTypeDate::FieldType(), 0, std::declval<DateLUTImpl>())), UInt16>) if (std::is_same_v<decltype(Transform::execute(DataTypeDate::FieldType(), 0, std::declval<DateLUTImpl>())), UInt16>)
out_return_type = std::make_shared<DataTypeDate>(); return std::make_shared<DataTypeDate>();
else else
out_return_type = std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments(arguments, 2, 0)); return std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments(arguments, 2, 0));
} }
else else
{ {
if (std::is_same_v<decltype(Transform::execute(DataTypeDateTime::FieldType(), 0, std::declval<DateLUTImpl>())), UInt16>) if (std::is_same_v<decltype(Transform::execute(DataTypeDateTime::FieldType(), 0, std::declval<DateLUTImpl>())), UInt16>)
out_return_type = std::make_shared<DataTypeDate>(); return std::make_shared<DataTypeDate>();
else else
out_return_type = std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments(arguments, 2, 0)); return std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments(arguments, 2, 0));
} }
} }
@ -1291,10 +1285,7 @@ public:
size_t getNumberOfArguments() const override { return 2; } size_t getNumberOfArguments() const override { return 2; }
void getReturnTypeAndPrerequisitesImpl( DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
const ColumnsWithTypeAndName & arguments,
DataTypePtr & out_return_type,
std::vector<ExpressionAction> &) override
{ {
if (arguments.size() != 2) if (arguments.size() != 2)
throw Exception("Number of arguments for function " + getName() + " doesn't match: passed " throw Exception("Number of arguments for function " + getName() + " doesn't match: passed "
@ -1307,7 +1298,7 @@ public:
". Should be DateTime", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; ". Should be DateTime", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
String time_zone_name = extractTimeZoneNameFromFunctionArguments(arguments, 1, 0); String time_zone_name = extractTimeZoneNameFromFunctionArguments(arguments, 1, 0);
out_return_type = std::make_shared<DataTypeDateTime>(time_zone_name); return std::make_shared<DataTypeDateTime>(time_zone_name);
} }
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override

View File

@ -1,13 +1,11 @@
#pragma once #pragma once
#include <DataTypes/DataTypeArray.h> #include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeExpression.h>
#include <DataTypes/DataTypesNumber.h> #include <DataTypes/DataTypesNumber.h>
#include <Interpreters/ExpressionActions.h> #include <Interpreters/ExpressionActions.h>
#include <Columns/ColumnsNumber.h> #include <Columns/ColumnsNumber.h>
#include <Columns/ColumnArray.h> #include <Columns/ColumnArray.h>
#include <Columns/ColumnConst.h> #include <Columns/ColumnConst.h>
#include <Columns/ColumnExpression.h>
#include <Common/typeid_cast.h> #include <Common/typeid_cast.h>
#include <Functions/IFunction.h> #include <Functions/IFunction.h>
#include <Functions/FunctionsMiscellaneous.h> #include <Functions/FunctionsMiscellaneous.h>
@ -635,15 +633,16 @@ public:
nested_types[i] = array_type->getNestedType(); nested_types[i] = array_type->getNestedType();
} }
const DataTypeExpression * expression_type = checkAndGetDataType<DataTypeExpression>(&*arguments[0]); const DataTypeFunction * function_type = checkAndGetDataType<DataTypeFunction>(&*arguments[0]);
if (!expression_type || expression_type->getArgumentTypes().size() != nested_types.size()) if (!function_type || function_type->getArgumentTypes().size() != nested_types.size())
throw Exception("First argument for this overload of " + getName() + " must be an expression with " throw Exception("First argument for this overload of " + getName() + " must be a function with "
+ toString(nested_types.size()) + " arguments. Found " + toString(nested_types.size()) + " arguments. Found "
+ arguments[0]->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + arguments[0]->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
arguments[0] = std::make_shared<DataTypeExpression>(nested_types); arguments[0] = std::make_shared<DataTypeFunction>(nested_types);
} }
/*
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{ {
size_t min_args = Impl::needExpression() ? 2 : 1; size_t min_args = Impl::needExpression() ? 2 : 1;
@ -693,10 +692,9 @@ public:
return Impl::getReturnType(return_type, first_array_type->getNestedType()); return Impl::getReturnType(return_type, first_array_type->getNestedType());
} }
} }
*/
void getReturnTypeAndPrerequisitesImpl(const ColumnsWithTypeAndName & arguments, DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
DataTypePtr & out_return_type,
ExpressionActions::Actions & out_prerequisites) override
{ {
size_t min_args = Impl::needExpression() ? 2 : 1; size_t min_args = Impl::needExpression() ? 2 : 1;
if (arguments.size() < min_args) if (arguments.size() < min_args)
@ -707,7 +705,7 @@ public:
if (arguments.size() == 1) if (arguments.size() == 1)
{ {
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(&*arguments[0].type); const auto array_type = checkAndGetDataType<DataTypeArray>(&*arguments[0].type);
if (!array_type) if (!array_type)
throw Exception("The only argument for function " + getName() + " must be array. Found " throw Exception("The only argument for function " + getName() + " must be array. Found "
@ -719,7 +717,7 @@ public:
throw Exception("The only argument for function " + getName() + " must be array of UInt8. Found " throw Exception("The only argument for function " + getName() + " must be array of UInt8. Found "
+ arguments[0].type->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + arguments[0].type->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
out_return_type = Impl::getReturnType(nested_type, nested_type); return Impl::getReturnType(nested_type, nested_type);
} }
else else
{ {
@ -727,59 +725,31 @@ public:
throw Exception("Function " + getName() + " needs one array argument.", throw Exception("Function " + getName() + " needs one array argument.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (!arguments[0].column) const auto data_type_function = checkAndGetDataType<DataTypeFunction>(arguments[0].type.get());
throw Exception("Type of first argument for function " + getName() + " must be an expression.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const ColumnExpression * column_expression = typeid_cast<const ColumnExpression *>(arguments[0].column.get()); if (!data_type_function)
throw Exception("First argument for function " + getName() + " must be a function.",
if (!column_expression)
throw Exception("Column of first argument for function " + getName() + " must be an expression.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
/// The types of the remaining arguments are already checked in getLambdaArgumentTypes. /// The types of the remaining arguments are already checked in getLambdaArgumentTypes.
/// Let's add to the block all the columns mentioned in the expression, multiplied into an array parallel to the one being processed. DataTypePtr return_type = data_type_function->getReturnType();
const ExpressionActions & expression = *column_expression->getExpression();
const NamesAndTypesList & required_columns = expression.getRequiredColumnsWithTypes();
const NamesAndTypesList expression_arguments = column_expression->getArguments();
NameSet argument_names;
for (const auto & expression_argument : expression_arguments)
argument_names.emplace(expression_argument.name);
for (const auto & required_column : required_columns)
{
if (argument_names.count(required_column.name))
continue;
Names replicate_arguments;
replicate_arguments.push_back(required_column.name);
replicate_arguments.push_back(arguments[1].name);
out_prerequisites.push_back(ExpressionAction::applyFunction(std::make_shared<FunctionReplicate>(), replicate_arguments));
}
DataTypePtr return_type = column_expression->getReturnType();
if (Impl::needBoolean() && !checkDataType<DataTypeUInt8>(&*return_type)) if (Impl::needBoolean() && !checkDataType<DataTypeUInt8>(&*return_type))
throw Exception("Expression for function " + getName() + " must return UInt8, found " throw Exception("Expression for function " + getName() + " must return UInt8, found "
+ return_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + return_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const DataTypeArray * first_array_type = checkAndGetDataType<DataTypeArray>(&*arguments[1].type); const auto first_array_type = checkAndGetDataType<DataTypeArray>(&*arguments[1].type);
out_return_type = Impl::getReturnType(return_type, first_array_type->getNestedType()); return Impl::getReturnType(return_type, first_array_type->getNestedType());
} }
} }
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override
{
executeImpl(block, arguments, {}, result);
}
void executeImpl(Block & block, const ColumnNumbers & arguments, const ColumnNumbers & prerequisites, size_t result) override
{ {
if (arguments.size() == 1) if (arguments.size() == 1)
{ {
ColumnPtr column_array_ptr = block.getByPosition(arguments[0]).column; ColumnPtr column_array_ptr = block.getByPosition(arguments[0]).column;
const ColumnArray * column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get()); const auto * column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
if (!column_array) if (!column_array)
{ {
@ -797,28 +767,28 @@ public:
const auto & column_with_type_and_name = block.getByPosition(arguments[0]); const auto & column_with_type_and_name = block.getByPosition(arguments[0]);
if (!column_with_type_and_name.column) if (!column_with_type_and_name.column)
throw Exception("First argument for function " + getName() + " must be an expression.", throw Exception("First argument for function " + getName() + " must be a function.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const ColumnExpression * column_expression = typeid_cast<const ColumnExpression *>(column_with_type_and_name.column.get()); const auto * column_function = typeid_cast<const ColumnFunction *>(column_with_type_and_name.column.get());
ColumnPtr offsets_column; ColumnPtr offsets_column;
Block temp_block;
const ExpressionActions & expression = *column_expression->getExpression();
const NamesAndTypesList & expression_arguments = column_expression->getArguments();
NameSet argument_names;
ColumnPtr column_first_array_ptr; ColumnPtr column_first_array_ptr;
const ColumnArray * column_first_array = nullptr; const ColumnArray * column_first_array = nullptr;
/// Put the expression arguments in the block. ColumnsWithTypeAndName arrays;
arrays.reserve(arguments.size() - 1);
size_t i = 0; for (size_t i = 1; i < arguments.size(); ++i)
for (const auto expression_argument : expression_arguments)
{ {
ColumnPtr column_array_ptr = block.getByPosition(arguments[i + 1]).column; const auto & array_with_type_and_name = block.getByPosition(arguments[i]);
const ColumnArray * column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
ColumnPtr column_array_ptr = array_with_type_and_name.column;
const auto * column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
const DataTypePtr & array_type_ptr = array_with_type_and_name.type;
const auto * array_type = checkAndGetDataType<DataTypeArray>(array_type_ptr.get());
if (!column_array) if (!column_array)
{ {
@ -829,6 +799,9 @@ public:
column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get()); column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
} }
if (!array_type)
throw Exception("Expected array type, found " + array_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!offsets_column) if (!offsets_column)
{ {
offsets_column = column_array->getOffsetsPtr(); offsets_column = column_array->getOffsetsPtr();
@ -841,46 +814,23 @@ public:
throw Exception("Arrays passed to " + getName() + " must have equal size", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH); throw Exception("Arrays passed to " + getName() + " must have equal size", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH);
} }
if (i == 0) if (i == 1)
{ {
column_first_array_ptr = column_array_ptr; column_first_array_ptr = column_array_ptr;
column_first_array = column_array; column_first_array = column_array;
} }
temp_block.insert({ arrays.emplace_back(ColumnWithTypeAndName(column_array->getDataPtr(),
column_array->getDataPtr(), array_type->getNestedType(), array_with_type_and_name.name));
expression_argument.type,
expression_argument.name});
argument_names.insert(expression_argument.name);
++i;
} }
/// Put all the necessary columns multiplied by the sizes of arrays into the block. /// Put all the necessary columns multiplied by the sizes of arrays into the block.
auto replicated_column_function_ptr = column_function->replicate(column_first_array->getOffsets());
auto * replicated_column_function = typeid_cast<ColumnFunction *>(replicated_column_function_ptr.get());
replicated_column_function->appendArguments(arrays);
Names required_columns = expression.getRequiredColumns(); block.getByPosition(result).column = Impl::execute(*column_first_array,
size_t prerequisite_index = 0; replicated_column_function->reduce().column);
for (size_t i = 0; i < required_columns.size(); ++i)
{
const String & name = required_columns[i];
if (argument_names.count(name))
continue;
ColumnWithTypeAndName replicated_column = block.getByPosition(prerequisites[prerequisite_index]);
replicated_column.name = name;
replicated_column.column = typeid_cast<const ColumnArray &>(*replicated_column.column).getDataPtr();
replicated_column.type = typeid_cast<const DataTypeArray &>(*replicated_column.type).getNestedType(),
temp_block.insert(std::move(replicated_column));
++prerequisite_index;
}
expression.execute(temp_block);
block.getByPosition(result).column = Impl::execute(*column_first_array, temp_block.getByName(column_expression->getReturnName()).column);
} }
} }
}; };

View File

@ -1643,8 +1643,7 @@ public:
return name; return name;
} }
void getReturnTypeAndPrerequisitesImpl( DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override;
const ColumnsWithTypeAndName & arguments, DataTypePtr & out_return_type, ExpressionActions::Actions & out_prerequisites) override;
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override; void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override;
@ -1678,8 +1677,7 @@ void FunctionVisibleWidth::executeImpl(Block & block, const ColumnNumbers & argu
} }
void FunctionHasColumnInTable::getReturnTypeAndPrerequisitesImpl( DataTypePtr FunctionHasColumnInTable::getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const
const ColumnsWithTypeAndName & arguments, DataTypePtr & out_return_type, ExpressionActions::Actions & /*out_prerequisites*/)
{ {
if (arguments.size() < 3 || arguments.size() > 6) if (arguments.size() < 3 || arguments.size() > 6)
throw Exception{"Invalid number of arguments for function " + getName(), throw Exception{"Invalid number of arguments for function " + getName(),
@ -1697,7 +1695,7 @@ void FunctionHasColumnInTable::getReturnTypeAndPrerequisitesImpl(
} }
} }
out_return_type = std::make_shared<DataTypeUInt8>(); return std::make_shared<DataTypeUInt8>();
} }

View File

@ -1,13 +1,16 @@
#pragma once #pragma once
#include <Functions/IFunction.h> #include <Functions/IFunction.h>
#include <Interpreters/ExpressionActions.h>
#include <DataTypes/DataTypeFunction.h>
#include <IO/WriteBufferFromString.h>
#include <IO/Operators.h>
#include <Columns/ColumnFunction.h>
namespace DB namespace DB
{ {
/** Creates an array, multiplying the column (the first argument) by the number of elements in the array (the second argument). /** Creates an array, multiplying the column (the first argument) by the number of elements in the array (the second argument).
* Used only as prerequisites for higher-order functions.
*/ */
class FunctionReplicate : public IFunction class FunctionReplicate : public IFunction
{ {
@ -32,4 +35,181 @@ public:
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override; void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override;
}; };
/// Executes expression. Uses for lambda functions implementation. Can't be created from factory.
class FunctionExpression : public IFunctionBase, public IPreparedFunction,
public std::enable_shared_from_this<FunctionExpression>
{
public:
FunctionExpression(const ExpressionActionsPtr & expression_actions,
const DataTypes & argument_types, const Names & argument_names,
const DataTypePtr & return_type, const std::string & return_name)
: expression_actions(expression_actions), argument_types(argument_types),
argument_names(argument_names), return_type(return_type), return_name(return_name)
{
}
String getName() const override { return "FunctionExpression"; }
const DataTypes & getArgumentTypes() const override { return argument_types; }
const DataTypePtr & getReturnType() const override { return return_type; }
PreparedFunctionPtr prepare(const Block &) const override
{
return std::const_pointer_cast<FunctionExpression>(shared_from_this());
}
void execute(Block & block, const ColumnNumbers & arguments, size_t result) override
{
Block expr_block;
for (size_t i = 0; i < arguments.size(); ++i)
{
const auto & argument = block.getByPosition(arguments[i]);
/// Replace column name with value from argument_names.
expr_block.insert({argument.column, argument.type, argument_names[i]});
}
expression_actions->execute(expr_block);
block.getByPosition(result).column = expr_block.getByName(return_name).column;
}
private:
ExpressionActionsPtr expression_actions;
DataTypes argument_types;
Names argument_names;
DataTypePtr return_type;
std::string return_name;
};
/// Captures columns which are used by lambda function but not in argument list.
/// Returns ColumnFunction with captured columns.
/// For lambda(x, x + y) x is in lambda_arguments, y is in captured arguments, expression_actions is 'x + y'.
/// execute(y) returns ColumnFunction(FunctionExpression(x + y), y) with type Function(x) -> function_return_type.
class FunctionCapture : public IFunctionBase, public IPreparedFunction, public FunctionBuilderImpl,
public std::enable_shared_from_this<FunctionCapture>
{
public:
FunctionCapture(const ExpressionActionsPtr & expression_actions, const Names & captured,
const NamesAndTypesList & lambda_arguments,
const DataTypePtr & function_return_type, const std::string & expression_return_name)
: expression_actions(expression_actions), captured_names(captured), lambda_arguments(lambda_arguments)
, function_return_type(function_return_type), expression_return_name(expression_return_name)
{
const auto & all_arguments = expression_actions->getRequiredColumnsWithTypes();
std::unordered_map<std::string, DataTypePtr> arguments_map;
for (const auto & arg : all_arguments)
arguments_map[arg.name] = arg.type;
auto collect = [&arguments_map](const Names & names)
{
DataTypes types;
types.reserve(names.size());
for (const auto & name : names)
{
auto it = arguments_map.find(name);
if (it == arguments_map.end())
throw Exception("Lambda captured argument " + name + " not found in required columns.",
ErrorCodes::LOGICAL_ERROR);
types.push_back(it->second);
arguments_map.erase(it);
}
return types;
};
captured_types = collect(captured_names);
DataTypes argument_types;
argument_types.reserve(lambda_arguments.size());
for (const auto & lambda_argument : lambda_arguments)
argument_types.push_back(lambda_argument.type);
return_type = std::make_shared<DataTypeFunction>(argument_types, function_return_type);
name = "Capture[" + toString(captured_types) + "](" + toString(argument_types) + ") -> "
+ function_return_type->getName();
}
String getName() const override { return name; }
const DataTypes & getArgumentTypes() const override { return captured_types; }
const DataTypePtr & getReturnType() const override { return return_type; }
PreparedFunctionPtr prepare(const Block &) const override
{
return std::const_pointer_cast<FunctionCapture>(shared_from_this());
}
void execute(Block & block, const ColumnNumbers & arguments, size_t result) override
{
ColumnsWithTypeAndName columns;
columns.reserve(arguments.size());
Names names;
DataTypes types;
names.reserve(captured_names.size() + lambda_arguments.size());
names.insert(names.end(), captured_names.begin(), captured_names.end());
types.reserve(captured_types.size() + lambda_arguments.size());
types.insert(types.end(), captured_types.begin(), captured_types.end());
for (const auto & lambda_argument : lambda_arguments)
{
names.push_back(lambda_argument.name);
types.push_back(lambda_argument.type);
}
for (const auto & argument : arguments)
columns.push_back(block.getByPosition(argument));
auto function = std::make_shared<FunctionExpression>(expression_actions, types, names,
function_return_type, expression_return_name);
auto size = block.rows();
block.getByPosition(result).column = ColumnFunction::create(size, std::move(function), columns);
}
size_t getNumberOfArguments() const override { return captured_types.size(); }
protected:
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const override { return return_type; }
bool useDefaultImplementationForNulls() const override { return false; }
FunctionBasePtr buildImpl(const ColumnsWithTypeAndName &, const DataTypePtr &) const override
{
return std::const_pointer_cast<FunctionCapture>(shared_from_this());
}
private:
std::string toString(const DataTypes & data_types) const
{
std::string result;
{
WriteBufferFromString buffer(result);
bool first = true;
for (const auto & type : data_types)
{
if (!first)
buffer << ", ";
first = false;
buffer << type->getName();
}
}
return result;
}
ExpressionActionsPtr expression_actions;
DataTypes captured_types;
Names captured_names;
NamesAndTypesList lambda_arguments;
DataTypePtr function_return_type;
DataTypePtr return_type;
std::string expression_return_name;
std::string name;
};
} }

View File

@ -10,14 +10,14 @@ void registerFunctionsRound(FunctionFactory & factory)
factory.registerFunction<FunctionRoundDuration>(); factory.registerFunction<FunctionRoundDuration>();
factory.registerFunction<FunctionRoundAge>(); factory.registerFunction<FunctionRoundAge>();
factory.registerFunction("round", FunctionRound::create, FunctionFactory::CaseInsensitive); factory.registerFunction<FunctionRound>("round", FunctionFactory::CaseInsensitive);
factory.registerFunction("floor", FunctionFloor::create, FunctionFactory::CaseInsensitive); factory.registerFunction<FunctionFloor>("floor", FunctionFactory::CaseInsensitive);
factory.registerFunction("ceil", FunctionCeil::create, FunctionFactory::CaseInsensitive); factory.registerFunction<FunctionCeil>("ceil", FunctionFactory::CaseInsensitive);
factory.registerFunction("trunc", FunctionTrunc::create, FunctionFactory::CaseInsensitive); factory.registerFunction<FunctionTrunc>("trunc", FunctionFactory::CaseInsensitive);
/// Compatibility aliases. /// Compatibility aliases.
factory.registerFunction("ceiling", FunctionCeil::create, FunctionFactory::CaseInsensitive); factory.registerFunction<FunctionCeil>("ceiling", FunctionFactory::CaseInsensitive);
factory.registerFunction("truncate", FunctionTrunc::create, FunctionFactory::CaseInsensitive); factory.registerFunction<FunctionTrunc>("truncate", FunctionFactory::CaseInsensitive);
} }
} }

View File

@ -118,8 +118,7 @@ public:
return {1}; return {1};
} }
void getReturnTypeAndPrerequisitesImpl( DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
const ColumnsWithTypeAndName & arguments, DataTypePtr & out_return_type, ExpressionActions::Actions & /*out_prerequisites*/) override
{ {
size_t count_arrays = 0; size_t count_arrays = 0;
@ -135,10 +134,12 @@ public:
throw Exception("First argument for function " + getName() + " must be tuple or array of tuple.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); throw Exception("First argument for function " + getName() + " must be tuple or array of tuple.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
size_t index = getElementNum(arguments[1].column, *tuple); size_t index = getElementNum(arguments[1].column, *tuple);
out_return_type = tuple->getElements()[index]; DataTypePtr out_return_type = tuple->getElements()[index];
for (; count_arrays; --count_arrays) for (; count_arrays; --count_arrays)
out_return_type = std::make_shared<DataTypeArray>(out_return_type); out_return_type = std::make_shared<DataTypeArray>(out_return_type);
return out_return_type;
} }
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override
@ -174,7 +175,7 @@ public:
} }
private: private:
size_t getElementNum(const ColumnPtr & index_column, const DataTypeTuple & tuple) size_t getElementNum(const ColumnPtr & index_column, const DataTypeTuple & tuple) const
{ {
if (auto index_col = checkAndGetColumnConst<ColumnUInt8>(index_column.get())) if (auto index_col = checkAndGetColumnConst<ColumnUInt8>(index_column.get()))
{ {

View File

@ -15,7 +15,7 @@ namespace DB
namespace ErrorCodes namespace ErrorCodes
{ {
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
} }
namespace namespace
@ -86,15 +86,15 @@ ColumnPtr wrapInNullable(const ColumnPtr & src, Block & block, const ColumnNumbe
} }
struct NullPresense struct NullPresence
{ {
bool has_nullable = false; bool has_nullable = false;
bool has_null_constant = false; bool has_null_constant = false;
}; };
NullPresense getNullPresense(const Block & block, const ColumnNumbers & args) NullPresence getNullPresense(const Block & block, const ColumnNumbers & args)
{ {
NullPresense res; NullPresence res;
for (const auto & arg : args) for (const auto & arg : args)
{ {
@ -109,9 +109,9 @@ NullPresense getNullPresense(const Block & block, const ColumnNumbers & args)
return res; return res;
} }
NullPresense getNullPresense(const ColumnsWithTypeAndName & args) NullPresence getNullPresense(const ColumnsWithTypeAndName & args)
{ {
NullPresense res; NullPresence res;
for (const auto & elem : args) for (const auto & elem : args)
{ {
@ -124,43 +124,6 @@ NullPresense getNullPresense(const ColumnsWithTypeAndName & args)
return res; return res;
} }
NullPresense getNullPresense(const DataTypes & types)
{
NullPresense res;
for (const auto & type : types)
{
if (!res.has_nullable)
res.has_nullable = type->isNullable();
if (!res.has_null_constant)
res.has_null_constant = type->onlyNull();
}
return res;
}
/// Turn the specified set of data types into their respective nested data types.
DataTypes toNestedDataTypes(const DataTypes & args)
{
DataTypes new_args;
new_args.reserve(args.size());
for (const auto & arg : args)
{
if (arg->isNullable())
{
auto nullable_type = static_cast<const DataTypeNullable *>(arg.get());
const DataTypePtr & nested_type = nullable_type->getNestedType();
new_args.push_back(nested_type);
}
else
new_args.push_back(arg);
}
return new_args;
}
bool allArgumentsAreConstants(const Block & block, const ColumnNumbers & args) bool allArgumentsAreConstants(const Block & block, const ColumnNumbers & args)
{ {
for (auto arg : args) for (auto arg : args)
@ -168,14 +131,14 @@ bool allArgumentsAreConstants(const Block & block, const ColumnNumbers & args)
return false; return false;
return true; return true;
} }
}
bool defaultImplementationForConstantArguments( bool PreparedFunctionImpl::defaultImplementationForConstantArguments(Block & block, const ColumnNumbers & args, size_t result)
IFunction & func, Block & block, const ColumnNumbers & args, size_t result)
{ {
if (args.empty() || !func.useDefaultImplementationForConstants() || !allArgumentsAreConstants(block, args)) if (args.empty() || !useDefaultImplementationForConstants() || !allArgumentsAreConstants(block, args))
return false; return false;
ColumnNumbers arguments_to_remain_constants = func.getArgumentsThatAreAlwaysConstant(); ColumnNumbers arguments_to_remain_constants = getArgumentsThatAreAlwaysConstant();
Block temporary_block; Block temporary_block;
bool have_converted_columns = false; bool have_converted_columns = false;
@ -198,8 +161,8 @@ bool defaultImplementationForConstantArguments(
* not in "arguments_to_remain_constants" set. Otherwise we get infinite recursion. * not in "arguments_to_remain_constants" set. Otherwise we get infinite recursion.
*/ */
if (!have_converted_columns) if (!have_converted_columns)
throw Exception("Number of arguments for function " + func.getName() + " doesn't match: the function requires more arguments", throw Exception("Number of arguments for function " + getName() + " doesn't match: the function requires more arguments",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
temporary_block.insert(block.getByPosition(result)); temporary_block.insert(block.getByPosition(result));
@ -207,31 +170,30 @@ bool defaultImplementationForConstantArguments(
for (size_t i = 0; i < arguments_size; ++i) for (size_t i = 0; i < arguments_size; ++i)
temporary_argument_numbers[i] = i; temporary_argument_numbers[i] = i;
func.execute(temporary_block, temporary_argument_numbers, arguments_size); execute(temporary_block, temporary_argument_numbers, arguments_size);
block.getByPosition(result).column = ColumnConst::create(temporary_block.getByPosition(arguments_size).column, block.rows()); block.getByPosition(result).column = ColumnConst::create(temporary_block.getByPosition(arguments_size).column, block.rows());
return true; return true;
} }
bool defaultImplementationForNulls( bool PreparedFunctionImpl::defaultImplementationForNulls(Block & block, const ColumnNumbers & args, size_t result)
IFunction & func, Block & block, const ColumnNumbers & args, size_t result)
{ {
if (args.empty() || !func.useDefaultImplementationForNulls()) if (args.empty() || !useDefaultImplementationForNulls())
return false; return false;
NullPresense null_presense = getNullPresense(block, args); NullPresence null_presence = getNullPresense(block, args);
if (null_presense.has_null_constant) if (null_presence.has_null_constant)
{ {
block.getByPosition(result).column = block.getByPosition(result).type->createColumnConst(block.rows(), Null()); block.getByPosition(result).column = block.getByPosition(result).type->createColumnConst(block.rows(), Null());
return true; return true;
} }
if (null_presense.has_nullable) if (null_presence.has_nullable)
{ {
Block temporary_block = createBlockWithNestedColumns(block, args, result); Block temporary_block = createBlockWithNestedColumns(block, args, result);
func.execute(temporary_block, args, result); execute(temporary_block, args, result);
block.getByPosition(result).column = wrapInNullable(temporary_block.getByPosition(result).column, block, args, result); block.getByPosition(result).column = wrapInNullable(temporary_block.getByPosition(result).column, block, args, result);
return true; return true;
} }
@ -239,10 +201,18 @@ bool defaultImplementationForNulls(
return false; return false;
} }
void PreparedFunctionImpl::execute(Block & block, const ColumnNumbers & args, size_t result)
{
if (defaultImplementationForConstantArguments(block, args, result))
return;
if (defaultImplementationForNulls(block, args, result))
return;
executeImpl(block, args, result);
} }
void FunctionBuilderImpl::checkNumberOfArguments(size_t number_of_arguments) const
void IFunction::checkNumberOfArguments(size_t number_of_arguments) const
{ {
if (isVariadic()) if (isVariadic())
return; return;
@ -251,90 +221,31 @@ void IFunction::checkNumberOfArguments(size_t number_of_arguments) const
if (number_of_arguments != expected_number_of_arguments) if (number_of_arguments != expected_number_of_arguments)
throw Exception("Number of arguments for function " + getName() + " doesn't match: passed " throw Exception("Number of arguments for function " + getName() + " doesn't match: passed "
+ toString(number_of_arguments) + ", should be " + toString(expected_number_of_arguments), + toString(number_of_arguments) + ", should be " + toString(expected_number_of_arguments),
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
} }
DataTypePtr FunctionBuilderImpl::getReturnType(const ColumnsWithTypeAndName & arguments) const
DataTypePtr IFunction::getReturnType(const DataTypes & arguments) const
{ {
checkNumberOfArguments(arguments.size()); checkNumberOfArguments(arguments.size());
if (!arguments.empty() && useDefaultImplementationForNulls()) if (!arguments.empty() && useDefaultImplementationForNulls())
{ {
NullPresense null_presense = getNullPresense(arguments); NullPresence null_presense = getNullPresense(arguments);
if (null_presense.has_null_constant) if (null_presense.has_null_constant)
{ {
return makeNullable(std::make_shared<DataTypeNothing>()); return makeNullable(std::make_shared<DataTypeNothing>());
} }
if (null_presense.has_nullable) if (null_presense.has_nullable)
{ {
return makeNullable(getReturnTypeImpl(toNestedDataTypes(arguments))); Block nested_block = createBlockWithNestedColumns(Block(arguments), ext::collection_cast<ColumnNumbers>(ext::range(0, arguments.size())));
auto return_type = getReturnTypeImpl(ColumnsWithTypeAndName(nested_block.begin(), nested_block.end()));
return makeNullable(return_type);
} }
} }
return getReturnTypeImpl(arguments); return getReturnTypeImpl(arguments);
} }
void IFunction::getReturnTypeAndPrerequisites(
const ColumnsWithTypeAndName & arguments,
DataTypePtr & out_return_type,
std::vector<ExpressionAction> & out_prerequisites)
{
checkNumberOfArguments(arguments.size());
if (!arguments.empty() && useDefaultImplementationForNulls())
{
NullPresense null_presense = getNullPresense(arguments);
if (null_presense.has_null_constant)
{
out_return_type = makeNullable(std::make_shared<DataTypeNothing>());
return;
}
if (null_presense.has_nullable)
{
Block nested_block = createBlockWithNestedColumns(Block(arguments), ext::collection_cast<ColumnNumbers>(ext::range(0, arguments.size())));
getReturnTypeAndPrerequisitesImpl(ColumnsWithTypeAndName(nested_block.begin(), nested_block.end()), out_return_type, out_prerequisites);
out_return_type = makeNullable(out_return_type);
return;
}
}
getReturnTypeAndPrerequisitesImpl(arguments, out_return_type, out_prerequisites);
} }
void IFunction::getLambdaArgumentTypes(DataTypes & arguments) const
{
checkNumberOfArguments(arguments.size());
getLambdaArgumentTypesImpl(arguments);
}
void IFunction::execute(Block & block, const ColumnNumbers & args, size_t result)
{
if (defaultImplementationForConstantArguments(*this, block, args, result))
return;
if (defaultImplementationForNulls(*this, block, args, result))
return;
executeImpl(block, args, result);
}
void IFunction::execute(Block & block, const ColumnNumbers & args, const ColumnNumbers & prerequisites, size_t result)
{
if (!prerequisites.empty())
{
executeImpl(block, args, prerequisites, result);
return;
}
execute(block, args, result);
}
}

View File

@ -16,45 +16,79 @@ namespace ErrorCodes
{ {
extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NOT_IMPLEMENTED; extern const int NOT_IMPLEMENTED;
extern const int LOGICAL_ERROR;
} }
struct ExpressionAction; /// The simplest executable object.
/// Motivation:
/// * Prepare something heavy once before main execution loop instead of doing it for each block.
/** Interface for normal functions. /// * Provide const interface for IFunctionBase (later).
* Normal functions are functions that do not change the number of rows in the table, class IPreparedFunction
* and the result of which for each row does not depend on other rows.
*
* A function can take an arbitrary number of arguments; returns exactly one value.
* The type of the result depends on the type and number of arguments.
*
* The function is dispatched for the whole block. This allows you to perform all kinds of checks rarely,
* and do the main job as an efficient loop.
*
* The function is applied to one or more columns of the block, and writes its result,
* adding a new column to the block. The function does not modify its arguments.
*/
class IFunction
{ {
public: public:
/** The successor of IFunction must implement: virtual ~IPreparedFunction() = default;
* - getName
* - either getReturnType, or getReturnTypeAndPrerequisites
* - one of the overloads of `execute`.
*/
/// Get the main function name. /// Get the main function name.
virtual String getName() const = 0; virtual String getName() const = 0;
/// Override and return true if function could take different number of arguments. virtual void execute(Block & block, const ColumnNumbers & arguments, size_t result) = 0;
virtual bool isVariadic() const { return false; } };
/// For non-variadic functions, return number of arguments; otherwise return zero (that should be ignored). using PreparedFunctionPtr = std::shared_ptr<IPreparedFunction>;
virtual size_t getNumberOfArguments() const = 0;
/// Throw if number of arguments is incorrect. Default implementation will check only in non-variadic case. class PreparedFunctionImpl : public IPreparedFunction
/// It is called inside getReturnType. {
virtual void checkNumberOfArguments(size_t number_of_arguments) const; public:
void execute(Block & block, const ColumnNumbers & arguments, size_t result) final;
protected:
virtual void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) = 0;
/** Default implementation in presence of Nullable arguments or NULL constants as arguments is the following:
* if some of arguments are NULL constants then return NULL constant,
* if some of arguments are Nullable, then execute function as usual for block,
* where Nullable columns are substituted with nested columns (they have arbitrary values in rows corresponding to NULL value)
* and wrap result in Nullable column where NULLs are in all rows where any of arguments are NULL.
*/
virtual bool useDefaultImplementationForNulls() const { return true; }
/** If the function have non-zero number of arguments,
* and if all arguments are constant, that we could automatically provide default implementation:
* arguments are converted to ordinary columns with single value, then function is executed as usual,
* and then the result is converted to constant column.
*/
virtual bool useDefaultImplementationForConstants() const { return false; }
/** Some arguments could remain constant during this implementation.
*/
virtual ColumnNumbers getArgumentsThatAreAlwaysConstant() const { return {}; }
private:
bool defaultImplementationForNulls(Block & block, const ColumnNumbers & args, size_t result);
bool defaultImplementationForConstantArguments(Block & block, const ColumnNumbers & args, size_t result);
};
/// Function with known arguments and return type.
class IFunctionBase
{
public:
virtual ~IFunctionBase() = default;
/// Get the main function name.
virtual String getName() const = 0;
virtual const DataTypes & getArgumentTypes() const = 0;
virtual const DataTypePtr & getReturnType() const = 0;
/// Do preparations and return executable.
/// sample_block should contain data types of arguments and values of constants, if relevant.
virtual PreparedFunctionPtr prepare(const Block & sample_block) const = 0;
/// TODO: make const
virtual void execute(Block & block, const ColumnNumbers & arguments, size_t result)
{
return prepare(block)->execute(block, arguments, result);
}
/** Should we evaluate this function while constant folding, if arguments are constants? /** Should we evaluate this function while constant folding, if arguments are constants?
* Usually this is true. Notable counterexample is function 'sleep'. * Usually this is true. Notable counterexample is function 'sleep'.
@ -94,85 +128,6 @@ public:
*/ */
virtual bool isDeterministicInScopeOfQuery() { return true; } virtual bool isDeterministicInScopeOfQuery() { return true; }
/// Get the result type by argument type. If the function does not apply to these arguments, throw an exception.
/// Overloading for those who do not need prerequisites and values of constant arguments. Not called from outside.
DataTypePtr getReturnType(const DataTypes & arguments) const;
virtual DataTypePtr getReturnTypeImpl(const DataTypes & /*arguments*/) const
{
throw Exception("getReturnType is not implemented for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
/** Get the result type by argument types and constant argument values.
* If the function does not apply to these arguments, throw an exception.
* You can also return a description of the additional columns that are required to perform the function.
* For non-constant columns `arguments[i].column = nullptr`.
* Meaningful element types in out_prerequisites: APPLY_FUNCTION, ADD_COLUMN.
*/
void getReturnTypeAndPrerequisites(
const ColumnsWithTypeAndName & arguments,
DataTypePtr & out_return_type,
std::vector<ExpressionAction> & out_prerequisites);
virtual void getReturnTypeAndPrerequisitesImpl(
const ColumnsWithTypeAndName & arguments,
DataTypePtr & out_return_type,
std::vector<ExpressionAction> & /*out_prerequisites*/)
{
DataTypes types(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i)
types[i] = arguments[i].type;
out_return_type = getReturnTypeImpl(types);
}
/// For higher-order functions (functions, that have lambda expression as at least one argument).
/// You pass data types with empty DataTypeExpression for lambda arguments.
/// This function will replace it with DataTypeExpression containing actual types.
void getLambdaArgumentTypes(DataTypes & arguments) const;
virtual void getLambdaArgumentTypesImpl(DataTypes & /*arguments*/) const
{
throw Exception("Function " + getName() + " can't have lambda-expressions as arguments", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
/// Execute the function on the block. Note: can be called simultaneously from several threads, for one object.
/// Overloading for those who do not need `prerequisites`. Not called from outside.
void execute(Block & block, const ColumnNumbers & arguments, size_t result);
/// Execute the function above the block. Note: can be called simultaneously from several threads, for one object.
/// `prerequisites` go in the same order as `out_prerequisites` obtained from getReturnTypeAndPrerequisites.
void execute(Block & block, const ColumnNumbers & arguments, const ColumnNumbers & prerequisites, size_t result);
virtual void executeImpl(Block & /*block*/, const ColumnNumbers & /*arguments*/, size_t /*result*/)
{
throw Exception("executeImpl is not implemented for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
virtual void executeImpl(Block & block, const ColumnNumbers & arguments, const ColumnNumbers & /*prerequisites*/, size_t result)
{
executeImpl(block, arguments, result);
}
/** Default implementation in presense of Nullable arguments or NULL constants as arguments is the following:
* if some of arguments are NULL constants then return NULL constant,
* if some of arguments are Nullable, then execute function as usual for block,
* where Nullable columns are substituted with nested columns (they have arbitary values in rows corresponding to NULL value)
* and wrap result in Nullable column where NULLs are in all rows where any of arguments are NULL.
*/
virtual bool useDefaultImplementationForNulls() const { return true; }
/** If the function have non-zero number of arguments,
* and if all arguments are constant, that we could automatically provide default implementation:
* arguments are converted to ordinary columns with single value, then function is executed as usual,
* and then the result is converted to constant column.
*/
virtual bool useDefaultImplementationForConstants() const { return false; }
/** Some arguments could remain constant during this implementation.
*/
virtual ColumnNumbers getArgumentsThatAreAlwaysConstant() const { return {}; }
/** Lets you know if the function is monotonic in a range of values. /** Lets you know if the function is monotonic in a range of values.
* This is used to work with the index in a sorted chunk of data. * This is used to work with the index in a sorted chunk of data.
* And allows to use the index not only when it is written, for example `date >= const`, but also, for example, `toMonth(date) >= 11`. * And allows to use the index not only when it is written, for example `date >= const`, but also, for example, `toMonth(date) >= 11`.
@ -188,7 +143,7 @@ public:
bool is_always_monotonic = false; /// Is true if function is monotonic on the whole input range I bool is_always_monotonic = false; /// Is true if function is monotonic on the whole input range I
Monotonicity(bool is_monotonic_ = false, bool is_positive_ = true, bool is_always_monotonic_ = false) Monotonicity(bool is_monotonic_ = false, bool is_positive_ = true, bool is_always_monotonic_ = false)
: is_monotonic(is_monotonic_), is_positive(is_positive_), is_always_monotonic(is_always_monotonic_) {} : is_monotonic(is_monotonic_), is_positive(is_positive_), is_always_monotonic(is_always_monotonic_) {}
}; };
/** Get information about monotonicity on a range of values. Call only if hasInformationAboutMonotonicity. /** Get information about monotonicity on a range of values. Call only if hasInformationAboutMonotonicity.
@ -198,12 +153,221 @@ public:
{ {
throw Exception("Function " + getName() + " has no information about its monotonicity.", ErrorCodes::NOT_IMPLEMENTED); throw Exception("Function " + getName() + " has no information about its monotonicity.", ErrorCodes::NOT_IMPLEMENTED);
} }
virtual ~IFunction() {}
}; };
using FunctionBasePtr = std::shared_ptr<IFunctionBase>;
/// Creates IFunctionBase from argument types list.
class IFunctionBuilder
{
public:
virtual ~IFunctionBuilder() = default;
/// Get the main function name.
virtual String getName() const = 0;
/// Override and return true if function could take different number of arguments.
virtual bool isVariadic() const { return false; }
/// For non-variadic functions, return number of arguments; otherwise return zero (that should be ignored).
virtual size_t getNumberOfArguments() const = 0;
/// Throw if number of arguments is incorrect. Default implementation will check only in non-variadic case.
virtual void checkNumberOfArguments(size_t number_of_arguments) const = 0;
/// Check arguments and return IFunctionBase.
virtual FunctionBasePtr build(const ColumnsWithTypeAndName & arguments) const = 0;
/// For higher-order functions (functions, that have lambda expression as at least one argument).
/// You pass data types with empty DataTypeFunction for lambda arguments.
/// This function will replace it with DataTypeFunction containing actual types.
virtual void getLambdaArgumentTypes(DataTypes & arguments) const = 0;
};
using FunctionBuilderPtr = std::shared_ptr<IFunctionBuilder>;
class FunctionBuilderImpl : public IFunctionBuilder
{
public:
FunctionBasePtr build(const ColumnsWithTypeAndName & arguments) const final
{
return buildImpl(arguments, getReturnType(arguments));
}
/// Default implementation. Will check only in non-variadic case.
void checkNumberOfArguments(size_t number_of_arguments) const override;
DataTypePtr getReturnType(const ColumnsWithTypeAndName & arguments) const;
void getLambdaArgumentTypes(DataTypes & arguments) const override
{
checkNumberOfArguments(arguments.size());
getLambdaArgumentTypesImpl(arguments);
}
protected:
/// Get the result type by argument type. If the function does not apply to these arguments, throw an exception.
virtual DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const
{
DataTypes data_types(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i)
data_types[i] = arguments[i].type;
return getReturnTypeImpl(data_types);
}
virtual DataTypePtr getReturnTypeImpl(const DataTypes & /*arguments*/) const
{
throw Exception("getReturnType is not implemented for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
/** If useDefaultImplementationForNulls() is true, than change arguments for getReturnType() and buildImpl():
* if some of arguments are Nullable(Nothing) then don't call getReturnType(), call buildImpl() with return_type = Nullable(Nothing),
* if some of arguments are Nullable, then:
* - Nullable types are substituted with nested types for getReturnType() function
* - wrap getReturnType() result in Nullable type and pass to buildImpl
*
* Otherwise build returns buildImpl(arguments, getReturnType(arguments));
*/
virtual bool useDefaultImplementationForNulls() const { return true; }
virtual FunctionBasePtr buildImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type) const = 0;
virtual void getLambdaArgumentTypesImpl(DataTypes & /*arguments*/) const
{
throw Exception("Function " + getName() + " can't have lambda-expressions as arguments", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
};
/// Previous function interface.
class IFunction : public std::enable_shared_from_this<IFunction>,
public FunctionBuilderImpl, public IFunctionBase, public PreparedFunctionImpl
{
public:
String getName() const override = 0;
/// TODO: make const
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override = 0;
/// Override this functions to change default implementation behavior. See details in IMyFunction.
bool useDefaultImplementationForNulls() const override { return true; }
bool useDefaultImplementationForConstants() const override { return false; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {}; }
using PreparedFunctionImpl::execute;
using FunctionBuilderImpl::getReturnTypeImpl;
using FunctionBuilderImpl::getLambdaArgumentTypesImpl;
using FunctionBuilderImpl::getReturnType;
PreparedFunctionPtr prepare(const Block & /*sample_block*/) const final
{
throw Exception("prepare is not implemented for IFunction", ErrorCodes::NOT_IMPLEMENTED);
}
const DataTypes & getArgumentTypes() const final
{
throw Exception("getArgumentTypes is not implemented for IFunction", ErrorCodes::NOT_IMPLEMENTED);
}
const DataTypePtr & getReturnType() const override
{
throw Exception("getReturnType is not implemented for IFunction", ErrorCodes::NOT_IMPLEMENTED);
}
protected:
FunctionBasePtr buildImpl(const ColumnsWithTypeAndName & /*arguments*/, const DataTypePtr & /*return_type*/) const final
{
throw Exception("buildImpl is not implemented for IFunction", ErrorCodes::NOT_IMPLEMENTED);
}
};
/// Wrappers over IFunction.
class DefaultExecutable final : public PreparedFunctionImpl
{
public:
explicit DefaultExecutable(std::shared_ptr<IFunction> function) : function(std::move(function)) {}
String getName() const override { return function->getName(); }
protected:
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) final
{
return function->executeImpl(block, arguments, result);
}
bool useDefaultImplementationForNulls() const final { return function->useDefaultImplementationForNulls(); }
bool useDefaultImplementationForConstants() const final { return function->useDefaultImplementationForConstants(); }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const final { return function->getArgumentsThatAreAlwaysConstant(); }
private:
std::shared_ptr<IFunction> function;
};
class DefaultFunction final : public IFunctionBase
{
public:
DefaultFunction(std::shared_ptr<IFunction> function, DataTypes arguments, DataTypePtr return_type)
: function(std::move(function)), arguments(std::move(arguments)), return_type(std::move(return_type)) {}
String getName() const override { return function->getName(); }
const DataTypes & getArgumentTypes() const override { return arguments; }
const DataTypePtr & getReturnType() const override { return return_type; }
PreparedFunctionPtr prepare(const Block & /*sample_block*/) const override { return std::make_shared<DefaultExecutable>(function); }
bool isSuitableForConstantFolding() const override { return function->isSuitableForConstantFolding(); }
bool isInjective(const Block & sample_block) override { return function->isInjective(sample_block); }
bool isDeterministicInScopeOfQuery() override { return function->isDeterministicInScopeOfQuery(); }
bool hasInformationAboutMonotonicity() const override { return function->hasInformationAboutMonotonicity(); }
IFunctionBase::Monotonicity getMonotonicityForRange(const IDataType & type, const Field & left, const Field & right) const override
{
return function->getMonotonicityForRange(type, left, right);
}
private:
std::shared_ptr<IFunction> function;
DataTypes arguments;
DataTypePtr return_type;
};
class DefaultFunctionBuilder : public FunctionBuilderImpl
{
public:
explicit DefaultFunctionBuilder(std::shared_ptr<IFunction> function) : function(std::move(function)) {}
void checkNumberOfArguments(size_t number_of_arguments) const override
{
return function->checkNumberOfArguments(number_of_arguments);
}
String getName() const override { return function->getName(); };
bool isVariadic() const override { return function->isVariadic(); }
size_t getNumberOfArguments() const override { return function->getNumberOfArguments(); }
protected:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { return function->getReturnTypeImpl(arguments); }
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override { return function->getReturnTypeImpl(arguments); }
bool useDefaultImplementationForNulls() const override { return function->useDefaultImplementationForNulls(); }
FunctionBasePtr buildImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type) const override
{
DataTypes data_types(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i)
data_types[i] = arguments[i].type;
return std::make_shared<DefaultFunction>(function, data_types, return_type);
}
void getLambdaArgumentTypesImpl(DataTypes & arguments) const override { return function->getLambdaArgumentTypesImpl(arguments); }
private:
std::shared_ptr<IFunction> function;
};
using FunctionPtr = std::shared_ptr<IFunction>; using FunctionPtr = std::shared_ptr<IFunction>;
} }

View File

@ -36,7 +36,6 @@ Names ExpressionAction::getNeededColumns() const
{ {
Names res = argument_names; Names res = argument_names;
res.insert(res.end(), prerequisite_names.begin(), prerequisite_names.end());
res.insert(res.end(), array_joined_columns.begin(), array_joined_columns.end()); res.insert(res.end(), array_joined_columns.begin(), array_joined_columns.end());
for (const auto & column : projection) for (const auto & column : projection)
@ -49,7 +48,7 @@ Names ExpressionAction::getNeededColumns() const
} }
ExpressionAction ExpressionAction::applyFunction(const FunctionPtr & function_, ExpressionAction ExpressionAction::applyFunction(const FunctionBuilderPtr & function_,
const std::vector<std::string> & argument_names_, const std::vector<std::string> & argument_names_,
std::string result_name_) std::string result_name_)
{ {
@ -68,7 +67,7 @@ ExpressionAction ExpressionAction::applyFunction(const FunctionPtr & function_,
ExpressionAction a; ExpressionAction a;
a.type = APPLY_FUNCTION; a.type = APPLY_FUNCTION;
a.result_name = result_name_; a.result_name = result_name_;
a.function = function_; a.function_builder = function_;
a.argument_names = argument_names_; a.argument_names = argument_names_;
return a; return a;
} }
@ -128,7 +127,7 @@ ExpressionAction ExpressionAction::arrayJoin(const NameSet & array_joined_column
a.array_join_is_left = array_join_is_left; a.array_join_is_left = array_join_is_left;
if (array_join_is_left) if (array_join_is_left)
a.function = FunctionFactory::instance().get("emptyArrayToSingle", context); a.function_builder = FunctionFactory::instance().get("emptyArrayToSingle", context);
return a; return a;
} }
@ -160,13 +159,8 @@ ExpressionActions::Actions ExpressionAction::getPrerequisites(Block & sample_blo
arguments[i] = sample_block.getByName(argument_names[i]); arguments[i] = sample_block.getByName(argument_names[i]);
} }
function->getReturnTypeAndPrerequisites(arguments, result_type, res); function = function_builder->build(arguments);
result_type = function->getReturnType();
for (size_t i = 0; i < res.size(); ++i)
{
if (res[i].result_name != "")
prerequisite_names.push_back(res[i].result_name);
}
} }
return res; return res;
@ -201,15 +195,6 @@ void ExpressionAction::prepare(Block & sample_block)
all_const = false; all_const = false;
} }
ColumnNumbers prerequisites(prerequisite_names.size());
for (size_t i = 0; i < prerequisite_names.size(); ++i)
{
prerequisites[i] = sample_block.getPositionByName(prerequisite_names[i]);
ColumnPtr col = sample_block.safeGetByPosition(prerequisites[i]).column;
if (!col || !col->isColumnConst())
all_const = false;
}
ColumnPtr new_column; ColumnPtr new_column;
/// If all arguments are constants, and function is suitable to be executed in 'prepare' stage - execute function. /// If all arguments are constants, and function is suitable to be executed in 'prepare' stage - execute function.
@ -222,7 +207,7 @@ void ExpressionAction::prepare(Block & sample_block)
new_column.type = result_type; new_column.type = result_type;
sample_block.insert(std::move(new_column)); sample_block.insert(std::move(new_column));
function->execute(sample_block, arguments, prerequisites, result_position); function->execute(sample_block, arguments, result_position);
/// If the result is not a constant, just in case, we will consider the result as unknown. /// If the result is not a constant, just in case, we will consider the result as unknown.
ColumnWithTypeAndName & col = sample_block.safeGetByPosition(result_position); ColumnWithTypeAndName & col = sample_block.safeGetByPosition(result_position);
@ -343,19 +328,11 @@ void ExpressionAction::execute(Block & block) const
arguments[i] = block.getPositionByName(argument_names[i]); arguments[i] = block.getPositionByName(argument_names[i]);
} }
ColumnNumbers prerequisites(prerequisite_names.size());
for (size_t i = 0; i < prerequisite_names.size(); ++i)
{
if (!block.has(prerequisite_names[i]))
throw Exception("Not found column: '" + prerequisite_names[i] + "'", ErrorCodes::NOT_FOUND_COLUMN_IN_BLOCK);
prerequisites[i] = block.getPositionByName(prerequisite_names[i]);
}
size_t num_columns_without_result = block.columns(); size_t num_columns_without_result = block.columns();
block.insert({ nullptr, result_type, result_name}); block.insert({ nullptr, result_type, result_name});
ProfileEvents::increment(ProfileEvents::FunctionExecute); ProfileEvents::increment(ProfileEvents::FunctionExecute);
function->execute(block, arguments, prerequisites, num_columns_without_result); function->execute(block, arguments, num_columns_without_result);
break; break;
} }
@ -383,7 +360,7 @@ void ExpressionAction::execute(Block & block) const
Block tmp_block{src_col, {{}, src_col.type, {}}}; Block tmp_block{src_col, {{}, src_col.type, {}}};
function->execute(tmp_block, {0}, 1); function_builder->build({src_col})->execute(tmp_block, {0}, 1);
non_empty_array_columns[name] = tmp_block.safeGetByPosition(1).column; non_empty_array_columns[name] = tmp_block.safeGetByPosition(1).column;
} }
@ -837,6 +814,7 @@ void ExpressionActions::finalize(const Names & output_columns)
action.type = ExpressionAction::ADD_COLUMN; action.type = ExpressionAction::ADD_COLUMN;
action.result_type = result.type; action.result_type = result.type;
action.added_column = result.column; action.added_column = result.column;
action.function_builder = nullptr;
action.function = nullptr; action.function = nullptr;
action.argument_names.clear(); action.argument_names.clear();
in.clear(); in.clear();
@ -889,9 +867,6 @@ void ExpressionActions::finalize(const Names & output_columns)
for (const auto & name : action.argument_names) for (const auto & name : action.argument_names)
++columns_refcount[name]; ++columns_refcount[name];
for (const auto & name : action.prerequisite_names)
++columns_refcount[name];
for (const auto & name_alias : action.projection) for (const auto & name_alias : action.projection)
++columns_refcount[name_alias.first]; ++columns_refcount[name_alias.first];
} }
@ -920,9 +895,6 @@ void ExpressionActions::finalize(const Names & output_columns)
for (const auto & name : action.argument_names) for (const auto & name : action.argument_names)
process(name); process(name);
for (const auto & name : action.prerequisite_names)
process(name);
/// For `projection`, there is no reduction in `refcount`, because the `project` action replaces the names of the columns, in effect, already deleting them under the old names. /// For `projection`, there is no reduction in `refcount`, because the `project` action replaces the names of the columns, in effect, already deleting them under the old names.
} }

View File

@ -22,8 +22,11 @@ using NamesWithAliases = std::vector<NameWithAlias>;
class Join; class Join;
class IFunction; class IFunctionBase;
using FunctionPtr = std::shared_ptr<IFunction>; using FunctionBasePtr = std::shared_ptr<IFunctionBase>;
class IFunctionBuilder;
using FunctionBuilderPtr = std::shared_ptr<IFunctionBuilder>;
class IDataType; class IDataType;
using DataTypePtr = std::shared_ptr<const IDataType>; using DataTypePtr = std::shared_ptr<const IDataType>;
@ -68,9 +71,9 @@ public:
ColumnPtr added_column; ColumnPtr added_column;
/// For APPLY_FUNCTION and LEFT ARRAY JOIN. /// For APPLY_FUNCTION and LEFT ARRAY JOIN.
mutable FunctionPtr function; /// mutable - to allow execute. FunctionBuilderPtr function_builder;
FunctionBasePtr function;
Names argument_names; Names argument_names;
Names prerequisite_names;
/// For ARRAY_JOIN /// For ARRAY_JOIN
NameSet array_joined_columns; NameSet array_joined_columns;
@ -85,7 +88,7 @@ public:
/// If result_name_ == "", as name "function_name(arguments separated by commas) is used". /// If result_name_ == "", as name "function_name(arguments separated by commas) is used".
static ExpressionAction applyFunction( static ExpressionAction applyFunction(
const FunctionPtr & function_, const std::vector<std::string> & argument_names_, std::string result_name_ = ""); const FunctionBuilderPtr & function_, const std::vector<std::string> & argument_names_, std::string result_name_ = "");
static ExpressionAction addColumn(const ColumnWithTypeAndName & added_column_); static ExpressionAction addColumn(const ColumnWithTypeAndName & added_column_);
static ExpressionAction removeColumn(const std::string & removed_name); static ExpressionAction removeColumn(const std::string & removed_name);

View File

@ -18,12 +18,10 @@
#include <DataTypes/DataTypeArray.h> #include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeNullable.h> #include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeTuple.h> #include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeExpression.h>
#include <DataTypes/NestedUtils.h> #include <DataTypes/NestedUtils.h>
#include <DataTypes/DataTypesNumber.h> #include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnSet.h> #include <Columns/ColumnSet.h>
#include <Columns/ColumnExpression.h>
#include <Columns/ColumnConst.h> #include <Columns/ColumnConst.h>
#include <Interpreters/InterpreterSelectQuery.h> #include <Interpreters/InterpreterSelectQuery.h>
@ -58,6 +56,8 @@
#include <ext/range.h> #include <ext/range.h>
#include <DataTypes/DataTypeFactory.h> #include <DataTypes/DataTypeFactory.h>
#include <DataTypes/DataTypeFunction.h>
#include <Functions/FunctionsMiscellaneous.h>
namespace DB namespace DB
@ -1867,11 +1867,9 @@ struct ExpressionAnalyzer::ScopeStack
throw Exception("Unknown identifier: " + name, ErrorCodes::UNKNOWN_IDENTIFIER); throw Exception("Unknown identifier: " + name, ErrorCodes::UNKNOWN_IDENTIFIER);
} }
void addAction(const ExpressionAction & action, const Names & additional_required_columns = Names()) void addAction(const ExpressionAction & action)
{ {
size_t level = 0; size_t level = 0;
for (size_t i = 0; i < additional_required_columns.size(); ++i)
level = std::max(level, getColumnLevel(additional_required_columns[i]));
Names required = action.getNeededColumns(); Names required = action.getNeededColumns();
for (size_t i = 0; i < required.size(); ++i) for (size_t i = 0; i < required.size(); ++i)
level = std::max(level, getColumnLevel(required[i])); level = std::max(level, getColumnLevel(required[i]));
@ -2104,7 +2102,7 @@ void ExpressionAnalyzer::getActionsImpl(const ASTPtr & ast, bool no_subqueries,
return; return;
} }
const FunctionPtr & function = FunctionFactory::instance().get(node->name, context); const FunctionBuilderPtr & function_builder = FunctionFactory::instance().get(node->name, context);
Names argument_names; Names argument_names;
DataTypes argument_types; DataTypes argument_types;
@ -2128,7 +2126,7 @@ void ExpressionAnalyzer::getActionsImpl(const ASTPtr & ast, bool no_subqueries,
throw Exception("First argument of lambda must be a tuple", ErrorCodes::TYPE_MISMATCH); throw Exception("First argument of lambda must be a tuple", ErrorCodes::TYPE_MISMATCH);
has_lambda_arguments = true; has_lambda_arguments = true;
argument_types.emplace_back(std::make_shared<DataTypeExpression>(DataTypes(lambda_args_tuple->arguments->children.size()))); argument_types.emplace_back(std::make_shared<DataTypeFunction>(DataTypes(lambda_args_tuple->arguments->children.size())));
/// Select the name in the next cycle. /// Select the name in the next cycle.
argument_names.emplace_back(); argument_names.emplace_back();
} }
@ -2183,11 +2181,9 @@ void ExpressionAnalyzer::getActionsImpl(const ASTPtr & ast, bool no_subqueries,
if (only_consts && !arguments_present) if (only_consts && !arguments_present)
return; return;
Names additional_requirements;
if (has_lambda_arguments && !only_consts) if (has_lambda_arguments && !only_consts)
{ {
function->getLambdaArgumentTypes(argument_types); function_builder->getLambdaArgumentTypes(argument_types);
/// Call recursively for lambda expressions. /// Call recursively for lambda expressions.
for (size_t i = 0; i < node->arguments->children.size(); ++i) for (size_t i = 0; i < node->arguments->children.size(); ++i)
@ -2197,7 +2193,7 @@ void ExpressionAnalyzer::getActionsImpl(const ASTPtr & ast, bool no_subqueries,
ASTFunction * lambda = typeid_cast<ASTFunction *>(child.get()); ASTFunction * lambda = typeid_cast<ASTFunction *>(child.get());
if (lambda && lambda->name == "lambda") if (lambda && lambda->name == "lambda")
{ {
const DataTypeExpression * lambda_type = typeid_cast<const DataTypeExpression *>(argument_types[i].get()); const DataTypeFunction * lambda_type = typeid_cast<const DataTypeFunction *>(argument_types[i].get());
ASTFunction * lambda_args_tuple = typeid_cast<ASTFunction *>(lambda->arguments->children.at(0).get()); ASTFunction * lambda_args_tuple = typeid_cast<ASTFunction *>(lambda->arguments->children.at(0).get());
ASTs lambda_arg_asts = lambda_args_tuple->arguments->children; ASTs lambda_arg_asts = lambda_args_tuple->arguments->children;
NamesAndTypesList lambda_arguments; NamesAndTypesList lambda_arguments;
@ -2220,22 +2216,23 @@ void ExpressionAnalyzer::getActionsImpl(const ASTPtr & ast, bool no_subqueries,
String result_name = lambda->arguments->children.at(1)->getColumnName(); String result_name = lambda->arguments->children.at(1)->getColumnName();
lambda_actions->finalize(Names(1, result_name)); lambda_actions->finalize(Names(1, result_name));
DataTypePtr result_type = lambda_actions->getSampleBlock().getByName(result_name).type; DataTypePtr result_type = lambda_actions->getSampleBlock().getByName(result_name).type;
argument_types[i] = std::make_shared<DataTypeExpression>(lambda_type->getArgumentTypes(), result_type);
Names captured = lambda_actions->getRequiredColumns(); Names captured;
for (size_t j = 0; j < captured.size(); ++j) Names required = lambda_actions->getRequiredColumns();
if (findColumn(captured[j], lambda_arguments) == lambda_arguments.end()) for (size_t j = 0; j < required.size(); ++j)
additional_requirements.push_back(captured[j]); if (findColumn(required[j], lambda_arguments) == lambda_arguments.end())
captured.push_back(required[j]);
/// We can not name `getColumnName()`, /// We can not name `getColumnName()`,
/// because it does not uniquely define the expression (the types of arguments can be different). /// because it does not uniquely define the expression (the types of arguments can be different).
argument_names[i] = getUniqueName(actions_stack.getSampleBlock(), "__lambda"); String lambda_name = getUniqueName(actions_stack.getSampleBlock(), "__lambda");
ColumnWithTypeAndName lambda_column; auto function_capture = std::make_shared<FunctionCapture>(
lambda_column.column = ColumnExpression::create(1, lambda_actions, lambda_arguments, result_type, result_name); lambda_actions, captured, lambda_arguments, result_type, result_name);
lambda_column.type = argument_types[i]; actions_stack.addAction(ExpressionAction::applyFunction(function_capture, captured, lambda_name));
lambda_column.name = argument_names[i];
actions_stack.addAction(ExpressionAction::addColumn(lambda_column)); argument_types[i] = std::make_shared<DataTypeFunction>(lambda_type->getArgumentTypes(), result_type);
argument_names[i] = lambda_name;
} }
} }
} }
@ -2253,8 +2250,7 @@ void ExpressionAnalyzer::getActionsImpl(const ASTPtr & ast, bool no_subqueries,
} }
if (arguments_present) if (arguments_present)
actions_stack.addAction(ExpressionAction::applyFunction(function, argument_names, node->getColumnName()), actions_stack.addAction(ExpressionAction::applyFunction(function_builder, argument_names, node->getColumnName()));
additional_requirements);
} }
} }
else if (ASTLiteral * node = typeid_cast<ASTLiteral *>(ast.get())) else if (ASTLiteral * node = typeid_cast<ASTLiteral *>(ast.get()))

View File

@ -20,6 +20,9 @@ class FieldWithInfinity;
using SetElements = std::vector<std::vector<Field>>; using SetElements = std::vector<std::vector<Field>>;
using SetElementsPtr = std::unique_ptr<SetElements>; using SetElementsPtr = std::unique_ptr<SetElements>;
class IFunctionBase;
using FunctionBasePtr = std::shared_ptr<IFunctionBase>;
/** Data structure for implementation of IN expression. /** Data structure for implementation of IN expression.
*/ */
class Set class Set
@ -175,7 +178,7 @@ public:
{ {
size_t tuple_index; size_t tuple_index;
size_t pk_index; size_t pk_index;
std::vector<FunctionPtr> functions; std::vector<FunctionBasePtr> functions;
DataTypePtr data_type; DataTypePtr data_type;
}; };

View File

@ -24,16 +24,10 @@ ColumnPtr castColumn(const ColumnWithTypeAndName & arg, const DataTypePtr & type
} }
}; };
FunctionPtr func_cast = FunctionFactory::instance().get("CAST", context); FunctionBuilderPtr func_builder_cast = FunctionFactory::instance().get("CAST", context);
{ ColumnsWithTypeAndName arguments{ temporary_block.getByPosition(0), temporary_block.getByPosition(1) };
DataTypePtr unused_return_type; auto func_cast = func_builder_cast->build(arguments);
ColumnsWithTypeAndName arguments{ temporary_block.getByPosition(0), temporary_block.getByPosition(1) };
std::vector<ExpressionAction> unused_prerequisites;
/// Prepares function to execution. TODO It is not obvious.
func_cast->getReturnTypeAndPrerequisites(arguments, unused_return_type, unused_prerequisites);
}
func_cast->execute(temporary_block, {0, 1}, 2); func_cast->execute(temporary_block, {0, 1}, 2);
return temporary_block.getByPosition(2).column; return temporary_block.getByPosition(2).column;

View File

@ -950,7 +950,7 @@ void MergeTreeData::createConvertExpression(const DataPartPtr & part, const Name
out_expression->add(ExpressionAction::addColumn( out_expression->add(ExpressionAction::addColumn(
{ DataTypeString().createColumnConst(1, new_type_name), std::make_shared<DataTypeString>(), new_type_name_column })); { DataTypeString().createColumnConst(1, new_type_name), std::make_shared<DataTypeString>(), new_type_name_column }));
const FunctionPtr & function = FunctionFactory::instance().get("CAST", context); const auto & function = FunctionFactory::instance().get("CAST", context);
out_expression->add(ExpressionAction::applyFunction( out_expression->add(ExpressionAction::applyFunction(
function, Names{column.name, new_type_name_column}), out_names); function, Names{column.name, new_type_name_column}), out_names);

View File

@ -342,17 +342,15 @@ static bool getConstant(const ASTPtr & expr, Block & block_with_constants, Field
static void applyFunction( static void applyFunction(
FunctionPtr & func, const FunctionBasePtr & func,
const DataTypePtr & arg_type, const Field & arg_value, const DataTypePtr & arg_type, const Field & arg_value,
DataTypePtr & res_type, Field & res_value) DataTypePtr & res_type, Field & res_value)
{ {
std::vector<ExpressionAction> unused_prerequisites; res_type = func->getReturnType();
ColumnsWithTypeAndName arguments{{ arg_type->createColumnConst(1, arg_value), arg_type, "x" }};
func->getReturnTypeAndPrerequisites(arguments, res_type, unused_prerequisites);
Block block Block block
{ {
arguments[0], { arg_type->createColumnConst(1, arg_value), arg_type, "x" },
{ nullptr, res_type, "y" } { nullptr, res_type, "y" }
}; };
@ -526,14 +524,14 @@ bool PKCondition::isPrimaryKeyPossiblyWrappedByMonotonicFunctions(
for (auto it = chain_not_tested_for_monotonicity.rbegin(); it != chain_not_tested_for_monotonicity.rend(); ++it) for (auto it = chain_not_tested_for_monotonicity.rbegin(); it != chain_not_tested_for_monotonicity.rend(); ++it)
{ {
FunctionPtr func = FunctionFactory::instance().tryGet((*it)->name, context); auto func_builder = FunctionFactory::instance().tryGet((*it)->name, context);
ColumnsWithTypeAndName arguments{{ nullptr, primary_key_column_type, "" }};
auto func = func_builder->build(arguments);
if (!func || !func->hasInformationAboutMonotonicity()) if (!func || !func->hasInformationAboutMonotonicity())
return false; return false;
std::vector<ExpressionAction> unused_prerequisites; primary_key_column_type = func->getReturnType();
ColumnsWithTypeAndName arguments{{ nullptr, primary_key_column_type, "" }};
func->getReturnTypeAndPrerequisites(arguments, primary_key_column_type, unused_prerequisites);
out_functions_chain.push_back(func); out_functions_chain.push_back(func);
} }

View File

@ -17,6 +17,9 @@
namespace DB namespace DB
{ {
class IFunction;
using FunctionBasePtr = std::shared_ptr<IFunctionBase>;
/** Range with open or closed ends; possibly unbounded. /** Range with open or closed ends; possibly unbounded.
*/ */
struct Range struct Range
@ -296,7 +299,7 @@ public:
* If the primary key column is wrapped in functions that can be monotonous in some value ranges * If the primary key column is wrapped in functions that can be monotonous in some value ranges
* (for example: -toFloat64(toDayOfWeek(date))), then here the functions will be located: toDayOfWeek, toFloat64, negate. * (for example: -toFloat64(toDayOfWeek(date))), then here the functions will be located: toDayOfWeek, toFloat64, negate.
*/ */
using MonotonicFunctionsChain = std::vector<FunctionPtr>; using MonotonicFunctionsChain = std::vector<FunctionBasePtr>;
mutable MonotonicFunctionsChain monotonic_functions_chain; /// The function execution does not violate the constancy. mutable MonotonicFunctionsChain monotonic_functions_chain; /// The function execution does not violate the constancy.
}; };