Add ArrowStream input and output format

This commit is contained in:
hcz 2020-05-21 12:07:47 +08:00
parent 4a4914361c
commit e11fa03bdd
4 changed files with 91 additions and 32 deletions

View File

@ -16,12 +16,12 @@ namespace DB
namespace ErrorCodes namespace ErrorCodes
{ {
extern const int BAD_ARGUMENTS; extern const int UNKNOWN_EXCEPTION;
extern const int CANNOT_READ_ALL_DATA; extern const int CANNOT_READ_ALL_DATA;
} }
ArrowBlockInputFormat::ArrowBlockInputFormat(ReadBuffer & in_, const Block & header_) ArrowBlockInputFormat::ArrowBlockInputFormat(ReadBuffer & in_, const Block & header_, bool stream_)
: IInputFormat(header_, in_) : IInputFormat(header_, in_), stream{stream_}
{ {
prepareReader(); prepareReader();
} }
@ -31,11 +31,21 @@ Chunk ArrowBlockInputFormat::generate()
Chunk res; Chunk res;
const Block & header = getPort().getHeader(); const Block & header = getPort().getHeader();
if (record_batch_current >= record_batch_total) if (!stream && record_batch_current >= record_batch_total)
return res; return res;
std::vector<std::shared_ptr<arrow::RecordBatch>> single_batch(1); std::vector<std::shared_ptr<arrow::RecordBatch>> single_batch(1);
arrow::Status read_status = file_reader->ReadRecordBatch(record_batch_current, &single_batch[0]); arrow::Status read_status;
if (stream)
{
read_status = stream_reader->ReadNext(&single_batch[0]);
}
else
{
read_status = file_reader->ReadRecordBatch(record_batch_current, &single_batch[0]);
if (!single_batch[0])
return res;
}
if (!read_status.ok()) if (!read_status.ok())
throw Exception{"Error while reading batch of Arrow data: " + read_status.ToString(), throw Exception{"Error while reading batch of Arrow data: " + read_status.ToString(),
ErrorCodes::CANNOT_READ_ALL_DATA}; ErrorCodes::CANNOT_READ_ALL_DATA};
@ -57,16 +67,30 @@ void ArrowBlockInputFormat::resetParser()
{ {
IInputFormat::resetParser(); IInputFormat::resetParser();
if (stream)
stream_reader.reset();
else
file_reader.reset(); file_reader.reset();
prepareReader(); prepareReader();
} }
void ArrowBlockInputFormat::prepareReader() void ArrowBlockInputFormat::prepareReader()
{ {
arrow::Status open_status = arrow::ipc::RecordBatchFileReader::Open(asArrowFile(in), &file_reader); arrow::Status status;
if (!open_status.ok())
throw Exception(open_status.ToString(), ErrorCodes::BAD_ARGUMENTS); if (stream)
status = arrow::ipc::RecordBatchStreamReader::Open(asArrowFile(in), &stream_reader);
else
status = arrow::ipc::RecordBatchFileReader::Open(asArrowFile(in), &file_reader);
if (!status.ok())
throw Exception{"Error while opening a table: " + status.ToString(), ErrorCodes::UNKNOWN_EXCEPTION};
if (stream)
record_batch_total = -1;
else
record_batch_total = file_reader->num_record_batches(); record_batch_total = file_reader->num_record_batches();
record_batch_current = 0; record_batch_current = 0;
} }
@ -79,7 +103,17 @@ void registerInputFormatProcessorArrow(FormatFactory &factory)
const RowInputFormatParams & /* params */, const RowInputFormatParams & /* params */,
const FormatSettings & /* format_settings */) const FormatSettings & /* format_settings */)
{ {
return std::make_shared<ArrowBlockInputFormat>(buf, sample); return std::make_shared<ArrowBlockInputFormat>(buf, sample, false);
});
factory.registerInputFormatProcessor(
"ArrowStream",
[](ReadBuffer & buf,
const Block & sample,
const RowInputFormatParams & /* params */,
const FormatSettings & /* format_settings */)
{
return std::make_shared<ArrowBlockInputFormat>(buf, sample, true);
}); });
} }

View File

@ -4,6 +4,7 @@
#include <Processors/Formats/IInputFormat.h> #include <Processors/Formats/IInputFormat.h>
namespace arrow { class RecordBatchReader; }
namespace arrow::ipc { class RecordBatchFileReader; } namespace arrow::ipc { class RecordBatchFileReader; }
namespace DB namespace DB
@ -14,7 +15,7 @@ class ReadBuffer;
class ArrowBlockInputFormat : public IInputFormat class ArrowBlockInputFormat : public IInputFormat
{ {
public: public:
ArrowBlockInputFormat(ReadBuffer & in_, const Block & header_); ArrowBlockInputFormat(ReadBuffer & in_, const Block & header_, bool stream_);
void resetParser() override; void resetParser() override;
@ -24,12 +25,13 @@ protected:
Chunk generate() override; Chunk generate() override;
private: private:
void prepareReader(); bool stream;
std::shared_ptr<arrow::RecordBatchReader> stream_reader;
private:
std::shared_ptr<arrow::ipc::RecordBatchFileReader> file_reader; std::shared_ptr<arrow::ipc::RecordBatchFileReader> file_reader;
int record_batch_total = 0; int record_batch_total = 0;
int record_batch_current = 0; int record_batch_current = 0;
void prepareReader();
}; };
} }

