added impl

This commit is contained in:
yariks5s 2023-07-26 15:00:25 +00:00
parent 6205218e2b
commit f8200e50cb

View File

@ -42,6 +42,15 @@
#include <Common/assert_cast.h>
#include <Common/typeid_cast.h>
#include <Common/Arena.h>
#include <Core/ColumnWithTypeAndName.h>
#include <base/types.h>
#include <Columns/ColumnArray.h>
#include <Columns/IColumn.h>
#include <Core/ColumnsWithTypeAndName.h>
#include <DataTypes/IDataType.h>
#include <DataTypes/getMostSubtype.h>
#include <base/TypeLists.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <Interpreters/Context.h>
@ -102,6 +111,9 @@ template <typename DataType> constexpr bool IsFloatingPoint = false;
template <> inline constexpr bool IsFloatingPoint<DataTypeFloat32> = true;
template <> inline constexpr bool IsFloatingPoint<DataTypeFloat64> = true;
template <typename DataType> constexpr bool IsArray = false;
template <> inline constexpr bool IsArray<DataTypeArray> = true;
template <typename DataType> constexpr bool IsDateOrDateTime = false;
template <> inline constexpr bool IsDateOrDateTime<DataTypeDate> = true;
template <> inline constexpr bool IsDateOrDateTime<DataTypeDateTime> = true;
@ -1125,6 +1137,92 @@ class FunctionBinaryArithmetic : public IFunction
return function->execute(arguments, result_type, input_rows_count);
}
template <typename ColumnType>
ColumnPtr executeArrayPlusMinus(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type,
size_t input_rows_count, const FunctionOverloadResolverPtr & function_builder) const
{
auto function = function_builder->build(arguments);
return function->execute(arguments, result_type, input_rows_count);
}
static ColumnPtr callFunctionNotEquals(ColumnWithTypeAndName first, ColumnWithTypeAndName second, ContextPtr context)
{
ColumnsWithTypeAndName args{first, second};
auto eq_func = FunctionFactory::instance().get("notEquals", context)->build(args);
return eq_func->execute(args, eq_func->getResultType(), args.front().column->size());
}
ColumnPtr executeArrayImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const
{
bool is_const = false;
const auto * return_type_array = checkAndGetDataType<DataTypeArray>(result_type.get());
if (!return_type_array)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Return type for function {} must be array.", getName());
ColumnPtr result_column = executeArray<IColumn>(arguments, result_type, input_rows_count);
if (arguments[0].dumpStructure().contains("Const"))
is_const = true;
if (is_const)
return result_column;
else
return ColumnArray::create(result_column, typeid_cast<const ColumnArray *>(arguments[0].column.get())->getOffsetsPtr());
}
template <typename ColumnType>
ColumnPtr executeArray(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const
{
if constexpr (is_multiply || is_division)
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Cannot use multiplication or division on arrays");
auto num_args = arguments.size();
DataTypes data_types;
ColumnsWithTypeAndName new_arguments {num_args};
DataTypePtr t;
const auto * left_const = typeid_cast<const ColumnConst *>(arguments[0].column.get());
const auto * right_const = typeid_cast<const ColumnConst *>(arguments[1].column.get());
/// Unpacking arrays if both are constants.
if (left_const && right_const)
{
new_arguments[0] = {left_const->getDataColumnPtr(), arguments[0].type, arguments[0].name};
new_arguments[1] = {right_const->getDataColumnPtr(), arguments[1].type, arguments[1].name};
auto col = executeImpl(new_arguments, result_type, 1);
return ColumnConst::create(std::move(col), input_rows_count);
}
/// Unpacking arrays if at least one column is constant.
if (left_const || right_const)
{
new_arguments[0] = {arguments[0].column->convertToFullColumnIfConst(), arguments[0].type, arguments[0].name};
new_arguments[1] = {arguments[1].column->convertToFullColumnIfConst(), arguments[1].type, arguments[1].name};
return executeImpl(new_arguments, result_type, input_rows_count);
}
/// Unpacking non-const arrays and checking sizes of them.
UInt64 data = 0;
for (size_t i = 0; i < num_args; ++i)
{
auto a = typeid_cast<const ColumnArray *>(arguments[i].column.get())->getData().getPtr();
if (i == 0)
data = *typeid_cast<const ColumnArray *>(arguments[i].column.get())->getOffsets().data();
else
{
if (*typeid_cast<const ColumnArray *>(arguments[i].column.get())->getOffsets().data() != data)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Arguments must be one size");
}
t = typeid_cast<const DataTypeArray *>(arguments[i].type.get())->getNestedType();
new_arguments[i] = {a, t, arguments[i].name};
}
return executeImpl(new_arguments, t, input_rows_count);
}
ColumnPtr executeTupleNumberOperator(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type,
size_t input_rows_count, const FunctionOverloadResolverPtr & function_builder) const
{
@ -1326,6 +1424,16 @@ public:
return getReturnTypeImplStatic(new_arguments, context);
}
if (isArray(arguments[0]) || isArray(arguments[1]))
{
DataTypes new_arguments {
static_cast<const DataTypeArray &>(*arguments[0]).getNestedType(),
static_cast<const DataTypeArray &>(*arguments[1]).getNestedType(),
};
return std::make_shared<DataTypeArray>(getReturnTypeImplStatic(new_arguments, context));
}
/// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
if (auto function_builder = getFunctionForIntervalArithmetic(arguments[0], arguments[1], context))
{
@ -2031,6 +2139,9 @@ ColumnPtr executeStringInteger(const ColumnsWithTypeAndName & arguments, const A
return (res = executeNumeric(arguments, left, right, right_nullmap)) != nullptr;
});
if (isArray(result_type))
return executeArrayImpl(arguments, result_type, input_rows_count);
if (!valid)
{
// This is a logical error, because the types should have been checked