Added generic variants of least and greatest functions #4767

This commit is contained in:
Alexey Milovidov 2020-04-17 01:28:08 +03:00
parent 475ab6feef
commit 1df5c7cedf
4 changed files with 143 additions and 3 deletions

View File

@ -44,7 +44,7 @@ public:
/// Name of a Column kind, without parameters (example: FixedString, Array). /// Name of a Column kind, without parameters (example: FixedString, Array).
virtual const char * getFamilyName() const = 0; virtual const char * getFamilyName() const = 0;
/** If column isn't constant, returns nullptr (or itself). /** If column isn't constant, returns itself.
* If column is constant, transforms constant to full column (if column type allows such transform) and return it. * If column is constant, transforms constant to full column (if column type allows such transform) and return it.
*/ */
virtual Ptr convertToFullColumnIfConst() const { return getPtr(); } virtual Ptr convertToFullColumnIfConst() const { return getPtr(); }

View File

@ -0,0 +1,136 @@
#pragma once
#include <DataTypes/getLeastSupertype.h>
#include <Interpreters/castColumn.h>
#include <Columns/ColumnsNumber.h>
#include <Functions/IFunctionImpl.h>
#include <Functions/FunctionFactory.h>
#include <ext/map.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
enum class LeastGreatest
{
Least,
Greatest
};
template <LeastGreatest kind>
class FunctionLeastGreatestGeneric : public IFunction
{
public:
static constexpr auto name = kind == LeastGreatest::Least ? "least" : "greatest";
static FunctionPtr create(const Context &) { return std::make_shared<FunctionLeastGreatestGeneric<kind>>(); }
private:
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 0; }
bool isVariadic() const override { return true; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & types) const override
{
if (types.empty())
throw Exception("Function " + getName() + " cannot be called without arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
return getLeastSupertype(types);
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
{
size_t num_arguments = arguments.size();
if (1 == num_arguments)
{
block.getByPosition(result).column = block.getByPosition(arguments[0]).column;
return;
}
auto result_type = block.getByPosition(result).type;
Columns converted_columns(num_arguments);
for (size_t arg = 0; arg < num_arguments; ++arg)
converted_columns[arg] = castColumn(block.getByPosition(arguments[arg]), result_type)->convertToFullColumnIfConst();
auto result_column = result_type->createColumn();
result_column->reserve(input_rows_count);
for (size_t row_num = 0; row_num < input_rows_count; ++row_num)
{
size_t best_arg = 0;
for (size_t arg = 1; arg < num_arguments; ++arg)
{
auto cmp_result = converted_columns[arg]->compareAt(row_num, row_num, *converted_columns[best_arg], 1);
if constexpr (kind == LeastGreatest::Least)
{
if (cmp_result < 0)
best_arg = arg;
}
else
{
if (cmp_result > 0)
best_arg = arg;
}
}
result_column->insertFrom(*converted_columns[best_arg], row_num);
}
block.getByPosition(result).column = std::move(result_column);
}
};
template <LeastGreatest kind, typename SpecializedFunction>
class LeastGreatestOverloadResolver : public IFunctionOverloadResolverImpl
{
public:
static constexpr auto name = kind == LeastGreatest::Least ? "least" : "greatest";
static FunctionOverloadResolverImplPtr create(const Context & context)
{
return std::make_unique<LeastGreatestOverloadResolver<kind, SpecializedFunction>>(context);
}
explicit LeastGreatestOverloadResolver(const Context & context_) : context(context_) {}
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 0; }
bool isVariadic() const override { return true; }
FunctionBaseImplPtr build(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type) const override
{
DataTypes argument_types;
/// More efficient specialization for two numeric arguments.
if (arguments.size() == 2 && isNumber(arguments[0].type) && isNumber(arguments[1].type))
return std::make_unique<DefaultFunction>(SpecializedFunction::create(context), argument_types, return_type);
return std::make_unique<DefaultFunction>(
FunctionLeastGreatestGeneric<kind>::create(context), argument_types, return_type);
}
DataTypePtr getReturnType(const DataTypes & types) const override
{
if (types.empty())
throw Exception("Function " + getName() + " cannot be called without arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
return getLeastSupertype(types);
}
private:
const Context & context;
};
}

View File

@ -1,6 +1,8 @@
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Functions/FunctionBinaryArithmetic.h> #include <Functions/FunctionBinaryArithmetic.h>
#include <Core/AccurateComparison.h> #include <Core/AccurateComparison.h>
#include <Functions/LeastGreatestGeneric.h>
namespace DB namespace DB
{ {
@ -57,7 +59,7 @@ using FunctionGreatest = FunctionBinaryArithmetic<GreatestImpl, NameGreatest>;
void registerFunctionGreatest(FunctionFactory & factory) void registerFunctionGreatest(FunctionFactory & factory)
{ {
factory.registerFunction<FunctionGreatest>(FunctionFactory::CaseInsensitive); factory.registerFunction<LeastGreatestOverloadResolver<LeastGreatest::Greatest, FunctionGreatest>>(FunctionFactory::CaseInsensitive);
} }
} }

View File

@ -1,6 +1,8 @@
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Functions/FunctionBinaryArithmetic.h> #include <Functions/FunctionBinaryArithmetic.h>
#include <Core/AccurateComparison.h> #include <Core/AccurateComparison.h>
#include <Functions/LeastGreatestGeneric.h>
namespace DB namespace DB
{ {
@ -57,7 +59,7 @@ using FunctionLeast = FunctionBinaryArithmetic<LeastImpl, NameLeast>;
void registerFunctionLeast(FunctionFactory & factory) void registerFunctionLeast(FunctionFactory & factory)
{ {
factory.registerFunction<FunctionLeast>(FunctionFactory::CaseInsensitive); factory.registerFunction<LeastGreatestOverloadResolver<LeastGreatest::Least, FunctionLeast>>(FunctionFactory::CaseInsensitive);
} }
} }