Merge branch 'master' into brotli

This commit is contained in:
Mikhail 2019-02-12 22:52:23 +03:00 committed by GitHub
commit 4fd289c1f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
142 changed files with 1170 additions and 579 deletions

View File

@ -102,9 +102,6 @@ add_headers_and_sources(dbms src/Interpreters/ClusterProxy)
add_headers_and_sources(dbms src/Columns)
add_headers_and_sources(dbms src/Storages)
add_headers_and_sources(dbms src/Storages/Distributed)
if(USE_RDKAFKA)
add_headers_and_sources(dbms src/Storages/Kafka)
endif()
add_headers_and_sources(dbms src/Storages/MergeTree)
add_headers_and_sources(dbms src/Client)
add_headers_and_sources(dbms src/Formats)
@ -297,11 +294,7 @@ if (USE_CAPNP)
endif ()
if (USE_RDKAFKA)
target_link_libraries (dbms PRIVATE ${RDKAFKA_LIBRARY})
target_link_libraries (dbms PRIVATE ${CPPKAFKA_LIBRARY})
if (NOT USE_INTERNAL_RDKAFKA_LIBRARY)
target_include_directories (dbms SYSTEM BEFORE PRIVATE ${RDKAFKA_INCLUDE_DIR})
endif ()
target_link_libraries (dbms PRIVATE clickhouse_storage_kafka)
endif ()
target_link_libraries(dbms PRIVATE ${OPENSSL_CRYPTO_LIBRARY} Threads::Threads)

View File

@ -2,10 +2,10 @@
set(VERSION_REVISION 54415)
set(VERSION_MAJOR 19)
set(VERSION_MINOR 3)
set(VERSION_PATCH 0)
set(VERSION_GITHASH 1db4bd8c2a1a0cd610c8a6564e8194dca5265562)
set(VERSION_DESCRIBE v19.3.0-testing)
set(VERSION_STRING 19.3.0)
set(VERSION_PATCH 1)
set(VERSION_GITHASH 48280074c4a9151ca010fb0a777efd82634460bd)
set(VERSION_DESCRIBE v19.3.1-testing)
set(VERSION_STRING 19.3.1)
# end of autochange
set(VERSION_EXTRA "" CACHE STRING "")

View File

@ -1,13 +1,6 @@
add_library (clickhouse-odbc-bridge-lib ${LINK_MODE}
PingHandler.cpp
MainHandler.cpp
ColumnInfoHandler.cpp
IdentifierQuoteHandler.cpp
HandlerFactory.cpp
ODBCBridge.cpp
getIdentifierQuote.cpp
validateODBCConnectionString.cpp
)
add_headers_and_sources(clickhouse_odbc_bridge .)
add_library (clickhouse-odbc-bridge-lib ${LINK_MODE} ${clickhouse_odbc_bridge_sources})
target_link_libraries (clickhouse-odbc-bridge-lib PRIVATE daemon dbms clickhouse_common_io)
target_include_directories (clickhouse-odbc-bridge-lib PUBLIC ${ClickHouse_SOURCE_DIR}/libs/libdaemon/include)

View File

@ -4,7 +4,7 @@
#include <memory>
#include <DataStreams/copyData.h>
#include <DataTypes/DataTypeFactory.h>
#include <Dictionaries/ODBCBlockInputStream.h>
#include "ODBCBlockInputStream.h"
#include <Formats/BinaryRowInputStream.h>
#include <Formats/FormatFactory.h>
#include <IO/WriteBufferFromHTTPServerResponse.h>

View File

@ -6,7 +6,7 @@
#include <Poco/Data/RecordSet.h>
#include <Poco/Data/Session.h>
#include <Poco/Data/Statement.h>
#include "ExternalResultDescription.h"
#include <Dictionaries/ExternalResultDescription.h>
namespace DB

View File

@ -240,7 +240,7 @@ void PerformanceTest::runQueries(
statistics.startWatches();
try
{
executeQuery(connection, query, statistics, stop_conditions, interrupt_listener, context);
executeQuery(connection, query, statistics, stop_conditions, interrupt_listener, context, test_info.settings);
if (test_info.exec_type == ExecutionType::Loop)
{
@ -254,7 +254,7 @@ void PerformanceTest::runQueries(
break;
}
executeQuery(connection, query, statistics, stop_conditions, interrupt_listener, context);
executeQuery(connection, query, statistics, stop_conditions, interrupt_listener, context, test_info.settings);
}
}
}

View File

@ -44,14 +44,14 @@ void executeQuery(
TestStats & statistics,
TestStopConditions & stop_conditions,
InterruptListener & interrupt_listener,
Context & context)
Context & context,
const Settings & settings)
{
statistics.watch_per_query.restart();
statistics.last_query_was_cancelled = false;
statistics.last_query_rows_read = 0;
statistics.last_query_bytes_read = 0;
Settings settings;
RemoteBlockInputStream stream(connection, query, {}, context, &settings);
stream.setProgressCallback(

View File

@ -4,6 +4,7 @@
#include "TestStopConditions.h"
#include <Common/InterruptListener.h>
#include <Interpreters/Context.h>
#include <Interpreters/Settings.h>
#include <Client/Connection.h>
namespace DB
@ -14,5 +15,6 @@ void executeQuery(
TestStats & statistics,
TestStopConditions & stop_conditions,
InterruptListener & interrupt_listener,
Context & context);
Context & context,
const Settings & settings);
}

View File

