add postgres-like cast operator

This commit is contained in:
Anton Popov 2021-05-04 06:43:17 +03:00
parent 08f10dced0
commit 2b79bf838f
9 changed files with 277 additions and 21 deletions

View File

@ -1061,7 +1061,10 @@ void ActionsMatcher::visit(const ASTFunction & node, const ASTPtr & ast, Data &
void ActionsMatcher::visit(const ASTLiteral & literal, const ASTPtr & /* ast */,
Data & data)
{
DataTypePtr type = applyVisitor(FieldToDataType(), literal.value);
DataTypePtr type = literal.data_type_hint
? literal.data_type_hint
: applyVisitor(FieldToDataType(), literal.value);
const auto value = convertFieldToType(literal.value, *type);
// FIXME why do we have a second pass with a clean sample block over the same

View File

@ -76,6 +76,8 @@ void ASTLiteral::appendColumnNameImpl(WriteBuffer & ostr) const
void ASTLiteral::formatImplWithoutAlias(const FormatSettings & settings, IAST::FormatState &, IAST::FormatStateStacked) const
{
settings.ostr << applyVisitor(FieldVisitorToString(), value);
if (data_type_hint)
settings.ostr << "::" << data_type_hint->getName();
}
}

View File

@ -4,6 +4,7 @@
#include <Parsers/ASTWithAlias.h>
#include <Parsers/TokenIterator.h>
#include <Common/FieldVisitors.h>
#include <DataTypes/IDataType.h>
#include <optional>
@ -33,6 +34,9 @@ public:
*/
String unique_column_name;
/// Hint for data type of literal, that can be set by operator "::".
DataTypePtr data_type_hint;
/** Get the text that identifies this element. */
String getID(char delim) const override { return "Literal" + (delim + applyVisitor(FieldVisitorDump(), value)); }

View File

@ -40,6 +40,9 @@
#include "ASTColumnsMatcher.h"
#include <Interpreters/StorageID.h>
#include <DataTypes/DataTypeFactory.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeTuple.h>
namespace DB
{
@ -794,7 +797,118 @@ bool ParserCodec::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
return true;
}
bool ParserCastExpression::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
ASTPtr createFunctionCast(const ASTPtr & expr_ast, const ASTPtr & type_ast)
{
/// Convert to canonical representation in functional form: CAST(expr, 'type')
auto type_literal = std::make_shared<ASTLiteral>(queryToString(type_ast));
auto expr_list_args = std::make_shared<ASTExpressionList>();
expr_list_args->children.push_back(expr_ast);
expr_list_args->children.push_back(std::move(type_literal));
auto func_node = std::make_shared<ASTFunction>();
func_node->name = "CAST";
func_node->arguments = std::move(expr_list_args);
func_node->children.push_back(func_node->arguments);
return func_node;
}
bool ParserCastOperator::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
/// Numbers, strings, tuples and arrays of them.
/// Types, that doesn't have representation in Field, e.g.: Date, DateTime,
/// can't be read from text as literals.
auto is_good_token = [](const auto & token)
{
return token == TokenType::Number
|| token == TokenType::StringLiteral
|| token == TokenType::Comma
|| token == TokenType::OpeningSquareBracket
|| token == TokenType::ClosingSquareBracket
|| token == TokenType::OpeningRoundBracket
|| token == TokenType::ClosingRoundBracket;
};
auto is_number_or_string = [](const auto & type) { return isNumber(type) || isStringOrFixedString(type); };
auto is_good_type = [&is_number_or_string](const auto & type)
{
if (is_number_or_string(type))
return true;
if (const auto * type_array = typeid_cast<const DataTypeArray *>(type.get()))
return is_number_or_string(type_array->getNestedType());
if (const auto * type_tuple = typeid_cast<const DataTypeTuple *>(type.get()))
{
const auto & elems = type_tuple->getElements();
return std::all_of(elems.begin(), elems.end(), [&](const auto & elem) { return is_number_or_string(elem); });
}
return false;
};
const char * data_begin = pos->begin;
bool is_number_literal = pos->type == TokenType::Number;
bool is_string_literal = pos->type == TokenType::StringLiteral;
size_t skipped_tokens = 0;
while (pos.isValid() && is_good_token(pos->type))
{
++pos;
++skipped_tokens;
}
if (!pos.isValid())
return false;
if ((is_string_literal || is_number_literal) && skipped_tokens != 1)
return false;
ASTPtr type_ast;
ParserToken parser_colon(TokenType::Colon);
const char * data_end = pos->begin;
if (parser_colon.ignore(pos, expected)
&& parser_colon.ignore(pos, expected)
&& ParserDataType().parse(pos, type_ast, expected))
{
auto type = DataTypeFactory::instance().get(type_ast);
if (!is_good_type(type))
return false;
/// Allow to parse numbers only from number literals,
/// because SerializationNumber uses unsafe version of int deserialization
/// and it won't throw an exception in case of error.
if (isNumber(type) && !is_number_literal)
return false;
ReadBufferFromMemory buf(data_begin, data_end - data_begin);
auto column = type->createColumn();
try
{
if (is_string_literal)
type->getDefaultSerialization()->deserializeTextQuoted(*column, buf, {});
else
type->getDefaultSerialization()->deserializeTextEscaped(*column, buf, {});
}
catch (const Exception &)
{
expected.add(pos, "literal with operator ::");
return false;
}
auto literal = std::make_shared<ASTLiteral>((*column)[0]);
literal->data_type_hint = type;
node = std::move(literal);
return true;
}
return false;
}
bool ParserCastAsExpression::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
/// Either CAST(expr AS type) or CAST(expr, 'type')
/// The latter will be parsed normally as a function later.
@ -809,20 +923,7 @@ bool ParserCastExpression::parseImpl(Pos & pos, ASTPtr & node, Expected & expect
&& ParserDataType().parse(pos, type_node, expected)
&& ParserToken(TokenType::ClosingRoundBracket).ignore(pos, expected))
{
/// Convert to canonical representation in functional form: CAST(expr, 'type')
auto type_literal = std::make_shared<ASTLiteral>(queryToString(type_node));
auto expr_list_args = std::make_shared<ASTExpressionList>();
expr_list_args->children.push_back(expr_node);
expr_list_args->children.push_back(std::move(type_literal));
auto func_node = std::make_shared<ASTFunction>();
func_node->name = "CAST";
func_node->arguments = std::move(expr_list_args);
func_node->children.push_back(func_node->arguments);
node = std::move(func_node);
node = createFunctionCast(expr_node, type_node);
return true;
}
@ -1951,12 +2052,13 @@ bool ParserMySQLGlobalVariable::parseImpl(Pos & pos, ASTPtr & node, Expected & e
bool ParserExpressionElement::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
return ParserSubquery().parse(pos, node, expected)
|| ParserCastOperator().parse(pos, node, expected)
|| ParserTupleOfLiterals().parse(pos, node, expected)
|| ParserParenthesisExpression().parse(pos, node, expected)
|| ParserArrayOfLiterals().parse(pos, node, expected)
|| ParserArray().parse(pos, node, expected)
|| ParserLiteral().parse(pos, node, expected)
|| ParserCastExpression().parse(pos, node, expected)
|| ParserCastAsExpression().parse(pos, node, expected)
|| ParserExtractExpression().parse(pos, node, expected)
|| ParserDateAddExpression().parse(pos, node, expected)
|| ParserDateDiffExpression().parse(pos, node, expected)

View File

@ -209,10 +209,22 @@ protected:
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override;
};
class ParserCastExpression : public IParserBase
/// Fast path of cast operator "::".
/// It tries to read literal as text.
/// If it fails, later operator will be transformed to function CAST.
/// Examples: "0.1::Decimal(38, 38)", "[1, 2]::Array(UInt8)"
class ParserCastOperator : public IParserBase
{
protected:
const char * getName() const override { return "CAST expression"; }
const char * getName() const override { return "CAST operator"; }
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override;
};
ASTPtr createFunctionCast(const ASTPtr & expr_ast, const ASTPtr & type_ast);
class ParserCastAsExpression : public IParserBase
{
protected:
const char * getName() const override { return "CAST AS expression"; }
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override;
};

