mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-17 13:13:36 +00:00
Draft: very simple variant
This commit is contained in:
parent
7c47832405
commit
2e8a296cc9
@ -42,14 +42,6 @@ namespace ErrorCodes
|
||||
template <typename Impl, typename Name>
|
||||
class FunctionArrayMapped : public IFunction
|
||||
{
|
||||
private:
|
||||
size_t
|
||||
my_min(size_t a, size_t b) const
|
||||
{
|
||||
return (a < b) ? a : b;
|
||||
}
|
||||
|
||||
|
||||
public:
|
||||
static constexpr auto name = Name::name;
|
||||
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionArrayMapped>(); }
|
||||
@ -75,8 +67,9 @@ public:
|
||||
throw Exception("Function " + getName() + " needs at least one array argument.",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
size_t arguments_to_skip = Impl::isFolding() ? 1 : 0;
|
||||
DataTypes nested_types(arguments.size() - 1);
|
||||
for (size_t i = 0; i < nested_types.size(); ++i)
|
||||
for (size_t i = 0; i < nested_types.size() - arguments_to_skip; ++i)
|
||||
{
|
||||
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(&*arguments[i + 1]);
|
||||
if (!array_type)
|
||||
@ -84,6 +77,8 @@ public:
|
||||
+ arguments[i + 1]->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
nested_types[i] = recursiveRemoveLowCardinality(array_type->getNestedType());
|
||||
}
|
||||
if (Impl::isFolding())
|
||||
nested_types[nested_types.size() - 1] = arguments[arguments.size() - 1];
|
||||
|
||||
const DataTypeFunction * function_type = checkAndGetDataType<DataTypeFunction>(arguments[0].get());
|
||||
if (!function_type || function_type->getArgumentTypes().size() != nested_types.size())
|
||||
@ -138,14 +133,25 @@ public:
|
||||
throw Exception("Expression for function " + getName() + " must return UInt8, found "
|
||||
+ return_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
const auto * first_array_type = checkAndGetDataType<DataTypeArray>(arguments[1].type.get());
|
||||
|
||||
return Impl::getReturnType(return_type, first_array_type->getNestedType());
|
||||
if (Impl::isFolding())
|
||||
{
|
||||
const auto accum_type = arguments.back().type;
|
||||
return Impl::getReturnType(return_type, accum_type);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto * first_array_type = checkAndGetDataType<DataTypeArray>(arguments[1].type.get());
|
||||
return Impl::getReturnType(return_type, first_array_type->getNestedType());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override
|
||||
{
|
||||
std::cerr << " *** FOLDING" << std::endl;
|
||||
std::cerr << " isFolding(): " << (Impl::isFolding() ? "yes" : "no-") << std::endl;
|
||||
std::cerr << " arguments.size() = " << arguments.size() << std::endl;
|
||||
|
||||
if (arguments.size() == 1)
|
||||
{
|
||||
ColumnPtr column_array_ptr = arguments[0].column;
|
||||
@ -155,7 +161,7 @@ public:
|
||||
{
|
||||
const ColumnConst * column_const_array = checkAndGetColumnConst<ColumnArray>(column_array_ptr.get());
|
||||
if (!column_const_array)
|
||||
throw Exception("Expected array column, found " + column_array_ptr->getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
throw Exception("X1 Expected array column, found " + column_array_ptr->getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
column_array_ptr = column_const_array->convertToFullColumn();
|
||||
column_array = assert_cast<const ColumnArray *>(column_array_ptr.get());
|
||||
}
|
||||
@ -181,10 +187,12 @@ public:
|
||||
ColumnPtr column_first_array_ptr;
|
||||
const ColumnArray * column_first_array = nullptr;
|
||||
|
||||
size_t arguments_to_skip = Impl::isFolding() ? 1 : 0;
|
||||
|
||||
ColumnsWithTypeAndName arrays;
|
||||
arrays.reserve(arguments.size() - 1);
|
||||
|
||||
for (size_t i = 1; i < arguments.size(); ++i)
|
||||
for (size_t i = 1; i < arguments.size() - arguments_to_skip; ++i)
|
||||
{
|
||||
const auto & array_with_type_and_name = arguments[i];
|
||||
|
||||
@ -198,21 +206,13 @@ public:
|
||||
{
|
||||
const ColumnConst * column_const_array = checkAndGetColumnConst<ColumnArray>(column_array_ptr.get());
|
||||
if (!column_const_array)
|
||||
throw Exception("Expected array column, found " + column_array_ptr->getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
throw Exception("X2 Expected array column, found " + column_array_ptr->getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
column_array_ptr = recursiveRemoveLowCardinality(column_const_array->convertToFullColumn());
|
||||
column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
|
||||
}
|
||||
|
||||
ColumnPtr int_column = column_array->getDataPtr();
|
||||
ColumnPtr ca_ptr = ColumnArray::create( int_column->cut(0, my_min(1, int_column->size())),
|
||||
column_array->getOffsetsPtr());
|
||||
const auto * ca = checkAndGetColumn<ColumnArray>(ca_ptr.get());
|
||||
|
||||
column_array_ptr = ca_ptr;
|
||||
column_array = ca;
|
||||
|
||||
if (!array_type)
|
||||
throw Exception("Expected array type, found " + array_type_ptr->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
throw Exception("X3 Expected array type, found " + array_type_ptr->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
if (!offsets_column)
|
||||
{
|
||||
@ -236,17 +236,84 @@ public:
|
||||
recursiveRemoveLowCardinality(array_type->getNestedType()),
|
||||
array_with_type_and_name.name));
|
||||
}
|
||||
if (Impl::isFolding())
|
||||
arrays.emplace_back(arguments[arguments.size() - 1]); // TODO .last()
|
||||
std::cerr << " arrays.size() = " << arrays.size() << std::endl;
|
||||
std::cerr << " column_first_array->getData().size() = " << column_first_array->getData().size() << std::endl;
|
||||
|
||||
/// Put all the necessary columns multiplied by the sizes of arrays into the columns.
|
||||
auto replicated_column_function_ptr = IColumn::mutate(column_function->replicate(column_first_array->getOffsets()));
|
||||
auto * replicated_column_function = typeid_cast<ColumnFunction *>(replicated_column_function_ptr.get());
|
||||
replicated_column_function->appendArguments(arrays);
|
||||
if (Impl::isFolding() && (column_first_array->getData().size() > 0)) // TODO .size() -> .empty()
|
||||
{
|
||||
|
||||
auto lambda_result = replicated_column_function->reduce().column;
|
||||
if (lambda_result->lowCardinality())
|
||||
lambda_result = lambda_result->convertToFullColumnIfLowCardinality();
|
||||
|
||||
return Impl::execute(*column_first_array, lambda_result);
|
||||
ColumnWithTypeAndName accumulator = arguments.back();
|
||||
ColumnPtr res;
|
||||
|
||||
for(size_t i = 0; i < column_first_array->getData().size(); ++i)
|
||||
{
|
||||
std::cerr << " ----- iteration " << i << " ------" << std::endl;
|
||||
// Make slice of input arrays and accumulator for lambda
|
||||
ColumnsWithTypeAndName iter_arrays;
|
||||
iter_arrays.reserve(arrays.size());
|
||||
for(size_t j = 0; j < arrays.size() - 1; ++j)
|
||||
{
|
||||
auto const & arr = arrays[j];
|
||||
std::cerr << " " << j << ") " << 1 << std::endl;
|
||||
/*
|
||||
const ColumnArray * arr_array = checkAndGetColumn<ColumnArray>(arr.column.get());
|
||||
std::cerr << " " << j << ") " << 1 << " " << arr_array << std::endl;
|
||||
std::cerr << " " << j << ") " << 2 << std::endl;
|
||||
const ColumnPtr & nested_column_x = arr_array->getData().cut(i, 1);
|
||||
std::cerr << " " << j << ") " << 3 << std::endl;
|
||||
const ColumnPtr & offsets_column_x = ColumnArray::ColumnOffsets::create(1, 1);
|
||||
std::cerr << " " << j << ") " << 4 << std::endl;
|
||||
auto new_arr_array = ColumnArray::create(nested_column_x, offsets_column_x);
|
||||
std::cerr << " " << j << ") " << 5 << std::endl;
|
||||
*/
|
||||
iter_arrays.emplace_back(ColumnWithTypeAndName(arr.column->cut(i, 1),
|
||||
arr.type,
|
||||
arr.name));
|
||||
std::cerr << " " << j << ") " << 6 << std::endl;
|
||||
}
|
||||
iter_arrays.emplace_back(accumulator);
|
||||
// ----
|
||||
std::cerr << " formed" << std::endl;
|
||||
auto replicated_column_function_ptr = IColumn::mutate(column_function->replicate(IColumn::Offsets(1, 1)));
|
||||
auto * replicated_column_function = typeid_cast<ColumnFunction *>(replicated_column_function_ptr.get());
|
||||
std::cerr << " pre append" << std::endl;
|
||||
replicated_column_function->appendArguments(iter_arrays);
|
||||
std::cerr << " post append" << std::endl;
|
||||
auto lambda_result = replicated_column_function->reduce().column;
|
||||
if (lambda_result->lowCardinality())
|
||||
lambda_result = lambda_result->convertToFullColumnIfLowCardinality();
|
||||
std::cerr << " pre execute" << std::endl;
|
||||
res = Impl::execute(*column_first_array, lambda_result); // TODO column_first_array
|
||||
std::cerr << " post execute : res " << res->dumpStructure() << std::endl;
|
||||
std::cerr << " post execute : res[0] " << (*res)[0].dump() << std::endl;
|
||||
// ~~~
|
||||
// ~~~
|
||||
// ~~~
|
||||
accumulator.column = res;
|
||||
}
|
||||
|
||||
return res;
|
||||
|
||||
}
|
||||
else
|
||||
{
|
||||
/// Put all the necessary columns multiplied by the sizes of arrays into the columns.
|
||||
auto replicated_column_function_ptr = IColumn::mutate(column_function->replicate(column_first_array->getOffsets()));
|
||||
auto * replicated_column_function = typeid_cast<ColumnFunction *>(replicated_column_function_ptr.get());
|
||||
replicated_column_function->appendArguments(arrays);
|
||||
|
||||
auto lambda_result = replicated_column_function->reduce().column;
|
||||
if (lambda_result->lowCardinality())
|
||||
lambda_result = lambda_result->convertToFullColumnIfLowCardinality();
|
||||
|
||||
ColumnPtr res = Impl::execute(*column_first_array, lambda_result);
|
||||
std::cerr << " ^^^ FOLDING" << std::endl;
|
||||
return res;
|
||||
//return Impl::execute(*column_first_array, lambda_result);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -76,6 +76,7 @@ struct ArrayAggregateImpl
|
||||
static bool needBoolean() { return false; }
|
||||
static bool needExpression() { return false; }
|
||||
static bool needOneArray() { return false; }
|
||||
static bool isFolding() { return false; }
|
||||
|
||||
static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypePtr & /*array_element*/)
|
||||
{
|
||||
|
@ -19,6 +19,7 @@ struct ArrayAllImpl
|
||||
static bool needBoolean() { return true; }
|
||||
static bool needExpression() { return false; }
|
||||
static bool needOneArray() { return false; }
|
||||
static bool isFolding() { return false; }
|
||||
|
||||
static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & /*array_element*/)
|
||||
{
|
||||
|
@ -20,6 +20,7 @@ struct ArrayCompactImpl
|
||||
static bool needBoolean() { return false; }
|
||||
static bool needExpression() { return false; }
|
||||
static bool needOneArray() { return false; }
|
||||
static bool isFolding() { return false; }
|
||||
|
||||
static DataTypePtr getReturnType(const DataTypePtr & nested_type, const DataTypePtr &)
|
||||
{
|
||||
|
@ -19,6 +19,7 @@ struct ArrayCountImpl
|
||||
static bool needBoolean() { return true; }
|
||||
static bool needExpression() { return false; }
|
||||
static bool needOneArray() { return false; }
|
||||
static bool isFolding() { return false; }
|
||||
|
||||
static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & /*array_element*/)
|
||||
{
|
||||
|
@ -20,6 +20,7 @@ struct ArrayCumSumImpl
|
||||
static bool needBoolean() { return false; }
|
||||
static bool needExpression() { return false; }
|
||||
static bool needOneArray() { return false; }
|
||||
static bool isFolding() { return false; }
|
||||
|
||||
static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypePtr & /*array_element*/)
|
||||
{
|
||||
|
@ -23,6 +23,7 @@ struct ArrayCumSumNonNegativeImpl
|
||||
static bool needBoolean() { return false; }
|
||||
static bool needExpression() { return false; }
|
||||
static bool needOneArray() { return false; }
|
||||
static bool isFolding() { return false; }
|
||||
|
||||
static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypePtr & /*array_element*/)
|
||||
{
|
||||
|
@ -23,6 +23,7 @@ struct ArrayDifferenceImpl
|
||||
static bool needBoolean() { return false; }
|
||||
static bool needExpression() { return false; }
|
||||
static bool needOneArray() { return false; }
|
||||
static bool isFolding() { return false; }
|
||||
|
||||
static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypePtr & /*array_element*/)
|
||||
{
|
||||
|
@ -19,6 +19,7 @@ struct ArrayExistsImpl
|
||||
static bool needBoolean() { return true; }
|
||||
static bool needExpression() { return false; }
|
||||
static bool needOneArray() { return false; }
|
||||
static bool isFolding() { return false; }
|
||||
|
||||
static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & /*array_element*/)
|
||||
{
|
||||
|
@ -22,6 +22,7 @@ struct ArrayFillImpl
|
||||
static bool needBoolean() { return true; }
|
||||
static bool needExpression() { return true; }
|
||||
static bool needOneArray() { return false; }
|
||||
static bool isFolding() { return false; }
|
||||
|
||||
static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & array_element)
|
||||
{
|
||||
|
@ -18,6 +18,7 @@ struct ArrayFilterImpl
|
||||
static bool needBoolean() { return true; }
|
||||
static bool needExpression() { return true; }
|
||||
static bool needOneArray() { return false; }
|
||||
static bool isFolding() { return false; }
|
||||
|
||||
static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & array_element)
|
||||
{
|
||||
|
@ -16,6 +16,7 @@ struct ArrayFirstImpl
|
||||
static bool needBoolean() { return false; }
|
||||
static bool needExpression() { return true; }
|
||||
static bool needOneArray() { return false; }
|
||||
static bool isFolding() { return false; }
|
||||
|
||||
static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & array_element)
|
||||
{
|
||||
|
@ -16,6 +16,7 @@ struct ArrayFirstIndexImpl
|
||||
static bool needBoolean() { return false; }
|
||||
static bool needExpression() { return true; }
|
||||
static bool needOneArray() { return false; }
|
||||
static bool isFolding() { return false; }
|
||||
|
||||
static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & /*array_element*/)
|
||||
{
|
||||
|
58
src/Functions/array/arrayFold.cpp
Normal file
58
src/Functions/array/arrayFold.cpp
Normal file
@ -0,0 +1,58 @@
|
||||
#include "FunctionArrayMapped.h"
|
||||
#include <Functions/FunctionFactory.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
/** arrayFold(x1,...,xn,accum -> expression, array1,...,arrayn, init_accum) - apply the expression to each element of the array (or set of parallel arrays).
|
||||
*/
|
||||
struct ArrayFoldImpl
|
||||
{
|
||||
static bool needBoolean() { return false; }
|
||||
static bool needExpression() { return true; }
|
||||
static bool needOneArray() { return false; }
|
||||
static bool isFolding() { return true; }
|
||||
|
||||
static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & accum_type)
|
||||
{
|
||||
return accum_type;
|
||||
}
|
||||
|
||||
static ColumnPtr execute(const ColumnArray & array, ColumnPtr mapped)
|
||||
{
|
||||
std::cerr << " **** ArrayFoldImpl **** " << std::endl;
|
||||
std::cerr << " array: " << array.dumpStructure() << std::endl;
|
||||
std::cerr << " mapped: " << mapped->dumpStructure() << std::endl;
|
||||
std::cerr << " mapped[0]: " << (*mapped)[0].dump() << std::endl;
|
||||
// std::cerr << " mapped[1]: " << (*mapped)[1].dump() << std::endl;
|
||||
//
|
||||
|
||||
ColumnPtr res;
|
||||
if (mapped->size() == 0)
|
||||
{
|
||||
res = mapped;
|
||||
}
|
||||
else
|
||||
{
|
||||
res = mapped->cut(0, 1);
|
||||
}
|
||||
std::cerr << " ^^^^ ArrayFoldImpl ^^^^" << std::endl;
|
||||
return res;
|
||||
|
||||
// return ColumnArray::create(mapped->convertToFullColumnIfConst(), array.getOffsetsPtr());
|
||||
// return ColumnArray::create(mapped->convertToFullColumnIfConst(), array.getOffsetsPtr());
|
||||
}
|
||||
};
|
||||
|
||||
struct NameArrayFold { static constexpr auto name = "arrayFold"; };
|
||||
using FunctionArrayFold = FunctionArrayMapped<ArrayFoldImpl, NameArrayFold>;
|
||||
|
||||
void registerFunctionArrayFold(FunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction<FunctionArrayFold>();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
@ -15,6 +15,8 @@ struct ArrayMapImpl
|
||||
static bool needExpression() { return true; }
|
||||
/// true if the array must be exactly one.
|
||||
static bool needOneArray() { return false; }
|
||||
/// true if function do folding
|
||||
static bool isFolding() { return false; }
|
||||
|
||||
static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypePtr & /*array_element*/)
|
||||
{
|
||||
|
@ -13,6 +13,7 @@ struct ArraySortImpl
|
||||
static bool needBoolean() { return false; }
|
||||
static bool needExpression() { return false; }
|
||||
static bool needOneArray() { return false; }
|
||||
static bool isFolding() { return false; }
|
||||
|
||||
static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & array_element)
|
||||
{
|
||||
|
@ -17,6 +17,7 @@ struct ArraySplitImpl
|
||||
static bool needBoolean() { return true; }
|
||||
static bool needExpression() { return true; }
|
||||
static bool needOneArray() { return false; }
|
||||
static bool isFolding() { return false; }
|
||||
|
||||
static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & array_element)
|
||||
{
|
||||
|
@ -4,6 +4,7 @@ namespace DB
|
||||
class FunctionFactory;
|
||||
|
||||
void registerFunctionArrayMap(FunctionFactory & factory);
|
||||
void registerFunctionArrayFold(FunctionFactory & factory);
|
||||
void registerFunctionArrayFilter(FunctionFactory & factory);
|
||||
void registerFunctionArrayCount(FunctionFactory & factory);
|
||||
void registerFunctionArrayExists(FunctionFactory & factory);
|
||||
@ -22,6 +23,7 @@ void registerFunctionArrayDifference(FunctionFactory & factory);
|
||||
void registerFunctionsHigherOrder(FunctionFactory & factory)
|
||||
{
|
||||
registerFunctionArrayMap(factory);
|
||||
registerFunctionArrayFold(factory);
|
||||
registerFunctionArrayFilter(factory);
|
||||
registerFunctionArrayCount(factory);
|
||||
registerFunctionArrayExists(factory);
|
||||
|
@ -144,6 +144,7 @@ SRCS(
|
||||
array/arrayFirst.cpp
|
||||
array/arrayFirstIndex.cpp
|
||||
array/arrayFlatten.cpp
|
||||
array/arrayFold.cpp
|
||||
array/arrayIntersect.cpp
|
||||
array/arrayJoin.cpp
|
||||
array/arrayMap.cpp
|
||||
|
Loading…
Reference in New Issue
Block a user