Merge pull request #54608 from yariks5s/arr_scalar_mult_div_modulo

Added support for array&scalar operations
This commit is contained in:
robot-clickhouse 2023-09-18 17:26:42 +02:00 committed by GitHub
commit 6564743794
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 150 additions and 3 deletions

View File

@ -1156,12 +1156,12 @@ class FunctionBinaryArithmetic : public IFunction
return function->execute(arguments, result_type, input_rows_count);
}
ColumnPtr executeArrayImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const
ColumnPtr executeArraysImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const
{
const auto * return_type_array = checkAndGetDataType<DataTypeArray>(result_type.get());
if (!return_type_array)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Return type for function {} must be array.", getName());
throw Exception(ErrorCodes::LOGICAL_ERROR, "Return type for function {} must be array", getName());
auto num_args = arguments.size();
DataTypes data_types;
@ -1211,6 +1211,72 @@ class FunctionBinaryArithmetic : public IFunction
return ColumnArray::create(res, typeid_cast<const ColumnArray *>(arguments[0].column.get())->getOffsetsPtr());
}
ColumnPtr executeArrayWithNumericImpl(const ColumnsWithTypeAndName & args, const DataTypePtr & result_type, size_t input_rows_count) const
{
ColumnsWithTypeAndName arguments = args;
bool is_swapped = isNumber(args[0].type); /// Defines the order of arguments (If array is first argument - is_swapped = false)
const auto * return_type_array = checkAndGetDataType<DataTypeArray>(result_type.get());
if (!return_type_array)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Return type for function {} must be array", getName());
auto num_args = arguments.size();
DataTypes data_types;
ColumnsWithTypeAndName new_arguments {num_args};
DataTypePtr result_array_type;
const auto * left_const = typeid_cast<const ColumnConst *>(arguments[0].column.get());
const auto * right_const = typeid_cast<const ColumnConst *>(arguments[1].column.get());
if (left_const && right_const)
{
new_arguments[0] = {left_const->getDataColumnPtr(), arguments[0].type, arguments[0].name};
new_arguments[1] = {right_const->getDataColumnPtr(), arguments[1].type, arguments[1].name};
auto col = executeImpl(new_arguments, result_type, 1);
return ColumnConst::create(std::move(col), input_rows_count);
}
if (right_const && is_swapped)
{
new_arguments[0] = {arguments[0].column.get()->getPtr(), arguments[0].type, arguments[0].name};
new_arguments[1] = {right_const->convertToFullColumnIfConst(), arguments[1].type, arguments[1].name};
return executeImpl(new_arguments, result_type, input_rows_count);
}
else if (left_const && !is_swapped)
{
new_arguments[0] = {left_const->convertToFullColumnIfConst(), arguments[0].type, arguments[0].name};
new_arguments[1] = {arguments[1].column.get()->getPtr(), arguments[1].type, arguments[1].name};
return executeImpl(new_arguments, result_type, input_rows_count);
}
if (is_swapped)
std::swap(arguments[1], arguments[0]);
const auto * left_array_col = typeid_cast<const ColumnArray *>(arguments[0].column.get());
const auto & left_array_elements_type = typeid_cast<const DataTypeArray *>(arguments[0].type.get())->getNestedType();
const auto & right_col = arguments[1].column.get()->cloneResized(left_array_col->size());
size_t rows_count = 0;
const auto & left_offsets = left_array_col->getOffsets();
if (!left_offsets.empty())
rows_count = left_offsets.back();
new_arguments[0] = {left_array_col->getDataPtr(), left_array_elements_type, arguments[0].name};
if (right_const)
new_arguments[1] = {right_col->cloneResized(rows_count), arguments[1].type, arguments[1].name};
else
new_arguments[1] = {right_col->replicate(left_array_col->getOffsets()), arguments[1].type, arguments[1].name};
result_array_type = left_array_elements_type;
if (is_swapped)
std::swap(new_arguments[1], new_arguments[0]);
auto res = executeImpl(new_arguments, result_array_type, rows_count);
return ColumnArray::create(res, left_array_col->getOffsetsPtr());
}
ColumnPtr executeTupleNumberOperator(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type,
size_t input_rows_count, const FunctionOverloadResolverPtr & function_builder) const
{
@ -1425,6 +1491,25 @@ public:
}
}
if constexpr (is_multiply || is_division)
{
if (isArray(arguments[0]) && isNumber(arguments[1]))
{
DataTypes new_arguments {
static_cast<const DataTypeArray &>(*arguments[0]).getNestedType(),
arguments[1],
};
return std::make_shared<DataTypeArray>(getReturnTypeImplStatic(new_arguments, context));
}
if (isNumber(arguments[0]) && isArray(arguments[1]))
{
DataTypes new_arguments {
arguments[0],
static_cast<const DataTypeArray &>(*arguments[1]).getNestedType(),
};
return std::make_shared<DataTypeArray>(getReturnTypeImplStatic(new_arguments, context));
}
}
/// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
if (auto function_builder = getFunctionForIntervalArithmetic(arguments[0], arguments[1], context))
@ -2132,7 +2217,11 @@ ColumnPtr executeStringInteger(const ColumnsWithTypeAndName & arguments, const A
});
if (isArray(result_type))
return executeArrayImpl(arguments, result_type, input_rows_count);
{
if (!isArray(arguments[0].type) || !isArray(arguments[1].type))
return executeArrayWithNumericImpl(arguments, result_type, input_rows_count);
return executeArraysImpl(arguments, result_type, input_rows_count);
}
if (!valid)
{

View File

@ -0,0 +1,33 @@
[14,21,35]
[14,21,35]
[14,21,35]
[14,21,35]
[14,21,35]
[14,21,35]
[14,21,35]
[[[14,21,35,35]]]
[[[14,21,35,35]]]
[[[1,1.5,2.5,2.5]]]
[[[1,0.6666666666666666,0.4,0.4]]]
[(7,14),(14,14)]
[(NULL,14),(14,NULL)]
[(NULL,2),(2,NULL)]
[(7,700000000000000000000),(NULL,7340039)]
[14,0]
[14,7]
[14,14]
[14,21]
[14,28]
[0,0,0]
[2,3,5]
[4,6,10]
[6,9,15]
[8,12,20]
[]
[0]
[0,42]
[0,42,84]
[0,42,84,126]
[60,15,5]
[0,0,1]
[2.4,2.4,1.2]

View File

@ -0,0 +1,25 @@
SELECT materialize([2, 3, 5]) * materialize(7);
SELECT materialize(7) * materialize([2, 3, 5]);
SELECT [2, 3, 5] * materialize(7);
SELECT materialize(7) * [2, 3, 5];
SELECT materialize([2, 3, 5]) * 7;
SELECT 7 * materialize([2, 3, 5]);
SELECT [2, 3, 5] * 7;
SELECT [[[2, 3, 5, 5]]] * 7;
SELECT 7 * [[[2, 3, 5, 5]]];
SELECT [[[2, 3, 5, 5]]] / 2;
SELECT 2 / [[[2, 3, 5, 5]]];
SELECT [(1, 2), (2, 2)] * 7;
SELECT [(NULL, 2), (2, NULL)] * 7;
SELECT [(NULL, 2), (2, NULL)] / 1;
SELECT [(1., 100000000000000000000.), (NULL, 1048577)] * 7;
SELECT [CAST('2', 'UInt64'), number] * 7 FROM numbers(5);
SELECT [2, 3, 5] * number FROM numbers(5);
SELECT range(number) * 42 FROM numbers(5);
CREATE TABLE my_table (values Array(Int32)) ENGINE = MergeTree() ORDER BY values;
INSERT INTO my_table (values) VALUES ([12, 3, 1]);
SELECT values * 5 FROM my_table WHERE arrayExists(x -> x > 5, values);
DROP TABLE my_table;
SELECT [6, 6, 3] % 2;
SELECT [6, 6, 3] / 2.5::Decimal(1, 1);
SELECT [1] / 'a'; -- { serverError 43 }