View File

@ -3,7 +3,6 @@
#include <Parsers/ASTExpressionList.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTFunctionWithKeyValueArguments.h>
#include <Parsers/ExpressionElementParsers.h>
#include <Parsers/ParserCreateQuery.h>
#include <Parsers/parseIntervalKind.h>
#include <Common/StringUtils/StringUtils.h>
@ -558,11 +557,34 @@ bool ParserUnaryMinusExpression::parseImpl(Pos & pos, ASTPtr & node, Expected &
}
bool ParserCastExpression::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
ASTPtr expr_ast;
if (!elem_parser.parse(pos, expr_ast, expected))
return false;
ASTPtr type_ast;
ParserToken parser_colon(TokenType::Colon);
if (parser_colon.ignore(pos, expected)
&& parser_colon.ignore(pos, expected)
&& ParserDataType().parse(pos, type_ast, expected))
{
node = createFunctionCast(expr_ast, type_ast);
}
else
{
node = expr_ast;
}
return true;
}
bool ParserArrayElementExpression::parseImpl(Pos & pos, ASTPtr & node, Expected &expected)
{
return ParserLeftAssociativeBinaryOperatorList{
operators,
std::make_unique<ParserExpressionElement>(),
std::make_unique<ParserCastExpression>(),
std::make_unique<ParserExpressionWithOptionalAlias>(false)
}.parse(pos, node, expected);
}

View File

