diff --git a/src/Functions/tupleElement.cpp b/src/Functions/tupleElement.cpp index 3763d878b1d..68ee7d07152 100644 --- a/src/Functions/tupleElement.cpp +++ b/src/Functions/tupleElement.cpp @@ -20,6 +20,8 @@ namespace ErrorCodes extern const int ILLEGAL_INDEX; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int NOT_FOUND_COLUMN_IN_BLOCK; + extern const int NUMBER_OF_DIMENSIONS_MISMATHED; + extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; } namespace @@ -56,7 +58,7 @@ public: ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { - return {1, 2}; + return {1}; } bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; } @@ -71,7 +73,6 @@ public: ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); size_t count_arrays = 0; - const IDataType * tuple_col = arguments[0].type.get(); while (const DataTypeArray * array = checkAndGetDataType(tuple_col)) { @@ -83,21 +84,33 @@ public: if (!tuple) throw Exception("First argument for function " + getName() + " must be tuple or array of tuple.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - size_t index = 0; - if (!getElementNum(arguments[1].column, *tuple, index, number_of_arguments)) + auto index = getElementNum(arguments[1].column, *tuple, number_of_arguments); + if (index.has_value()) { + DataTypePtr out_return_type = tuple->getElements()[index.value()]; + + for (; count_arrays; --count_arrays) + out_return_type = std::make_shared(out_return_type); + + return out_return_type; + } else + { + const IDataType * default_col = arguments[2].type.get(); + size_t default_argument_count_arrays = 0; + if (const DataTypeArray * array = checkAndGetDataType(default_col)) + { + default_argument_count_arrays = array->getNumberOfDimensions(); + } + + if (count_arrays != default_argument_count_arrays) + { + throw Exception(ErrorCodes::NUMBER_OF_DIMENSIONS_MISMATHED, "Dimension of types mismatched between first argument and third argument. Dimension of 1st argument: {}. Dimension of 3rd argument: {}.",count_arrays, default_argument_count_arrays); + } return arguments[2].type; } - - DataTypePtr out_return_type = tuple->getElements()[index]; - - for (; count_arrays; --count_arrays) - out_return_type = std::make_shared(out_return_type); - - return out_return_type; } - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override { Columns array_offsets; @@ -105,6 +118,12 @@ public: const IDataType * tuple_type = first_arg.type.get(); const IColumn * tuple_col = first_arg.column.get(); + bool first_arg_is_const = false; + if (typeid_cast(tuple_col)) + { + tuple_col = assert_cast(tuple_col)->getDataColumnPtr().get(); + first_arg_is_const = true; + } while (const DataTypeArray * array_type = checkAndGetDataType(tuple_type)) { const ColumnArray * array_col = assert_cast(tuple_col); @@ -119,22 +138,74 @@ public: if (!tuple_type_concrete || !tuple_col_concrete) throw Exception("First argument for function " + getName() + " must be tuple or array of tuple.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - size_t index = 0; - if (!getElementNum(arguments[1].column, *tuple_type_concrete, index, arguments.size())) + auto index = getElementNum(arguments[1].column, *tuple_type_concrete, arguments.size()); + + if (!index.has_value()) { + if (!array_offsets.empty()) + { + CheckArrayOffsets(arguments[0].column, arguments[2].column); + } return arguments[2].column; } - ColumnPtr res = tuple_col_concrete->getColumns()[index]; + + ColumnPtr res = tuple_col_concrete->getColumns()[index.value()]; /// Wrap into Arrays for (auto it = array_offsets.rbegin(); it != array_offsets.rend(); ++it) res = ColumnArray::create(res, *it); + if (first_arg_is_const) + { + res = ColumnConst::create(res, input_rows_count); + } return res; } private: - bool getElementNum(const ColumnPtr & index_column, const DataTypeTuple & tuple, size_t & index, const size_t argument_size) const + + void CheckArrayOffsets(ColumnPtr col_x, ColumnPtr col_y) const + { + if (isColumnConst(*col_x)) + { + CheckArrayOffsetsWithFirstArgConst(col_x, col_y); + } else if (isColumnConst(*col_y)) + { + CheckArrayOffsetsWithFirstArgConst(col_y, col_x); + } else + { + const auto & array_x = *assert_cast(col_x.get()); + const auto & array_y = *assert_cast(col_y.get()); + if (!array_x.hasEqualOffsets(array_y)) + { + throw Exception("The argument 1 and argument 3 of function " + getName() + " have different array sizes", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH); + } + } + } + + void CheckArrayOffsetsWithFirstArgConst(ColumnPtr col_x, ColumnPtr col_y) const + { + col_x = assert_cast(col_x.get())->getDataColumnPtr(); + col_y = col_y->convertToFullColumnIfConst(); + const auto & array_x = *assert_cast(col_x.get()); + const auto & array_y = *assert_cast(col_y.get()); + + const auto & offsets_x = array_x.getOffsets(); + const auto & offsets_y = array_y.getOffsets(); + + ColumnArray::Offset prev_offset = 0; + size_t row_size = offsets_y.size(); + for (size_t row = 0; row < row_size; ++row) + { + if (unlikely(offsets_x[0] != offsets_y[row] - prev_offset)) + { + throw Exception("The argument 1 and argument 3 of function " + getName() + " have different array sizes", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH); + } + prev_offset = offsets_y[row]; + } + } + + std::optional getElementNum(const ColumnPtr & index_column, const DataTypeTuple & tuple, const size_t argument_size) const { if ( checkAndGetColumnConst(index_column.get()) @@ -143,29 +214,29 @@ private: || checkAndGetColumnConst(index_column.get()) ) { - index = index_column->getUInt(0); + size_t index = index_column->getUInt(0); if (index == 0) throw Exception("Indices in tuples are 1-based.", ErrorCodes::ILLEGAL_INDEX); if (index > tuple.getElements().size()) throw Exception("Index for tuple element is out of range.", ErrorCodes::ILLEGAL_INDEX); - index--; - return true; + + return std::optional(index - 1); } else if (const auto * name_col = checkAndGetColumnConst(index_column.get())) { - if (tuple.getPositionByName(name_col->getValue(), index)) + auto index = tuple.tryGetPositionByName(name_col->getValue()); + if (index.has_value()) { - return true; + return index; } if (argument_size == 2) { throw Exception("Tuple doesn't have element with name '" + name_col->getValue() + "'", ErrorCodes::NOT_FOUND_COLUMN_IN_BLOCK); } - - return false; + return std::nullopt; } else throw Exception("Second argument to " + getName() + " must be a constant UInt or String", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);