Draft: very simple variant

This commit is contained in:
Dmitry Krylov 2021-03-12 17:24:30 +10:00
parent 7c47832405
commit 2e8a296cc9
19 changed files with 176 additions and 32 deletions

View File

@ -42,14 +42,6 @@ namespace ErrorCodes
template <typename Impl, typename Name>
class FunctionArrayMapped : public IFunction
{
private:
size_t
my_min(size_t a, size_t b) const
{
return (a < b) ? a : b;
}
public:
static constexpr auto name = Name::name;
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionArrayMapped>(); }
@ -75,8 +67,9 @@ public:
throw Exception("Function " + getName() + " needs at least one array argument.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
size_t arguments_to_skip = Impl::isFolding() ? 1 : 0;
DataTypes nested_types(arguments.size() - 1);
for (size_t i = 0; i < nested_types.size(); ++i)
for (size_t i = 0; i < nested_types.size() - arguments_to_skip; ++i)
{
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(&*arguments[i + 1]);
if (!array_type)
@ -84,6 +77,8 @@ public:
+ arguments[i + 1]->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
nested_types[i] = recursiveRemoveLowCardinality(array_type->getNestedType());
}
if (Impl::isFolding())
nested_types[nested_types.size() - 1] = arguments[arguments.size() - 1];
const DataTypeFunction * function_type = checkAndGetDataType<DataTypeFunction>(arguments[0].get());
if (!function_type || function_type->getArgumentTypes().size() != nested_types.size())
@ -138,14 +133,25 @@ public:
throw Exception("Expression for function " + getName() + " must return UInt8, found "
+ return_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const auto * first_array_type = checkAndGetDataType<DataTypeArray>(arguments[1].type.get());
return Impl::getReturnType(return_type, first_array_type->getNestedType());
if (Impl::isFolding())
{
const auto accum_type = arguments.back().type;
return Impl::getReturnType(return_type, accum_type);
}
else
{
const auto * first_array_type = checkAndGetDataType<DataTypeArray>(arguments[1].type.get());
return Impl::getReturnType(return_type, first_array_type->getNestedType());
}
}
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override
{
std::cerr << " *** FOLDING" << std::endl;
std::cerr << " isFolding(): " << (Impl::isFolding() ? "yes" : "no-") << std::endl;
std::cerr << " arguments.size() = " << arguments.size() << std::endl;
if (arguments.size() == 1)
{
ColumnPtr column_array_ptr = arguments[0].column;
@ -155,7 +161,7 @@ public:
{
const ColumnConst * column_const_array = checkAndGetColumnConst<ColumnArray>(column_array_ptr.get());
if (!column_const_array)
throw Exception("Expected array column, found " + column_array_ptr->getName(), ErrorCodes::ILLEGAL_COLUMN);
throw Exception("X1 Expected array column, found " + column_array_ptr->getName(), ErrorCodes::ILLEGAL_COLUMN);
column_array_ptr = column_const_array->convertToFullColumn();
column_array = assert_cast<const ColumnArray *>(column_array_ptr.get());
}
@ -181,10 +187,12 @@ public:
ColumnPtr column_first_array_ptr;
const ColumnArray * column_first_array = nullptr;
size_t arguments_to_skip = Impl::isFolding() ? 1 : 0;
ColumnsWithTypeAndName arrays;
arrays.reserve(arguments.size() - 1);
for (size_t i = 1; i < arguments.size(); ++i)
for (size_t i = 1; i < arguments.size() - arguments_to_skip; ++i)
{
const auto & array_with_type_and_name = arguments[i];
@ -198,21 +206,13 @@ public:
{
const ColumnConst * column_const_array = checkAndGetColumnConst<ColumnArray>(column_array_ptr.get());
if (!column_const_array)
throw Exception("Expected array column, found " + column_array_ptr->getName(), ErrorCodes::ILLEGAL_COLUMN);
throw Exception("X2 Expected array column, found " + column_array_ptr->getName(), ErrorCodes::ILLEGAL_COLUMN);
column_array_ptr = recursiveRemoveLowCardinality(column_const_array->convertToFullColumn());
column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
}
ColumnPtr int_column = column_array->getDataPtr();
ColumnPtr ca_ptr = ColumnArray::create( int_column->cut(0, my_min(1, int_column->size())),
column_array->getOffsetsPtr());
const auto * ca = checkAndGetColumn<ColumnArray>(ca_ptr.get());
column_array_ptr = ca_ptr;
column_array = ca;
if (!array_type)
throw Exception("Expected array type, found " + array_type_ptr->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
throw Exception("X3 Expected array type, found " + array_type_ptr->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!offsets_column)
{
@ -236,17 +236,84 @@ public:
recursiveRemoveLowCardinality(array_type->getNestedType()),
array_with_type_and_name.name));
}
if (Impl::isFolding())
arrays.emplace_back(arguments[arguments.size() - 1]); // TODO .last()
std::cerr << " arrays.size() = " << arrays.size() << std::endl;
std::cerr << " column_first_array->getData().size() = " << column_first_array->getData().size() << std::endl;
/// Put all the necessary columns multiplied by the sizes of arrays into the columns.
auto replicated_column_function_ptr = IColumn::mutate(column_function->replicate(column_first_array->getOffsets()));
auto * replicated_column_function = typeid_cast<ColumnFunction *>(replicated_column_function_ptr.get());
replicated_column_function->appendArguments(arrays);
if (Impl::isFolding() && (column_first_array->getData().size() > 0)) // TODO .size() -> .empty()
{
auto lambda_result = replicated_column_function->reduce().column;
if (lambda_result->lowCardinality())
lambda_result = lambda_result->convertToFullColumnIfLowCardinality();
return Impl::execute(*column_first_array, lambda_result);
ColumnWithTypeAndName accumulator = arguments.back();
ColumnPtr res;
for(size_t i = 0; i < column_first_array->getData().size(); ++i)
{
std::cerr << " ----- iteration " << i << " ------" << std::endl;
// Make slice of input arrays and accumulator for lambda
ColumnsWithTypeAndName iter_arrays;
iter_arrays.reserve(arrays.size());
for(size_t j = 0; j < arrays.size() - 1; ++j)
{
auto const & arr = arrays[j];
std::cerr << " " << j << ") " << 1 << std::endl;
/*
const ColumnArray * arr_array = checkAndGetColumn<ColumnArray>(arr.column.get());
std::cerr << " " << j << ") " << 1 << " " << arr_array << std::endl;
std::cerr << " " << j << ") " << 2 << std::endl;
const ColumnPtr & nested_column_x = arr_array->getData().cut(i, 1);
std::cerr << " " << j << ") " << 3 << std::endl;
const ColumnPtr & offsets_column_x = ColumnArray::ColumnOffsets::create(1, 1);
std::cerr << " " << j << ") " << 4 << std::endl;
auto new_arr_array = ColumnArray::create(nested_column_x, offsets_column_x);
std::cerr << " " << j << ") " << 5 << std::endl;
*/
iter_arrays.emplace_back(ColumnWithTypeAndName(arr.column->cut(i, 1),
arr.type,
arr.name));
std::cerr << " " << j << ") " << 6 << std::endl;
}
iter_arrays.emplace_back(accumulator);
// ----
std::cerr << " formed" << std::endl;
auto replicated_column_function_ptr = IColumn::mutate(column_function->replicate(IColumn::Offsets(1, 1)));
auto * replicated_column_function = typeid_cast<ColumnFunction *>(replicated_column_function_ptr.get());
std::cerr << " pre append" << std::endl;
replicated_column_function->appendArguments(iter_arrays);
std::cerr << " post append" << std::endl;
auto lambda_result = replicated_column_function->reduce().column;
if (lambda_result->lowCardinality())
lambda_result = lambda_result->convertToFullColumnIfLowCardinality();
std::cerr << " pre execute" << std::endl;
res = Impl::execute(*column_first_array, lambda_result); // TODO column_first_array
std::cerr << " post execute : res " << res->dumpStructure() << std::endl;
std::cerr << " post execute : res[0] " << (*res)[0].dump() << std::endl;
// ~~~
// ~~~
// ~~~
accumulator.column = res;
}
return res;
}
else
{
/// Put all the necessary columns multiplied by the sizes of arrays into the columns.
auto replicated_column_function_ptr = IColumn::mutate(column_function->replicate(column_first_array->getOffsets()));
auto * replicated_column_function = typeid_cast<ColumnFunction *>(replicated_column_function_ptr.get());
replicated_column_function->appendArguments(arrays);
auto lambda_result = replicated_column_function->reduce().column;
if (lambda_result->lowCardinality())
lambda_result = lambda_result->convertToFullColumnIfLowCardinality();
ColumnPtr res = Impl::execute(*column_first_array, lambda_result);
std::cerr << " ^^^ FOLDING" << std::endl;
return res;
//return Impl::execute(*column_first_array, lambda_result);
}
}
}
};

