Merge pull request #33698 from hexiaoting/dev-map-funciton

This commit is contained in:
Vladimir C 2022-03-04 11:51:17 +01:00 committed by GitHub
commit 79b21c80b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 574 additions and 93 deletions

View File

@ -36,8 +36,8 @@ public:
static Ptr create(const ColumnPtr & column) { return ColumnMap::create(column->assumeMutable()); }
static Ptr create(ColumnPtr && arg) { return create(arg); }
template <typename Arg, typename = typename std::enable_if<std::is_rvalue_reference<Arg &&>::value>::type>
static MutablePtr create(Arg && arg) { return Base::create(std::forward<Arg>(arg)); }
template <typename ... Args, typename = typename std::enable_if<IsMutableColumns<Args ...>::value>::type>
static MutablePtr create(Args &&... args) { return Base::create(std::forward<Args>(args)...); }
std::string getName() const override;
const char * getFamilyName() const override { return "Map"; }

View File

@ -1,18 +1,29 @@
#pragma once
#include <type_traits>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnFunction.h>
#include <Columns/ColumnMap.h>
#include <Columns/IColumn.h>
#include <Common/Exception.h>
#include <Common/assert_cast.h>
#include <Common/typeid_cast.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeFunction.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnFunction.h>
#include <Common/typeid_cast.h>
#include <Common/assert_cast.h>
#include <Functions/IFunction.h>
#include <DataTypes/DataTypeMap.h>
#include <Functions/FunctionHelpers.h>
#include <IO/WriteHelpers.h>
#include <Functions/IFunction.h>
#include <Interpreters/Context_fwd.h>
#include <IO/WriteHelpers.h>
namespace DB
{
@ -21,11 +32,38 @@ namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int LOGICAL_ERROR;
extern const int SIZES_OF_ARRAYS_DOESNT_MATCH;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
template <typename T>
ColumnPtr getOffsetsPtr(const T & column)
{
if constexpr (std::is_same_v<T, ColumnArray>)
{
return column.getOffsetsPtr();
}
else // ColumnMap
{
return column.getNestedColumn().getOffsetsPtr();
}
}
template <typename T>
const IColumn::Offsets & getOffsets(const T & column)
{
if constexpr (std::is_same_v<T, ColumnArray>)
{
return column.getOffsets();
}
else // ColumnMap
{
return column.getNestedColumn().getOffsets();
}
}
/** Higher-order functions for arrays.
* These functions optionally apply a map (transform) to array (or multiple arrays of identical size) by lambda function,
* and return some result based on that transformation.
@ -60,29 +98,42 @@ public:
void getLambdaArgumentTypes(DataTypes & arguments) const override
{
if (arguments.empty())
throw Exception("Function " + getName() + " needs at least one argument; passed "
+ toString(arguments.size()) + ".",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} needs at least one argument, passed {}", getName(), arguments.size());
if (arguments.size() == 1)
throw Exception("Function " + getName() + " needs at least one array argument.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} needs at least one argument with data", getName());
DataTypes nested_types(arguments.size() - 1);
for (size_t i = 0; i < nested_types.size(); ++i)
if (arguments.size() > 2 && Impl::needOneArray())
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} needs one argument with data", getName());
size_t nested_types_count = std::is_same_v<typename Impl::data_type, DataTypeMap> ? (arguments.size() - 1) * 2 : (arguments.size() - 1);
DataTypes nested_types(nested_types_count);
for (size_t i = 0; i < arguments.size() - 1; ++i)
{
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(&*arguments[i + 1]);
const auto * array_type = checkAndGetDataType<typename Impl::data_type>(&*arguments[i + 1]);
if (!array_type)
throw Exception("Argument " + toString(i + 2) + " of function " + getName() + " must be array. Found "
+ arguments[i + 1]->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
nested_types[i] = recursiveRemoveLowCardinality(array_type->getNestedType());
if constexpr (std::is_same_v<typename Impl::data_type, DataTypeMap>)
{
nested_types[2 * i] = recursiveRemoveLowCardinality(array_type->getKeyType());
nested_types[2 * i + 1] = recursiveRemoveLowCardinality(array_type->getValueType());
}
else if constexpr (std::is_same_v<typename Impl::data_type, DataTypeArray>)
{
nested_types[i] = recursiveRemoveLowCardinality(array_type->getNestedType());
}
}
const DataTypeFunction * function_type = checkAndGetDataType<DataTypeFunction>(arguments[0].get());
if (!function_type || function_type->getArgumentTypes().size() != nested_types.size())
throw Exception("First argument for this overload of " + getName() + " must be a function with "
+ toString(nested_types.size()) + " arguments. Found "
+ arguments[0]->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"First argument for this overload of {} must be a function with {} arguments, found {} instead",
getName(), nested_types.size(), arguments[0]->getName());
arguments[0] = std::make_shared<DataTypeFunction>(nested_types);
}
@ -91,37 +142,39 @@ public:
{
size_t min_args = Impl::needExpression() ? 2 : 1;
if (arguments.size() < min_args)
throw Exception("Function " + getName() + " needs at least "
+ toString(min_args) + " argument; passed "
+ toString(arguments.size()) + ".",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} needs at least {} argument, passed {}",
getName(), min_args, arguments.size());
if (arguments.size() == 1)
if ((arguments.size() == 1) && std::is_same_v<typename Impl::data_type, DataTypeArray>)
{
const auto * array_type = checkAndGetDataType<DataTypeArray>(arguments[0].type.get());
const auto * data_type = checkAndGetDataType<typename Impl::data_type>(arguments[0].type.get());
if (!array_type)
if (!data_type)
throw Exception("The only argument for function " + getName() + " must be array. Found "
+ arguments[0].type->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
+ arguments[0].type->getName() + " instead", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
DataTypePtr nested_type = array_type->getNestedType();
DataTypePtr nested_type = data_type->getNestedType();
if (Impl::needBoolean() && !WhichDataType(nested_type).isUInt8())
throw Exception("The only argument for function " + getName() + " must be array of UInt8. Found "
+ arguments[0].type->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
+ arguments[0].type->getName() + " instead", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return Impl::getReturnType(nested_type, nested_type);
if constexpr (std::is_same_v<typename Impl::data_type, DataTypeArray>)
return Impl::getReturnType(nested_type, nested_type);
else
throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Unreachable code reached");
}
else
{
if (arguments.size() > 2 && Impl::needOneArray())
throw Exception("Function " + getName() + " needs one array argument.",
throw Exception("Function " + getName() + " needs one argument with data",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
const auto * data_type_function = checkAndGetDataType<DataTypeFunction>(arguments[0].type.get());
if (!data_type_function)
throw Exception("First argument for function " + getName() + " must be a function.",
throw Exception("First argument for function " + getName() + " must be a function",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
/// The types of the remaining arguments are already checked in getLambdaArgumentTypes.
@ -131,9 +184,28 @@ public:
throw Exception("Expression for function " + getName() + " must return UInt8, found "
+ return_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const auto * first_array_type = checkAndGetDataType<DataTypeArray>(arguments[1].type.get());
static_assert(
std::is_same_v<typename Impl::data_type, DataTypeMap> ||
std::is_same_v<typename Impl::data_type, DataTypeArray>,
"unsupported type");
return Impl::getReturnType(return_type, first_array_type->getNestedType());
if (arguments.size() < 2)
{
throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "{}", arguments.size());
}
const auto * first_array_type = checkAndGetDataType<typename Impl::data_type>(arguments[1].type.get());
if (!first_array_type)
throw DB::Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Unsupported type {}", arguments[1].type->getName());
if constexpr (std::is_same_v<typename Impl::data_type, DataTypeArray>)
return Impl::getReturnType(return_type, first_array_type->getNestedType());
if constexpr (std::is_same_v<typename Impl::data_type, DataTypeMap>)
return Impl::getReturnType(return_type, first_array_type->getKeyValueTypes());
throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Unreachable code reached");
}
}
@ -142,18 +214,25 @@ public:
if (arguments.size() == 1)
{
ColumnPtr column_array_ptr = arguments[0].column;
const auto * column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
const auto * column_array = checkAndGetColumn<typename Impl::column_type>(column_array_ptr.get());
if (!column_array)
{
const ColumnConst * column_const_array = checkAndGetColumnConst<ColumnArray>(column_array_ptr.get());
const ColumnConst * column_const_array = checkAndGetColumnConst<typename Impl::column_type>(column_array_ptr.get());
if (!column_const_array)
throw Exception("Expected array column, found " + column_array_ptr->getName(), ErrorCodes::ILLEGAL_COLUMN);
column_array_ptr = column_const_array->convertToFullColumn();
column_array = assert_cast<const ColumnArray *>(column_array_ptr.get());
column_array = assert_cast<const typename Impl::column_type *>(column_array_ptr.get());
}
return Impl::execute(*column_array, column_array->getDataPtr());
if constexpr (std::is_same_v<typename Impl::column_type, ColumnMap>)
{
return Impl::execute(*column_array, column_array->getNestedColumn().getDataPtr());
}
else
{
return Impl::execute(*column_array, column_array->getDataPtr());
}
}
else
{
@ -172,7 +251,7 @@ public:
ColumnPtr offsets_column;
ColumnPtr column_first_array_ptr;
const ColumnArray * column_first_array = nullptr;
const typename Impl::column_type * column_first_array = nullptr;
ColumnsWithTypeAndName arrays;
arrays.reserve(arguments.size() - 1);
@ -182,18 +261,18 @@ public:
const auto & array_with_type_and_name = arguments[i];
ColumnPtr column_array_ptr = array_with_type_and_name.column;
const auto * column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
const auto * column_array = checkAndGetColumn<typename Impl::column_type>(column_array_ptr.get());
const DataTypePtr & array_type_ptr = array_with_type_and_name.type;
const auto * array_type = checkAndGetDataType<DataTypeArray>(array_type_ptr.get());
const auto * array_type = checkAndGetDataType<typename Impl::data_type>(array_type_ptr.get());
if (!column_array)
{
const ColumnConst * column_const_array = checkAndGetColumnConst<ColumnArray>(column_array_ptr.get());
const ColumnConst * column_const_array = checkAndGetColumnConst<typename Impl::column_type>(column_array_ptr.get());
if (!column_const_array)
throw Exception("Expected array column, found " + column_array_ptr->getName(), ErrorCodes::ILLEGAL_COLUMN);
column_array_ptr = recursiveRemoveLowCardinality(column_const_array->convertToFullColumn());
column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
column_array = checkAndGetColumn<typename Impl::column_type>(column_array_ptr.get());
}
if (!array_type)
@ -201,13 +280,13 @@ public:
if (!offsets_column)
{
offsets_column = column_array->getOffsetsPtr();
offsets_column = getOffsetsPtr(*column_array);
}
else
{
/// The first condition is optimization: do not compare data if the pointers are equal.
if (column_array->getOffsetsPtr() != offsets_column
&& column_array->getOffsets() != typeid_cast<const ColumnArray::ColumnOffsets &>(*offsets_column).getData())
if (getOffsetsPtr(*column_array) != offsets_column
&& getOffsets(*column_array) != typeid_cast<const ColumnArray::ColumnOffsets &>(*offsets_column).getData())
throw Exception("Arrays passed to " + getName() + " must have equal size", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH);
}
@ -217,13 +296,23 @@ public:
column_first_array = column_array;
}
arrays.emplace_back(ColumnWithTypeAndName(column_array->getDataPtr(),
recursiveRemoveLowCardinality(array_type->getNestedType()),
array_with_type_and_name.name));
if constexpr (std::is_same_v<DataTypeMap, typename Impl::data_type>)
{
arrays.emplace_back(ColumnWithTypeAndName(
column_array->getNestedData().getColumnPtr(0), recursiveRemoveLowCardinality(array_type->getKeyType()), array_with_type_and_name.name+".key"));
arrays.emplace_back(ColumnWithTypeAndName(
column_array->getNestedData().getColumnPtr(1), recursiveRemoveLowCardinality(array_type->getValueType()), array_with_type_and_name.name+".value"));
}
else
{
arrays.emplace_back(ColumnWithTypeAndName(column_array->getDataPtr(),
recursiveRemoveLowCardinality(array_type->getNestedType()),
array_with_type_and_name.name));
}
}
/// Put all the necessary columns multiplied by the sizes of arrays into the columns.
auto replicated_column_function_ptr = IColumn::mutate(column_function->replicate(column_first_array->getOffsets()));
auto replicated_column_function_ptr = IColumn::mutate(column_function->replicate(getOffsets(*column_first_array)));
auto * replicated_column_function = typeid_cast<ColumnFunction *>(replicated_column_function_ptr.get());
replicated_column_function->appendArguments(arrays);

View File

@ -1,12 +1,18 @@
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypeDateTime64.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnDecimal.h>
#include "FunctionArrayMapped.h"
#include <Functions/FunctionFactory.h>
#include <base/defines.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeDateTime64.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include "FunctionArrayMapped.h"
namespace DB
{
@ -83,6 +89,9 @@ using ArrayAggregateResult = typename ArrayAggregateResultImpl<ArrayElement, ope
template<AggregateOperation aggregate_operation>
struct ArrayAggregateImpl
{
using column_type = ColumnArray;
using data_type = DataTypeArray;
static bool needBoolean() { return false; }
static bool needExpression() { return false; }
static bool needOneArray() { return false; }

View File

@ -1,8 +1,8 @@
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnsNumber.h>
#include "FunctionArrayMapped.h"
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include "FunctionArrayMapped.h"
namespace DB
{
@ -16,6 +16,9 @@ namespace ErrorCodes
*/
struct ArrayAllImpl
{
using column_type = ColumnArray;
using data_type = DataTypeArray;
static bool needBoolean() { return true; }
static bool needExpression() { return false; }
static bool needOneArray() { return false; }

View File

@ -1,10 +1,13 @@
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnsNumber.h>
#include <Common/HashTable/HashTable.h>
#include <Functions/array/FunctionArrayMapped.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <Functions/array/FunctionArrayMapped.h>
namespace DB
@ -16,6 +19,9 @@ namespace ErrorCodes
struct ArrayCompactImpl
{
using column_type = ColumnArray;
using data_type = DataTypeArray;
static bool needBoolean() { return false; }
static bool needExpression() { return false; }
static bool needOneArray() { return false; }

View File

@ -1,8 +1,9 @@
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnsNumber.h>
#include "FunctionArrayMapped.h"
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include "FunctionArrayMapped.h"
namespace DB
{
@ -16,6 +17,9 @@ namespace ErrorCodes
*/
struct ArrayCountImpl
{
using column_type = ColumnArray;
using data_type = DataTypeArray;
static bool needBoolean() { return true; }
static bool needExpression() { return false; }
static bool needOneArray() { return false; }

View File

@ -1,10 +1,11 @@
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnDecimal.h>
#include "FunctionArrayMapped.h"
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include "FunctionArrayMapped.h"
namespace DB
{
@ -17,6 +18,9 @@ namespace ErrorCodes
struct ArrayCumSumImpl
{
using column_type = ColumnArray;
using data_type = DataTypeArray;
static bool needBoolean() { return false; }
static bool needExpression() { return false; }
static bool needOneArray() { return false; }

View File

@ -1,10 +1,10 @@
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnDecimal.h>
#include "FunctionArrayMapped.h"
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include "FunctionArrayMapped.h"
namespace DB
{
@ -19,6 +19,9 @@ namespace ErrorCodes
*/
struct ArrayCumSumNonNegativeImpl
{
using column_type = ColumnArray;
using data_type = DataTypeArray;
static bool needBoolean() { return false; }
static bool needExpression() { return false; }
static bool needOneArray() { return false; }

View File

@ -1,10 +1,11 @@
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnDecimal.h>
#include "FunctionArrayMapped.h"
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include "FunctionArrayMapped.h"
namespace DB
{
@ -20,6 +21,9 @@ namespace ErrorCodes
*/
struct ArrayDifferenceImpl
{
using column_type = ColumnArray;
using data_type = DataTypeArray;
static bool needBoolean() { return false; }
static bool needExpression() { return false; }
static bool needOneArray() { return false; }

View File

@ -1,8 +1,9 @@
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnsNumber.h>
#include "FunctionArrayMapped.h"
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include "FunctionArrayMapped.h"
namespace DB
{
@ -16,6 +17,9 @@ namespace ErrorCodes
*/
struct ArrayExistsImpl
{
using column_type = ColumnArray;
using data_type = DataTypeArray;
static bool needBoolean() { return true; }
static bool needExpression() { return false; }
static bool needOneArray() { return false; }

View File

@ -1,8 +1,9 @@
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnsNumber.h>
#include "FunctionArrayMapped.h"
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include "FunctionArrayMapped.h"
namespace DB
{
@ -19,6 +20,9 @@ namespace ErrorCodes
template <bool reverse>
struct ArrayFillImpl
{
using column_type = ColumnArray;
using data_type = DataTypeArray;
static bool needBoolean() { return true; }
static bool needExpression() { return true; }
static bool needOneArray() { return false; }

View File

@ -1,8 +1,9 @@
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnsNumber.h>
#include "FunctionArrayMapped.h"
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include "FunctionArrayMapped.h"
namespace DB
{
@ -15,6 +16,9 @@ namespace ErrorCodes
*/
struct ArrayFilterImpl
{
using column_type = ColumnArray;
using data_type = DataTypeArray;
static bool needBoolean() { return true; }
static bool needExpression() { return true; }
static bool needOneArray() { return false; }

View File

@ -1,8 +1,9 @@
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnsNumber.h>
#include "FunctionArrayMapped.h"
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include "FunctionArrayMapped.h"
namespace DB
{
@ -20,6 +21,9 @@ enum class ArrayFirstLastStrategy
template <ArrayFirstLastStrategy strategy>
struct ArrayFirstLastImpl
{
using column_type = ColumnArray;
using data_type = DataTypeArray;
static bool needBoolean() { return false; }
static bool needExpression() { return true; }
static bool needOneArray() { return false; }

View File

@ -1,8 +1,9 @@
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnsNumber.h>
#include "FunctionArrayMapped.h"
#include <Functions/FunctionFactory.h>
#include "FunctionArrayMapped.h"
namespace DB
{
@ -20,6 +21,9 @@ enum class ArrayFirstLastIndexStrategy
template <ArrayFirstLastIndexStrategy strategy>
struct ArrayFirstLastIndexImpl
{
using column_type = ColumnArray;
using data_type = DataTypeArray;
static bool needBoolean() { return false; }
static bool needExpression() { return true; }
static bool needOneArray() { return false; }

View File

@ -1,14 +1,18 @@
#include "FunctionArrayMapped.h"
#include <Functions/FunctionFactory.h>
#include "FunctionArrayMapped.h"
namespace DB
{
/** arrayMap(x1,...,xn -> expression, array1,...,arrayn) - apply the expression to each element of the array (or set of parallel arrays).
/** arrayMap(x1, ..., xn -> expression, array1, ..., arrayn) - apply the expression to each element of the array (or set of parallel arrays).
*/
struct ArrayMapImpl
{
using column_type = ColumnArray;
using data_type = DataTypeArray;
/// true if the expression (for an overload of f(expression, arrays)) or an array (for f(array)) should be boolean.
static bool needBoolean() { return false; }
/// true if the f(array) overload is unavailable.

View File

@ -1,8 +1,8 @@
#include "FunctionArrayMapped.h"
#include <base/sort.h>
#include <Functions/FunctionFactory.h>
namespace DB
{
@ -11,6 +11,9 @@ namespace DB
template <bool positive>
struct ArraySortImpl
{
using column_type = ColumnArray;
using data_type = DataTypeArray;
static bool needBoolean() { return false; }
static bool needExpression() { return false; }
static bool needOneArray() { return false; }

View File

@ -1,8 +1,9 @@
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnsNumber.h>
#include "FunctionArrayMapped.h"
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include "FunctionArrayMapped.h"
namespace DB
{
@ -14,6 +15,9 @@ namespace ErrorCodes
template <bool reverse>
struct ArraySplitImpl
{
using column_type = ColumnArray;
using data_type = DataTypeArray;
static bool needBoolean() { return true; }
static bool needExpression() { return true; }
static bool needOneArray() { return false; }

View File

@ -518,6 +518,115 @@ public:
}
};
class FunctionMapUpdate : public IFunction
{
public:
static constexpr auto name = "mapUpdate";
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionMapUpdate>(); }
String getName() const override
{
return name;
}
size_t getNumberOfArguments() const override { return 2; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if (arguments.size() != 2)
throw Exception("Number of arguments for function " + getName() + " doesn't match: passed "
+ toString(arguments.size()) + ", should be 2",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
const DataTypeMap * left = checkAndGetDataType<DataTypeMap>(arguments[0].type.get());
const DataTypeMap * right = checkAndGetDataType<DataTypeMap>(arguments[1].type.get());
if (!left || !right)
throw Exception{"The two arguments for function " + getName() + " must be both Map type",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!left->getKeyType()->equals(*right->getKeyType()) || !left->getValueType()->equals(*right->getValueType()))
throw Exception{"The Key And Value type of Map for function " + getName() + " must be the same",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
return std::make_shared<DataTypeMap>(left->getKeyType(), left->getValueType());
}
bool useDefaultImplementationForConstants() const override { return true; }
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override
{
const ColumnMap * col_map_left = typeid_cast<const ColumnMap *>(arguments[0].column.get());
const auto * col_const_map_left = checkAndGetColumnConst<ColumnMap>(arguments[0].column.get());
if (col_const_map_left)
col_map_left = typeid_cast<const ColumnMap *>(&col_const_map_left->getDataColumn());
if (!col_map_left)
return nullptr;
const ColumnMap * col_map_right = typeid_cast<const ColumnMap *>(arguments[1].column.get());
const auto * col_const_map_right = checkAndGetColumnConst<ColumnMap>(arguments[1].column.get());
if (col_const_map_right)
col_map_right = typeid_cast<const ColumnMap *>(&col_const_map_right->getDataColumn());
if (!col_map_right)
return nullptr;
const auto & nested_column_left = col_map_left->getNestedColumn();
const auto & keys_data_left = col_map_left->getNestedData().getColumn(0);
const auto & values_data_left = col_map_left->getNestedData().getColumn(1);
const auto & offsets_left = nested_column_left.getOffsets();
const auto & nested_column_right = col_map_right->getNestedColumn();
const auto & keys_data_right = col_map_right->getNestedData().getColumn(0);
const auto & values_data_right = col_map_right->getNestedData().getColumn(1);
const auto & offsets_right = nested_column_right.getOffsets();
const auto & result_type_map = static_cast<const DataTypeMap &>(*result_type);
const DataTypePtr & key_type = result_type_map.getKeyType();
const DataTypePtr & value_type = result_type_map.getValueType();
MutableColumnPtr keys_data = key_type->createColumn();
MutableColumnPtr values_data = value_type->createColumn();
MutableColumnPtr offsets = DataTypeNumber<IColumn::Offset>().createColumn();
IColumn::Offset current_offset = 0;
for (size_t idx = 0; idx < input_rows_count; ++idx)
{
for (size_t i = offsets_left[idx - 1]; i < offsets_left[idx]; ++i)
{
bool matched = false;
auto key = keys_data_left.getDataAt(i);
for (size_t j = offsets_right[idx - 1]; j < offsets_right[idx]; ++j)
{
if (keys_data_right.getDataAt(j).toString() == key.toString())
{
matched = true;
break;
}
}
if (!matched)
{
keys_data->insertFrom(keys_data_left, i);
values_data->insertFrom(values_data_left, i);
++current_offset;
}
}
for (size_t j = offsets_right[idx - 1]; j < offsets_right[idx]; ++j)
{
keys_data->insertFrom(keys_data_right, j);
values_data->insertFrom(values_data_right, j);
++current_offset;
}
offsets->insert(current_offset);
}
auto nested_column = ColumnArray::create(
ColumnTuple::create(Columns{std::move(keys_data), std::move(values_data)}),
std::move(offsets));
return ColumnMap::create(nested_column);
}
};
}
void registerFunctionsMap(FunctionFactory & factory)
@ -528,6 +637,7 @@ void registerFunctionsMap(FunctionFactory & factory)
factory.registerFunction<FunctionMapValues>();
factory.registerFunction<FunctionMapContainsKeyLike>();
factory.registerFunction<FunctionExtractKeyLike>();
factory.registerFunction<FunctionMapUpdate>();
}
}

144
src/Functions/mapFilter.cpp Normal file
View File

@ -0,0 +1,144 @@
#include <Columns/ColumnMap.h>
#include <Columns/ColumnTuple.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypeTuple.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/array/FunctionArrayMapped.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
/** Higher-order functions for map.
* These functions optionally apply a map by lambda function,
* and return some result based on that transformation.
*/
/** mapFilter((k, v) -> predicate, map) - leave in the map only the kv elements for which the expression is true.
*/
struct MapFilterImpl
{
using data_type = DataTypeMap;
using column_type = ColumnMap;
static constexpr auto name = "mapFilter";
static bool needBoolean() { return true; }
static bool needExpression() { return true; }
static bool needOneArray() { return true; }
static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypes & elems)
{
return std::make_shared<DataTypeMap>(elems);
}
/// If there are several arrays, the first one is passed here.
static ColumnPtr execute(const ColumnMap & map_column, ColumnPtr mapped)
{
const ColumnUInt8 * column_filter = typeid_cast<const ColumnUInt8 *>(&*mapped);
if (!column_filter)
{
const auto * column_filter_const = checkAndGetColumnConst<ColumnUInt8>(&*mapped);
if (!column_filter_const)
throw Exception("Unexpected type of filter column", ErrorCodes::ILLEGAL_COLUMN);
if (column_filter_const->getValue<UInt8>())
return map_column.clone();
else
{
const auto * column_array = typeid_cast<const ColumnArray *>(map_column.getNestedColumnPtr().get());
const auto * column_tuple = typeid_cast<const ColumnTuple *>(column_array->getDataPtr().get());
ColumnPtr keys = column_tuple->getColumnPtr(0)->cloneEmpty();
ColumnPtr values = column_tuple->getColumnPtr(1)->cloneEmpty();
return ColumnMap::create(keys, values, ColumnArray::ColumnOffsets::create(map_column.size(), 0));
}
}
const IColumn::Filter & filter = column_filter->getData();
ColumnPtr filtered = map_column.getNestedColumn().getData().filter(filter, -1);
const IColumn::Offsets & in_offsets = map_column.getNestedColumn().getOffsets();
auto column_offsets = ColumnArray::ColumnOffsets::create(in_offsets.size());
IColumn::Offsets & out_offsets = column_offsets->getData();
size_t in_pos = 0;
size_t out_pos = 0;
for (size_t i = 0; i < in_offsets.size(); ++i)
{
for (; in_pos < in_offsets[i]; ++in_pos)
{
if (filter[in_pos])
++out_pos;
}
out_offsets[i] = out_pos;
}
return ColumnMap::create(ColumnArray::create(filtered, std::move(column_offsets)));
}
};
/** mapApply((k,v) -> expression, map) - apply the expression to the map.
*/
struct MapApplyImpl
{
using data_type = DataTypeMap;
using column_type = ColumnMap;
static constexpr auto name = "mapApply";
/// true if the expression (for an overload of f(expression, maps)) or a map (for f(map)) should be boolean.
static bool needBoolean() { return false; }
static bool needExpression() { return true; }
static bool needOneArray() { return true; }
static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypes & /*elems*/)
{
const auto * tuple_types = typeid_cast<const DataTypeTuple *>(expression_return.get());
if (!tuple_types)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Expected return type is tuple, got {}", expression_return->getName());
if (tuple_types->getElements().size() != 2)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Expected 2 columns as map's key and value, but found {}", tuple_types->getElements().size());
return std::make_shared<DataTypeMap>(tuple_types->getElements());
}
static ColumnPtr execute(const ColumnMap & map, ColumnPtr mapped)
{
const auto * column_tuple = checkAndGetColumn<ColumnTuple>(mapped.get());
if (!column_tuple)
{
const ColumnConst * column_const_tuple = checkAndGetColumnConst<ColumnTuple>(mapped.get());
if (!column_const_tuple)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Expected tuple column, found {}", mapped->getName());
auto cols = convertConstTupleToConstantElements(*column_const_tuple);
return ColumnMap::create(cols[0]->convertToFullColumnIfConst(), cols[1]->convertToFullColumnIfConst(), map.getNestedColumn().getOffsetsPtr());
}
return ColumnMap::create(column_tuple->getColumnPtr(0), column_tuple->getColumnPtr(1),
map.getNestedColumn().getOffsetsPtr());
}
};
void registerFunctionMapApply(FunctionFactory & factory)
{
factory.registerFunction<FunctionArrayMapped<MapFilterImpl, MapFilterImpl>>();
factory.registerFunction<FunctionArrayMapped<MapApplyImpl, MapApplyImpl>>();
}
}

View File

@ -18,6 +18,7 @@ void registerFunctionsArraySort(FunctionFactory & factory);
void registerFunctionArrayCumSum(FunctionFactory & factory);
void registerFunctionArrayCumSumNonNegative(FunctionFactory & factory);
void registerFunctionArrayDifference(FunctionFactory & factory);
void registerFunctionMapApply(FunctionFactory & factory);
void registerFunctionsHigherOrder(FunctionFactory & factory)
{
@ -36,6 +37,7 @@ void registerFunctionsHigherOrder(FunctionFactory & factory)
registerFunctionArrayCumSum(factory);
registerFunctionArrayCumSumNonNegative(factory);
registerFunctionArrayDifference(factory);
registerFunctionMapApply(factory);
}
}

View File

@ -0,0 +1,33 @@
{}
{}
{}
{'key3':103}
{}
{}
{}
{'key3':100,'key2':101,'key4':102} {'key4':102}
{'key3':101,'key2':102,'key4':103} {'key2':102,'key4':103}
{'key3':102,'key2':103,'key4':104} {'key3':102,'key2':103,'key4':104}
{'key3':103,'key2':104,'key4':105} {'key3':103,'key2':104,'key4':105}
{'key1':1111,'key2':2222} {'key2':2222}
{'key1':1112,'key2':2224} {'key1':1112,'key2':2224}
{'key1':1113,'key2':2226} {'key1':1113,'key2':2226}
{'key3':101,'key2':102,'key4':103}
{'key3':102,'key2':103,'key4':104}
{'key3':103,'key2':104,'key4':105}
{'key3':104,'key2':105,'key4':106}
{'key1':1112,'key2':2223}
{'key1':1113,'key2':2225}
{'key1':1114,'key2':2227}
{}
{}
{}
{}
{}
{}
{}
{3:2,1:0,2:0}
{1:2,2:3}
{1:2,2:3}
{'x':'y','x':'y'}
{'x':'y','x':'y'}

View File

@ -0,0 +1,39 @@
DROP TABLE IF EXISTS table_map;
CREATE TABLE table_map (id UInt32, col Map(String, UInt64)) engine = MergeTree() ORDER BY tuple();
INSERT INTO table_map SELECT number, map('key1', number, 'key2', number * 2) FROM numbers(1111, 3);
INSERT INTO table_map SELECT number, map('key3', number, 'key2', number + 1, 'key4', number + 2) FROM numbers(100, 4);
SELECT mapFilter((k, v) -> k like '%3' and v > 102, col) FROM table_map ORDER BY id;
SELECT col, mapFilter((k, v) -> ((v % 10) > 1), col) FROM table_map ORDER BY id ASC;
SELECT mapApply((k, v) -> (k, v + 1), col) FROM table_map ORDER BY id;
SELECT mapFilter((k, v) -> 0, col) from table_map;
SELECT mapApply((k, v) -> tuple(v + 9223372036854775806), col) FROM table_map; -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT mapUpdate(map(1, 3, 3, 2), map(1, 0, 2, 0));
SELECT mapApply((x, y) -> (x, x + 1), map(1, 0, 2, 0));
SELECT mapApply((x, y) -> (x, x + 1), materialize(map(1, 0, 2, 0)));
SELECT mapApply((x, y) -> ('x', 'y'), map(1, 0, 2, 0));
SELECT mapApply((x, y) -> ('x', 'y'), materialize(map(1, 0, 2, 0)));
SELECT mapApply(); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT mapApply((x, y) -> (x), map(1, 0, 2, 0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT mapApply((x, y) -> ('x'), map(1, 0, 2, 0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT mapApply((x) -> (x, x), map(1, 0, 2, 0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT mapApply((x, y) -> (x, 1, 2), map(1, 0, 2, 0)); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT mapApply((x, y) -> (x, x + 1)); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT mapApply(map(1, 0, 2, 0), (x, y) -> (x, x + 1)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT mapApply((x, y) -> (x, x+1), map(1, 0, 2, 0), map(1, 0, 2, 0)); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT mapFilter(); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT mapFilter((x, y) -> (toInt32(x)), map(1, 0, 2, 0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT mapFilter((x, y) -> ('x'), map(1, 0, 2, 0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT mapFilter((x) -> (x, x), map(1, 0, 2, 0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT mapFilter((x, y) -> (x, 1, 2), map(1, 0, 2, 0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT mapFilter((x, y) -> (x, x + 1)); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT mapFilter(map(1, 0, 2, 0), (x, y) -> (x > 0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT mapFilter((x, y) -> (x, x + 1), map(1, 0, 2, 0), map(1, 0, 2, 0)); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT mapUpdate(); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT mapUpdate(map(1, 3, 3, 2), map(1, 0, 2, 0), map(1, 0, 2, 0)); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
DROP TABLE table_map;