Fix functions with Array(LowCardinality) arguments. #3004

This commit is contained in:
Nikolai Kochetov 2018-09-20 15:46:55 +03:00
parent 19d00cde7e
commit d66527ec15

View File

@ -2,11 +2,16 @@
#include <Common/typeid_cast.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnNullable.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeNothing.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/Native.h>
#include <DataTypes/DataTypeWithDictionary.h>
#include <DataTypes/getLeastSupertype.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnTuple.h>
#include <Columns/ColumnWithDictionary.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
@ -36,6 +41,55 @@ namespace ErrorCodes
}
static DataTypePtr recursiveRemoveLowCardinality(const DataTypePtr & type)
{
if (!type)
return type;
if (const auto * array_type = typeid_cast<const DataTypeArray *>(type.get()))
return std::make_shared<DataTypeArray>(recursiveRemoveLowCardinality(array_type->getNestedType()));
if (const auto * tuple_type = typeid_cast<const DataTypeTuple *>(type.get()))
{
DataTypes elements = tuple_type->getElements();
for (auto & element : elements)
element = recursiveRemoveLowCardinality(element);
return std::make_shared<DataTypeTuple>(elements);
}
if (const auto * low_cardinality_type = typeid_cast<const DataTypeWithDictionary *>(type.get()))
return low_cardinality_type->getDictionaryType();
return type;
}
static ColumnPtr recursiveRemoveLowCardinality(const ColumnPtr & column)
{
if (!column)
return column;
if (const auto * column_array = typeid_cast<const ColumnArray *>(column.get()))
return ColumnArray::create(recursiveRemoveLowCardinality(column_array->getDataPtr()), column_array->getOffsetsPtr());
if (const auto * column_const = typeid_cast<const ColumnConst *>(column.get()))
return ColumnConst::create(recursiveRemoveLowCardinality(column_const->getDataColumnPtr()), column_const->size());
if (const auto * column_tuple = typeid_cast<const ColumnTuple *>(column.get()))
{
Columns columns = column_tuple->getColumns();
for (auto & element : columns)
element = recursiveRemoveLowCardinality(element);
return ColumnTuple::create(columns);
}
if (const auto * column_low_cardinality = typeid_cast<const ColumnWithDictionary *>(column.get()))
return column_low_cardinality->convertToFullColumn();
return column;
}
ColumnPtr wrapInNullable(const ColumnPtr & src, const Block & block, const ColumnNumbers & args, size_t result, size_t input_rows_count)
{
ColumnPtr result_null_map_column;
@ -286,19 +340,9 @@ static void convertColumnsWithDictionaryToFull(Block & block, const ColumnNumber
for (auto arg : args)
{
ColumnWithTypeAndName & column = block.getByPosition(arg);
if (auto * column_const = checkAndGetColumn<ColumnConst>(column.column.get()))
column.column = column_const->removeLowCardinality();
else if (auto * column_with_dict = checkAndGetColumn<ColumnWithDictionary>(column.column.get()))
{
auto * type_with_dict = checkAndGetDataType<DataTypeWithDictionary>(column.type.get());
if (!type_with_dict)
throw Exception("Incompatible type for column with dictionary: " + column.type->getName(),
ErrorCodes::LOGICAL_ERROR);
column.column = column_with_dict->convertToFullColumn();
column.type = type_with_dict->getDictionaryType();
}
column.column = recursiveRemoveLowCardinality(column.column);
column.type = recursiveRemoveLowCardinality(column.type);
}
}
@ -451,7 +495,6 @@ llvm::Value * IFunction::compile(llvm::IRBuilderBase & builder, const DataTypes
#endif
DataTypePtr FunctionBuilderImpl::getReturnType(const ColumnsWithTypeAndName & arguments) const
{
if (useDefaultImplementationForColumnsWithDictionary())
@ -477,6 +520,12 @@ DataTypePtr FunctionBuilderImpl::getReturnType(const ColumnsWithTypeAndName & ar
}
}
for (auto & arg : args_without_dictionary)
{
arg.column = recursiveRemoveLowCardinality(arg.column);
arg.type = recursiveRemoveLowCardinality(arg.type);
}
if (canBeExecutedOnLowCardinalityDictionary() && has_low_cardinality && num_full_low_cardinality_columns <= 1)
return std::make_shared<DataTypeWithDictionary>(getReturnTypeWithoutDictionary(args_without_dictionary));
else