View File

@ -76,6 +76,7 @@ struct ArrayAggregateImpl
static bool needBoolean() { return false; }
static bool needExpression() { return false; }
static bool needOneArray() { return false; }
static bool isFolding() { return false; }
static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypePtr & /*array_element*/)
{

View File

@ -19,6 +19,7 @@ struct ArrayAllImpl
static bool needBoolean() { return true; }
static bool needExpression() { return false; }
static bool needOneArray() { return false; }
static bool isFolding() { return false; }
static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & /*array_element*/)
{

View File

@ -20,6 +20,7 @@ struct ArrayCompactImpl
static bool needBoolean() { return false; }
static bool needExpression() { return false; }
static bool needOneArray() { return false; }
static bool isFolding() { return false; }
static DataTypePtr getReturnType(const DataTypePtr & nested_type, const DataTypePtr &)
{

View File

@ -19,6 +19,7 @@ struct ArrayCountImpl
static bool needBoolean() { return true; }
static bool needExpression() { return false; }
static bool needOneArray() { return false; }
static bool isFolding() { return false; }
static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & /*array_element*/)
{

View File

@ -20,6 +20,7 @@ struct ArrayCumSumImpl
static bool needBoolean() { return false; }
static bool needExpression() { return false; }
static bool needOneArray() { return false; }
static bool isFolding() { return false; }
static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypePtr & /*array_element*/)
{

View File

@ -23,6 +23,7 @@ struct ArrayCumSumNonNegativeImpl
static bool needBoolean() { return false; }
static bool needExpression() { return false; }
static bool needOneArray() { return false; }
static bool isFolding() { return false; }
static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypePtr & /*array_element*/)
{

View File

@ -23,6 +23,7 @@ struct ArrayDifferenceImpl
static bool needBoolean() { return false; }
static bool needExpression() { return false; }
static bool needOneArray() { return false; }
static bool isFolding() { return false; }
static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypePtr & /*array_element*/)
{

View File

@ -19,6 +19,7 @@ struct ArrayExistsImpl
static bool needBoolean() { return true; }
static bool needExpression() { return false; }
static bool needOneArray() { return false; }
static bool isFolding() { return false; }
static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & /*array_element*/)
{

View File

@ -22,6 +22,7 @@ struct ArrayFillImpl
static bool needBoolean() { return true; }
static bool needExpression() { return true; }
static bool needOneArray() { return false; }
static bool isFolding() { return false; }
static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & array_element)
{

View File

@ -18,6 +18,7 @@ struct ArrayFilterImpl
static bool needBoolean() { return true; }
static bool needExpression() { return true; }
static bool needOneArray() { return false; }
static bool isFolding() { return false; }
static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & array_element)
{

View File

@ -16,6 +16,7 @@ struct ArrayFirstImpl
static bool needBoolean() { return false; }
static bool needExpression() { return true; }
static bool needOneArray() { return false; }
static bool isFolding() { return false; }
static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & array_element)
{

View File

@ -16,6 +16,7 @@ struct ArrayFirstIndexImpl
static bool needBoolean() { return false; }
static bool needExpression() { return true; }
static bool needOneArray() { return false; }
static bool isFolding() { return false; }
static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & /*array_element*/)
{

View File

@ -0,0 +1,58 @@
#include "FunctionArrayMapped.h"
#include <Functions/FunctionFactory.h>
namespace DB
{
/** arrayFold(x1,...,xn,accum -> expression, array1,...,arrayn, init_accum) - apply the expression to each element of the array (or set of parallel arrays).
*/
struct ArrayFoldImpl
{
static bool needBoolean() { return false; }
static bool needExpression() { return true; }
static bool needOneArray() { return false; }
static bool isFolding() { return true; }
static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & accum_type)
{
return accum_type;
}
static ColumnPtr execute(const ColumnArray & array, ColumnPtr mapped)
{
std::cerr << " **** ArrayFoldImpl **** " << std::endl;
std::cerr << " array: " << array.dumpStructure() << std::endl;
std::cerr << " mapped: " << mapped->dumpStructure() << std::endl;
std::cerr << " mapped[0]: " << (*mapped)[0].dump() << std::endl;
// std::cerr << " mapped[1]: " << (*mapped)[1].dump() << std::endl;
//
ColumnPtr res;
if (mapped->size() == 0)
{
res = mapped;
}
else
{
res = mapped->cut(0, 1);
}
std::cerr << " ^^^^ ArrayFoldImpl ^^^^" << std::endl;
return res;
// return ColumnArray::create(mapped->convertToFullColumnIfConst(), array.getOffsetsPtr());
// return ColumnArray::create(mapped->convertToFullColumnIfConst(), array.getOffsetsPtr());
}
};
struct NameArrayFold { static constexpr auto name = "arrayFold"; };
using FunctionArrayFold = FunctionArrayMapped<ArrayFoldImpl, NameArrayFold>;
void registerFunctionArrayFold(FunctionFactory & factory)
{
factory.registerFunction<FunctionArrayFold>();
}
}

View File

@ -15,6 +15,8 @@ struct ArrayMapImpl
static bool needExpression() { return true; }
/// true if the array must be exactly one.
static bool needOneArray() { return false; }
/// true if function do folding
static bool isFolding() { return false; }
static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypePtr & /*array_element*/)
{

View File

@ -13,6 +13,7 @@ struct ArraySortImpl
static bool needBoolean() { return false; }
static bool needExpression() { return false; }
static bool needOneArray() { return false; }
static bool isFolding() { return false; }
static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & array_element)
{

View File

@ -17,6 +17,7 @@ struct ArraySplitImpl
static bool needBoolean() { return true; }
static bool needExpression() { return true; }
static bool needOneArray() { return false; }
static bool isFolding() { return false; }
static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & array_element)
{

View File

@ -4,6 +4,7 @@ namespace DB
class FunctionFactory;
void registerFunctionArrayMap(FunctionFactory & factory);
void registerFunctionArrayFold(FunctionFactory & factory);
void registerFunctionArrayFilter(FunctionFactory & factory);
void registerFunctionArrayCount(FunctionFactory & factory);
void registerFunctionArrayExists(FunctionFactory & factory);
@ -22,6 +23,7 @@ void registerFunctionArrayDifference(FunctionFactory & factory);
void registerFunctionsHigherOrder(FunctionFactory & factory)
{
registerFunctionArrayMap(factory);
registerFunctionArrayFold(factory);
registerFunctionArrayFilter(factory);
registerFunctionArrayCount(factory);
registerFunctionArrayExists(factory);

View File

@ -144,6 +144,7 @@ SRCS(
array/arrayFirst.cpp
array/arrayFirstIndex.cpp
array/arrayFlatten.cpp
array/arrayFold.cpp
array/arrayIntersect.cpp
array/arrayJoin.cpp
array/arrayMap.cpp