opt if when input type is map

This commit is contained in:
taiyang-li 2024-01-31 17:24:51 +08:00
parent c339a74ac3
commit 2ad7607bad
4 changed files with 169 additions and 14 deletions

View File

@ -1,9 +1,19 @@
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <type_traits>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnMap.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnTuple.h>
#include <Columns/ColumnVector.h>
#include <Columns/MaskOperations.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeFixedString.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeMap.h>
#include <DataTypes/DataTypeNullable.h>
<<<<<<< HEAD
#include <DataTypes/DataTypeVariant.h>
#include <DataTypes/NumberTraits.h>
#include <DataTypes/getLeastSupertype.h>
@ -20,14 +30,28 @@
#include <Common/typeid_cast.h>
#include <Common/assert_cast.h>
#include <Functions/IFunction.h>
=======
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/NumberTraits.h>
#include <DataTypes/getLeastSupertype.h>
#include <Functions/FunctionFactory.h>
>>>>>>> 83434321f39... opt if when input type is map
#include <Functions/FunctionHelpers.h>
#include <Functions/GatherUtils/Algorithms.h>
#include <Functions/FunctionIfBase.h>
#include <Functions/GatherUtils/Algorithms.h>
#include <Functions/IFunction.h>
#include <Interpreters/castColumn.h>
<<<<<<< HEAD
#include <Interpreters/Context.h>
#include <Functions/FunctionFactory.h>
#include <type_traits>
=======
#include <Common/assert_cast.h>
#include <Common/typeid_cast.h>
>>>>>>> 83434321f39... opt if when input type is map
namespace DB
{
@ -36,6 +60,7 @@ namespace ErrorCodes
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NOT_IMPLEMENTED;
extern const int SIZES_OF_ARRAYS_DONT_MATCH;
}
namespace
@ -679,6 +704,87 @@ private:
return ColumnTuple::create(tuple_columns);
}
ColumnPtr executeMap(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const
{
auto extract_kv_from_map = [](const ColumnMap * map)
{
const ColumnTuple & tuple = map->getNestedData();
const auto & keys = tuple.getColumnPtr(0);
const auto & values = tuple.getColumnPtr(1);
const auto & offsets = map->getNestedColumn().getOffsetsPtr();
return std::make_pair(ColumnArray::create(keys, offsets), ColumnArray::create(values, offsets));
};
/// Extract keys and values from both arguments
Columns key_cols(2);
Columns value_cols(2);
for (size_t i = 0; i < 2; ++i)
{
const auto & arg = arguments[i + 1];
if (const ColumnMap * map = checkAndGetColumn<ColumnMap>(arg.column.get()))
{
auto [key_col, value_col] = extract_kv_from_map(map);
key_cols[i] = std::move(key_col);
value_cols[i] = std::move(value_col);
}
else if (const ColumnConst * const_map = checkAndGetColumnConst<ColumnMap>(arg.column.get()))
{
const ColumnMap * map_data = assert_cast<const ColumnMap *>(&const_map->getDataColumn());
auto [key_col, value_col] = extract_kv_from_map(map_data);
size_t size = const_map->size();
key_cols[i] = ColumnConst::create(std::move(key_col), size);
value_cols[i] = ColumnConst::create(std::move(value_col), size);
}
else
return nullptr;
}
/// Compose temporary columns for keys and values
ColumnsWithTypeAndName key_columns(3);
key_columns[0] = arguments[0];
ColumnsWithTypeAndName value_columns(3);
value_columns[0] = arguments[0];
for (size_t i = 0; i < 2; ++i)
{
const auto & arg = arguments[i + 1];
const DataTypeMap & type = static_cast<const DataTypeMap &>(*arg.type);
const auto & key_type = type.getKeyType();
const auto & value_type = type.getValueType();
key_columns[i + 1] = {key_cols[i], key_type, {}};
value_columns[i + 1] = {value_cols[i], value_type, {}};
}
/// Calculate function corresponding keys and values in map
const DataTypeMap & map_result_type = static_cast<const DataTypeMap &>(*result_type);
auto key_result_type = std::make_shared<DataTypeArray>(map_result_type.getKeyType());
auto value_result_type = std::make_shared<DataTypeArray>(map_result_type.getValueType());
ColumnPtr key_result = executeImpl(key_columns, key_result_type, input_rows_count);
ColumnPtr value_result = executeImpl(value_columns, value_result_type, input_rows_count);
/// key_result and value_result are not constant columns otherwise we won't reach here in executeMap
const auto * key_array = assert_cast<const ColumnArray *>(key_result.get());
const auto * value_array = assert_cast<const ColumnArray *>(value_result.get());
if (!key_array)
throw Exception(
ErrorCodes::ILLEGAL_COLUMN, "Illegal key result column {} in executeMap for function {}", key_result->getName(), getName());
if (!value_array)
throw Exception(
ErrorCodes::ILLEGAL_COLUMN,
"Illegal value result column {} in executeMap for function {}",
value_result->getName(),
getName());
if (!key_array->hasEqualOffsets(*value_array))
throw Exception(
ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH,
"Key result column and value result column in executeMap for function {} must have equal sizes",
getName());
auto nested_column = ColumnArray::create(
ColumnTuple::create(Columns{key_array->getDataPtr(), value_array->getDataPtr()}), key_array->getOffsetsPtr());
return ColumnMap::create(std::move(nested_column));
}
static ColumnPtr executeGeneric(
const ColumnUInt8 * cond_col, const ColumnsWithTypeAndName & arguments, size_t input_rows_count, bool use_variant_when_no_common_type)
{
@ -1195,7 +1301,8 @@ public:
|| (res = executeTyped<UUID, UUID>(cond_col, arguments, result_type, input_rows_count))
|| (res = executeString(cond_col, arguments, result_type))
|| (res = executeGenericArray(cond_col, arguments, result_type))
|| (res = executeTuple(arguments, result_type, input_rows_count))))
|| (res = executeTuple(arguments, result_type, input_rows_count))
|| (res = executeMap(arguments, result_type, input_rows_count))))
{
return executeGeneric(cond_col, arguments, input_rows_count, use_variant_when_no_common_type);
}