@ -6,6 +6,7 @@
#include <Parsers/CommonParsers.h>
#include <Parsers/ASTSelectWithUnionQuery.h>
#include <Parsers/ExpressionElementParsers.h>
#include <Common/IntervalKind.h>
namespace DB
@ -205,6 +206,20 @@ protected:
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override;
};
/// CAST operator "::". This parser is used if left argument
/// of operator cannot be read as simple literal from text of query.
/// Example: "[1, 1 + 1, 1 + 2]::Array(UInt8)"
class ParserCastExpression : public IParserBase
{
private:
ParserExpressionElement elem_parser;
protected:
const char * getName() const override { return "CAST expression"; }
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override;
};
class ParserArrayElementExpression : public IParserBase
{

View File

@ -0,0 +1,55 @@
0.10000000000000000000000000000000000000 Decimal(38, 38)
SELECT
\'0.10000000000000000000000000000000000000\'::Decimal(38, 38) AS c,
toTypeName(c)
[1,2,3] Array(UInt32)
SELECT
[1, 2, 3]::Array(UInt32) AS c,
toTypeName(c)
abc FixedString(3)
SELECT
\'abc\'::FixedString(3) AS c,
toTypeName(c)
123 String
SELECT
\'123\'::String AS c,
toTypeName(c)
1 Int8
SELECT
1::Int8 AS c,
toTypeName(c)
[1,2,3] Array(UInt32)
SELECT
CAST([1, 1 + 1, 1 + 2], \'Array(UInt32)\') AS c,
toTypeName(c)
2010-10-10 Date
SELECT
CAST(\'2010-10-10\', \'Date\') AS c,
toTypeName(c)
2010-10-10 00:00:00 DateTime
SELECT
CAST(\'2010-10-10\', \'DateTime\') AS c,
toTypeName(c)
['2010-10-10','2010-10-10'] Array(Date)
SELECT CAST([\'2010-10-10\', \'2010-10-10\'], \'Array(Date)\')
3 UInt32
SELECT
CAST(1 + 2, \'UInt32\') AS c,
toTypeName(c)
0.5 Float64
SELECT
CAST(\'0.1000\'::Decimal(4, 4) * 5, \'Float64\') AS c,
toTypeName(c)
0 UInt8
SELECT
CAST(number, \'UInt8\') AS c,
toTypeName(c)
FROM numbers(1)
1970-01-11 Date
SELECT
CAST((((0 + 1) + 2) + 3) + 4, \'Date\') AS c,
toTypeName(c)
0.6000 Decimal(4, 4)
SELECT
CAST((\'0.1000\'::Decimal(4, 4) + \'0.2000\'::Decimal(4, 4)) + \'0.3000\'::Decimal(4, 4), \'Decimal(4, 4)\') AS c,
toTypeName(c)

View File

@ -0,0 +1,41 @@
SELECT 0.1::Decimal(38, 38) AS c, toTypeName(c);
EXPLAIN SYNTAX SELECT 0.1::Decimal(38, 38) AS c, toTypeName(c);
SELECT [1, 2, 3]::Array(UInt32) AS c, toTypeName(c);
EXPLAIN SYNTAX SELECT [1, 2, 3]::Array(UInt32) AS c, toTypeName(c);
SELECT 'abc'::FixedString(3) AS c, toTypeName(c);
EXPLAIN SYNTAX SELECT 'abc'::FixedString(3) AS c, toTypeName(c);
SELECT 123::String AS c, toTypeName(c);
EXPLAIN SYNTAX SELECT 123::String AS c, toTypeName(c);
SELECT 1::Int8 AS c, toTypeName(c);
EXPLAIN SYNTAX SELECT 1::Int8 AS c, toTypeName(c);
SELECT [1, 1 + 1, 1 + 2]::Array(UInt32) AS c, toTypeName(c);
EXPLAIN SYNTAX SELECT [1, 1 + 1, 1 + 2]::Array(UInt32) AS c, toTypeName(c);
SELECT '2010-10-10'::Date AS c, toTypeName(c);
EXPLAIN SYNTAX SELECT '2010-10-10'::Date AS c, toTypeName(c);
SELECT '2010-10-10'::DateTime AS c, toTypeName(c);
EXPLAIN SYNTAX SELECT '2010-10-10'::DateTime AS c, toTypeName(c);
SELECT ['2010-10-10', '2010-10-10']::Array(Date) AS c, toTypeName(c);
EXPLAIN SYNTAX SELECT ['2010-10-10', '2010-10-10']::Array(Date);
SELECT (1 + 2)::UInt32 AS c, toTypeName(c);
EXPLAIN SYNTAX SELECT (1 + 2)::UInt32 AS c, toTypeName(c);
SELECT (0.1::Decimal(4, 4) * 5)::Float64 AS c, toTypeName(c);
EXPLAIN SYNTAX SELECT (0.1::Decimal(4, 4) * 5)::Float64 AS c, toTypeName(c);
SELECT number::UInt8 AS c, toTypeName(c) FROM numbers(1);
EXPLAIN SYNTAX SELECT number::UInt8 AS c, toTypeName(c) FROM numbers(1);
SELECT (0 + 1 + 2 + 3 + 4)::Date AS c, toTypeName(c);
EXPLAIN SYNTAX SELECT (0 + 1 + 2 + 3 + 4)::Date AS c, toTypeName(c);
SELECT (0.1::Decimal(4, 4) + 0.2::Decimal(4, 4) + 0.3::Decimal(4, 4))::Decimal(4, 4) AS c, toTypeName(c);
EXPLAIN SYNTAX SELECT (0.1::Decimal(4, 4) + 0.2::Decimal(4, 4) + 0.3::Decimal(4, 4))::Decimal(4, 4) AS c, toTypeName(c);