@ -31,12 +31,13 @@ template <typename Data>
class AggregateFunctionArgMinMax final : public IAggregateFunctionDataHelper<Data, AggregateFunctionArgMinMax<Data>>
{
private:
DataTypePtr type_res;
DataTypePtr type_val;
const DataTypePtr & type_res;
const DataTypePtr & type_val;
public:
AggregateFunctionArgMinMax(const DataTypePtr & type_res, const DataTypePtr & type_val)
: type_res(type_res), type_val(type_val)
: IAggregateFunctionDataHelper<Data, AggregateFunctionArgMinMax<Data>>({type_res, type_val}, {}),
type_res(this->argument_types[0]), type_val(this->argument_types[1])
{
if (!type_val->isComparable())
throw Exception("Illegal type " + type_val->getName() + " of second argument of aggregate function " + getName()

View File

@ -28,7 +28,8 @@ private:
public:
AggregateFunctionArray(AggregateFunctionPtr nested_, const DataTypes & arguments)
: nested_func(nested_), num_arguments(arguments.size())
: IAggregateFunctionHelper<AggregateFunctionArray>(arguments, {})
, nested_func(nested_), num_arguments(arguments.size())
{
for (const auto & type : arguments)
if (!isArray(type))

View File

@ -27,9 +27,9 @@ AggregateFunctionPtr createAggregateFunctionAvg(const std::string & name, const
AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
if (isDecimal(data_type))
res.reset(createWithDecimalType<AggregateFuncAvg>(*data_type, *data_type));
res.reset(createWithDecimalType<AggregateFuncAvg>(*data_type, *data_type, argument_types));
else
res.reset(createWithNumericType<AggregateFuncAvg>(*data_type));
res.reset(createWithNumericType<AggregateFuncAvg>(*data_type, argument_types));
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,

View File

@ -49,13 +49,15 @@ public:
using ColVecResult = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128>, ColumnVector<Float64>>;
/// ctor for native types
AggregateFunctionAvg()
: scale(0)
AggregateFunctionAvg(const DataTypes & argument_types)
: IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>(argument_types, {})
, scale(0)
{}
/// ctor for Decimals
AggregateFunctionAvg(const IDataType & data_type)
: scale(getDecimalScale(data_type))
AggregateFunctionAvg(const IDataType & data_type, const DataTypes & argument_types)
: IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>(argument_types, {})
, scale(getDecimalScale(data_type))
{}
String getName() const override { return "avg"; }

View File

@ -21,7 +21,7 @@ AggregateFunctionPtr createAggregateFunctionBitwise(const std::string & name, co
+ " is illegal, because it cannot be used in bitwise operations",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
AggregateFunctionPtr res(createWithUnsignedIntegerType<AggregateFunctionBitwise, Data>(*argument_types[0]));
AggregateFunctionPtr res(createWithUnsignedIntegerType<AggregateFunctionBitwise, Data>(*argument_types[0], argument_types[0]));
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -43,6 +43,9 @@ template <typename T, typename Data>
class AggregateFunctionBitwise final : public IAggregateFunctionDataHelper<Data, AggregateFunctionBitwise<T, Data>>
{
public:
AggregateFunctionBitwise(const DataTypePtr & type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitwise<T, Data>>({type}, {}) {}
String getName() const override { return Data::name(); }
DataTypePtr getReturnType() const override

View File

@ -111,6 +111,7 @@ public:
}
AggregateFunctionBoundingRatio(const DataTypes & arguments)
: IAggregateFunctionDataHelper<AggregateFunctionBoundingRatioData, AggregateFunctionBoundingRatio>(arguments, {})
{
const auto x_arg = arguments.at(0).get();
const auto y_arg = arguments.at(0).get();

View File

@ -9,12 +9,12 @@ namespace DB
namespace
{
AggregateFunctionPtr createAggregateFunctionCount(const std::string & name, const DataTypes & /*argument_types*/, const Array & parameters)
AggregateFunctionPtr createAggregateFunctionCount(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertNoParameters(name, parameters);
/// 'count' accept any number of arguments and (in this case of non-Nullable types) simply ignore them.
return std::make_shared<AggregateFunctionCount>();
return std::make_shared<AggregateFunctionCount>(argument_types);
}
}

View File

@ -28,6 +28,8 @@ namespace ErrorCodes
class AggregateFunctionCount final : public IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCount>
{
public:
AggregateFunctionCount(const DataTypes & argument_types) : IAggregateFunctionDataHelper(argument_types, {}) {}
String getName() const override { return "count"; }
DataTypePtr getReturnType() const override
@ -74,7 +76,8 @@ public:
class AggregateFunctionCountNotNullUnary final : public IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCountNotNullUnary>
{
public:
AggregateFunctionCountNotNullUnary(const DataTypePtr & argument)
AggregateFunctionCountNotNullUnary(const DataTypePtr & argument, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCountNotNullUnary>({argument}, params)
{
if (!argument->isNullable())
throw Exception("Logical error: not Nullable data type passed to AggregateFunctionCountNotNullUnary", ErrorCodes::LOGICAL_ERROR);
@ -120,7 +123,8 @@ public:
class AggregateFunctionCountNotNullVariadic final : public IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCountNotNullVariadic>
{
public:
AggregateFunctionCountNotNullVariadic(const DataTypes & arguments)
AggregateFunctionCountNotNullVariadic(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCountNotNullVariadic>(arguments, params)
{
number_of_arguments = arguments.size();

View File

@ -26,12 +26,12 @@ AggregateFunctionPtr createAggregateFunctionEntropy(const std::string & name, co
if (num_args == 1)
{
/// Specialized implementation for single argument of numeric type.
if (auto res = createWithNumericBasedType<AggregateFunctionEntropy>(*argument_types[0], num_args))
if (auto res = createWithNumericBasedType<AggregateFunctionEntropy>(*argument_types[0], argument_types))
return AggregateFunctionPtr(res);
}
/// Generic implementation for other types or for multiple arguments.
return std::make_shared<AggregateFunctionEntropy<UInt128>>(num_args);
return std::make_shared<AggregateFunctionEntropy<UInt128>>(argument_types);
}
}

View File

@ -97,7 +97,9 @@ private:
size_t num_args;
public:
AggregateFunctionEntropy(size_t num_args) : num_args(num_args)
AggregateFunctionEntropy(const DataTypes & argument_types)
: IAggregateFunctionDataHelper<EntropyData<Value>, AggregateFunctionEntropy<Value>>(argument_types, {})
, num_args(argument_types.size())
{
}

View File

@ -97,7 +97,8 @@ private:
public:
AggregateFunctionForEach(AggregateFunctionPtr nested_, const DataTypes & arguments)
: nested_func(nested_), num_arguments(arguments.size())
: IAggregateFunctionDataHelper<AggregateFunctionForEachData, AggregateFunctionForEach>(arguments, {})
, nested_func(nested_), num_arguments(arguments.size())
{
nested_size_of_data = nested_func->sizeOfData();

View File

@ -48,12 +48,13 @@ class GroupArrayNumericImpl final
: public IAggregateFunctionDataHelper<GroupArrayNumericData<T>, GroupArrayNumericImpl<T, Tlimit_num_elems>>
{
static constexpr bool limit_num_elems = Tlimit_num_elems::value;
DataTypePtr data_type;
DataTypePtr & data_type;
UInt64 max_elems;
public:
explicit GroupArrayNumericImpl(const DataTypePtr & data_type_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: data_type(data_type_), max_elems(max_elems_) {}
: IAggregateFunctionDataHelper<GroupArrayNumericData<T>, GroupArrayNumericImpl<T, Tlimit_num_elems>>({data_type_}, {})
, data_type(this->argument_types[0]), max_elems(max_elems_) {}
String getName() const override { return "groupArray"; }
@ -248,12 +249,13 @@ class GroupArrayGeneralListImpl final
static Data & data(AggregateDataPtr place) { return *reinterpret_cast<Data*>(place); }
static const Data & data(ConstAggregateDataPtr place) { return *reinterpret_cast<const Data*>(place); }
DataTypePtr data_type;
DataTypePtr & data_type;
UInt64 max_elems;
public:
GroupArrayGeneralListImpl(const DataTypePtr & data_type, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: data_type(data_type), max_elems(max_elems_) {}
: IAggregateFunctionDataHelper<GroupArrayGeneralListData<Node>, GroupArrayGeneralListImpl<Node, limit_num_elems>>({data_type}, {})
, data_type(this->argument_types[0]), max_elems(max_elems_) {}
String getName() const override { return "groupArray"; }

View File

@ -13,6 +13,10 @@ namespace
AggregateFunctionPtr createAggregateFunctionGroupArrayInsertAt(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertBinary(name, argument_types);
if (argument_types.size() != 2)
throw Exception("Aggregate function groupArrayInsertAt requires two arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
return std::make_shared<AggregateFunctionGroupArrayInsertAtGeneric>(argument_types, parameters);
}

View File

@ -54,12 +54,14 @@ class AggregateFunctionGroupArrayInsertAtGeneric final
: public IAggregateFunctionDataHelper<AggregateFunctionGroupArrayInsertAtDataGeneric, AggregateFunctionGroupArrayInsertAtGeneric>
{
private:
DataTypePtr type;
DataTypePtr & type;
Field default_value;
UInt64 length_to_resize = 0; /// zero means - do not do resizing.
public:
AggregateFunctionGroupArrayInsertAtGeneric(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionGroupArrayInsertAtDataGeneric, AggregateFunctionGroupArrayInsertAtGeneric>(arguments, params)
, type(argument_types[0])
{
if (!params.empty())
{
@ -76,14 +78,9 @@ public:
}
}
if (arguments.size() != 2)
throw Exception("Aggregate function " + getName() + " requires two arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (!isUnsignedInteger(arguments[1]))
throw Exception("Second argument of aggregate function " + getName() + " must be integer.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
type = arguments.front();
if (default_value.isNull())
default_value = type->getDefault();
else

View File

@ -15,11 +15,15 @@ namespace
/// Substitute return type for Date and DateTime
class AggregateFunctionGroupUniqArrayDate : public AggregateFunctionGroupUniqArray<DataTypeDate::FieldType>
{
public:
AggregateFunctionGroupUniqArrayDate(const DataTypePtr & argument_type) : AggregateFunctionGroupUniqArray<DataTypeDate::FieldType>(argument_type) {}
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDate>()); }
};
class AggregateFunctionGroupUniqArrayDateTime : public AggregateFunctionGroupUniqArray<DataTypeDateTime::FieldType>
{
public:
AggregateFunctionGroupUniqArrayDateTime(const DataTypePtr & argument_type) : AggregateFunctionGroupUniqArray<DataTypeDateTime::FieldType>(argument_type) {}
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDateTime>()); }
};
@ -27,8 +31,8 @@ class AggregateFunctionGroupUniqArrayDateTime : public AggregateFunctionGroupUni
static IAggregateFunction * createWithExtraTypes(const DataTypePtr & argument_type)
{
WhichDataType which(argument_type);
if (which.idx == TypeIndex::Date) return new AggregateFunctionGroupUniqArrayDate;
else if (which.idx == TypeIndex::DateTime) return new AggregateFunctionGroupUniqArrayDateTime;
if (which.idx == TypeIndex::Date) return new AggregateFunctionGroupUniqArrayDate(argument_type);
else if (which.idx == TypeIndex::DateTime) return new AggregateFunctionGroupUniqArrayDateTime(argument_type);
else
{
/// Check that we can use plain version of AggreagteFunctionGroupUniqArrayGeneric
@ -44,7 +48,7 @@ AggregateFunctionPtr createAggregateFunctionGroupUniqArray(const std::string & n
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionGroupUniqArray>(*argument_types[0]));
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionGroupUniqArray>(*argument_types[0], argument_types[0]));
if (!res)
res = AggregateFunctionPtr(createWithExtraTypes(argument_types[0]));

View File

@ -44,6 +44,9 @@ private:
using State = AggregateFunctionGroupUniqArrayData<T>;
public:
AggregateFunctionGroupUniqArray(const DataTypePtr & argument_type)
: IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayData<T>, AggregateFunctionGroupUniqArray<T>>({argument_type}, {}) {}
String getName() const override { return "groupUniqArray"; }
DataTypePtr getReturnType() const override
@ -115,7 +118,7 @@ template <bool is_plain_column = false>
class AggreagteFunctionGroupUniqArrayGeneric
: public IAggregateFunctionDataHelper<AggreagteFunctionGroupUniqArrayGenericData, AggreagteFunctionGroupUniqArrayGeneric<is_plain_column>>
{
DataTypePtr input_data_type;
DataTypePtr & input_data_type;
using State = AggreagteFunctionGroupUniqArrayGenericData;
@ -125,7 +128,8 @@ class AggreagteFunctionGroupUniqArrayGeneric
public:
AggreagteFunctionGroupUniqArrayGeneric(const DataTypePtr & input_data_type)
: input_data_type(input_data_type) {}
: IAggregateFunctionDataHelper<AggreagteFunctionGroupUniqArrayGenericData, AggreagteFunctionGroupUniqArrayGeneric<is_plain_column>>({input_data_type}, {})
, input_data_type(this->argument_types[0]) {}
String getName() const override { return "groupUniqArray"; }

View File

@ -39,7 +39,7 @@ AggregateFunctionPtr createAggregateFunctionHistogram(const std::string & name,
throw Exception("Bin count should be positive", ErrorCodes::BAD_ARGUMENTS);
assertUnary(name, arguments);
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionHistogram>(*arguments[0], bins_count));
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionHistogram>(*arguments[0], bins_count, arguments, params));
if (!res)
throw Exception("Illegal type " + arguments[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -304,8 +304,9 @@ private:
const UInt32 max_bins;
public:
AggregateFunctionHistogram(UInt32 max_bins)
: max_bins(max_bins)
AggregateFunctionHistogram(UInt32 max_bins, const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionHistogramData, AggregateFunctionHistogram<T>>(arguments, params)
, max_bins(max_bins)
{
}

View File

@ -28,7 +28,8 @@ private:
public:
AggregateFunctionIf(AggregateFunctionPtr nested, const DataTypes & types)
: nested_func(nested), num_arguments(types.size())
: IAggregateFunctionHelper<AggregateFunctionIf>(types, nested->getParameters())
, nested_func(nested), num_arguments(types.size())
{
if (num_arguments == 0)
throw Exception("Aggregate function " + getName() + " require at least one argument", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);

View File

@ -59,7 +59,7 @@ private:
public:
AggregateFunctionIntersectionsMax(AggregateFunctionIntersectionsKind kind_, const DataTypes & arguments)
: kind(kind_)
: IAggregateFunctionDataHelper<MaxIntersectionsData<PointType>, AggregateFunctionIntersectionsMax<PointType>>(arguments, {}), kind(kind_)
{
if (!isNumber(arguments[0]))
throw Exception{getName() + ": first argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};

View File

@ -47,7 +47,7 @@ public:
+ ", because it corresponds to different aggregate function: " + function->getFunctionName() + " instead of " + nested_function->getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<AggregateFunctionMerge>(nested_function, *argument);
return std::make_shared<AggregateFunctionMerge>(nested_function, argument);
}
};

View File

@ -22,13 +22,14 @@ private:
AggregateFunctionPtr nested_func;
public:
AggregateFunctionMerge(const AggregateFunctionPtr & nested_, const IDataType & argument)
: nested_func(nested_)
AggregateFunctionMerge(const AggregateFunctionPtr & nested_, const DataTypePtr & argument)
: IAggregateFunctionHelper<AggregateFunctionMerge>({argument}, nested_->getParameters())
, nested_func(nested_)
{
const DataTypeAggregateFunction * data_type = typeid_cast<const DataTypeAggregateFunction *>(&argument);
const DataTypeAggregateFunction * data_type = typeid_cast<const DataTypeAggregateFunction *>(argument.get());
if (!data_type || data_type->getFunctionName() != nested_func->getName())
throw Exception("Illegal type " + argument.getName() + " of argument for aggregate function " + getName(),
throw Exception("Illegal type " + argument->getName() + " of argument for aggregate function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}

View File

@ -676,10 +676,12 @@ template <typename Data>
class AggregateFunctionsSingleValue final : public IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data>>
{
private:
DataTypePtr type;
DataTypePtr & type;
public:
AggregateFunctionsSingleValue(const DataTypePtr & type) : type(type)
AggregateFunctionsSingleValue(const DataTypePtr & type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data>>({type}, {})
, type(this->argument_types[0])
{
if (StringRef(Data::name()) == StringRef("min")
|| StringRef(Data::name()) == StringRef("max"))

View File

@ -15,6 +15,9 @@ namespace DB
class AggregateFunctionNothing final : public IAggregateFunctionHelper<AggregateFunctionNothing>
{
public:
AggregateFunctionNothing(const DataTypes & arguments, const Array & params)
: IAggregateFunctionHelper<AggregateFunctionNothing>(arguments, params) {}
String getName() const override
{
return "nothing";

View File

@ -30,7 +30,7 @@ public:
}
AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array &) const override
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override
{
bool has_nullable_types = false;
bool has_null_types = false;
@ -55,29 +55,29 @@ public:
if (nested_function && nested_function->getName() == "count")
{
if (arguments.size() == 1)
return std::make_shared<AggregateFunctionCountNotNullUnary>(arguments[0]);
return std::make_shared<AggregateFunctionCountNotNullUnary>(arguments[0], params);
else
return std::make_shared<AggregateFunctionCountNotNullVariadic>(arguments);
return std::make_shared<AggregateFunctionCountNotNullVariadic>(arguments, params);
}
if (has_null_types)
return std::make_shared<AggregateFunctionNothing>();
return std::make_shared<AggregateFunctionNothing>(arguments, params);
bool return_type_is_nullable = nested_function->getReturnType()->canBeInsideNullable();
if (arguments.size() == 1)
{
if (return_type_is_nullable)
return std::make_shared<AggregateFunctionNullUnary<true>>(nested_function);
return std::make_shared<AggregateFunctionNullUnary<true>>(nested_function, arguments, params);
else
return std::make_shared<AggregateFunctionNullUnary<false>>(nested_function);
return std::make_shared<AggregateFunctionNullUnary<false>>(nested_function, arguments, params);
}
else
{
if (return_type_is_nullable)
return std::make_shared<AggregateFunctionNullVariadic<true>>(nested_function, arguments);
return std::make_shared<AggregateFunctionNullVariadic<true>>(nested_function, arguments, params);
else
return std::make_shared<AggregateFunctionNullVariadic<false>>(nested_function, arguments);
return std::make_shared<AggregateFunctionNullVariadic<false>>(nested_function, arguments, params);
}
}
};

View File

@ -68,8 +68,8 @@ protected:
}
public:
AggregateFunctionNullBase(AggregateFunctionPtr nested_function_)
: nested_function{nested_function_}
AggregateFunctionNullBase(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: IAggregateFunctionHelper<Derived>(arguments, params), nested_function{nested_function_}
{
if (result_is_nullable)
prefix_size = nested_function->alignOfData();
@ -187,8 +187,8 @@ template <bool result_is_nullable>
class AggregateFunctionNullUnary final : public AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullUnary<result_is_nullable>>
{
public:
AggregateFunctionNullUnary(AggregateFunctionPtr nested_function_)
: AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullUnary<result_is_nullable>>(std::move(nested_function_))
AggregateFunctionNullUnary(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullUnary<result_is_nullable>>(std::move(nested_function_), arguments, params)
{
}
@ -209,8 +209,8 @@ template <bool result_is_nullable>
class AggregateFunctionNullVariadic final : public AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullVariadic<result_is_nullable>>
{
public:
AggregateFunctionNullVariadic(AggregateFunctionPtr nested_function_, const DataTypes & arguments)
: AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullVariadic<result_is_nullable>>(std::move(nested_function_)),
AggregateFunctionNullVariadic(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullVariadic<result_is_nullable>>(std::move(nested_function_), arguments, params),
number_of_arguments(arguments.size())
{
if (number_of_arguments == 1)

View File

@ -73,11 +73,12 @@ private:
/// Used when there are single level to get.
Float64 level = 0.5;
DataTypePtr argument_type;
DataTypePtr & argument_type;
public:
AggregateFunctionQuantile(const DataTypePtr & argument_type, const Array & params)
: levels(params, returns_many), level(levels.levels[0]), argument_type(argument_type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionQuantile<Value, Data, Name, has_second_arg, FloatReturnType, returns_many>>({argument_type}, params)
, levels(params, returns_many), level(levels.levels[0]), argument_type(this->argument_types[0])
{
if (!returns_many && levels.size() > 1)
throw Exception("Aggregate function " + getName() + " require one parameter or less", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);

View File

@ -76,6 +76,7 @@ public:
}
AggregateFunctionRetention(const DataTypes & arguments)
: IAggregateFunctionDataHelper<AggregateFunctionRetentionData, AggregateFunctionRetention>(arguments, {})
{
for (const auto i : ext::range(0, arguments.size()))
{

View File

@ -19,7 +19,7 @@ AggregateFunctionPtr createAggregateFunctionSequenceCount(const std::string & na
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
String pattern = params.front().safeGet<std::string>();
return std::make_shared<AggregateFunctionSequenceCount>(argument_types, pattern);
return std::make_shared<AggregateFunctionSequenceCount>(argument_types, params, pattern);
}
AggregateFunctionPtr createAggregateFunctionSequenceMatch(const std::string & name, const DataTypes & argument_types, const Array & params)
@ -29,7 +29,7 @@ AggregateFunctionPtr createAggregateFunctionSequenceMatch(const std::string & na
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
String pattern = params.front().safeGet<std::string>();
return std::make_shared<AggregateFunctionSequenceMatch>(argument_types, pattern);
return std::make_shared<AggregateFunctionSequenceMatch>(argument_types, params, pattern);
}
}

View File

@ -139,8 +139,9 @@ template <typename Derived>
class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper<AggregateFunctionSequenceMatchData, Derived>
{
public:
AggregateFunctionSequenceBase(const DataTypes & arguments, const String & pattern)
: pattern(pattern)
AggregateFunctionSequenceBase(const DataTypes & arguments, const Array & params, const String & pattern)
: IAggregateFunctionDataHelper<AggregateFunctionSequenceMatchData, Derived>(arguments, params)
, pattern(pattern)
{
arg_count = arguments.size();
@ -578,6 +579,9 @@ private:
class AggregateFunctionSequenceMatch final : public AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>
{
public:
AggregateFunctionSequenceMatch(const DataTypes & arguments, const Array & params, const String & pattern)
: AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>(arguments, params, pattern) {}
using AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>::AggregateFunctionSequenceBase;
String getName() const override { return "sequenceMatch"; }
@ -603,6 +607,9 @@ public:
class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>
{
public:
AggregateFunctionSequenceCount(const DataTypes & arguments, const Array & params, const String & pattern)
: AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>(arguments, params, pattern) {}
using AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>::AggregateFunctionSequenceBase;
String getName() const override { return "sequenceCount"; }

View File

@ -24,7 +24,8 @@ private:
public:
AggregateFunctionState(AggregateFunctionPtr nested, const DataTypes & arguments, const Array & params)
: nested_func(nested), arguments(arguments), params(params) {}
: IAggregateFunctionHelper<AggregateFunctionState>(arguments, params)
, nested_func(nested), arguments(arguments), params(params) {}
String getName() const override
{

View File

@ -21,7 +21,7 @@ AggregateFunctionPtr createAggregateFunctionStatisticsUnary(const std::string &
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
AggregateFunctionPtr res(createWithNumericType<FunctionTemplate>(*argument_types[0]));
AggregateFunctionPtr res(createWithNumericType<FunctionTemplate>(*argument_types[0], argument_types[0]));
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -35,7 +35,7 @@ AggregateFunctionPtr createAggregateFunctionStatisticsBinary(const std::string &
assertNoParameters(name, parameters);
assertBinary(name, argument_types);
AggregateFunctionPtr res(createWithTwoNumericTypes<FunctionTemplate>(*argument_types[0], *argument_types[1]));
AggregateFunctionPtr res(createWithTwoNumericTypes<FunctionTemplate>(*argument_types[0], *argument_types[1], argument_types));
if (!res)
throw Exception("Illegal types " + argument_types[0]->getName() + " and " + argument_types[1]->getName()
+ " of arguments for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -111,6 +111,9 @@ class AggregateFunctionVariance final
: public IAggregateFunctionDataHelper<AggregateFunctionVarianceData<T, Op>, AggregateFunctionVariance<T, Op>>
{
public:
AggregateFunctionVariance(const DataTypePtr & arg)
: IAggregateFunctionDataHelper<AggregateFunctionVarianceData<T, Op>, AggregateFunctionVariance<T, Op>>({arg}, {}) {}
String getName() const override { return Op::name; }
DataTypePtr getReturnType() const override
@ -361,6 +364,10 @@ class AggregateFunctionCovariance final
AggregateFunctionCovariance<T, U, Op, compute_marginal_moments>>
{
public:
AggregateFunctionCovariance(const DataTypes & args) : IAggregateFunctionDataHelper<
CovarianceData<T, U, Op, compute_marginal_moments>,
AggregateFunctionCovariance<T, U, Op, compute_marginal_moments>>(args, {}) {}
String getName() const override { return Op::name; }
DataTypePtr getReturnType() const override

View File

@ -288,12 +288,14 @@ public:
using ResultType = typename StatFunc::ResultType;
using ColVecResult = ColumnVector<ResultType>;
AggregateFunctionVarianceSimple()
: src_scale(0)
AggregateFunctionVarianceSimple(const DataTypes & argument_types)
: IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>(argument_types, {})
, src_scale(0)
{}
AggregateFunctionVarianceSimple(const IDataType & data_type)
: src_scale(getDecimalScale(data_type))
AggregateFunctionVarianceSimple(const IDataType & data_type, const DataTypes & argument_types)
: IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>(argument_types, {})
, src_scale(getDecimalScale(data_type))
{}
String getName() const override

View File

@ -50,9 +50,9 @@ AggregateFunctionPtr createAggregateFunctionSum(const std::string & name, const
AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
if (isDecimal(data_type))
res.reset(createWithDecimalType<Function>(*data_type, *data_type));
res.reset(createWithDecimalType<Function>(*data_type, *data_type, argument_types));
else
res.reset(createWithNumericType<Function>(*data_type));
res.reset(createWithNumericType<Function>(*data_type, argument_types));
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,

View File

@ -102,12 +102,14 @@ public:
String getName() const override { return "sum"; }
AggregateFunctionSum()
: scale(0)
AggregateFunctionSum(const DataTypes & argument_types)
: IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data>>(argument_types, {})
, scale(0)
{}
AggregateFunctionSum(const IDataType & data_type)
: scale(getDecimalScale(data_type))
AggregateFunctionSum(const IDataType & data_type, const DataTypes & argument_types)
: IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data>>(argument_types, {})
, scale(getDecimalScale(data_type))
{}
DataTypePtr getReturnType() const override

View File

@ -80,9 +80,9 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con
auto [keys_type, values_types] = parseArguments(name, arguments);
AggregateFunctionPtr res(createWithNumericBasedType<Function>(*keys_type, keys_type, values_types));
AggregateFunctionPtr res(createWithNumericBasedType<Function>(*keys_type, keys_type, values_types, arguments));
if (!res)
res.reset(createWithDecimalType<Function>(*keys_type, keys_type, values_types));
res.reset(createWithDecimalType<Function>(*keys_type, keys_type, values_types, arguments));
if (!res)
throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -103,9 +103,9 @@ AggregateFunctionPtr createAggregateFunctionSumMapFiltered(const std::string & n
auto [keys_type, values_types] = parseArguments(name, arguments);
AggregateFunctionPtr res(createWithNumericBasedType<Function>(*keys_type, keys_type, values_types, keys_to_keep));
AggregateFunctionPtr res(createWithNumericBasedType<Function>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
if (!res)
res.reset(createWithDecimalType<Function>(*keys_type, keys_type, values_types, keys_to_keep));
res.reset(createWithDecimalType<Function>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
if (!res)
throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -61,8 +61,11 @@ private:
DataTypes values_types;
public:
AggregateFunctionSumMapBase(const DataTypePtr & keys_type, const DataTypes & values_types)
: keys_type(keys_type), values_types(values_types) {}
AggregateFunctionSumMapBase(
const DataTypePtr & keys_type, const DataTypes & values_types,
const DataTypes & argument_types, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionSumMapData<NearestFieldType<T>>, Derived>(argument_types, params)
, keys_type(keys_type), values_types(values_types) {}
String getName() const override { return "sumMap"; }
@ -271,8 +274,8 @@ private:
using Base = AggregateFunctionSumMapBase<T, Self, OverflowPolicy>;
public:
AggregateFunctionSumMap(const DataTypePtr & keys_type, DataTypes & values_types)
: Base{keys_type, values_types}
AggregateFunctionSumMap(const DataTypePtr & keys_type, DataTypes & values_types, const DataTypes & argument_types)
: Base{keys_type, values_types, argument_types, {}}
{}
String getName() const override { return "sumMap"; }
@ -291,8 +294,10 @@ private:
std::unordered_set<T> keys_to_keep;
public:
AggregateFunctionSumMapFiltered(const DataTypePtr & keys_type, const DataTypes & values_types, const Array & keys_to_keep_)
: Base{keys_type, values_types}
AggregateFunctionSumMapFiltered(
const DataTypePtr & keys_type, const DataTypes & values_types, const Array & keys_to_keep_,
const DataTypes & argument_types, const Array & params)
: Base{keys_type, values_types, argument_types, params}
{
keys_to_keep.reserve(keys_to_keep_.size());
for (const Field & f : keys_to_keep_)

View File

@ -39,19 +39,19 @@ class AggregateFunctionTopKDateTime : public AggregateFunctionTopK<DataTypeDateT
template <bool is_weighted>
static IAggregateFunction * createWithExtraTypes(const DataTypePtr & argument_type, UInt64 threshold)
static IAggregateFunction * createWithExtraTypes(const DataTypePtr & argument_type, UInt64 threshold, const Array & params)
{
WhichDataType which(argument_type);
if (which.idx == TypeIndex::Date)
return new AggregateFunctionTopKDate<is_weighted>(threshold);
return new AggregateFunctionTopKDate<is_weighted>(threshold, {argument_type}, params);
if (which.idx == TypeIndex::DateTime)
return new AggregateFunctionTopKDateTime<is_weighted>(threshold);
return new AggregateFunctionTopKDateTime<is_weighted>(threshold, {argument_type}, params);
/// Check that we can use plain version of AggregateFunctionTopKGeneric
if (argument_type->isValueUnambiguouslyRepresentedInContiguousMemoryRegion())
return new AggregateFunctionTopKGeneric<true, is_weighted>(threshold, argument_type);
return new AggregateFunctionTopKGeneric<true, is_weighted>(threshold, argument_type, params);
else
return new AggregateFunctionTopKGeneric<false, is_weighted>(threshold, argument_type);
return new AggregateFunctionTopKGeneric<false, is_weighted>(threshold, argument_type, params);
}
@ -90,10 +90,10 @@ AggregateFunctionPtr createAggregateFunctionTopK(const std::string & name, const
threshold = k;
}
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionTopK, is_weighted>(*argument_types[0], threshold));
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionTopK, is_weighted>(*argument_types[0], threshold, argument_types, params));
if (!res)
res = AggregateFunctionPtr(createWithExtraTypes<is_weighted>(argument_types[0], threshold));
res = AggregateFunctionPtr(createWithExtraTypes<is_weighted>(argument_types[0], threshold, params));
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() +

View File

@ -48,8 +48,9 @@ protected:
UInt64 reserved;
public:
AggregateFunctionTopK(UInt64 threshold)
: threshold(threshold), reserved(TOP_K_LOAD_FACTOR * threshold) {}
AggregateFunctionTopK(UInt64 threshold, const DataTypes & argument_types, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>(argument_types, params)
, threshold(threshold), reserved(TOP_K_LOAD_FACTOR * threshold) {}
String getName() const override { return is_weighted ? "topKWeighted" : "topK"; }
@ -136,13 +137,15 @@ private:
UInt64 threshold;
UInt64 reserved;
DataTypePtr input_data_type;
DataTypePtr & input_data_type;
static void deserializeAndInsert(StringRef str, IColumn & data_to);
public:
AggregateFunctionTopKGeneric(UInt64 threshold, const DataTypePtr & input_data_type)
: threshold(threshold), reserved(TOP_K_LOAD_FACTOR * threshold), input_data_type(input_data_type) {}
AggregateFunctionTopKGeneric(
UInt64 threshold, const DataTypePtr & input_data_type, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData, AggregateFunctionTopKGeneric<is_plain_column, is_weighted>>({input_data_type}, params)
, threshold(threshold), reserved(TOP_K_LOAD_FACTOR * threshold), input_data_type(this->argument_types[0]) {}
String getName() const override { return is_weighted ? "topKWeighted" : "topK"; }

View File

@ -43,19 +43,19 @@ AggregateFunctionPtr createAggregateFunctionUniq(const std::string & name, const
{
const IDataType & argument_type = *argument_types[0];
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionUniq, Data>(*argument_types[0]));
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionUniq, Data>(*argument_types[0], argument_types));
WhichDataType which(argument_type);
if (res)
return res;
else if (which.isDate())
return std::make_shared<AggregateFunctionUniq<DataTypeDate::FieldType, Data>>();
return std::make_shared<AggregateFunctionUniq<DataTypeDate::FieldType, Data>>(argument_types);
else if (which.isDateTime())
return std::make_shared<AggregateFunctionUniq<DataTypeDateTime::FieldType, Data>>();
return std::make_shared<AggregateFunctionUniq<DataTypeDateTime::FieldType, Data>>(argument_types);
else if (which.isStringOrFixedString())
return std::make_shared<AggregateFunctionUniq<String, Data>>();
return std::make_shared<AggregateFunctionUniq<String, Data>>(argument_types);
else if (which.isUUID())
return std::make_shared<AggregateFunctionUniq<DataTypeUUID::FieldType, Data>>();
return std::make_shared<AggregateFunctionUniq<DataTypeUUID::FieldType, Data>>(argument_types);
else if (which.isTuple())
{
if (use_exact_hash_function)
@ -89,19 +89,19 @@ AggregateFunctionPtr createAggregateFunctionUniq(const std::string & name, const
{
const IDataType & argument_type = *argument_types[0];
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionUniq, Data>(*argument_types[0]));
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionUniq, Data>(*argument_types[0], argument_types));
WhichDataType which(argument_type);
if (res)
return res;
else if (which.isDate())
return std::make_shared<AggregateFunctionUniq<DataTypeDate::FieldType, Data<DataTypeDate::FieldType>>>();
return std::make_shared<AggregateFunctionUniq<DataTypeDate::FieldType, Data<DataTypeDate::FieldType>>>(argument_types);
else if (which.isDateTime())
return std::make_shared<AggregateFunctionUniq<DataTypeDateTime::FieldType, Data<DataTypeDateTime::FieldType>>>();
return std::make_shared<AggregateFunctionUniq<DataTypeDateTime::FieldType, Data<DataTypeDateTime::FieldType>>>(argument_types);
else if (which.isStringOrFixedString())
return std::make_shared<AggregateFunctionUniq<String, Data<String>>>();
return std::make_shared<AggregateFunctionUniq<String, Data<String>>>(argument_types);
else if (which.isUUID())
return std::make_shared<AggregateFunctionUniq<DataTypeUUID::FieldType, Data<DataTypeUUID::FieldType>>>();
return std::make_shared<AggregateFunctionUniq<DataTypeUUID::FieldType, Data<DataTypeUUID::FieldType>>>(argument_types);
else if (which.isTuple())
{
if (use_exact_hash_function)

View File

@ -209,6 +209,9 @@ template <typename T, typename Data>
class AggregateFunctionUniq final : public IAggregateFunctionDataHelper<Data, AggregateFunctionUniq<T, Data>>
{
public:
AggregateFunctionUniq(const DataTypes & argument_types)
: IAggregateFunctionDataHelper<Data, AggregateFunctionUniq<T, Data>>(argument_types, {}) {}
String getName() const override { return Data::getName(); }
DataTypePtr getReturnType() const override
@ -257,6 +260,7 @@ private:
public:
AggregateFunctionUniqVariadic(const DataTypes & arguments)
: IAggregateFunctionDataHelper<Data, AggregateFunctionUniqVariadic<Data, is_exact, argument_is_tuple>>(arguments, {})
{
if (argument_is_tuple)
num_args = typeid_cast<const DataTypeTuple &>(*arguments[0]).getElements().size();

View File

@ -28,7 +28,7 @@ namespace
};
template <UInt8 K>
AggregateFunctionPtr createAggregateFunctionWithK(const DataTypes & argument_types)
AggregateFunctionPtr createAggregateFunctionWithK(const DataTypes & argument_types, const Array & params)
{
/// We use exact hash function if the arguments are not contiguous in memory, because only exact hash function has support for this case.
bool use_exact_hash_function = !isAllArgumentsContiguousInMemory(argument_types);
@ -37,33 +37,33 @@ namespace
{
const IDataType & argument_type = *argument_types[0];
AggregateFunctionPtr res(createWithNumericType<WithK<K>::template AggregateFunction>(*argument_types[0]));
AggregateFunctionPtr res(createWithNumericType<WithK<K>::template AggregateFunction>(*argument_types[0], argument_types, params));
WhichDataType which(argument_type);
if (res)
return res;
else if (which.isDate())
return std::make_shared<typename WithK<K>::template AggregateFunction<DataTypeDate::FieldType>>();
return std::make_shared<typename WithK<K>::template AggregateFunction<DataTypeDate::FieldType>>(argument_types, params);
else if (which.isDateTime())
return std::make_shared<typename WithK<K>::template AggregateFunction<DataTypeDateTime::FieldType>>();
return std::make_shared<typename WithK<K>::template AggregateFunction<DataTypeDateTime::FieldType>>(argument_types, params);
else if (which.isStringOrFixedString())
return std::make_shared<typename WithK<K>::template AggregateFunction<String>>();
return std::make_shared<typename WithK<K>::template AggregateFunction<String>>(argument_types, params);
else if (which.isUUID())
return std::make_shared<typename WithK<K>::template AggregateFunction<DataTypeUUID::FieldType>>();
return std::make_shared<typename WithK<K>::template AggregateFunction<DataTypeUUID::FieldType>>(argument_types, params);
else if (which.isTuple())
{
if (use_exact_hash_function)
return std::make_shared<typename WithK<K>::template AggregateFunctionVariadic<true, true>>(argument_types);
return std::make_shared<typename WithK<K>::template AggregateFunctionVariadic<true, true>>(argument_types, params);
else
return std::make_shared<typename WithK<K>::template AggregateFunctionVariadic<false, true>>(argument_types);
return std::make_shared<typename WithK<K>::template AggregateFunctionVariadic<false, true>>(argument_types, params);
}
}
/// "Variadic" method also works as a fallback generic case for a single argument.
if (use_exact_hash_function)
return std::make_shared<typename WithK<K>::template AggregateFunctionVariadic<true, false>>(argument_types);
return std::make_shared<typename WithK<K>::template AggregateFunctionVariadic<true, false>>(argument_types, params);
else
return std::make_shared<typename WithK<K>::template AggregateFunctionVariadic<false, false>>(argument_types);
return std::make_shared<typename WithK<K>::template AggregateFunctionVariadic<false, false>>(argument_types, params);
}
AggregateFunctionPtr createAggregateFunctionUniqCombined(
@ -95,23 +95,23 @@ namespace
switch (precision)
{
case 12:
return createAggregateFunctionWithK<12>(argument_types);
return createAggregateFunctionWithK<12>(argument_types, params);
case 13:
return createAggregateFunctionWithK<13>(argument_types);
return createAggregateFunctionWithK<13>(argument_types, params);
case 14:
return createAggregateFunctionWithK<14>(argument_types);
return createAggregateFunctionWithK<14>(argument_types, params);
case 15:
return createAggregateFunctionWithK<15>(argument_types);
return createAggregateFunctionWithK<15>(argument_types, params);
case 16:
return createAggregateFunctionWithK<16>(argument_types);
return createAggregateFunctionWithK<16>(argument_types, params);
case 17:
return createAggregateFunctionWithK<17>(argument_types);
return createAggregateFunctionWithK<17>(argument_types, params);
case 18:
return createAggregateFunctionWithK<18>(argument_types);
return createAggregateFunctionWithK<18>(argument_types, params);
case 19:
return createAggregateFunctionWithK<19>(argument_types);
return createAggregateFunctionWithK<19>(argument_types, params);
case 20:
return createAggregateFunctionWithK<20>(argument_types);
return createAggregateFunctionWithK<20>(argument_types, params);
}
__builtin_unreachable();

View File

@ -114,6 +114,9 @@ class AggregateFunctionUniqCombined final
: public IAggregateFunctionDataHelper<AggregateFunctionUniqCombinedData<T, K>, AggregateFunctionUniqCombined<T, K>>
{
public:
AggregateFunctionUniqCombined(const DataTypes & argument_types, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionUniqCombinedData<T, K>, AggregateFunctionUniqCombined<T, K>>(argument_types, params) {}
String getName() const override
{
return "uniqCombined";
@ -176,7 +179,9 @@ private:
size_t num_args = 0;
public:
explicit AggregateFunctionUniqCombinedVariadic(const DataTypes & arguments)
explicit AggregateFunctionUniqCombinedVariadic(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionUniqCombinedData<UInt64, K>,
AggregateFunctionUniqCombinedVariadic<is_exact, argument_is_tuple, K>>(arguments, params)
{
if (argument_is_tuple)
num_args = typeid_cast<const DataTypeTuple &>(*arguments[0]).getElements().size();

View File

@ -52,33 +52,33 @@ AggregateFunctionPtr createAggregateFunctionUniqUpTo(const std::string & name, c
{
const IDataType & argument_type = *argument_types[0];
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionUniqUpTo>(*argument_types[0], threshold));
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionUniqUpTo>(*argument_types[0], threshold, argument_types, params));
WhichDataType which(argument_type);
if (res)
return res;
else if (which.isDate())
return std::make_shared<AggregateFunctionUniqUpTo<DataTypeDate::FieldType>>(threshold);
return std::make_shared<AggregateFunctionUniqUpTo<DataTypeDate::FieldType>>(threshold, argument_types, params);
else if (which.isDateTime())
return std::make_shared<AggregateFunctionUniqUpTo<DataTypeDateTime::FieldType>>(threshold);
return std::make_shared<AggregateFunctionUniqUpTo<DataTypeDateTime::FieldType>>(threshold, argument_types, params);
else if (which.isStringOrFixedString())
return std::make_shared<AggregateFunctionUniqUpTo<String>>(threshold);
return std::make_shared<AggregateFunctionUniqUpTo<String>>(threshold, argument_types, params);
else if (which.isUUID())
return std::make_shared<AggregateFunctionUniqUpTo<DataTypeUUID::FieldType>>(threshold);
return std::make_shared<AggregateFunctionUniqUpTo<DataTypeUUID::FieldType>>(threshold, argument_types, params);
else if (which.isTuple())
{
if (use_exact_hash_function)
return std::make_shared<AggregateFunctionUniqUpToVariadic<true, true>>(argument_types, threshold);
return std::make_shared<AggregateFunctionUniqUpToVariadic<true, true>>(argument_types, params, threshold);
else
return std::make_shared<AggregateFunctionUniqUpToVariadic<false, true>>(argument_types, threshold);
return std::make_shared<AggregateFunctionUniqUpToVariadic<false, true>>(argument_types, params, threshold);
}
}
/// "Variadic" method also works as a fallback generic case for single argument.
if (use_exact_hash_function)
return std::make_shared<AggregateFunctionUniqUpToVariadic<true, false>>(argument_types, threshold);
return std::make_shared<AggregateFunctionUniqUpToVariadic<true, false>>(argument_types, params, threshold);
else
return std::make_shared<AggregateFunctionUniqUpToVariadic<false, false>>(argument_types, threshold);
return std::make_shared<AggregateFunctionUniqUpToVariadic<false, false>>(argument_types, params, threshold);
}
}

View File

@ -136,8 +136,9 @@ private:
UInt8 threshold;
public:
AggregateFunctionUniqUpTo(UInt8 threshold)
: threshold(threshold)
AggregateFunctionUniqUpTo(UInt8 threshold, const DataTypes & argument_types, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionUniqUpToData<T>, AggregateFunctionUniqUpTo<T>>(argument_types, params)
, threshold(threshold)
{
}
@ -195,8 +196,9 @@ private:
UInt8 threshold;
public:
AggregateFunctionUniqUpToVariadic(const DataTypes & arguments, UInt8 threshold)
: threshold(threshold)
AggregateFunctionUniqUpToVariadic(const DataTypes & arguments, const Array & params, UInt8 threshold)
: IAggregateFunctionDataHelper<AggregateFunctionUniqUpToData<UInt64>, AggregateFunctionUniqUpToVariadic<is_exact, argument_is_tuple>>(arguments, params)
, threshold(threshold)
{
if (argument_is_tuple)
num_args = typeid_cast<const DataTypeTuple &>(*arguments[0]).getElements().size();

View File

@ -189,6 +189,7 @@ public:
}
AggregateFunctionWindowFunnel(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionWindowFunnelData, AggregateFunctionWindowFunnel>(arguments, params)
{
const auto time_arg = arguments.front().get();
if (!WhichDataType(time_arg).isDateTime() && !WhichDataType(time_arg).isUInt32())

View File

@ -24,9 +24,9 @@ AggregateFunctionPtr createAggregateFunctionStatisticsUnary(const std::string &
AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
if (isDecimal(data_type))
res.reset(createWithDecimalType<FunctionTemplate>(*data_type, *data_type));
res.reset(createWithDecimalType<FunctionTemplate>(*data_type, *data_type, argument_types));
else
res.reset(createWithNumericType<FunctionTemplate>(*data_type));
res.reset(createWithNumericType<FunctionTemplate>(*data_type, argument_types));
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,
@ -40,7 +40,7 @@ AggregateFunctionPtr createAggregateFunctionStatisticsBinary(const std::string &
assertNoParameters(name, parameters);
assertBinary(name, argument_types);
AggregateFunctionPtr res(createWithTwoNumericTypes<FunctionTemplate>(*argument_types[0], *argument_types[1]));
AggregateFunctionPtr res(createWithTwoNumericTypes<FunctionTemplate>(*argument_types[0], *argument_types[1], argument_types));
if (!res)
throw Exception("Illegal types " + argument_types[0]->getName() + " and " + argument_types[1]->getName()
+ " of arguments for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -37,6 +37,9 @@ using ConstAggregateDataPtr = const char *;
class IAggregateFunction
{
public:
IAggregateFunction(const DataTypes & argument_types_, const Array & parameters_)
: argument_types(argument_types_), parameters(parameters_) {}
/// Get main function name.
virtual String getName() const = 0;
@ -108,6 +111,13 @@ public:
* const char * getHeaderFilePath() const override { return __FILE__; }
*/
virtual const char * getHeaderFilePath() const = 0;
const DataTypes & getArgumentTypes() const { return argument_types; }
const Array & getParameters() const { return parameters; }
protected:
DataTypes argument_types;
Array parameters;
};
@ -122,6 +132,8 @@ private:
}
public:
IAggregateFunctionHelper(const DataTypes & argument_types_, const Array & parameters_)
: IAggregateFunction(argument_types_, parameters_) {}
AddFunc getAddressOfAddFunction() const override { return &addFree; }
};
@ -137,6 +149,10 @@ protected:
static const Data & data(ConstAggregateDataPtr place) { return *reinterpret_cast<const Data*>(place); }
public:
IAggregateFunctionDataHelper(const DataTypes & argument_types_, const Array & parameters_)
: IAggregateFunctionHelper<Derived>(argument_types_, parameters_) {}
void create(AggregateDataPtr place) const override
{
new (place) Data;

View File

@ -3,6 +3,8 @@
#include <AggregateFunctions/AggregateFunctionState.h>
#include <DataStreams/ColumnGathererStream.h>
#include <IO/WriteBufferFromArena.h>
#include <IO/WriteBufferFromString.h>
#include <IO/Operators.h>
#include <Common/SipHash.h>
#include <Common/AlignedBuffer.h>
#include <Common/typeid_cast.h>
@ -258,11 +260,17 @@ MutableColumnPtr ColumnAggregateFunction::cloneEmpty() const
return create(func, Arenas(1, std::make_shared<Arena>()));
}
String ColumnAggregateFunction::getTypeString() const
{
return DataTypeAggregateFunction(func, func->getArgumentTypes(), func->getParameters()).getName();
}
Field ColumnAggregateFunction::operator[](size_t n) const
{
Field field = String();
Field field = AggregateFunctionStateData();
field.get<AggregateFunctionStateData &>().name = getTypeString();
{
WriteBufferFromString buffer(field.get<String &>());
WriteBufferFromString buffer(field.get<AggregateFunctionStateData &>().data);
func->serialize(data[n], buffer);
}
return field;
@ -270,9 +278,10 @@ Field ColumnAggregateFunction::operator[](size_t n) const
void ColumnAggregateFunction::get(size_t n, Field & res) const
{
res = String();
res = AggregateFunctionStateData();
res.get<AggregateFunctionStateData &>().name = getTypeString();
{
WriteBufferFromString buffer(res.get<String &>());
WriteBufferFromString buffer(res.get<AggregateFunctionStateData &>().data);
func->serialize(data[n], buffer);
}
}
@ -337,13 +346,23 @@ static void pushBackAndCreateState(ColumnAggregateFunction::Container & data, Ar
}
}
void ColumnAggregateFunction::insert(const Field & x)
{
String type_string = getTypeString();
if (x.getType() != Field::Types::AggregateFunctionState)
throw Exception(String("Inserting field of type ") + x.getTypeName() + " into ColumnAggregateFunction. "
"Expected " + Field::Types::toString(Field::Types::AggregateFunctionState), ErrorCodes::LOGICAL_ERROR);
auto & field_name = x.get<const AggregateFunctionStateData &>().name;
if (type_string != field_name)
throw Exception("Cannot insert filed with type " + field_name + " into column with type " + type_string,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
ensureOwnership();
Arena & arena = createOrGetArena();
pushBackAndCreateState(data, arena, func.get());
ReadBufferFromString read_buffer(x.get<const String &>());
ReadBufferFromString read_buffer(x.get<const AggregateFunctionStateData &>().data);
func->deserialize(data.back(), read_buffer, &arena);
}
@ -465,12 +484,13 @@ void ColumnAggregateFunction::getExtremes(Field & min, Field & max) const
AlignedBuffer place_buffer(func->sizeOfData(), func->alignOfData());
AggregateDataPtr place = place_buffer.data();
String serialized;
AggregateFunctionStateData serialized;
serialized.name = getTypeString();
func->create(place);
try
{
WriteBufferFromString buffer(serialized);
WriteBufferFromString buffer(serialized.data);
func->serialize(place, buffer);
}
catch (...)

View File

@ -94,6 +94,8 @@ private:
{
}
String getTypeString() const;
public:
~ColumnAggregateFunction() override;

View File

@ -166,7 +166,7 @@ template <typename T>
template <typename Type>
ColumnPtr ColumnDecimal<T>::indexImpl(const PaddedPODArray<Type> & indexes, UInt64 limit) const
{
size_t size = indexes.size();
UInt64 size = indexes.size();
if (limit == 0)
limit = size;

View File

@ -275,7 +275,7 @@ template <typename T>
template <typename Type>
ColumnPtr ColumnVector<T>::indexImpl(const PaddedPODArray<Type> & indexes, UInt64 limit) const
{
size_t size = indexes.size();
UInt64 size = indexes.size();
if (limit == 0)
limit = size;

View File

@ -99,7 +99,7 @@ public:
throw Exception("Method deserializeAndInsertFromArena is not supported for ColumnUnique.", ErrorCodes::NOT_IMPLEMENTED);
}
ColumnPtr index(const IColumn &, size_t) const override
ColumnPtr index(const IColumn &, UInt64) const override
{
throw Exception("Method index is not supported for ColumnUnique.", ErrorCodes::NOT_IMPLEMENTED);
}
@ -114,7 +114,7 @@ public:
throw Exception("Method filter is not supported for ColumnUnique.", ErrorCodes::NOT_IMPLEMENTED);
}
ColumnPtr permute(const IColumn::Permutation &, size_t) const override
ColumnPtr permute(const IColumn::Permutation &, UInt64) const override
{
throw Exception("Method permute is not supported for ColumnUnique.", ErrorCodes::NOT_IMPLEMENTED);
}
@ -124,7 +124,7 @@ public:
throw Exception("Method replicate is not supported for ColumnUnique.", ErrorCodes::NOT_IMPLEMENTED);
}
void getPermutation(bool, size_t, int, IColumn::Permutation &) const override
void getPermutation(bool, UInt64, int, IColumn::Permutation &) const override
{
throw Exception("Method getPermutation is not supported for ColumnUnique.", ErrorCodes::NOT_IMPLEMENTED);
}

View File

@ -364,7 +364,10 @@ struct HashMethodSingleLowCardinalityColumn : public SingleColumnMethod
}
if constexpr (has_mapped)
{
mapped_cache[row] = it->second;
return EmplaceResult(it->second, mapped_cache[row], inserted);
}
else
return EmplaceResult(inserted);
}

View File

@ -89,6 +89,13 @@ String FieldVisitorDump::operator() (const Tuple & x_def) const
return wb.str();
}
String FieldVisitorDump::operator() (const AggregateFunctionStateData & x) const
{
WriteBufferFromOwnString wb;
writeQuoted(x.name, wb);
writeQuoted(x.data, wb);
return wb.str();
}
/** In contrast to writeFloatText (and writeQuoted),
* even if number looks like integer after formatting, prints decimal point nevertheless (for example, Float64(1) is printed as 1.).
@ -121,6 +128,10 @@ String FieldVisitorToString::operator() (const DecimalField<Decimal32> & x) cons
String FieldVisitorToString::operator() (const DecimalField<Decimal64> & x) const { return formatQuoted(x); }
String FieldVisitorToString::operator() (const DecimalField<Decimal128> & x) const { return formatQuoted(x); }
String FieldVisitorToString::operator() (const UInt128 & x) const { return formatQuoted(UUID(x)); }
String FieldVisitorToString::operator() (const AggregateFunctionStateData & x) const
{
return "(" + formatQuoted(x.name) + ")" + formatQuoted(x.data);
}
String FieldVisitorToString::operator() (const Array & x) const
{
@ -231,5 +242,15 @@ void FieldVisitorHash::operator() (const DecimalField<Decimal128> & x) const
hash.update(x);
}
void FieldVisitorHash::operator() (const AggregateFunctionStateData & x) const
{
UInt8 type = Field::Types::AggregateFunctionState;
hash.update(type);
hash.update(x.name.size());
hash.update(x.name.data(), x.name.size());
hash.update(x.data.size());
hash.update(x.data.data(), x.data.size());
}
}

View File

@ -49,6 +49,7 @@ typename std::decay_t<Visitor>::ResultType applyVisitor(Visitor && visitor, F &&
case Field::Types::Decimal32: return visitor(field.template get<DecimalField<Decimal32>>());
case Field::Types::Decimal64: return visitor(field.template get<DecimalField<Decimal64>>());
case Field::Types::Decimal128: return visitor(field.template get<DecimalField<Decimal128>>());
case Field::Types::AggregateFunctionState: return visitor(field.template get<AggregateFunctionStateData>());
default:
throw Exception("Bad type of Field", ErrorCodes::BAD_TYPE_OF_FIELD);
@ -72,6 +73,7 @@ static typename std::decay_t<Visitor>::ResultType applyBinaryVisitorImpl(Visitor
case Field::Types::Decimal32: return visitor(field1, field2.template get<DecimalField<Decimal32>>());
case Field::Types::Decimal64: return visitor(field1, field2.template get<DecimalField<Decimal64>>());
case Field::Types::Decimal128: return visitor(field1, field2.template get<DecimalField<Decimal128>>());
case Field::Types::AggregateFunctionState: return visitor(field1, field2.template get<AggregateFunctionStateData>());
default:
throw Exception("Bad type of Field", ErrorCodes::BAD_TYPE_OF_FIELD);
@ -116,6 +118,9 @@ typename std::decay_t<Visitor>::ResultType applyVisitor(Visitor && visitor, F1 &
case Field::Types::Decimal128:
return applyBinaryVisitorImpl(
std::forward<Visitor>(visitor), field1.template get<DecimalField<Decimal128>>(), std::forward<F2>(field2));
case Field::Types::AggregateFunctionState:
return applyBinaryVisitorImpl(
std::forward<Visitor>(visitor), field1.template get<AggregateFunctionStateData>(), std::forward<F2>(field2));
default:
throw Exception("Bad type of Field", ErrorCodes::BAD_TYPE_OF_FIELD);
@ -138,6 +143,7 @@ public:
String operator() (const DecimalField<Decimal32> & x) const;
String operator() (const DecimalField<Decimal64> & x) const;
String operator() (const DecimalField<Decimal128> & x) const;
String operator() (const AggregateFunctionStateData & x) const;
};
@ -156,6 +162,7 @@ public:
String operator() (const DecimalField<Decimal32> & x) const;
String operator() (const DecimalField<Decimal64> & x) const;
String operator() (const DecimalField<Decimal128> & x) const;
String operator() (const AggregateFunctionStateData & x) const;
};
@ -201,6 +208,11 @@ public:
else
return x.getValue() / x.getScaleMultiplier();
}
T operator() (const AggregateFunctionStateData &) const
{
throw Exception("Cannot convert AggregateFunctionStateData to " + demangle(typeid(T).name()), ErrorCodes::CANNOT_CONVERT_TYPE);
}
};
@ -222,6 +234,7 @@ public:
void operator() (const DecimalField<Decimal32> & x) const;
void operator() (const DecimalField<Decimal64> & x) const;
void operator() (const DecimalField<Decimal128> & x) const;
void operator() (const AggregateFunctionStateData & x) const;
};
@ -246,6 +259,7 @@ public:
bool operator() (const UInt64 & l, const String & r) const { return cantCompare(l, r); }
bool operator() (const UInt64 & l, const Array & r) const { return cantCompare(l, r); }
bool operator() (const UInt64 & l, const Tuple & r) const { return cantCompare(l, r); }
bool operator() (const UInt64 & l, const AggregateFunctionStateData & r) const { return cantCompare(l, r); }
bool operator() (const Int64 & l, const Null & r) const { return cantCompare(l, r); }
bool operator() (const Int64 & l, const UInt64 & r) const { return accurate::equalsOp(l, r); }
@ -255,6 +269,7 @@ public:
bool operator() (const Int64 & l, const String & r) const { return cantCompare(l, r); }
bool operator() (const Int64 & l, const Array & r) const { return cantCompare(l, r); }
bool operator() (const Int64 & l, const Tuple & r) const { return cantCompare(l, r); }
bool operator() (const Int64 & l, const AggregateFunctionStateData & r) const { return cantCompare(l, r); }
bool operator() (const Float64 & l, const Null & r) const { return cantCompare(l, r); }
bool operator() (const Float64 & l, const UInt64 & r) const { return accurate::equalsOp(l, r); }
@ -264,6 +279,7 @@ public:
bool operator() (const Float64 & l, const String & r) const { return cantCompare(l, r); }
bool operator() (const Float64 & l, const Array & r) const { return cantCompare(l, r); }
bool operator() (const Float64 & l, const Tuple & r) const { return cantCompare(l, r); }
bool operator() (const Float64 & l, const AggregateFunctionStateData & r) const { return cantCompare(l, r); }
template <typename T>
bool operator() (const Null &, const T &) const
@ -321,6 +337,14 @@ public:
template <typename T> bool operator() (const Int64 & l, const DecimalField<T> & r) const { return DecimalField<Decimal128>(l, 0) == r; }
template <typename T> bool operator() (const Float64 & l, const DecimalField<T> & r) const { return cantCompare(l, r); }
template <typename T>
bool operator() (const AggregateFunctionStateData & l, const T & r) const
{
if constexpr (std::is_same_v<T, AggregateFunctionStateData>)
return l == r;
return cantCompare(l, r);
}
private:
template <typename T, typename U>
bool cantCompare(const T &, const U &) const
@ -344,6 +368,7 @@ public:
bool operator() (const UInt64 & l, const String & r) const { return cantCompare(l, r); }
bool operator() (const UInt64 & l, const Array & r) const { return cantCompare(l, r); }
bool operator() (const UInt64 & l, const Tuple & r) const { return cantCompare(l, r); }
bool operator() (const UInt64 & l, const AggregateFunctionStateData & r) const { return cantCompare(l, r); }
bool operator() (const Int64 & l, const Null & r) const { return cantCompare(l, r); }
bool operator() (const Int64 & l, const UInt64 & r) const { return accurate::lessOp(l, r); }
@ -353,6 +378,7 @@ public:
bool operator() (const Int64 & l, const String & r) const { return cantCompare(l, r); }
bool operator() (const Int64 & l, const Array & r) const { return cantCompare(l, r); }
bool operator() (const Int64 & l, const Tuple & r) const { return cantCompare(l, r); }
bool operator() (const Int64 & l, const AggregateFunctionStateData & r) const { return cantCompare(l, r); }
bool operator() (const Float64 & l, const Null & r) const { return cantCompare(l, r); }
bool operator() (const Float64 & l, const UInt64 & r) const { return accurate::lessOp(l, r); }
@ -362,6 +388,7 @@ public:
bool operator() (const Float64 & l, const String & r) const { return cantCompare(l, r); }
bool operator() (const Float64 & l, const Array & r) const { return cantCompare(l, r); }
bool operator() (const Float64 & l, const Tuple & r) const { return cantCompare(l, r); }
bool operator() (const Float64 & l, const AggregateFunctionStateData & r) const { return cantCompare(l, r); }
template <typename T>
bool operator() (const Null &, const T &) const
@ -419,6 +446,12 @@ public:
template <typename T> bool operator() (const Int64 & l, const DecimalField<T> & r) const { return DecimalField<Decimal128>(l, 0) < r; }
template <typename T> bool operator() (const Float64 &, const DecimalField<T> &) const { return false; }
template <typename T>
bool operator() (const AggregateFunctionStateData & l, const T & r) const
{
return cantCompare(l, r);
}
private:
template <typename T, typename U>
bool cantCompare(const T &, const U &) const
@ -447,6 +480,7 @@ public:
bool operator() (String &) const { throw Exception("Cannot sum Strings", ErrorCodes::LOGICAL_ERROR); }
bool operator() (Array &) const { throw Exception("Cannot sum Arrays", ErrorCodes::LOGICAL_ERROR); }
bool operator() (UInt128 &) const { throw Exception("Cannot sum UUIDs", ErrorCodes::LOGICAL_ERROR); }
bool operator() (AggregateFunctionStateData &) const { throw Exception("Cannot sum AggregateFunctionStates", ErrorCodes::LOGICAL_ERROR); }
template <typename T>
bool operator() (DecimalField<T> & x) const

View File

@ -75,6 +75,14 @@ namespace DB
x.push_back(value);
break;
}
case Field::Types::AggregateFunctionState:
{
AggregateFunctionStateData value;
DB::readStringBinary(value.name, buf);
DB::readStringBinary(value.data, buf);
x.push_back(value);
break;
}
}
}
}
@ -128,6 +136,12 @@ namespace DB
DB::writeBinary(get<Tuple>(*it), buf);
break;
}
case Field::Types::AggregateFunctionState:
{
DB::writeStringBinary(it->get<AggregateFunctionStateData>().name, buf);
DB::writeStringBinary(it->get<AggregateFunctionStateData>().data, buf);
break;
}
}
}
}
@ -209,6 +223,14 @@ namespace DB
x.push_back(value);
break;
}
case Field::Types::AggregateFunctionState:
{
AggregateFunctionStateData value;
DB::readStringBinary(value.name, buf);
DB::readStringBinary(value.data, buf);
x.push_back(value);
break;
}
}
}
}
@ -262,6 +284,12 @@ namespace DB
DB::writeBinary(get<Tuple>(*it), buf);
break;
}
case Field::Types::AggregateFunctionState:
{
DB::writeStringBinary(it->get<AggregateFunctionStateData>().name, buf);
DB::writeStringBinary(it->get<AggregateFunctionStateData>().data, buf);
break;
}
}
}
}

View File

@ -23,6 +23,7 @@ namespace ErrorCodes
extern const int BAD_GET;
extern const int NOT_IMPLEMENTED;
extern const int LOGICAL_ERROR;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
class Field;
@ -30,6 +31,41 @@ using Array = std::vector<Field>;
using TupleBackend = std::vector<Field>;
STRONG_TYPEDEF(TupleBackend, Tuple) /// Array and Tuple are different types with equal representation inside Field.
struct AggregateFunctionStateData
{
String name; /// Name with arguments.
String data;
bool operator < (const AggregateFunctionStateData &) const
{
throw Exception("Operator < is not implemented for AggregateFunctionStateData.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
bool operator <= (const AggregateFunctionStateData &) const
{
throw Exception("Operator <= is not implemented for AggregateFunctionStateData.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
bool operator > (const AggregateFunctionStateData &) const
{
throw Exception("Operator > is not implemented for AggregateFunctionStateData.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
bool operator >= (const AggregateFunctionStateData &) const
{
throw Exception("Operator >= is not implemented for AggregateFunctionStateData.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
bool operator == (const AggregateFunctionStateData & rhs) const
{
if (name != rhs.name)
throw Exception("Comparing aggregate functions with different types: " + name + " and " + rhs.name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return data == rhs.data;
}
};
template <typename T> bool decimalEqual(T x, T y, UInt32 x_scale, UInt32 y_scale);
template <typename T> bool decimalLess(T x, T y, UInt32 x_scale, UInt32 y_scale);
template <typename T> bool decimalLessOrEqual(T x, T y, UInt32 x_scale, UInt32 y_scale);
@ -131,6 +167,7 @@ public:
Decimal32 = 19,
Decimal64 = 20,
Decimal128 = 21,
AggregateFunctionState = 22,
};
static const int MIN_NON_POD = 16;
@ -151,6 +188,7 @@ public:
case Decimal32: return "Decimal32";
case Decimal64: return "Decimal64";
case Decimal128: return "Decimal128";
case AggregateFunctionState: return "AggregateFunctionState";
}
throw Exception("Bad type of Field", ErrorCodes::BAD_TYPE_OF_FIELD);
@ -325,6 +363,7 @@ public:
case Types::Decimal32: return get<DecimalField<Decimal32>>() < rhs.get<DecimalField<Decimal32>>();
case Types::Decimal64: return get<DecimalField<Decimal64>>() < rhs.get<DecimalField<Decimal64>>();
case Types::Decimal128: return get<DecimalField<Decimal128>>() < rhs.get<DecimalField<Decimal128>>();
case Types::AggregateFunctionState: return get<AggregateFunctionStateData>() < rhs.get<AggregateFunctionStateData>();
}
throw Exception("Bad type of Field", ErrorCodes::BAD_TYPE_OF_FIELD);
@ -356,6 +395,7 @@ public:
case Types::Decimal32: return get<DecimalField<Decimal32>>() <= rhs.get<DecimalField<Decimal32>>();
case Types::Decimal64: return get<DecimalField<Decimal64>>() <= rhs.get<DecimalField<Decimal64>>();
case Types::Decimal128: return get<DecimalField<Decimal128>>() <= rhs.get<DecimalField<Decimal128>>();
case Types::AggregateFunctionState: return get<AggregateFunctionStateData>() <= rhs.get<AggregateFunctionStateData>();
}
throw Exception("Bad type of Field", ErrorCodes::BAD_TYPE_OF_FIELD);
@ -385,6 +425,7 @@ public:
case Types::Decimal32: return get<DecimalField<Decimal32>>() == rhs.get<DecimalField<Decimal32>>();
case Types::Decimal64: return get<DecimalField<Decimal64>>() == rhs.get<DecimalField<Decimal64>>();
case Types::Decimal128: return get<DecimalField<Decimal128>>() == rhs.get<DecimalField<Decimal128>>();
case Types::AggregateFunctionState: return get<AggregateFunctionStateData>() == rhs.get<AggregateFunctionStateData>();
}
throw Exception("Bad type of Field", ErrorCodes::BAD_TYPE_OF_FIELD);
@ -398,7 +439,7 @@ public:
private:
std::aligned_union_t<DBMS_MIN_FIELD_SIZE - sizeof(Types::Which),
Null, UInt64, UInt128, Int64, Int128, Float64, String, Array, Tuple,
DecimalField<Decimal32>, DecimalField<Decimal64>, DecimalField<Decimal128>
DecimalField<Decimal32>, DecimalField<Decimal64>, DecimalField<Decimal128>, AggregateFunctionStateData
> storage;
Types::Which which;
@ -449,6 +490,7 @@ private:
case Types::Decimal32: f(field.template get<DecimalField<Decimal32>>()); return;
case Types::Decimal64: f(field.template get<DecimalField<Decimal64>>()); return;
case Types::Decimal128: f(field.template get<DecimalField<Decimal128>>()); return;
case Types::AggregateFunctionState: f(field.template get<AggregateFunctionStateData>()); return;
}
}
@ -501,6 +543,9 @@ private:
case Types::Tuple:
destroy<Tuple>();
break;
case Types::AggregateFunctionState:
destroy<AggregateFunctionStateData>();
break;
default:
break;
}
@ -531,6 +576,7 @@ template <> struct Field::TypeToEnum<Tuple> { static const Types::Which value
template <> struct Field::TypeToEnum<DecimalField<Decimal32>>{ static const Types::Which value = Types::Decimal32; };
template <> struct Field::TypeToEnum<DecimalField<Decimal64>>{ static const Types::Which value = Types::Decimal64; };
template <> struct Field::TypeToEnum<DecimalField<Decimal128>>{ static const Types::Which value = Types::Decimal128; };
template <> struct Field::TypeToEnum<AggregateFunctionStateData>{ static const Types::Which value = Types::AggregateFunctionState; };
template <> struct Field::EnumToType<Field::Types::Null> { using Type = Null; };
template <> struct Field::EnumToType<Field::Types::UInt64> { using Type = UInt64; };
@ -544,6 +590,7 @@ template <> struct Field::EnumToType<Field::Types::Tuple> { using Type = Tuple
template <> struct Field::EnumToType<Field::Types::Decimal32> { using Type = DecimalField<Decimal32>; };
template <> struct Field::EnumToType<Field::Types::Decimal64> { using Type = DecimalField<Decimal64>; };
template <> struct Field::EnumToType<Field::Types::Decimal128> { using Type = DecimalField<Decimal128>; };
template <> struct Field::EnumToType<Field::Types::AggregateFunctionState> { using Type = DecimalField<AggregateFunctionStateData>; };
template <typename T>
@ -573,6 +620,7 @@ T safeGet(Field & field)
template <> struct TypeName<Array> { static std::string get() { return "Array"; } };
template <> struct TypeName<Tuple> { static std::string get() { return "Tuple"; } };
template <> struct TypeName<AggregateFunctionStateData> { static std::string get() { return "AggregateFunctionState"; } };
template <typename T> struct NearestFieldTypeImpl;
@ -616,6 +664,8 @@ template <> struct NearestFieldTypeImpl<Tuple> { using Type = Tuple; };
template <> struct NearestFieldTypeImpl<bool> { using Type = UInt64; };
template <> struct NearestFieldTypeImpl<Null> { using Type = Null; };
template <> struct NearestFieldTypeImpl<AggregateFunctionStateData> { using Type = AggregateFunctionStateData; };
template <typename T>
using NearestFieldType = typename NearestFieldTypeImpl<T>::Type;

View File

@ -264,7 +264,8 @@ MutableColumnPtr DataTypeAggregateFunction::createColumn() const
/// Create empty state
Field DataTypeAggregateFunction::getDefault() const
{
Field field = String();
Field field = AggregateFunctionStateData();
field.get<AggregateFunctionStateData &>().name = getName();
AlignedBuffer place_buffer(function->sizeOfData(), function->alignOfData());
AggregateDataPtr place = place_buffer.data();
@ -273,7 +274,7 @@ Field DataTypeAggregateFunction::getDefault() const
try
{
WriteBufferFromString buffer_from_field(field.get<String &>());
WriteBufferFromString buffer_from_field(field.get<AggregateFunctionStateData &>().data);
function->serialize(place, buffer_from_field);
}
catch (...)

View File

@ -8,6 +8,7 @@
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeNothing.h>
#include <DataTypes/getLeastSupertype.h>
#include <DataTypes/DataTypeFactory.h>
#include <Common/Exception.h>
#include <ext/size.h>
@ -104,5 +105,10 @@ DataTypePtr FieldToDataType::operator() (const Tuple & x) const
return std::make_shared<DataTypeTuple>(element_types);
}
DataTypePtr FieldToDataType::operator() (const AggregateFunctionStateData & x) const
{
auto & name = static_cast<const AggregateFunctionStateData &>(x).name;
return DataTypeFactory::instance().get(name);
}
}

View File

@ -28,6 +28,7 @@ public:
DataTypePtr operator() (const DecimalField<Decimal32> & x) const;
DataTypePtr operator() (const DecimalField<Decimal64> & x) const;
DataTypePtr operator() (const DecimalField<Decimal128> & x) const;
DataTypePtr operator() (const AggregateFunctionStateData & x) const;
};
}

View File

@ -65,9 +65,9 @@ struct ToStartOfDayImpl
{
return time_zone.toDate(t);
}
static inline UInt32 execute(UInt16, const DateLUTImpl &)
static inline UInt32 execute(UInt16 d, const DateLUTImpl & time_zone)
{
return dateIsNotSupported(name);
return time_zone.toDate(DayNum(d));
}
using FactorTransform = ZeroTransform;

View File

@ -37,23 +37,33 @@ public:
if (arguments.size() == 1)
{
if (!isDateOrDateTime(arguments[0].type))
throw Exception("Illegal type " + arguments[0].type->getName() + " of argument of function " + getName() +
". Should be a date or a date with time", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
throw Exception(
"Illegal type " + arguments[0].type->getName() + " of argument of function " + getName()
+ ". Should be a date or a date with time",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
else if (arguments.size() == 2)
{
if (!WhichDataType(arguments[0].type).isDateTime()
|| !WhichDataType(arguments[1].type).isString())
if (!isDateOrDateTime(arguments[0].type))
throw Exception(
"Illegal type " + arguments[0].type->getName() + " of argument of function " + getName()
+ ". Should be a date or a date with time",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!isString(arguments[1].type))
throw Exception(
"Function " + getName() + " supports 1 or 2 arguments. The 1st argument "
"must be of type Date or DateTime. The 2nd argument (optional) must be "
"a constant string with timezone name. The timezone argument is allowed "
"only when the 1st argument has the type DateTime",
"must be of type Date or DateTime. The 2nd argument (optional) must be "
"a constant string with timezone name",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (isDate(arguments[0].type) && std::is_same_v<ToDataType, DataTypeDate>)
throw Exception(
"The timezone argument of function " + getName() + " is allowed only when the 1st argument has the type DateTime",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
else
throw Exception("Number of arguments for function " + getName() + " doesn't match: passed "
+ toString(arguments.size()) + ", should be 1 or 2",
throw Exception(
"Number of arguments for function " + getName() + " doesn't match: passed " + toString(arguments.size())
+ ", should be 1 or 2",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
/// For DateTime, if time zone is specified, attach it to type.

View File

@ -332,9 +332,9 @@ static const ColumnLowCardinality * findLowCardinalityArgument(const Block & blo
}
static ColumnPtr replaceLowCardinalityColumnsByNestedAndGetDictionaryIndexes(
Block & block, const ColumnNumbers & args, bool can_be_executed_on_default_arguments)
Block & block, const ColumnNumbers & args, bool can_be_executed_on_default_arguments, size_t input_rows_count)
{
size_t num_rows = 0;
size_t num_rows = input_rows_count;
ColumnPtr indexes;
for (auto arg : args)
@ -354,7 +354,10 @@ static ColumnPtr replaceLowCardinalityColumnsByNestedAndGetDictionaryIndexes(
{
ColumnWithTypeAndName & column = block.getByPosition(arg);
if (auto * column_const = checkAndGetColumn<ColumnConst>(column.column.get()))
{
column.column = column_const->removeLowCardinality()->cloneResized(num_rows);
column.type = removeLowCardinality(column.type);
}
else if (auto * low_cardinality_column = checkAndGetColumn<ColumnLowCardinality>(column.column.get()))
{
auto * low_cardinality_type = checkAndGetDataType<DataTypeLowCardinality>(column.type.get());
@ -423,7 +426,7 @@ void PreparedFunctionImpl::execute(Block & block, const ColumnNumbers & args, si
block_without_low_cardinality.safeGetByPosition(result).type = res_low_cardinality_type->getDictionaryType();
ColumnPtr indexes = replaceLowCardinalityColumnsByNestedAndGetDictionaryIndexes(
block_without_low_cardinality, args, can_be_executed_on_default_arguments);
block_without_low_cardinality, args, can_be_executed_on_default_arguments, input_rows_count);
executeWithoutLowCardinalityColumns(block_without_low_cardinality, args, result, block_without_low_cardinality.rows(), dry_run);

View File

@ -59,9 +59,16 @@ public:
ColumnArray::Offset offset = 0;
for (size_t i = 0; i < num_rows; ++i)
{
offset += col_num->getUInt(i);
auto array_size = col_num->getInt(i);
if (unlikely(array_size) < 0)
throw Exception("Array size cannot be negative: while executing function " + getName(), ErrorCodes::TOO_LARGE_ARRAY_SIZE);
offset += array_size;
if (unlikely(offset > max_arrays_size_in_block))
throw Exception("Too large array size while executing function " + getName(), ErrorCodes::TOO_LARGE_ARRAY_SIZE);
offsets.push_back(offset);
}

View File

@ -142,41 +142,54 @@ public:
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
auto check_date_time_argument = [&] {
bool first_argument_is_date = false;
auto check_first_argument = [&]
{
if (!isDateOrDateTime(arguments[0].type))
throw Exception(
"Illegal type " + arguments[0].type->getName() + " of argument of function " + getName()
+ ". Should be a date or a date with time",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
first_argument_is_date = isDate(arguments[0].type);
};
const DataTypeInterval * interval_type = nullptr;
auto check_interval_argument = [&] {
bool result_type_is_date = false;
auto check_interval_argument = [&]
{
interval_type = checkAndGetDataType<DataTypeInterval>(arguments[1].type.get());
if (!interval_type)
throw Exception(
"Illegal type " + arguments[1].type->getName() + " of argument of function " + getName()
+ ". Should be an interval of time",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
result_type_is_date = (interval_type->getKind() == DataTypeInterval::Year)
|| (interval_type->getKind() == DataTypeInterval::Quarter) || (interval_type->getKind() == DataTypeInterval::Month)
|| (interval_type->getKind() == DataTypeInterval::Week);
};
auto check_timezone_argument = [&] {
auto check_timezone_argument = [&]
{
if (!WhichDataType(arguments[2].type).isString())
throw Exception(
"Illegal type " + arguments[2].type->getName() + " of argument of function " + getName()
+ ". This argument is optional and must be a constant string with timezone name"
". This argument is allowed only when the 1st argument has the type DateTime",
+ ". This argument is optional and must be a constant string with timezone name",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (first_argument_is_date && result_type_is_date)
throw Exception(
"The timezone argument of function " + getName() + " with interval type " + interval_type->kindToString()
+ " is allowed only when the 1st argument has the type DateTime",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
};
if (arguments.size() == 2)
{
check_date_time_argument();
check_first_argument();
check_interval_argument();
}
else if (arguments.size() == 3)
{
check_date_time_argument();
check_first_argument();
check_interval_argument();
check_timezone_argument();
}
@ -188,11 +201,10 @@ public:
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
}
if ((interval_type->getKind() == DataTypeInterval::Second) || (interval_type->getKind() == DataTypeInterval::Minute)
|| (interval_type->getKind() == DataTypeInterval::Hour) || (interval_type->getKind() == DataTypeInterval::Day))
return std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments(arguments, 2, 0));
else
if (result_type_is_date)
return std::make_shared<DataTypeDate>();
else
return std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments(arguments, 2, 0));
}
bool useDefaultImplementationForConstants() const override { return true; }

View File

@ -35,6 +35,7 @@
#include <Interpreters/evaluateConstantExpression.h>
#include <Interpreters/convertFieldToType.h>
#include <Interpreters/interpretSubquery.h>
#include <Interpreters/DatabaseAndTableWithAlias.h>
namespace DB
{

View File

@ -109,11 +109,6 @@ void ExecuteScalarSubqueriesMatcher::visit(const ASTSubquery & subquery, ASTPtr
size_t columns = block.columns();
if (columns == 1)
{
if (typeid_cast<const DataTypeAggregateFunction*>(block.safeGetByPosition(0).type.get()))
{
throw Exception("Scalar subquery can't return an aggregate function state", ErrorCodes::INCORRECT_RESULT_OF_SCALAR_SUBQUERY);
}
auto lit = std::make_unique<ASTLiteral>((*block.safeGetByPosition(0).column)[0]);
lit->alias = subquery.alias;
lit->prefer_alias_to_column_name = subquery.prefer_alias_to_column_name;
@ -132,11 +127,6 @@ void ExecuteScalarSubqueriesMatcher::visit(const ASTSubquery & subquery, ASTPtr
exp_list->children.resize(columns);
for (size_t i = 0; i < columns; ++i)
{
if (typeid_cast<const DataTypeAggregateFunction*>(block.safeGetByPosition(i).type.get()))
{
throw Exception("Scalar subquery can't return an aggregate function state", ErrorCodes::INCORRECT_RESULT_OF_SCALAR_SUBQUERY);
}
exp_list->children[i] = addTypeConversion(
std::make_unique<ASTLiteral>((*block.safeGetByPosition(i).column)[0]),
block.safeGetByPosition(i).type->getName());

View File

@ -1019,9 +1019,10 @@ void ExpressionAnalyzer::collectUsedColumns()
for (NamesAndTypesList::iterator it = source_columns.begin(); it != source_columns.end();)
{
unknown_required_source_columns.erase(it->name);
const String & column_name = it->name;
unknown_required_source_columns.erase(column_name);
if (!required.count(it->name))
if (!required.count(column_name))
source_columns.erase(it++);
else
++it;

View File

@ -37,12 +37,17 @@ std::optional<String> IdentifierSemantic::getTableName(const ASTPtr & ast)
return {};
}
void IdentifierSemantic::setNeedLongName(ASTIdentifier & identifier, bool value)
{
identifier.semantic->need_long_name = value;
}
bool IdentifierSemantic::canBeAlias(const ASTIdentifier & identifier)
{
return identifier.semantic->can_be_alias;
}
std::pair<String, String> IdentifierSemantic::extractDatabaseAndTable(const ASTIdentifier & identifier)
{
if (identifier.name_parts.size() > 2)
@ -108,6 +113,8 @@ void IdentifierSemantic::setColumnNormalName(ASTIdentifier & identifier, const D
size_t match = IdentifierSemantic::canReferColumnToTable(identifier, db_and_table);
setColumnShortName(identifier, match);
if (match)
identifier.semantic->can_be_alias = false;
if (identifier.semantic->need_long_name)
{

View File

@ -10,6 +10,7 @@ struct IdentifierSemanticImpl
{
bool special = false;
bool need_long_name = false;
bool can_be_alias = true;
};
/// Static calss to manipulate IdentifierSemanticImpl via ASTIdentifier
@ -28,6 +29,7 @@ struct IdentifierSemantic
static String columnNormalName(const ASTIdentifier & identifier, const DatabaseAndTableWithAlias & db_and_table);
static void setColumnNormalName(ASTIdentifier & identifier, const DatabaseAndTableWithAlias & db_and_table);
static void setNeedLongName(ASTIdentifier & identifier, bool); /// if set setColumnNormalName makes qualified name
static bool canBeAlias(const ASTIdentifier & identifier);
private:
static bool doesIdentifierBelongTo(const ASTIdentifier & identifier, const String & database, const String & table);

View File

@ -1207,7 +1207,8 @@ private:
for (size_t i = 0; i < right_sample_block.columns(); ++i)
{
const ColumnWithTypeAndName & src_column = right_sample_block.getByPosition(i);
result_sample_block.insert(src_column.cloneEmpty());
if (!result_sample_block.has(src_column.name))
result_sample_block.insert(src_column.cloneEmpty());
}
const auto & key_names_right = parent.key_names_right;

View File

@ -250,7 +250,6 @@ void PredicateExpressionsOptimizer::setNewAliasesForInnerPredicate(
name = ast->getAliasOrColumnName();
}
IdentifierSemantic::setNeedLongName(*identifier, false);
identifier->setShortName(name);
}
}
@ -338,9 +337,9 @@ ASTs PredicateExpressionsOptimizer::getSelectQueryProjectionColumns(ASTPtr & ast
std::unordered_map<String, ASTPtr> aliases;
std::vector<DatabaseAndTableWithAlias> tables = getDatabaseAndTables(*select_query, context.getCurrentDatabase());
std::vector<TableWithColumnNames> tables_with_columns;
TranslateQualifiedNamesVisitor::Data::setTablesOnly(tables, tables_with_columns);
TranslateQualifiedNamesVisitor::Data qn_visitor_data{{}, tables_with_columns};
/// TODO: get tables from evaluateAsterisk instead of tablesOnly() to extract asterisks in general way
std::vector<TableWithColumnNames> tables_with_columns = TranslateQualifiedNamesVisitor::Data::tablesOnly(tables);
TranslateQualifiedNamesVisitor::Data qn_visitor_data({}, tables_with_columns, false);
TranslateQualifiedNamesVisitor(qn_visitor_data).visit(ast);
QueryAliasesVisitor::Data query_aliases_data{aliases};

View File

@ -56,14 +56,16 @@ std::vector<ASTPtr *> QueryAliasesMatcher::visit(const ASTArrayJoin &, const AST
{
visitOther(ast, data);
/// @warning It breaks botom-to-top order (childs processed after node here), could lead to some effects.
/// It's possible to add ast back to result vec to save order. It will need two phase ASTArrayJoin visit (setting phase in data).
std::vector<ASTPtr *> out;
std::vector<ASTPtr> grand_children;
for (auto & child1 : ast->children)
for (auto & child2 : child1->children)
for (auto & child3 : child2->children)
out.push_back(&child3);
return out;
grand_children.push_back(child3);
/// create own visitor to run bottom to top
for (auto & child : grand_children)
QueryAliasesVisitor(data).visit(child);
return {};
}
/// set unique aliases for all subqueries. this is needed, because:

View File

@ -7,12 +7,10 @@
#include <Parsers/ASTAsterisk.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Common/StringUtils/StringUtils.h>
#include <Common/typeid_cast.h>
#include <Parsers/ASTQualifiedAsterisk.h>
#include <IO/WriteHelpers.h>
namespace DB
@ -91,13 +89,6 @@ void QueryNormalizer::visit(ASTFunction & node, const ASTPtr &, Data & data)
/// and on all remote servers, function implementation will be same.
if (endsWith(func_name, "Distinct") && func_name_lowercase == "countdistinct")
func_name = data.settings.count_distinct_implementation;
/// As special case, treat count(*) as count(), not as count(list of all columns).
if (func_name_lowercase == "count" && func_arguments->children.size() == 1
&& typeid_cast<const ASTAsterisk *>(func_arguments->children[0].get()))
{
func_arguments->children.clear();
}
}
}
@ -111,7 +102,7 @@ void QueryNormalizer::visit(ASTIdentifier & node, ASTPtr & ast, Data & data)
/// If it is an alias, but not a parent alias (for constructs like "SELECT column + 1 AS column").
auto it_alias = data.aliases.find(node.name);
if (it_alias != data.aliases.end() && current_alias != node.name)
if (IdentifierSemantic::canBeAlias(node) && it_alias != data.aliases.end() && current_alias != node.name)
{
auto & alias_node = it_alias->second;
@ -138,84 +129,6 @@ void QueryNormalizer::visit(ASTIdentifier & node, ASTPtr & ast, Data & data)
}
}
/// Replace *, alias.*, database.table.* with a list of columns.
void QueryNormalizer::visit(ASTExpressionList & node, const ASTPtr &, Data & data)
{
if (!data.tables_with_columns)
return;
const auto & tables_with_columns = *data.tables_with_columns;
const auto & source_columns_set = data.source_columns_set;
ASTs old_children;
if (data.processAsterisks())
{
bool has_asterisk = false;
for (const auto & child : node.children)
{
if (typeid_cast<const ASTAsterisk *>(child.get()) ||
typeid_cast<const ASTQualifiedAsterisk *>(child.get()))
{
has_asterisk = true;
break;
}
}
if (has_asterisk)
{
old_children.swap(node.children);
node.children.reserve(old_children.size());
}
}
for (const auto & child : old_children)
{
if (typeid_cast<const ASTAsterisk *>(child.get()))
{
bool first_table = true;
for (const auto & [table_name, table_columns] : tables_with_columns)
{
for (const auto & column_name : table_columns)
if (first_table || !data.join_using_columns.count(column_name))
{
/// qualifed names for duplicates
if (!first_table && source_columns_set && source_columns_set->count(column_name))
node.children.emplace_back(std::make_shared<ASTIdentifier>(table_name.getQualifiedNamePrefix() + column_name));
else
node.children.emplace_back(std::make_shared<ASTIdentifier>(column_name));
}
first_table = false;
}
}
else if (const auto * qualified_asterisk = typeid_cast<const ASTQualifiedAsterisk *>(child.get()))
{
DatabaseAndTableWithAlias ident_db_and_name(qualified_asterisk->children[0]);
bool first_table = true;
for (const auto & [table_name, table_columns] : tables_with_columns)
{
if (ident_db_and_name.satisfies(table_name, true))
{
for (const auto & column_name : table_columns)
{
/// qualifed names for duplicates
if (!first_table && source_columns_set && source_columns_set->count(column_name))
node.children.emplace_back(std::make_shared<ASTIdentifier>(table_name.getQualifiedNamePrefix() + column_name));
else
node.children.emplace_back(std::make_shared<ASTIdentifier>(column_name));
}
break;
}
first_table = false;
}
}
else
node.children.emplace_back(child);
}
}
/// mark table identifiers as 'not columns'
void QueryNormalizer::visit(ASTTablesInSelectQueryElement & node, const ASTPtr &, Data &)
{
@ -229,9 +142,6 @@ void QueryNormalizer::visit(ASTTablesInSelectQueryElement & node, const ASTPtr &
/// special visitChildren() for ASTSelectQuery
void QueryNormalizer::visit(ASTSelectQuery & select, const ASTPtr & ast, Data & data)
{
if (auto join = select.join())
extractJoinUsingColumns(join->table_join, data);
for (auto & child : ast->children)
{
if (typeid_cast<const ASTSelectQuery *>(child.get()) ||
@ -312,8 +222,6 @@ void QueryNormalizer::visit(ASTPtr & ast, Data & data)
visit(*node, ast, data);
if (auto * node = typeid_cast<ASTIdentifier *>(ast.get()))
visit(*node, ast, data);
if (auto * node = typeid_cast<ASTExpressionList *>(ast.get()))
visit(*node, ast, data);
if (auto * node = typeid_cast<ASTTablesInSelectQueryElement *>(ast.get()))
visit(*node, ast, data);
if (auto * node = typeid_cast<ASTSelectQuery *>(ast.get()))
@ -344,27 +252,4 @@ void QueryNormalizer::visit(ASTPtr & ast, Data & data)
}
}
/// 'select * from a join b using id' should result one 'id' column
void QueryNormalizer::extractJoinUsingColumns(const ASTPtr ast, Data & data)
{
const auto & table_join = typeid_cast<const ASTTableJoin &>(*ast);
if (table_join.using_expression_list)
{
auto & keys = typeid_cast<ASTExpressionList &>(*table_join.using_expression_list);
for (const auto & key : keys.children)
if (auto opt_column = getIdentifierName(key))
data.join_using_columns.insert(*opt_column);
else if (typeid_cast<const ASTLiteral *>(key.get()))
data.join_using_columns.insert(key->getColumnName());
else
{
String alias = key->tryGetAlias();
if (alias.empty())
throw Exception("Logical error: expected identifier or alias, got: " + key->getID(), ErrorCodes::LOGICAL_ERROR);
data.join_using_columns.insert(alias);
}
}
}
}

View File

@ -1,11 +1,9 @@
#pragma once
#include <unordered_set>
#include <map>
#include <Core/Names.h>
#include <Parsers/IAST.h>
#include <Interpreters/DatabaseAndTableWithAlias.h>
#include <Interpreters/Aliases.h>
namespace DB
@ -21,9 +19,9 @@ inline bool functionIsInOrGlobalInOperator(const String & name)
return functionIsInOperator(name) || name == "globalIn" || name == "globalNotIn";
}
class ASTSelectQuery;
class ASTFunction;
class ASTIdentifier;
class ASTExpressionList;
struct ASTTablesInSelectQueryElement;
class Context;
@ -53,10 +51,6 @@ public:
const Aliases & aliases;
const ExtractedSettings settings;
const Context * context;
const NameSet * source_columns_set;
const std::vector<TableWithColumnNames> * tables_with_columns;
std::unordered_set<String> join_using_columns;
/// tmp data
size_t level;
@ -64,26 +58,11 @@ public:
SetOfASTs current_asts; /// vertices in the current call stack of this method
std::string current_alias; /// the alias referencing to the ancestor of ast (the deepest ancestor with aliases)
Data(const Aliases & aliases_, ExtractedSettings && settings_, const Context & context_,
const NameSet & source_columns_set, const std::vector<TableWithColumnNames> & tables_with_columns_)
: aliases(aliases_)
, settings(settings_)
, context(&context_)
, source_columns_set(&source_columns_set)
, tables_with_columns(&tables_with_columns_)
, level(0)
{}
Data(const Aliases & aliases_, ExtractedSettings && settings_)
: aliases(aliases_)
, settings(settings_)
, context(nullptr)
, source_columns_set(nullptr)
, tables_with_columns(nullptr)
, level(0)
{}
bool processAsterisks() const { return tables_with_columns && !tables_with_columns->empty(); }
};
QueryNormalizer(Data & data)
@ -102,13 +81,10 @@ private:
static void visit(ASTIdentifier &, ASTPtr &, Data &);
static void visit(ASTFunction &, const ASTPtr &, Data &);
static void visit(ASTExpressionList &, const ASTPtr &, Data &);
static void visit(ASTTablesInSelectQueryElement &, const ASTPtr &, Data &);
static void visit(ASTSelectQuery &, const ASTPtr &, Data &);
static void visitChildren(const ASTPtr &, Data & data);
static void extractJoinUsingColumns(const ASTPtr ast, Data & data);
};
}

View File

@ -301,6 +301,7 @@ struct Settings
M(SettingBool, allow_experimental_cross_to_join_conversion, false, "Convert CROSS JOIN to INNER JOIN if possible") \
M(SettingBool, cancel_http_readonly_queries_on_client_close, false, "Cancel HTTP readonly queries when a client closes the connection without waiting for response.") \
M(SettingBool, external_table_functions_use_nulls, true, "If it is set to true, external table functions will implicitly use Nullable type if needed. Otherwise NULLs will be substituted with default values. Currently supported only for 'mysql' table function.") \
M(SettingBool, allow_experimental_data_skipping_indices, false, "If it is set to true, data skipping indices can be used in CREATE TABLE/ALTER TABLE queries.")\
#define DECLARE(TYPE, NAME, DEFAULT, DESCRIPTION) \
TYPE NAME {DEFAULT};

View File

@ -78,49 +78,36 @@ void collectSourceColumns(ASTSelectQuery * select_query, StoragePtr storage, Nam
}
}
/// Translate qualified names such as db.table.column, table.column, table_alias.column to unqualified names.
void translateQualifiedNames(ASTPtr & query, ASTSelectQuery * select_query, const NameSet & source_columns,
const std::vector<TableWithColumnNames> & tables_with_columns)
/// Translate qualified names such as db.table.column, table.column, table_alias.column to names' normal form.
/// Expand asterisks and qualified asterisks with column names.
/// There would be columns in normal form & column aliases after translation. Column & column alias would be normalized in QueryNormalizer.
void translateQualifiedNames(ASTPtr & query, ASTSelectQuery * select_query, const Context & context, SyntaxAnalyzerResult & result,
const Names & source_columns_list, const NameSet & source_columns_set)
{
if (!select_query->tables || select_query->tables->children.empty())
return;
std::vector<TableWithColumnNames> tables_with_columns = getDatabaseAndTablesWithColumnNames(*select_query, context);
if (tables_with_columns.empty())
{
Names all_columns_name = source_columns_list;
/// TODO: asterisk_left_columns_only probably does not work in some cases
if (!context.getSettingsRef().asterisk_left_columns_only)
{
auto columns_from_joined_table = result.analyzed_join.getColumnsFromJoinedTable(source_columns_set, context, select_query);
for (auto & column : columns_from_joined_table)
all_columns_name.emplace_back(column.name_and_type.name);
}
tables_with_columns.emplace_back(DatabaseAndTableWithAlias{}, std::move(all_columns_name));
}
LogAST log;
TranslateQualifiedNamesVisitor::Data visitor_data{source_columns, tables_with_columns};
TranslateQualifiedNamesVisitor::Data visitor_data(source_columns_set, tables_with_columns);
TranslateQualifiedNamesVisitor visitor(visitor_data, log.stream());
visitor.visit(query);
}
/// For star nodes(`*`), expand them to a list of all columns. For literal nodes, substitute aliases.
void normalizeTree(
ASTPtr & query,
SyntaxAnalyzerResult & result,
const Names & source_columns,
const NameSet & source_columns_set,
const Context & context,
const ASTSelectQuery * select_query,
std::vector<TableWithColumnNames> & tables_with_columns)
{
const auto & settings = context.getSettingsRef();
Names all_columns_name = source_columns;
if (!settings.asterisk_left_columns_only)
{
auto columns_from_joined_table = result.analyzed_join.getColumnsFromJoinedTable(source_columns_set, context, select_query);
for (auto & column : columns_from_joined_table)
all_columns_name.emplace_back(column.name_and_type.name);
}
if (all_columns_name.empty())
throw Exception("An asterisk cannot be replaced with empty columns.", ErrorCodes::LOGICAL_ERROR);
if (tables_with_columns.empty())
tables_with_columns.emplace_back(DatabaseAndTableWithAlias{}, std::move(all_columns_name));
QueryNormalizer::Data normalizer_data(result.aliases, settings, context, source_columns_set, tables_with_columns);
QueryNormalizer(normalizer_data).visit(query);
}
bool hasArrayJoin(const ASTPtr & ast)
{
if (const ASTFunction * function = typeid_cast<const ASTFunction *>(&*ast))
@ -646,12 +633,10 @@ SyntaxAnalyzerResultPtr SyntaxAnalyzer::analyze(
if (source_columns_set.size() != source_columns_list.size())
throw Exception("Unexpected duplicates in source columns list.", ErrorCodes::LOGICAL_ERROR);
std::vector<TableWithColumnNames> tables_with_columns;
if (select_query)
{
tables_with_columns = getDatabaseAndTablesWithColumnNames(*select_query, context);
translateQualifiedNames(query, select_query, source_columns_set, tables_with_columns);
translateQualifiedNames(query, select_query, context, result,
(storage ? storage->getColumns().ordinary.getNames() : source_columns_list), source_columns_set);
/// Depending on the user's profile, check for the execution rights
/// distributed subqueries inside the IN or JOIN sections and process these subqueries.
@ -669,8 +654,10 @@ SyntaxAnalyzerResultPtr SyntaxAnalyzer::analyze(
}
/// Common subexpression elimination. Rewrite rules.
normalizeTree(query, result, (storage ? storage->getColumns().ordinary.getNames() : source_columns_list), source_columns_set,
context, select_query, tables_with_columns);
{
QueryNormalizer::Data normalizer_data(result.aliases, context.getSettingsRef());
QueryNormalizer(normalizer_data).visit(query);
}
/// Remove unneeded columns according to 'required_result_columns'.
/// Leave all selected columns in case of DISTINCT; columns that contain arrayJoin function inside.

View File

@ -1,3 +1,5 @@
#include <Poco/String.h>
#include <Interpreters/TranslateQualifiedNamesVisitor.h>
#include <Interpreters/IdentifierSemantic.h>
@ -5,10 +7,14 @@
#include <Core/Names.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTAsterisk.h>
#include <Parsers/ASTQualifiedAsterisk.h>
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTSelectWithUnionQuery.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Parsers/ASTExpressionList.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTFunction.h>
namespace DB
@ -41,12 +47,14 @@ std::vector<ASTPtr *> TranslateQualifiedNamesMatcher::visit(ASTPtr & ast, Data &
{
if (auto * t = typeid_cast<ASTIdentifier *>(ast.get()))
return visit(*t, ast, data);
if (auto * t = typeid_cast<ASTQualifiedAsterisk *>(ast.get()))
return visit(*t, ast, data);
if (auto * t = typeid_cast<ASTTableJoin *>(ast.get()))
return visit(*t, ast, data);
if (auto * t = typeid_cast<ASTSelectQuery *>(ast.get()))
return visit(*t, ast, data);
if (auto * node = typeid_cast<ASTExpressionList *>(ast.get()))
visit(*node, ast, data);
if (auto * node = typeid_cast<ASTFunction *>(ast.get()))
visit(*node, ast, data);
return {};
}
@ -73,6 +81,18 @@ std::vector<ASTPtr *> TranslateQualifiedNamesMatcher::visit(ASTIdentifier & iden
return {};
}
/// As special case, treat count(*) as count(), not as count(list of all columns).
void TranslateQualifiedNamesMatcher::visit(ASTFunction & node, const ASTPtr &, Data &)
{
ASTPtr & func_arguments = node.arguments;
String func_name_lowercase = Poco::toLower(node.name);
if (func_name_lowercase == "count" &&
func_arguments->children.size() == 1 &&
typeid_cast<const ASTAsterisk *>(func_arguments->children[0].get()))
func_arguments->children.clear();
}
std::vector<ASTPtr *> TranslateQualifiedNamesMatcher::visit(const ASTQualifiedAsterisk & , const ASTPtr & ast, Data & data)
{
if (ast->children.size() != 1)
@ -100,8 +120,11 @@ std::vector<ASTPtr *> TranslateQualifiedNamesMatcher::visit(ASTTableJoin & join,
return out;
}
std::vector<ASTPtr *> TranslateQualifiedNamesMatcher::visit(ASTSelectQuery & select, const ASTPtr & , Data &)
std::vector<ASTPtr *> TranslateQualifiedNamesMatcher::visit(ASTSelectQuery & select, const ASTPtr & , Data & data)
{
if (auto join = select.join())
extractJoinUsingColumns(join->table_join, data);
/// If the WHERE clause or HAVING consists of a single qualified column, the reference must be translated not only in children,
/// but also in where_expression and having_expression.
std::vector<ASTPtr *> out;
@ -114,4 +137,109 @@ std::vector<ASTPtr *> TranslateQualifiedNamesMatcher::visit(ASTSelectQuery & sel
return out;
}
/// Replace *, alias.*, database.table.* with a list of columns.
void TranslateQualifiedNamesMatcher::visit(ASTExpressionList & node, const ASTPtr &, Data & data)
{
const auto & tables_with_columns = data.tables;
const auto & source_columns = data.source_columns;
ASTs old_children;
if (data.processAsterisks())
{
bool has_asterisk = false;
for (const auto & child : node.children)
{
if (typeid_cast<const ASTAsterisk *>(child.get()))
{
if (tables_with_columns.empty())
throw Exception("An asterisk cannot be replaced with empty columns.", ErrorCodes::LOGICAL_ERROR);
has_asterisk = true;
break;
}
else if (auto qa = typeid_cast<const ASTQualifiedAsterisk *>(child.get()))
{
visit(*qa, child, data); /// check if it's OK before rewrite
has_asterisk = true;
break;
}
}
if (has_asterisk)
{
old_children.swap(node.children);
node.children.reserve(old_children.size());
}
}
for (const auto & child : old_children)
{
if (typeid_cast<const ASTAsterisk *>(child.get()))
{
bool first_table = true;
for (const auto & [table_name, table_columns] : tables_with_columns)
{
for (const auto & column_name : table_columns)
if (first_table || !data.join_using_columns.count(column_name))
{
/// qualifed names for duplicates
if (!first_table && source_columns.count(column_name))
node.children.emplace_back(std::make_shared<ASTIdentifier>(table_name.getQualifiedNamePrefix() + column_name));
else
node.children.emplace_back(std::make_shared<ASTIdentifier>(column_name));
}
first_table = false;
}
}
else if (const auto * qualified_asterisk = typeid_cast<const ASTQualifiedAsterisk *>(child.get()))
{
DatabaseAndTableWithAlias ident_db_and_name(qualified_asterisk->children[0]);
bool first_table = true;
for (const auto & [table_name, table_columns] : tables_with_columns)
{
if (ident_db_and_name.satisfies(table_name, true))
{
for (const auto & column_name : table_columns)
{
/// qualifed names for duplicates
if (!first_table && source_columns.count(column_name))
node.children.emplace_back(std::make_shared<ASTIdentifier>(table_name.getQualifiedNamePrefix() + column_name));
else
node.children.emplace_back(std::make_shared<ASTIdentifier>(column_name));
}
break;
}
first_table = false;
}
}
else
node.children.emplace_back(child);
}
}
/// 'select * from a join b using id' should result one 'id' column
void TranslateQualifiedNamesMatcher::extractJoinUsingColumns(const ASTPtr ast, Data & data)
{
const auto & table_join = typeid_cast<const ASTTableJoin &>(*ast);
if (table_join.using_expression_list)
{
auto & keys = typeid_cast<ASTExpressionList &>(*table_join.using_expression_list);
for (const auto & key : keys.children)
if (auto opt_column = getIdentifierName(key))
data.join_using_columns.insert(*opt_column);
else if (typeid_cast<const ASTLiteral *>(key.get()))
data.join_using_columns.insert(key->getColumnName());
else
{
String alias = key->tryGetAlias();
if (alias.empty())
throw Exception("Logical error: expected identifier or alias, got: " + key->getID(), ErrorCodes::LOGICAL_ERROR);
data.join_using_columns.insert(alias);
}
}
}
}

View File

@ -13,6 +13,8 @@ class ASTIdentifier;
class ASTQualifiedAsterisk;
struct ASTTableJoin;
class ASTSelectQuery;
class ASTExpressionList;
class ASTFunction;
/// Visit one node for names qualification. @sa InDepthNodeVisitor.
class TranslateQualifiedNamesMatcher
@ -22,15 +24,26 @@ public:
{
const NameSet & source_columns;
const std::vector<TableWithColumnNames> & tables;
std::unordered_set<String> join_using_columns;
bool has_columns;
static void setTablesOnly(const std::vector<DatabaseAndTableWithAlias> & tables,
std::vector<TableWithColumnNames> & tables_with_columns)
Data(const NameSet & source_columns_, const std::vector<TableWithColumnNames> & tables_, bool has_columns_ = true)
: source_columns(source_columns_)
, tables(tables_)
, has_columns(has_columns_)
{}
static std::vector<TableWithColumnNames> tablesOnly(const std::vector<DatabaseAndTableWithAlias> & tables)
{
tables_with_columns.clear();
std::vector<TableWithColumnNames> tables_with_columns;
tables_with_columns.reserve(tables.size());
for (const auto & table : tables)
tables_with_columns.emplace_back(TableWithColumnNames{table, {}});
return tables_with_columns;
}
bool processAsterisks() const { return !tables.empty() && has_columns; }
};
static constexpr const char * label = "TranslateQualifiedNames";
@ -43,10 +56,14 @@ private:
static std::vector<ASTPtr *> visit(const ASTQualifiedAsterisk & node, const ASTPtr & ast, Data &);
static std::vector<ASTPtr *> visit(ASTTableJoin & node, const ASTPtr & ast, Data &);
static std::vector<ASTPtr *> visit(ASTSelectQuery & node, const ASTPtr & ast, Data &);
static void visit(ASTExpressionList &, const ASTPtr &, Data &);
static void visit(ASTFunction &, const ASTPtr &, Data &);
static void extractJoinUsingColumns(const ASTPtr ast, Data & data);
};
/// Visits AST for names qualification.
/// It finds columns (general identifiers and asterisks) and translate their names according to tables' names.
/// It finds columns and translate their names to the normal form. Expand asterisks and qualified asterisks with column names.
using TranslateQualifiedNamesVisitor = InDepthNodeVisitor<TranslateQualifiedNamesMatcher, true>;
}

View File

@ -22,6 +22,7 @@
#include <DataTypes/DataTypeLowCardinality.h>
#include <common/DateLUT.h>
#include <DataTypes/DataTypeAggregateFunction.h>
namespace DB
@ -248,6 +249,18 @@ Field convertFieldToTypeImpl(const Field & src, const IDataType & type, const ID
return res;
}
}
else if (const DataTypeAggregateFunction * agg_func_type = typeid_cast<const DataTypeAggregateFunction *>(&type))
{
if (src.getType() != Field::Types::AggregateFunctionState)
throw Exception(String("Cannot convert ") + src.getTypeName() + " to " + agg_func_type->getName(),
ErrorCodes::TYPE_MISMATCH);
auto & name = src.get<AggregateFunctionStateData>().name;
if (agg_func_type->getName() != name)
throw Exception("Cannot convert " + name + " to " + agg_func_type->getName(), ErrorCodes::TYPE_MISMATCH);
return src;
}
if (src.getType() == Field::Types::String)
{
@ -257,6 +270,8 @@ Field convertFieldToTypeImpl(const Field & src, const IDataType & type, const ID
return (*col)[0];
}
// TODO (nemkov): should we attempt to parse value using or `type.deserializeAsTextEscaped()` type.deserializeAsTextEscaped() ?
throw Exception("Type mismatch in IN or VALUES section. Expected: " + type.getName() + ". Got: "
+ Field::Types::toString(src.getType()), ErrorCodes::TYPE_MISMATCH);

View File

@ -22,6 +22,15 @@ ASTIdentifier::ASTIdentifier(const String & name_, std::vector<String> && name_p
{
}
void ASTIdentifier::setShortName(const String & new_name)
{
name = new_name;
name_parts.clear();
semantic->need_long_name = false;
semantic->can_be_alias = true;
}
void ASTIdentifier::formatImplWithoutAlias(const FormatSettings & settings, FormatState &, FormatStateStacked) const
{
auto format_element = [&](const String & elem_name)

View File

@ -36,11 +36,7 @@ public:
bool compound() const { return !name_parts.empty(); }
bool isShort() const { return name_parts.empty() || name == name_parts.back(); }
void setShortName(const String & new_name)
{
name = new_name;
name_parts.clear();
}
void setShortName(const String & new_name);
const String & shortName() const
{

View File

@ -1,5 +1,6 @@
add_subdirectory (System)
add_subdirectory(System)
add_subdirectory(Kafka)
if (ENABLE_TESTS)
add_subdirectory (tests)
endif ()
if(ENABLE_TESTS)
add_subdirectory(tests)
endif()

View File

@ -185,7 +185,7 @@ public:
const SelectQueryInfo & /*query_info*/,
const Context & /*context*/,
QueryProcessingStage::Enum /*processed_stage*/,
size_t /*max_block_size*/,
UInt64 /*max_block_size*/,
unsigned /*num_streams*/)
{
throw Exception("Method read is not supported by storage " + getName(), ErrorCodes::NOT_IMPLEMENTED);

View File

@ -14,6 +14,7 @@ struct IndicesDescription
IndicesDescription() = default;
bool empty() const { return indices.empty(); }
String toString() const;
static IndicesDescription parse(const String & str);

View File

@ -0,0 +1,9 @@
if(USE_RDKAFKA)
include(${ClickHouse_SOURCE_DIR}/cmake/dbms_glob_sources.cmake)
add_headers_and_sources(clickhouse_storage_kafka .)
add_library(clickhouse_storage_kafka ${LINK_MODE} ${clickhouse_storage_kafka_sources})
target_link_libraries(clickhouse_storage_kafka PRIVATE clickhouse_common_io ${CPPKAFKA_LIBRARY} ${RDKAFKA_LIBRARY})
if(NOT USE_INTERNAL_RDKAFKA_LIBRARY)
target_include_directories(clickhouse_storage_kafka SYSTEM BEFORE PRIVATE ${RDKAFKA_INCLUDE_DIR})
endif()
endif()

View File

@ -53,7 +53,9 @@ public:
return size;
}
/// The task is started immediately.
TaskHandle addTask(const Task & task);
void removeTask(const TaskHandle & task);
~BackgroundProcessingPool();

View File

@ -72,6 +72,7 @@ namespace DB
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
extern const int MEMORY_LIMIT_EXCEEDED;
extern const int SYNTAX_ERROR;
extern const int INVALID_PARTITION_VALUE;
@ -831,7 +832,7 @@ void MergeTreeData::clearOldTemporaryDirectories(ssize_t custom_directories_life
Poco::DirectoryIterator end;
for (Poco::DirectoryIterator it{full_path}; it != end; ++it)
{
if (startsWith(it.name(), "tmp"))
if (startsWith(it.name(), "tmp_"))
{
Poco::File tmp_dir(full_path + it.name());
@ -1051,7 +1052,7 @@ bool isMetadataOnlyConversion(const IDataType * from, const IDataType * to)
}
void MergeTreeData::checkAlter(const AlterCommands & commands)
void MergeTreeData::checkAlter(const AlterCommands & commands, const Context & context)
{
/// Check that needed transformations can be applied to the list of columns without considering type conversions.
auto new_columns = getColumns();
@ -1060,6 +1061,11 @@ void MergeTreeData::checkAlter(const AlterCommands & commands)
ASTPtr new_primary_key_ast = primary_key_ast;
commands.apply(new_columns, new_indices, new_order_by_ast, new_primary_key_ast);
if (getIndicesDescription().empty() && !new_indices.empty() &&
!context.getSettingsRef().allow_experimental_data_skipping_indices)
throw Exception("You must set the setting `allow_experimental_data_skipping_indices` to 1 " \
"before using data skipping indices.", ErrorCodes::BAD_ARGUMENTS);
/// Set of columns that shouldn't be altered.
NameSet columns_alter_forbidden;

Some files were not shown because too many files have changed in this diff Show More