View File

@ -9,4 +9,7 @@
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, 1, zero + 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, 1, 2)) ]]></query>
<!-- Tests when branches are maps -->
<query>with rand32() % 2 as x select if(x, map(1,2,3,4), map(3,4,5,6)) from numbers(1000000) format Null</query>
<query>with rand32() % 2 as x select if(x, materialize(map(1,2,3,4)), materialize(map(3,4,5,6))) from numbers(1000000) format Null</query>
</test>

View File

@ -0,0 +1,30 @@
{1:2,3:4}
{3:4,5:6}
{1:2,3:4}
{3:4,5:6}
{3:4,5:6}
{1:2,3:4}
{1:2,3:4}
{1:2,3:4}
{3:4,5:6}
{3:4,5:6}
{3:4,5:6}
{3:4,5:6}
{1:2,3:4}
{1:2,3:4}
{3:4,5:6}
{3:4,5:6}
{3:4,5:6}
{3:4,5:6}
{3:4,5:6}
{3:4,5:6}
{1:2,3:4}
{1:2,3:4}
{1:2,3:4}
{1:2,3:4}
{3:4,5:6}
{3:4,5:6}
{1:2,3:4}
{1:2,3:4}
{1:2,3:4}
{1:2,3:4}

View File

@ -0,0 +1,15 @@
select if(number % 2 = 0, map(1,2,3,4), map(3,4,5,6)) from numbers(2);
select if(number % 2 = 0, materialize(map(1,2,3,4)), map(3,4,5,6)) from numbers(2);
select if(number % 2 = 0, map(3,4,5,6), materialize(map(1,2,3,4))) from numbers(2);
select if(1, map(1,2,3,4), map(3,4,5,6)) from numbers(2);
select if(0, map(1,2,3,4), map(3,4,5,6)) from numbers(2);
select if(null, map(1,2,3,4), map(3,4,5,6)) from numbers(2);
select if(1, materialize(map(1,2,3,4)), map(3,4,5,6)) from numbers(2);
select if(0, materialize(map(1,2,3,4)), map(3,4,5,6)) from numbers(2);
select if(null, materialize(map(1,2,3,4)), map(3,4,5,6)) from numbers(2);
select if(1, map(3,4,5,6), materialize(map(1,2,3,4))) from numbers(2);
select if(0, map(3,4,5,6), materialize(map(1,2,3,4))) from numbers(2);
select if(null, map(3,4,5,6), materialize(map(1,2,3,4))) from numbers(2);
select if(1, materialize(map(3,4,5,6)), materialize(map(1,2,3,4))) from numbers(2);
select if(0, materialize(map(3,4,5,6)), materialize(map(1,2,3,4))) from numbers(2);
select if(null, materialize(map(3,4,5,6)), materialize(map(1,2,3,4))) from numbers(2);