This commit is contained in:
kssenii 2020-11-22 21:05:54 +03:00
parent 2f6cb7f2f5
commit ea817862ba
2 changed files with 123 additions and 108 deletions

View File

@ -17,16 +17,25 @@
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <IO/ReadBufferFromString.h>
#include <Common/assert_cast.h>
#include <ext/range.h>
#include "PostgreSQLBlockInputStream.h"
#include <common/logger_useful.h>
#include <Core/Field.h>
#include "PostgreSQLBlockInputStream.h"
namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
extern const int UNKNOWN_TYPE;
}
PostgreSQLBlockInputStream::PostgreSQLBlockInputStream(
std::shared_ptr<pqxx::connection> connection_,
const std::string & query_str_,
@ -35,13 +44,18 @@ PostgreSQLBlockInputStream::PostgreSQLBlockInputStream(
: query_str(query_str_)
, max_block_size(max_block_size_)
, connection(connection_)
, work(std::make_unique<pqxx::work>(*connection))
, stream(std::make_unique<pqxx::stream_from>(*work, pqxx::from_query, std::string_view(query_str)))
{
description.init(sample_block);
}
void PostgreSQLBlockInputStream::readPrefix()
{
work = std::make_unique<pqxx::work>(*connection);
stream = std::make_unique<pqxx::stream_from>(*work, pqxx::from_query, std::string_view(query_str));
}
Block PostgreSQLBlockInputStream::readImpl()
{
/// Check if pqxx::stream_from is finished
@ -50,6 +64,7 @@ Block PostgreSQLBlockInputStream::readImpl()
MutableColumns columns = description.sample_block.cloneEmptyColumns();
size_t num_rows = 0;
std::string value;
while (true)
{
@ -57,32 +72,32 @@ Block PostgreSQLBlockInputStream::readImpl()
if (!row)
{
/// row is nullptr if pqxx::stream_from is finished
stream->complete();
work->commit();
break;
}
if (row->empty())
break;
std::string value;
for (const auto idx : ext::range(0, row->size()))
{
value = std::string((*row)[idx]);
LOG_DEBUG((&Poco::Logger::get("PostgreSQL")), "GOT {}", value);
const auto & sample = description.sample_block.getByPosition(idx);
if (!num_rows && description.types[idx].first == ValueType::vtArray)
getArrayInfo(idx, sample.type);
if (value.data())
/// if got NULL type, then pqxx::zview will return nullptr in c_str()
if ((*row)[idx].c_str())
{
value = std::string((*row)[idx]);
if (description.types[idx].second)
{
ColumnNullable & column_nullable = assert_cast<ColumnNullable &>(*columns[idx]);
const auto & data_type = assert_cast<const DataTypeNullable &>(*sample.type);
insertValue(column_nullable.getNestedColumn(), value, description.types[idx].first, data_type.getNestedType());
insertValue(column_nullable.getNestedColumn(), value, description.types[idx].first, data_type.getNestedType(), idx);
column_nullable.getNullMapData().emplace_back(0);
}
else
{
insertValue(*columns[idx], value, description.types[idx].first, sample.type);
insertValue(*columns[idx], value, description.types[idx].first, sample.type, idx);
}
}
else
@ -101,7 +116,7 @@ Block PostgreSQLBlockInputStream::readImpl()
void PostgreSQLBlockInputStream::insertValue(IColumn & column, const std::string & value,
const ExternalResultDescription::ValueType type, const DataTypePtr data_type)
const ExternalResultDescription::ValueType type, const DataTypePtr data_type, size_t idx)
{
switch (type)
{
@ -138,125 +153,114 @@ void PostgreSQLBlockInputStream::insertValue(IColumn & column, const std::string
case ValueType::vtString:
assert_cast<ColumnString &>(column).insertData(value.data(), value.size());
break;
case ValueType::vtDate:
//assert_cast<ColumnUInt16 &>(column).insertValue(UInt16(value.getDate().getDayNum()));
break;
case ValueType::vtDateTime:
//assert_cast<ColumnUInt32 &>(column).insertValue(UInt32(value.getDateTime()));
break;
case ValueType::vtUUID:
assert_cast<ColumnUInt128 &>(column).insert(parse<UUID>(value.data(), value.size()));
break;
case ValueType::vtDateTime64:[[fallthrough]];
case ValueType::vtDecimal32: [[fallthrough]];
case ValueType::vtDecimal64: [[fallthrough]];
case ValueType::vtDecimal128:[[fallthrough]];
case ValueType::vtDecimal256:
case ValueType::vtDate:
{
ReadBuffer buffer(const_cast<char *>(value.data()), value.size(), 0);
data_type->deserializeAsWholeText(column, buffer, FormatSettings{});
ReadBufferFromString istr(value);
data_type->deserializeAsWholeText(column, istr, FormatSettings{});
break;
}
case ValueType::vtDateTime:
{
ReadBufferFromString istr(value);
data_type->deserializeAsWholeText(column, istr, FormatSettings{});
break;
}
case ValueType::vtArray:
{
const auto * array_type = typeid_cast<const DataTypeArray *>(data_type.get());
auto nested = array_type->getNestedType();
pqxx::array_parser parser{value};
std::pair<pqxx::array_parser::juncture, std::string> parsed = parser.get_next();
size_t expected_dimensions = 1;
while (isArray(nested))
size_t dimension = 0, max_dimension = 0, expected_dimensions = array_info[idx].num_dimensions;
const auto parse_value = array_info[idx].pqxx_parser;
std::vector<std::vector<Field>> dimensions(expected_dimensions + 1);
while (parsed.first != pqxx::array_parser::juncture::done)
{
++expected_dimensions;
nested = typeid_cast<const DataTypeArray *>(nested.get())->getNestedType();
}
auto which = WhichDataType(nested);
if ((parsed.first == pqxx::array_parser::juncture::row_start) && (++dimension > expected_dimensions))
throw Exception("Got more dimensions than expected", ErrorCodes::BAD_ARGUMENTS);
auto get_array([&]() -> Field
{
pqxx::array_parser parser{value};
std::pair<pqxx::array_parser::juncture, std::string> parsed = parser.get_next();
else if (parsed.first == pqxx::array_parser::juncture::string_value)
dimensions[dimension].emplace_back(parse_value(parsed.second));
std::vector<std::vector<Field>> dimensions(expected_dimensions + 1);
size_t dimension = 0, max_dimension = 0;
bool new_row = false, null_value = false;
else if (parsed.first == pqxx::array_parser::juncture::null_value)
dimensions[dimension].emplace_back(array_info[idx].default_value);
while (parsed.first != pqxx::array_parser::juncture::done)
else if (parsed.first == pqxx::array_parser::juncture::row_end)
{
while (parsed.first == pqxx::array_parser::juncture::row_start)
{
++dimension;
if (dimension > expected_dimensions)
throw Exception("Got more dimensions than expected", ErrorCodes::BAD_ARGUMENTS);
max_dimension = std::max(max_dimension, dimension);
parsed = parser.get_next();
new_row = true;
}
if (--dimension == 0)
break;
/// TODO: dont forget to add test with null type
std::vector<Field> current_dimension_row;
while (parsed.first != pqxx::array_parser::juncture::row_end)
{
if (parsed.first == pqxx::array_parser::juncture::null_value)
null_value = true;
if (which.isUInt8() || which.isUInt16())
current_dimension_row.emplace_back(!null_value ? pqxx::from_string<uint16_t>(parsed.second) : UInt16());
else if (which.isUInt32())
current_dimension_row.emplace_back(!null_value ? pqxx::from_string<uint32_t>(parsed.second) : UInt32());
else if (which.isUInt64())
current_dimension_row.emplace_back(!null_value ? pqxx::from_string<uint64_t>(parsed.second) : UInt64());
else if (which.isInt8() || which.isInt16())
current_dimension_row.emplace_back(!null_value ? pqxx::from_string<int16_t>(parsed.second) : Int16());
else if (which.isInt32())
current_dimension_row.emplace_back(!null_value ? pqxx::from_string<int32_t>(parsed.second) : Int32());
else if (which.isInt64())
current_dimension_row.emplace_back(!null_value ? pqxx::from_string<int64_t>(parsed.second) : Int64());
//else if (which.isDate())
//else if (which.isDateTime())
else if (which.isFloat32())
current_dimension_row.emplace_back(!null_value ? pqxx::from_string<float>(parsed.second) : Float32());
else if (which.isFloat64())
current_dimension_row.emplace_back(!null_value ? pqxx::from_string<double>(parsed.second) : Float64());
else if (which.isString() || which.isFixedString())
current_dimension_row.emplace_back(!null_value ? parsed.second : String());
else throw Exception("Unexpected type " + nested->getName(), ErrorCodes::BAD_ARGUMENTS);
parsed = parser.get_next();
null_value = false;
}
while (parsed.first == pqxx::array_parser::juncture::row_end)
{
--dimension;
if (std::exchange(new_row, false))
{
if (dimension + 1 > max_dimension)
max_dimension = dimension + 1;
if (dimension)
dimensions[dimension].emplace_back(Array(current_dimension_row.begin(), current_dimension_row.end()));
else
return Array(current_dimension_row.begin(), current_dimension_row.end());
}
else if (dimension)
{
dimensions[dimension].emplace_back(Array(dimensions[dimension + 1].begin(), dimensions[dimension + 1].end()));
dimensions[dimension + 1].clear();
}
parsed = parser.get_next();
}
dimensions[dimension].emplace_back(Array(dimensions[dimension + 1].begin(), dimensions[dimension + 1].end()));
dimensions[dimension + 1].clear();
}
if (max_dimension < expected_dimensions)
throw Exception("Got less dimensions than expected", ErrorCodes::BAD_ARGUMENTS);
parsed = parser.get_next();
}
return Array(dimensions[1].begin(), dimensions[1].end());
});
if (max_dimension < expected_dimensions)
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"Got less dimensions than expected. ({} instead of {})", max_dimension, expected_dimensions);
assert_cast<ColumnArray &>(column).insert(get_array());
assert_cast<ColumnArray &>(column).insert(Array(dimensions[1].begin(), dimensions[1].end()));
break;
}
default:
throw Exception("Value of unsupported type:" + column.getName(), ErrorCodes::UNKNOWN_TYPE);
}
}
void PostgreSQLBlockInputStream::getArrayInfo(size_t column_idx, const DataTypePtr data_type)
{
const auto * array_type = typeid_cast<const DataTypeArray *>(data_type.get());
auto nested = array_type->getNestedType();
size_t count_dimensions = 1;
while (isArray(nested))
{
++count_dimensions;
nested = typeid_cast<const DataTypeArray *>(nested.get())->getNestedType();
}
Field default_value = nested->getDefault();
if (nested->isNullable())
nested = typeid_cast<const DataTypeNullable *>(nested.get())->getNestedType();
WhichDataType which(nested);
std::function<Field(std::string & fields)> parser;
if (which.isUInt8() || which.isUInt16())
parser = [&](std::string & field) -> Field { return pqxx::from_string<uint16_t>(field); };
else if (which.isUInt32())
parser = [&](std::string & field) -> Field { return pqxx::from_string<uint16_t>(field); };
else if (which.isUInt64())
parser = [&](std::string & field) -> Field { return pqxx::from_string<uint64_t>(field); };
else if (which.isInt8() || which.isInt16())
parser = [&](std::string & field) -> Field { return pqxx::from_string<int16_t>(field); };
else if (which.isInt32())
parser = [&](std::string & field) -> Field { return pqxx::from_string<int32_t>(field); };
else if (which.isInt64())
parser = [&](std::string & field) -> Field { return pqxx::from_string<uint16_t>(field); };
else if (which.isFloat32())
parser = [&](std::string & field) -> Field { return pqxx::from_string<float>(field); };
else if (which.isFloat64())
parser = [&](std::string & field) -> Field { return pqxx::from_string<double>(field); };
else if (which.isString() || which.isFixedString())
parser = [&](std::string & field) -> Field { return field; };
else if (which.isDate())
parser = [&](std::string & field) -> Field { return UInt16{LocalDate{field}.getDayNum()}; };
else if (which.isDateTime())
parser = [&](std::string & field) -> Field { return time_t{LocalDateTime{field}}; };
else throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unsupported type {} for pgArray", nested->getName());
array_info[column_idx] = {count_dimensions, default_value, parser};
}
}
#endif

View File

@ -24,13 +24,16 @@ public:
private:
using ValueType = ExternalResultDescription::ValueType;
void readPrefix() override;
Block readImpl() override;
void insertValue(IColumn & column, const std::string & value,
const ExternalResultDescription::ValueType type, const DataTypePtr data_type);
const ExternalResultDescription::ValueType type, const DataTypePtr data_type, size_t idx);
void insertDefaultValue(IColumn & column, const IColumn & sample_column)
{
column.insertFrom(sample_column, 0);
}
void getArrayInfo(size_t column_idx, const DataTypePtr data_type);
const String query_str;
const UInt64 max_block_size;
@ -39,6 +42,14 @@ private:
std::shared_ptr<pqxx::connection> connection;
std::unique_ptr<pqxx::work> work;
std::unique_ptr<pqxx::stream_from> stream;
struct ArrayInfo
{
size_t num_dimensions;
Field default_value;
std::function<Field(std::string & field)> pqxx_parser;
};
std::unordered_map<size_t, ArrayInfo> array_info;
};
}