From 8f8f6f966b222e5a5f21734c906caa9ad514c1e5 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Wed, 21 Dec 2022 20:52:16 +0800 Subject: [PATCH] Optimization for reading struct fields in parquet/orc files --- .../Formats/Impl/ArrowFormatUtil.cpp | 135 ++++++++++++++++++ src/Processors/Formats/Impl/ArrowFormatUtil.h | 33 +++++ .../Formats/Impl/ORCBlockInputFormat.cpp | 50 +------ .../Formats/Impl/ParquetBlockInputFormat.cpp | 48 +------ 4 files changed, 177 insertions(+), 89 deletions(-) create mode 100644 src/Processors/Formats/Impl/ArrowFormatUtil.cpp create mode 100644 src/Processors/Formats/Impl/ArrowFormatUtil.h diff --git a/src/Processors/Formats/Impl/ArrowFormatUtil.cpp b/src/Processors/Formats/Impl/ArrowFormatUtil.cpp new file mode 100644 index 00000000000..dcb350917bb --- /dev/null +++ b/src/Processors/Formats/Impl/ArrowFormatUtil.cpp @@ -0,0 +1,135 @@ +#include "ArrowFormatUtil.h" +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} +size_t ArrowFormatUtil::countIndicesForType(std::shared_ptr type) +{ + if (type->id() == arrow::Type::LIST) + { + auto ret = countIndicesForType(static_cast(type.get())->value_type()); + if (nested_type_has_index) + return ret + 1; + } + + if (type->id() == arrow::Type::STRUCT) + { + int indices = nested_type_has_index ? 1 : 0; + auto * struct_type = static_cast(type.get()); + for (int i = 0; i != struct_type->num_fields(); ++i) + indices += countIndicesForType(struct_type->field(i)->type()); + return indices; + } + + if (type->id() == arrow::Type::MAP) + { + auto * map_type = static_cast(type.get()); + auto ret = countIndicesForType(map_type->key_type()) + countIndicesForType(map_type->item_type()); + if (nested_type_has_index) + return ret + 1; + } + + return 1; +} + +std::map> +ArrowFormatUtil::calculateFieldIndices(const arrow::Schema & schema) +{ + std::map> result; + int index_start = nested_type_has_index ? 1 : 0; + for (int i = 0; i < schema.num_fields(); ++i) + { + const auto & field = schema.field(i); + calculateFieldIndices(*field, index_start, result); + } + return result; +} + +void ArrowFormatUtil::calculateFieldIndices(const arrow::Field & field, + int & current_start_index, + std::map> & result, + const std::string & name_prefix) +{ + std::string field_name = field.name(); + const auto & field_type = field.type(); + if (field_name.empty()) + { + current_start_index += countIndicesForType(field_type); + return; + } + if (ignore_case) + { + boost::to_lower(field_name); + } + + std::string full_path_name = name_prefix.empty() ? field_name : name_prefix + "." + field_name; + auto & index_info = result[full_path_name]; + index_info.first = current_start_index; + if (field_type->id() == arrow::Type::STRUCT) + { + if (nested_type_has_index) + current_start_index += 1; + + auto * struct_type = static_cast(field_type.get()); + for (int i = 0, n = struct_type->num_fields(); i < n ; ++i) + { + const auto & sub_field = struct_type->field(i); + calculateFieldIndices(*sub_field, current_start_index, result, full_path_name); + } + } + else + { + current_start_index += countIndicesForType(field_type); + } + index_info.second = current_start_index - index_info.first; +} + +std::vector ArrowFormatUtil::findRequiredIndices(const Block & header, + const arrow::Schema & schema) +{ + std::vector required_indices; + std::set added_nested_table; + std::set added_indices; + auto fields_indices = calculateFieldIndices(schema); + for (size_t i = 0; i < header.columns(); ++i) + { + const auto & named_col = header.getByPosition(i); + std::string col_name = named_col.name; + if (ignore_case) + boost::to_lower(col_name); + if (!import_nested) + { + col_name = Nested::splitName(col_name).first; + if (added_nested_table.count(col_name)) + continue; + added_nested_table.insert(col_name); + } + auto it = fields_indices.find(col_name); + if (it == fields_indices.end()) + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "Not found field({}) in arrow schema:{}", + named_col.name, schema.ToString()); + } + for (int j = 0; j < it->second.second; ++j) + { + auto index = it->second.first + j; + if (!added_indices.count(index)) + { + required_indices.emplace_back(index); + added_indices.insert(index); + } + } + } + return required_indices; +} +} diff --git a/src/Processors/Formats/Impl/ArrowFormatUtil.h b/src/Processors/Formats/Impl/ArrowFormatUtil.h new file mode 100644 index 00000000000..027426a6486 --- /dev/null +++ b/src/Processors/Formats/Impl/ArrowFormatUtil.h @@ -0,0 +1,33 @@ +#pragma once +#include +#include +#include +#include +#include "DataTypes/Serializations/ISerialization.h" +namespace DB +{ +class ArrowFormatUtil +{ +public: + explicit ArrowFormatUtil(bool ignore_case_, bool import_nested_, bool nested_type_has_index_) + : ignore_case(ignore_case_) + , import_nested(import_nested_) + , nested_type_has_index(nested_type_has_index_){} + ~ArrowFormatUtil() = default; + + std::map> + calculateFieldIndices(const arrow::Schema & schema); + + std::vector findRequiredIndices(const Block & header, const arrow::Schema & schema); + + size_t countIndicesForType(std::shared_ptr type); + +private: + bool ignore_case; + bool import_nested; + bool nested_type_has_index; + void calculateFieldIndices(const arrow::Field & field, + int & current_start_index, + std::map> & result, const std::string & name_prefix = ""); +}; +} diff --git a/src/Processors/Formats/Impl/ORCBlockInputFormat.cpp b/src/Processors/Formats/Impl/ORCBlockInputFormat.cpp index 19a0b2eb23c..952e964a3bf 100644 --- a/src/Processors/Formats/Impl/ORCBlockInputFormat.cpp +++ b/src/Processors/Formats/Impl/ORCBlockInputFormat.cpp @@ -9,8 +9,12 @@ #include #include "ArrowBufferedStreams.h" #include "ArrowColumnToCHColumn.h" +#include "ArrowFormatUtil.h" #include +#include +#include + namespace DB { @@ -89,28 +93,6 @@ const BlockMissingValues & ORCBlockInputFormat::getMissingValues() const return block_missing_values; } -static size_t countIndicesForType(std::shared_ptr type) -{ - if (type->id() == arrow::Type::LIST) - return countIndicesForType(static_cast(type.get())->value_type()) + 1; - - if (type->id() == arrow::Type::STRUCT) - { - int indices = 1; - auto * struct_type = static_cast(type.get()); - for (int i = 0; i != struct_type->num_fields(); ++i) - indices += countIndicesForType(struct_type->field(i)->type()); - return indices; - } - - if (type->id() == arrow::Type::MAP) - { - auto * map_type = static_cast(type.get()); - return countIndicesForType(map_type->key_type()) + countIndicesForType(map_type->item_type()) + 1; - } - - return 1; -} static void getFileReaderAndSchema( ReadBuffer & in, @@ -152,28 +134,8 @@ void ORCBlockInputFormat::prepareReader() format_settings.orc.case_insensitive_column_matching); missing_columns = arrow_column_to_ch_column->getMissingColumns(*schema); - const bool ignore_case = format_settings.orc.case_insensitive_column_matching; - std::unordered_set nested_table_names; - if (format_settings.orc.import_nested) - nested_table_names = Nested::getAllTableNames(getPort().getHeader(), ignore_case); - - /// In ReadStripe column indices should be started from 1, - /// because 0 indicates to select all columns. - int index = 1; - for (int i = 0; i < schema->num_fields(); ++i) - { - /// LIST type require 2 indices, STRUCT - the number of elements + 1, - /// so we should recursively count the number of indices we need for this type. - int indexes_count = static_cast(countIndicesForType(schema->field(i)->type())); - const auto & name = schema->field(i)->name(); - if (getPort().getHeader().has(name, ignore_case) || nested_table_names.contains(ignore_case ? boost::to_lower_copy(name) : name)) - { - for (int j = 0; j != indexes_count; ++j) - include_indices.push_back(index + j); - } - - index += indexes_count; - } + ArrowFormatUtil format_util(format_settings.orc.case_insensitive_column_matching, format_settings.orc.import_nested, true); + include_indices = format_util.findRequiredIndices(getPort().getHeader(), *schema); } ORCSchemaReader::ORCSchemaReader(ReadBuffer & in_, const FormatSettings & format_settings_) diff --git a/src/Processors/Formats/Impl/ParquetBlockInputFormat.cpp b/src/Processors/Formats/Impl/ParquetBlockInputFormat.cpp index ad1d1ba85b9..f70066d07ac 100644 --- a/src/Processors/Formats/Impl/ParquetBlockInputFormat.cpp +++ b/src/Processors/Formats/Impl/ParquetBlockInputFormat.cpp @@ -14,6 +14,7 @@ #include #include "ArrowBufferedStreams.h" #include "ArrowColumnToCHColumn.h" +#include "ArrowFormatUtil.h" #include namespace DB @@ -95,29 +96,6 @@ const BlockMissingValues & ParquetBlockInputFormat::getMissingValues() const return block_missing_values; } -static size_t countIndicesForType(std::shared_ptr type) -{ - if (type->id() == arrow::Type::LIST) - return countIndicesForType(static_cast(type.get())->value_type()); - - if (type->id() == arrow::Type::STRUCT) - { - int indices = 0; - auto * struct_type = static_cast(type.get()); - for (int i = 0; i != struct_type->num_fields(); ++i) - indices += countIndicesForType(struct_type->field(i)->type()); - return indices; - } - - if (type->id() == arrow::Type::MAP) - { - auto * map_type = static_cast(type.get()); - return countIndicesForType(map_type->key_type()) + countIndicesForType(map_type->item_type()); - } - - return 1; -} - static void getFileReaderAndSchema( ReadBuffer & in, std::unique_ptr & file_reader, @@ -150,28 +128,8 @@ void ParquetBlockInputFormat::prepareReader() format_settings.parquet.case_insensitive_column_matching); missing_columns = arrow_column_to_ch_column->getMissingColumns(*schema); - const bool ignore_case = format_settings.parquet.case_insensitive_column_matching; - std::unordered_set nested_table_names; - if (format_settings.parquet.import_nested) - nested_table_names = Nested::getAllTableNames(getPort().getHeader(), ignore_case); - - int index = 0; - for (int i = 0; i < schema->num_fields(); ++i) - { - /// STRUCT type require the number of indexes equal to the number of - /// nested elements, so we should recursively - /// count the number of indices we need for this type. - int indexes_count = static_cast(countIndicesForType(schema->field(i)->type())); - const auto & name = schema->field(i)->name(); - - if (getPort().getHeader().has(name, ignore_case) || nested_table_names.contains(ignore_case ? boost::to_lower_copy(name) : name)) - { - for (int j = 0; j != indexes_count; ++j) - column_indices.push_back(index + j); - } - - index += indexes_count; - } + ArrowFormatUtil format_util(format_settings.parquet.case_insensitive_column_matching, format_settings.parquet.import_nested, false); + column_indices = format_util.findRequiredIndices(getPort().getHeader(), *schema); } ParquetSchemaReader::ParquetSchemaReader(ReadBuffer & in_, const FormatSettings & format_settings_)