View File

@ -15,8 +15,8 @@ namespace ErrorCodes
extern const int UNKNOWN_EXCEPTION; extern const int UNKNOWN_EXCEPTION;
} }
ArrowBlockOutputFormat::ArrowBlockOutputFormat(WriteBuffer & out_, const Block & header_, const FormatSettings & format_settings_) ArrowBlockOutputFormat::ArrowBlockOutputFormat(WriteBuffer & out_, const Block & header_, bool stream_, const FormatSettings & format_settings_)
: IOutputFormat(header_, out_), format_settings{format_settings_}, arrow_ostream{std::make_shared<ArrowBufferedOutputStream>(out_)} : IOutputFormat(header_, out_), stream{stream_}, format_settings{format_settings_}, arrow_ostream{std::make_shared<ArrowBufferedOutputStream>(out_)}
{ {
} }
@ -29,12 +29,7 @@ void ArrowBlockOutputFormat::consume(Chunk chunk)
CHColumnToArrowColumn::chChunkToArrowTable(arrow_table, header, chunk, columns_num, "Arrow"); CHColumnToArrowColumn::chChunkToArrowTable(arrow_table, header, chunk, columns_num, "Arrow");
if (!writer) if (!writer)
{ prepareWriter(arrow_table->schema());
// TODO: should we use arrow::ipc::IpcOptions::alignment?
auto status = arrow::ipc::RecordBatchFileWriter::Open(arrow_ostream.get(), arrow_table->schema(), &writer);
if (!status.ok())
throw Exception{"Error while opening a table: " + status.ToString(), ErrorCodes::UNKNOWN_EXCEPTION};
}
// TODO: calculate row_group_size depending on a number of rows and table size // TODO: calculate row_group_size depending on a number of rows and table size
auto status = writer->WriteTable(*arrow_table, format_settings.arrow.row_group_size); auto status = writer->WriteTable(*arrow_table, format_settings.arrow.row_group_size);
@ -53,6 +48,20 @@ void ArrowBlockOutputFormat::finalize()
} }
} }
void ArrowBlockOutputFormat::prepareWriter(const std::shared_ptr<arrow::Schema> & schema)
{
arrow::Status status;
// TODO: should we use arrow::ipc::IpcOptions::alignment?
if (stream)
status = arrow::ipc::RecordBatchStreamWriter::Open(arrow_ostream.get(), schema, &writer);
else
status = arrow::ipc::RecordBatchFileWriter::Open(arrow_ostream.get(), schema, &writer);
if (!status.ok())
throw Exception{"Error while opening a table writer: " + status.ToString(), ErrorCodes::UNKNOWN_EXCEPTION};
}
void registerOutputFormatProcessorArrow(FormatFactory & factory) void registerOutputFormatProcessorArrow(FormatFactory & factory)
{ {
factory.registerOutputFormatProcessor( factory.registerOutputFormatProcessor(
@ -62,7 +71,17 @@ void registerOutputFormatProcessorArrow(FormatFactory & factory)
FormatFactory::WriteCallback, FormatFactory::WriteCallback,
const FormatSettings & format_settings) const FormatSettings & format_settings)
{ {
return std::make_shared<ArrowBlockOutputFormat>(buf, sample, format_settings); return std::make_shared<ArrowBlockOutputFormat>(buf, sample, false, format_settings);
});
factory.registerOutputFormatProcessor(
"ArrowStream",
[](WriteBuffer & buf,
const Block & sample,
FormatFactory::WriteCallback,
const FormatSettings & format_settings)
{
return std::make_shared<ArrowBlockOutputFormat>(buf, sample, true, format_settings);
}); });
} }

View File

@ -6,6 +6,7 @@
#include <Processors/Formats/IOutputFormat.h> #include <Processors/Formats/IOutputFormat.h>
#include "ArrowBufferedStreams.h" #include "ArrowBufferedStreams.h"
namespace arrow { class Schema; }
namespace arrow::ipc { class RecordBatchWriter; } namespace arrow::ipc { class RecordBatchWriter; }
namespace DB namespace DB
@ -14,7 +15,7 @@ namespace DB
class ArrowBlockOutputFormat : public IOutputFormat class ArrowBlockOutputFormat : public IOutputFormat
{ {
public: public:
ArrowBlockOutputFormat(WriteBuffer & out_, const Block & header_, const FormatSettings & format_settings_); ArrowBlockOutputFormat(WriteBuffer & out_, const Block & header_, bool stream_, const FormatSettings & format_settings_);
String getName() const override { return "ArrowBlockOutputFormat"; } String getName() const override { return "ArrowBlockOutputFormat"; }
void consume(Chunk) override; void consume(Chunk) override;
@ -23,9 +24,12 @@ public:
String getContentType() const override { return "application/octet-stream"; } String getContentType() const override { return "application/octet-stream"; }
private: private:
bool stream;
const FormatSettings format_settings; const FormatSettings format_settings;
std::shared_ptr<ArrowBufferedOutputStream> arrow_ostream; std::shared_ptr<ArrowBufferedOutputStream> arrow_ostream;
std::shared_ptr<arrow::ipc::RecordBatchWriter> writer; std::shared_ptr<arrow::ipc::RecordBatchWriter> writer;
void prepareWriter(const std::shared_ptr<arrow::Schema> & schema);
}; };
} }