Compare commits

...

42 Commits

Author SHA1 Message Date
Pavel Kruglov
f3c4fedd35
Merge 2c0f9dbfc5 into 44b4bd38b9 2024-11-20 15:25:05 -08:00
Mikhail Artemenko
44b4bd38b9
Merge pull request #72045 from ClickHouse/issues/70174/cluster_versions
Enable cluster table functions for DataLake Storages
2024-11-20 21:22:37 +00:00
Shichao Jin
40c7d5fd1a
Merge pull request #71894 from udiz/fix-arrayWithConstant-size-estimation
Fix: arrayWithConstant size estimation using row's element size
2024-11-20 19:56:27 +00:00
Vitaly Baranov
4e56c026cd
Merge pull request #72103 from vitlibar/get-rid-of-code-duplication-after-check-grant
Get rid of code duplication after adding CHECK GRANT
2024-11-20 17:30:12 +00:00
Raúl Marín
2e776256e8
Merge pull request #72046 from Algunenano/decimal_trash
Save several minutes of build time
2024-11-20 17:13:01 +00:00
Raúl Marín
f84083d174 Clang tidy gives one error at a time 2024-11-20 12:45:46 +01:00
Mikhail Artemenko
4ccebd9a24 fix syntax for iceberg in docs 2024-11-20 11:15:39 +00:00
Mikhail Artemenko
99177c0daf remove icebergCluster alias 2024-11-20 11:15:12 +00:00
Raúl Marín
17fdd2bd37 Loving tidy 2024-11-20 02:06:29 +01:00
Raúl Marín
1c414b9987 OSX fix 2024-11-19 21:48:37 +01:00
Raúl Marín
dd90fbe13b Fix clang tidy after moving implementation to cpp 2024-11-19 20:04:52 +01:00
Vitaly Baranov
ecedbcc763 Allow test 03234_check_grant.sh to run in parallel. 2024-11-19 17:48:12 +01:00
Vitaly Baranov
8551162dcb Get rid of code duplucation after adding CHECK GRANT. 2024-11-19 17:48:06 +01:00
Raúl Marín
5286fa65c4 Fix 2024-11-19 15:07:17 +01:00
Raúl Marín
e6f4afe569 Move things to implementation file 2024-11-19 14:42:45 +01:00
Raúl Marín
2146ab4e4e Move more things to private 2024-11-19 14:37:06 +01:00
Mikhail Artemenko
0951991c1d update aspell-dict.txt 2024-11-19 13:10:42 +00:00
Mikhail Artemenko
19aec5e572 Merge branch 'issues/70174/cluster_versions' of github.com:ClickHouse/ClickHouse into issues/70174/cluster_versions 2024-11-19 12:51:56 +00:00
Mikhail Artemenko
a367de9977 add docs 2024-11-19 12:49:59 +00:00
Mikhail Artemenko
6894e280b2 fix pr issues 2024-11-19 12:34:42 +00:00
Raúl Marín
59f73a2053 Add back declaration 2024-11-19 13:27:19 +01:00
Mikhail Artemenko
39ebe113d9 Merge branch 'master' into issues/70174/cluster_versions 2024-11-19 11:28:46 +00:00
Raúl Marín
514c1f7215 Add missing type 2024-11-19 11:22:43 +01:00
udiz
239bbaa133 use length 2024-11-19 00:00:43 +00:00
udiz
07fac5808d format null on test 2024-11-18 23:08:48 +00:00
udiz
ed95e0781f test uses less memory 2024-11-18 22:48:38 +00:00
Raúl Marín
445a5e9c9e Style 2024-11-18 20:03:35 +01:00
Raúl Marín
47bed13b42 Remove extra instantiations of classes 2024-11-18 19:51:42 +01:00
Raúl Marín
557b3e370d Remove code bloat from ColumnVector.h 2024-11-18 19:08:42 +01:00
robot-clickhouse
014608fb6b Automatic style fix 2024-11-18 17:51:51 +00:00
Mikhail Artemenko
a29ded4941 add test for iceberg 2024-11-18 17:39:46 +00:00
Mikhail Artemenko
d2efae7511 enable cluster versions for datalake storages 2024-11-18 17:35:21 +00:00
Raúl Marín
6b55754bc9 Remove some nested includes in IFunction usage 2024-11-18 18:25:37 +01:00
Raúl Marín
e33f5bb4e9 Remove unused leftovers
Usage was removed in 23.6
https://github.com/ClickHouse/ClickHouse/pull/50531
2024-11-18 17:45:50 +01:00
Raúl Marín
a258b6d0f2 Prevent magic_enum in Field.h 2024-11-18 17:36:45 +01:00
Raúl Marín
1c308f970b Try to remove more decimal instantiations 2024-11-18 17:21:06 +01:00
Raúl Marín
ec776fe8db Remove wasteful template instatiations 2024-11-18 17:08:43 +01:00
Raúl Marín
fb552dd2c0 Remove unused trash 2024-11-18 13:49:11 +01:00
avogar
2c0f9dbfc5 Fix Dynamic serialization in Pretty JSON formats 2024-11-14 13:22:06 +00:00
udiz
6879aa130a newline 2024-11-13 22:47:54 +00:00
udiz
43f3c886a2 add test 2024-11-13 22:46:36 +00:00
udiz
c383a743f7 arrayWithConstant size estimation using single value size 2024-11-13 20:02:31 +00:00
116 changed files with 1468 additions and 1123 deletions

View File

@ -49,4 +49,4 @@ LIMIT 2
**See Also** **See Also**
- [DeltaLake engine](/docs/en/engines/table-engines/integrations/deltalake.md) - [DeltaLake engine](/docs/en/engines/table-engines/integrations/deltalake.md)
- [DeltaLake cluster table function](/docs/en/sql-reference/table-functions/deltalakeCluster.md)

View File

@ -0,0 +1,30 @@
---
slug: /en/sql-reference/table-functions/deltalakeCluster
sidebar_position: 46
sidebar_label: deltaLakeCluster
title: "deltaLakeCluster Table Function"
---
This is an extension to the [deltaLake](/docs/en/sql-reference/table-functions/deltalake.md) table function.
Allows processing files from [Delta Lake](https://github.com/delta-io/delta) tables in Amazon S3 in parallel from many nodes in a specified cluster. On initiator it creates a connection to all nodes in the cluster and dispatches each file dynamically. On the worker node it asks the initiator about the next task to process and processes it. This is repeated until all tasks are finished.
**Syntax**
``` sql
deltaLakeCluster(cluster_name, url [,aws_access_key_id, aws_secret_access_key] [,format] [,structure] [,compression])
```
**Arguments**
- `cluster_name` — Name of a cluster that is used to build a set of addresses and connection parameters to remote and local servers.
- Description of all other arguments coincides with description of arguments in equivalent [deltaLake](/docs/en/sql-reference/table-functions/deltalake.md) table function.
**Returned value**
A table with the specified structure for reading data from cluster in the specified Delta Lake table in S3.
**See Also**
- [deltaLake engine](/docs/en/engines/table-engines/integrations/deltalake.md)
- [deltaLake table function](/docs/en/sql-reference/table-functions/deltalake.md)

View File

@ -29,4 +29,4 @@ A table with the specified structure for reading data in the specified Hudi tabl
**See Also** **See Also**
- [Hudi engine](/docs/en/engines/table-engines/integrations/hudi.md) - [Hudi engine](/docs/en/engines/table-engines/integrations/hudi.md)
- [Hudi cluster table function](/docs/en/sql-reference/table-functions/hudiCluster.md)

View File

@ -0,0 +1,30 @@
---
slug: /en/sql-reference/table-functions/hudiCluster
sidebar_position: 86
sidebar_label: hudiCluster
title: "hudiCluster Table Function"
---
This is an extension to the [hudi](/docs/en/sql-reference/table-functions/hudi.md) table function.
Allows processing files from Apache [Hudi](https://hudi.apache.org/) tables in Amazon S3 in parallel from many nodes in a specified cluster. On initiator it creates a connection to all nodes in the cluster and dispatches each file dynamically. On the worker node it asks the initiator about the next task to process and processes it. This is repeated until all tasks are finished.
**Syntax**
``` sql
hudiCluster(cluster_name, url [,aws_access_key_id, aws_secret_access_key] [,format] [,structure] [,compression])
```
**Arguments**
- `cluster_name` — Name of a cluster that is used to build a set of addresses and connection parameters to remote and local servers.
- Description of all other arguments coincides with description of arguments in equivalent [hudi](/docs/en/sql-reference/table-functions/hudi.md) table function.
**Returned value**
A table with the specified structure for reading data from cluster in the specified Hudi table in S3.
**See Also**
- [Hudi engine](/docs/en/engines/table-engines/integrations/hudi.md)
- [Hudi table function](/docs/en/sql-reference/table-functions/hudi.md)

View File

@ -72,3 +72,4 @@ Table function `iceberg` is an alias to `icebergS3` now.
**See Also** **See Also**
- [Iceberg engine](/docs/en/engines/table-engines/integrations/iceberg.md) - [Iceberg engine](/docs/en/engines/table-engines/integrations/iceberg.md)
- [Iceberg cluster table function](/docs/en/sql-reference/table-functions/icebergCluster.md)

View File

@ -0,0 +1,43 @@
---
slug: /en/sql-reference/table-functions/icebergCluster
sidebar_position: 91
sidebar_label: icebergCluster
title: "icebergCluster Table Function"
---
This is an extension to the [iceberg](/docs/en/sql-reference/table-functions/iceberg.md) table function.
Allows processing files from Apache [Iceberg](https://iceberg.apache.org/) in parallel from many nodes in a specified cluster. On initiator it creates a connection to all nodes in the cluster and dispatches each file dynamically. On the worker node it asks the initiator about the next task to process and processes it. This is repeated until all tasks are finished.
**Syntax**
``` sql
icebergS3Cluster(cluster_name, url [, NOSIGN | access_key_id, secret_access_key, [session_token]] [,format] [,compression_method])
icebergS3Cluster(cluster_name, named_collection[, option=value [,..]])
icebergAzureCluster(cluster_name, connection_string|storage_account_url, container_name, blobpath, [,account_name], [,account_key] [,format] [,compression_method])
icebergAzureCluster(cluster_name, named_collection[, option=value [,..]])
icebergHDFSCluster(cluster_name, path_to_table, [,format] [,compression_method])
icebergHDFSCluster(cluster_name, named_collection[, option=value [,..]])
```
**Arguments**
- `cluster_name` — Name of a cluster that is used to build a set of addresses and connection parameters to remote and local servers.
- Description of all other arguments coincides with description of arguments in equivalent [iceberg](/docs/en/sql-reference/table-functions/iceberg.md) table function.
**Returned value**
A table with the specified structure for reading data from cluster in the specified Iceberg table.
**Examples**
```sql
SELECT * FROM icebergS3Cluster('cluster_simple', 'http://test.s3.amazonaws.com/clickhouse-bucket/test_table', 'test', 'test')
```
**See Also**
- [Iceberg engine](/docs/en/engines/table-engines/integrations/iceberg.md)
- [Iceberg table function](/docs/en/sql-reference/table-functions/iceberg.md)

View File

@ -9,6 +9,12 @@
namespace DB namespace DB
{ {
namespace ErrorCodes
{
extern const int INVALID_GRANT;
}
namespace namespace
{ {
void formatOptions(bool grant_option, bool is_partial_revoke, String & result) void formatOptions(bool grant_option, bool is_partial_revoke, String & result)
@ -211,18 +217,43 @@ AccessRightsElement::AccessRightsElement(
{ {
} }
void AccessRightsElement::eraseNonGrantable() AccessFlags AccessRightsElement::getGrantableFlags() const
{ {
if (isGlobalWithParameter() && !anyParameter()) if (isGlobalWithParameter() && !anyParameter())
access_flags &= AccessFlags::allFlagsGrantableOnGlobalWithParameterLevel(); return access_flags & AccessFlags::allFlagsGrantableOnGlobalWithParameterLevel();
else if (!anyColumn()) else if (!anyColumn())
access_flags &= AccessFlags::allFlagsGrantableOnColumnLevel(); return access_flags & AccessFlags::allFlagsGrantableOnColumnLevel();
else if (!anyTable()) else if (!anyTable())
access_flags &= AccessFlags::allFlagsGrantableOnTableLevel(); return access_flags & AccessFlags::allFlagsGrantableOnTableLevel();
else if (!anyDatabase()) else if (!anyDatabase())
access_flags &= AccessFlags::allFlagsGrantableOnDatabaseLevel(); return access_flags & AccessFlags::allFlagsGrantableOnDatabaseLevel();
else else
access_flags &= AccessFlags::allFlagsGrantableOnGlobalLevel(); return access_flags & AccessFlags::allFlagsGrantableOnGlobalLevel();
}
void AccessRightsElement::throwIfNotGrantable() const
{
if (empty())
return;
auto grantable_flags = getGrantableFlags();
if (grantable_flags)
return;
if (!anyColumn())
throw Exception(ErrorCodes::INVALID_GRANT, "{} cannot be granted on the column level", access_flags.toString());
if (!anyTable())
throw Exception(ErrorCodes::INVALID_GRANT, "{} cannot be granted on the table level", access_flags.toString());
if (!anyDatabase())
throw Exception(ErrorCodes::INVALID_GRANT, "{} cannot be granted on the database level", access_flags.toString());
if (!anyParameter())
throw Exception(ErrorCodes::INVALID_GRANT, "{} cannot be granted on the global with parameter level", access_flags.toString());
throw Exception(ErrorCodes::INVALID_GRANT, "{} cannot be granted", access_flags.toString());
}
void AccessRightsElement::eraseNotGrantable()
{
access_flags = getGrantableFlags();
} }
void AccessRightsElement::replaceEmptyDatabase(const String & current_database) void AccessRightsElement::replaceEmptyDatabase(const String & current_database)
@ -251,11 +282,17 @@ bool AccessRightsElements::sameOptions() const
return (size() < 2) || std::all_of(std::next(begin()), end(), [this](const AccessRightsElement & e) { return e.sameOptions(front()); }); return (size() < 2) || std::all_of(std::next(begin()), end(), [this](const AccessRightsElement & e) { return e.sameOptions(front()); });
} }
void AccessRightsElements::eraseNonGrantable() void AccessRightsElements::throwIfNotGrantable() const
{
for (const auto & element : *this)
element.throwIfNotGrantable();
}
void AccessRightsElements::eraseNotGrantable()
{ {
std::erase_if(*this, [](AccessRightsElement & element) std::erase_if(*this, [](AccessRightsElement & element)
{ {
element.eraseNonGrantable(); element.eraseNotGrantable();
return element.empty(); return element.empty();
}); });
} }
@ -269,4 +306,45 @@ void AccessRightsElements::replaceEmptyDatabase(const String & current_database)
String AccessRightsElements::toString() const { return toStringImpl(*this, true); } String AccessRightsElements::toString() const { return toStringImpl(*this, true); }
String AccessRightsElements::toStringWithoutOptions() const { return toStringImpl(*this, false); } String AccessRightsElements::toStringWithoutOptions() const { return toStringImpl(*this, false); }
void AccessRightsElements::formatElementsWithoutOptions(WriteBuffer & buffer, bool hilite) const
{
bool no_output = true;
for (size_t i = 0; i != size(); ++i)
{
const auto & element = (*this)[i];
auto keywords = element.access_flags.toKeywords();
if (keywords.empty() || (!element.anyColumn() && element.columns.empty()))
continue;
for (const auto & keyword : keywords)
{
if (!std::exchange(no_output, false))
buffer << ", ";
buffer << (hilite ? IAST::hilite_keyword : "") << keyword << (hilite ? IAST::hilite_none : "");
if (!element.anyColumn())
element.formatColumnNames(buffer);
}
bool next_element_on_same_db_and_table = false;
if (i != size() - 1)
{
const auto & next_element = (*this)[i + 1];
if (element.sameDatabaseAndTableAndParameter(next_element))
{
next_element_on_same_db_and_table = true;
}
}
if (!next_element_on_same_db_and_table)
{
buffer << " ";
element.formatONClause(buffer, hilite);
}
}
if (no_output)
buffer << (hilite ? IAST::hilite_keyword : "") << "USAGE ON " << (hilite ? IAST::hilite_none : "") << "*.*";
}
} }

View File

@ -79,8 +79,14 @@ struct AccessRightsElement
return (grant_option == other.grant_option) && (is_partial_revoke == other.is_partial_revoke); return (grant_option == other.grant_option) && (is_partial_revoke == other.is_partial_revoke);
} }
/// Returns only those flags which can be granted.
AccessFlags getGrantableFlags() const;
/// Throws an exception if some flags can't be granted.
void throwIfNotGrantable() const;
/// Resets flags which cannot be granted. /// Resets flags which cannot be granted.
void eraseNonGrantable(); void eraseNotGrantable();
bool isEmptyDatabase() const { return database.empty() and !anyDatabase(); } bool isEmptyDatabase() const { return database.empty() and !anyDatabase(); }
@ -110,8 +116,11 @@ public:
bool sameDatabaseAndTable() const; bool sameDatabaseAndTable() const;
bool sameOptions() const; bool sameOptions() const;
/// Throws an exception if some flags can't be granted.
void throwIfNotGrantable() const;
/// Resets flags which cannot be granted. /// Resets flags which cannot be granted.
void eraseNonGrantable(); void eraseNotGrantable();
/// If the database is empty, replaces it with `current_database`. Otherwise does nothing. /// If the database is empty, replaces it with `current_database`. Otherwise does nothing.
void replaceEmptyDatabase(const String & current_database); void replaceEmptyDatabase(const String & current_database);
@ -119,6 +128,7 @@ public:
/// Returns a human-readable representation like "GRANT SELECT, UPDATE(x, y) ON db.table". /// Returns a human-readable representation like "GRANT SELECT, UPDATE(x, y) ON db.table".
String toString() const; String toString() const;
String toStringWithoutOptions() const; String toStringWithoutOptions() const;
void formatElementsWithoutOptions(WriteBuffer & buffer, bool hilite) const;
}; };
} }

View File

@ -4,6 +4,7 @@
#include <Core/Settings.h> #include <Core/Settings.h>
#include <Functions/grouping.h> #include <Functions/grouping.h>
#include <Functions/IFunctionAdaptors.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>

View File

@ -13,6 +13,7 @@
#include <DataTypes/getLeastSupertype.h> #include <DataTypes/getLeastSupertype.h>
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Functions/IFunctionAdaptors.h>
#include <Functions/UserDefined/UserDefinedExecutableFunctionFactory.h> #include <Functions/UserDefined/UserDefinedExecutableFunctionFactory.h>
#include <Functions/UserDefined/UserDefinedSQLFunctionFactory.h> #include <Functions/UserDefined/UserDefinedSQLFunctionFactory.h>
#include <Functions/grouping.h> #include <Functions/grouping.h>

View File

@ -8,6 +8,10 @@
#include <Common/assert_cast.h> #include <Common/assert_cast.h>
#include <Common/iota.h> #include <Common/iota.h>
#include <Core/DecimalFunctions.h>
#include <Core/TypeId.h>
#include <base/TypeName.h>
#include <base/sort.h> #include <base/sort.h>
#include <IO/WriteHelpers.h> #include <IO/WriteHelpers.h>
@ -30,6 +34,19 @@ namespace ErrorCodes
extern const int NOT_IMPLEMENTED; extern const int NOT_IMPLEMENTED;
} }
template <is_decimal T>
const char * ColumnDecimal<T>::getFamilyName() const
{
return TypeName<T>.data();
}
template <is_decimal T>
TypeIndex ColumnDecimal<T>::getDataType() const
{
return TypeToTypeIndex<T>;
}
template <is_decimal T> template <is_decimal T>
#if !defined(DEBUG_OR_SANITIZER_BUILD) #if !defined(DEBUG_OR_SANITIZER_BUILD)
int ColumnDecimal<T>::compareAt(size_t n, size_t m, const IColumn & rhs_, int) const int ColumnDecimal<T>::compareAt(size_t n, size_t m, const IColumn & rhs_, int) const
@ -46,6 +63,12 @@ int ColumnDecimal<T>::doCompareAt(size_t n, size_t m, const IColumn & rhs_, int)
return decimalLess<T>(b, a, other.scale, scale) ? 1 : (decimalLess<T>(a, b, scale, other.scale) ? -1 : 0); return decimalLess<T>(b, a, other.scale, scale) ? 1 : (decimalLess<T>(a, b, scale, other.scale) ? -1 : 0);
} }
template <is_decimal T>
Float64 ColumnDecimal<T>::getFloat64(size_t n) const
{
return DecimalUtils::convertTo<Float64>(data[n], scale);
}
template <is_decimal T> template <is_decimal T>
const char * ColumnDecimal<T>::deserializeAndInsertFromArena(const char * pos) const char * ColumnDecimal<T>::deserializeAndInsertFromArena(const char * pos)
{ {

View File

@ -1,14 +1,9 @@
#pragma once #pragma once
#include <base/sort.h>
#include <base/TypeName.h>
#include <Core/Field.h>
#include <Core/DecimalFunctions.h>
#include <Core/TypeId.h>
#include <Common/typeid_cast.h>
#include <Columns/ColumnFixedSizeHelper.h> #include <Columns/ColumnFixedSizeHelper.h>
#include <Columns/IColumn.h> #include <Columns/IColumn.h>
#include <Columns/IColumnImpl.h> #include <Columns/IColumnImpl.h>
#include <Core/Field.h>
namespace DB namespace DB
@ -39,8 +34,8 @@ private:
{} {}
public: public:
const char * getFamilyName() const override { return TypeName<T>.data(); } const char * getFamilyName() const override;
TypeIndex getDataType() const override { return TypeToTypeIndex<T>; } TypeIndex getDataType() const override;
bool isNumeric() const override { return false; } bool isNumeric() const override { return false; }
bool canBeInsideNullable() const override { return true; } bool canBeInsideNullable() const override { return true; }
@ -98,7 +93,7 @@ public:
return StringRef(reinterpret_cast<const char *>(&data[n]), sizeof(data[n])); return StringRef(reinterpret_cast<const char *>(&data[n]), sizeof(data[n]));
} }
Float64 getFloat64(size_t n) const final { return DecimalUtils::convertTo<Float64>(data[n], scale); } Float64 getFloat64(size_t n) const final;
const char * deserializeAndInsertFromArena(const char * pos) override; const char * deserializeAndInsertFromArena(const char * pos) override;
const char * skipSerializedInArena(const char * pos) const override; const char * skipSerializedInArena(const char * pos) const override;

View File

@ -347,7 +347,7 @@ ColumnWithTypeAndName ColumnFunction::reduce() const
if (is_function_compiled) if (is_function_compiled)
ProfileEvents::increment(ProfileEvents::CompiledFunctionExecute); ProfileEvents::increment(ProfileEvents::CompiledFunctionExecute);
res.column = function->execute(columns, res.type, elements_size); res.column = function->execute(columns, res.type, elements_size, /* dry_run = */ false);
if (res.column->getDataType() != res.type->getColumnType()) if (res.column->getDataType() != res.type->getColumnType())
throw Exception( throw Exception(
ErrorCodes::LOGICAL_ERROR, ErrorCodes::LOGICAL_ERROR,

View File

@ -32,6 +32,8 @@
# include <emmintrin.h> # include <emmintrin.h>
#endif #endif
#include "config.h"
#if USE_MULTITARGET_CODE #if USE_MULTITARGET_CODE
# include <immintrin.h> # include <immintrin.h>
#endif #endif
@ -658,7 +660,7 @@ inline void doFilterAligned(const UInt8 *& filt_pos, const UInt8 *& filt_end_ali
reinterpret_cast<void *>(&res_data[current_offset]), mask & KMASK); reinterpret_cast<void *>(&res_data[current_offset]), mask & KMASK);
current_offset += std::popcount(mask & KMASK); current_offset += std::popcount(mask & KMASK);
/// prepare mask for next iter, if ELEMENTS_PER_VEC = 64, no next iter /// prepare mask for next iter, if ELEMENTS_PER_VEC = 64, no next iter
if (ELEMENTS_PER_VEC < 64) if constexpr (ELEMENTS_PER_VEC < 64)
{ {
mask >>= ELEMENTS_PER_VEC; mask >>= ELEMENTS_PER_VEC;
} }
@ -992,6 +994,151 @@ ColumnPtr ColumnVector<T>::createWithOffsets(const IColumn::Offsets & offsets, c
return res; return res;
} }
DECLARE_DEFAULT_CODE(
template <typename Container, typename Type> void vectorIndexImpl(
const Container & data, const PaddedPODArray<Type> & indexes, size_t limit, Container & res_data)
{
for (size_t i = 0; i < limit; ++i)
res_data[i] = data[indexes[i]];
}
);
DECLARE_AVX512VBMI_SPECIFIC_CODE(
template <typename Container, typename Type>
void vectorIndexImpl(const Container & data, const PaddedPODArray<Type> & indexes, size_t limit, Container & res_data)
{
static constexpr UInt64 MASK64 = 0xffffffffffffffff;
const size_t limit64 = limit & ~63;
size_t pos = 0;
size_t data_size = data.size();
auto data_pos = reinterpret_cast<const UInt8 *>(data.data());
auto indexes_pos = reinterpret_cast<const UInt8 *>(indexes.data());
auto res_pos = reinterpret_cast<UInt8 *>(res_data.data());
if (limit == 0)
return; /// nothing to do, just return
if (data_size <= 64)
{
/// one single mask load for table size <= 64
__mmask64 last_mask = MASK64 >> (64 - data_size);
__m512i table1 = _mm512_maskz_loadu_epi8(last_mask, data_pos);
/// 64 bytes table lookup using one single permutexvar_epi8
while (pos < limit64)
{
__m512i vidx = _mm512_loadu_epi8(indexes_pos + pos);
__m512i out = _mm512_permutexvar_epi8(vidx, table1);
_mm512_storeu_epi8(res_pos + pos, out);
pos += 64;
}
/// tail handling
if (limit > limit64)
{
__mmask64 tail_mask = MASK64 >> (limit64 + 64 - limit);
__m512i vidx = _mm512_maskz_loadu_epi8(tail_mask, indexes_pos + pos);
__m512i out = _mm512_permutexvar_epi8(vidx, table1);
_mm512_mask_storeu_epi8(res_pos + pos, tail_mask, out);
}
}
else if (data_size <= 128)
{
/// table size (64, 128] requires 2 zmm load
__mmask64 last_mask = MASK64 >> (128 - data_size);
__m512i table1 = _mm512_loadu_epi8(data_pos);
__m512i table2 = _mm512_maskz_loadu_epi8(last_mask, data_pos + 64);
/// 128 bytes table lookup using one single permute2xvar_epi8
while (pos < limit64)
{
__m512i vidx = _mm512_loadu_epi8(indexes_pos + pos);
__m512i out = _mm512_permutex2var_epi8(table1, vidx, table2);
_mm512_storeu_epi8(res_pos + pos, out);
pos += 64;
}
if (limit > limit64)
{
__mmask64 tail_mask = MASK64 >> (limit64 + 64 - limit);
__m512i vidx = _mm512_maskz_loadu_epi8(tail_mask, indexes_pos + pos);
__m512i out = _mm512_permutex2var_epi8(table1, vidx, table2);
_mm512_mask_storeu_epi8(res_pos + pos, tail_mask, out);
}
}
else
{
if (data_size > 256)
{
/// byte index will not exceed 256 boundary.
data_size = 256;
}
__m512i table1 = _mm512_loadu_epi8(data_pos);
__m512i table2 = _mm512_loadu_epi8(data_pos + 64);
__m512i table3, table4;
if (data_size <= 192)
{
/// only 3 tables need to load if size <= 192
__mmask64 last_mask = MASK64 >> (192 - data_size);
table3 = _mm512_maskz_loadu_epi8(last_mask, data_pos + 128);
table4 = _mm512_setzero_si512();
}
else
{
__mmask64 last_mask = MASK64 >> (256 - data_size);
table3 = _mm512_loadu_epi8(data_pos + 128);
table4 = _mm512_maskz_loadu_epi8(last_mask, data_pos + 192);
}
/// 256 bytes table lookup can use: 2 permute2xvar_epi8 plus 1 blender with MSB
while (pos < limit64)
{
__m512i vidx = _mm512_loadu_epi8(indexes_pos + pos);
__m512i tmp1 = _mm512_permutex2var_epi8(table1, vidx, table2);
__m512i tmp2 = _mm512_permutex2var_epi8(table3, vidx, table4);
__mmask64 msb = _mm512_movepi8_mask(vidx);
__m512i out = _mm512_mask_blend_epi8(msb, tmp1, tmp2);
_mm512_storeu_epi8(res_pos + pos, out);
pos += 64;
}
if (limit > limit64)
{
__mmask64 tail_mask = MASK64 >> (limit64 + 64 - limit);
__m512i vidx = _mm512_maskz_loadu_epi8(tail_mask, indexes_pos + pos);
__m512i tmp1 = _mm512_permutex2var_epi8(table1, vidx, table2);
__m512i tmp2 = _mm512_permutex2var_epi8(table3, vidx, table4);
__mmask64 msb = _mm512_movepi8_mask(vidx);
__m512i out = _mm512_mask_blend_epi8(msb, tmp1, tmp2);
_mm512_mask_storeu_epi8(res_pos + pos, tail_mask, out);
}
}
}
);
template <typename T>
template <typename Type>
ColumnPtr ColumnVector<T>::indexImpl(const PaddedPODArray<Type> & indexes, size_t limit) const
{
chassert(limit <= indexes.size());
auto res = this->create(limit);
typename Self::Container & res_data = res->getData();
#if USE_MULTITARGET_CODE
if constexpr (sizeof(T) == 1 && sizeof(Type) == 1)
{
/// VBMI optimization only applicable for (U)Int8 types
if (isArchSupported(TargetArch::AVX512VBMI))
{
TargetSpecific::AVX512VBMI::vectorIndexImpl<Container, Type>(data, indexes, limit, res_data);
return res;
}
}
#endif
TargetSpecific::Default::vectorIndexImpl<Container, Type>(data, indexes, limit, res_data);
return res;
}
/// Explicit template instantiations - to avoid code bloat in headers. /// Explicit template instantiations - to avoid code bloat in headers.
template class ColumnVector<UInt8>; template class ColumnVector<UInt8>;
template class ColumnVector<UInt16>; template class ColumnVector<UInt16>;
@ -1012,4 +1159,17 @@ template class ColumnVector<UUID>;
template class ColumnVector<IPv4>; template class ColumnVector<IPv4>;
template class ColumnVector<IPv6>; template class ColumnVector<IPv6>;
INSTANTIATE_INDEX_TEMPLATE_IMPL(ColumnVector)
/// Used by ColumnVariant.cpp
template ColumnPtr ColumnVector<UInt8>::indexImpl<UInt16>(const PaddedPODArray<UInt16> & indexes, size_t limit) const;
template ColumnPtr ColumnVector<UInt8>::indexImpl<UInt32>(const PaddedPODArray<UInt32> & indexes, size_t limit) const;
template ColumnPtr ColumnVector<UInt8>::indexImpl<UInt64>(const PaddedPODArray<UInt64> & indexes, size_t limit) const;
template ColumnPtr ColumnVector<UInt64>::indexImpl<UInt8>(const PaddedPODArray<UInt8> & indexes, size_t limit) const;
template ColumnPtr ColumnVector<UInt64>::indexImpl<UInt16>(const PaddedPODArray<UInt16> & indexes, size_t limit) const;
template ColumnPtr ColumnVector<UInt64>::indexImpl<UInt32>(const PaddedPODArray<UInt32> & indexes, size_t limit) const;
#if defined(OS_DARWIN)
template ColumnPtr ColumnVector<UInt8>::indexImpl<size_t>(const PaddedPODArray<size_t> & indexes, size_t limit) const;
template ColumnPtr ColumnVector<UInt64>::indexImpl<size_t>(const PaddedPODArray<size_t> & indexes, size_t limit) const;
#endif
} }

View File

@ -13,10 +13,6 @@
#include "config.h" #include "config.h"
#if USE_MULTITARGET_CODE
# include <immintrin.h>
#endif
namespace DB namespace DB
{ {
@ -320,151 +316,6 @@ protected:
Container data; Container data;
}; };
DECLARE_DEFAULT_CODE(
template <typename Container, typename Type>
inline void vectorIndexImpl(const Container & data, const PaddedPODArray<Type> & indexes, size_t limit, Container & res_data)
{
for (size_t i = 0; i < limit; ++i)
res_data[i] = data[indexes[i]];
}
);
DECLARE_AVX512VBMI_SPECIFIC_CODE(
template <typename Container, typename Type>
inline void vectorIndexImpl(const Container & data, const PaddedPODArray<Type> & indexes, size_t limit, Container & res_data)
{
static constexpr UInt64 MASK64 = 0xffffffffffffffff;
const size_t limit64 = limit & ~63;
size_t pos = 0;
size_t data_size = data.size();
auto data_pos = reinterpret_cast<const UInt8 *>(data.data());
auto indexes_pos = reinterpret_cast<const UInt8 *>(indexes.data());
auto res_pos = reinterpret_cast<UInt8 *>(res_data.data());
if (limit == 0)
return; /// nothing to do, just return
if (data_size <= 64)
{
/// one single mask load for table size <= 64
__mmask64 last_mask = MASK64 >> (64 - data_size);
__m512i table1 = _mm512_maskz_loadu_epi8(last_mask, data_pos);
/// 64 bytes table lookup using one single permutexvar_epi8
while (pos < limit64)
{
__m512i vidx = _mm512_loadu_epi8(indexes_pos + pos);
__m512i out = _mm512_permutexvar_epi8(vidx, table1);
_mm512_storeu_epi8(res_pos + pos, out);
pos += 64;
}
/// tail handling
if (limit > limit64)
{
__mmask64 tail_mask = MASK64 >> (limit64 + 64 - limit);
__m512i vidx = _mm512_maskz_loadu_epi8(tail_mask, indexes_pos + pos);
__m512i out = _mm512_permutexvar_epi8(vidx, table1);
_mm512_mask_storeu_epi8(res_pos + pos, tail_mask, out);
}
}
else if (data_size <= 128)
{
/// table size (64, 128] requires 2 zmm load
__mmask64 last_mask = MASK64 >> (128 - data_size);
__m512i table1 = _mm512_loadu_epi8(data_pos);
__m512i table2 = _mm512_maskz_loadu_epi8(last_mask, data_pos + 64);
/// 128 bytes table lookup using one single permute2xvar_epi8
while (pos < limit64)
{
__m512i vidx = _mm512_loadu_epi8(indexes_pos + pos);
__m512i out = _mm512_permutex2var_epi8(table1, vidx, table2);
_mm512_storeu_epi8(res_pos + pos, out);
pos += 64;
}
if (limit > limit64)
{
__mmask64 tail_mask = MASK64 >> (limit64 + 64 - limit);
__m512i vidx = _mm512_maskz_loadu_epi8(tail_mask, indexes_pos + pos);
__m512i out = _mm512_permutex2var_epi8(table1, vidx, table2);
_mm512_mask_storeu_epi8(res_pos + pos, tail_mask, out);
}
}
else
{
if (data_size > 256)
{
/// byte index will not exceed 256 boundary.
data_size = 256;
}
__m512i table1 = _mm512_loadu_epi8(data_pos);
__m512i table2 = _mm512_loadu_epi8(data_pos + 64);
__m512i table3, table4;
if (data_size <= 192)
{
/// only 3 tables need to load if size <= 192
__mmask64 last_mask = MASK64 >> (192 - data_size);
table3 = _mm512_maskz_loadu_epi8(last_mask, data_pos + 128);
table4 = _mm512_setzero_si512();
}
else
{
__mmask64 last_mask = MASK64 >> (256 - data_size);
table3 = _mm512_loadu_epi8(data_pos + 128);
table4 = _mm512_maskz_loadu_epi8(last_mask, data_pos + 192);
}
/// 256 bytes table lookup can use: 2 permute2xvar_epi8 plus 1 blender with MSB
while (pos < limit64)
{
__m512i vidx = _mm512_loadu_epi8(indexes_pos + pos);
__m512i tmp1 = _mm512_permutex2var_epi8(table1, vidx, table2);
__m512i tmp2 = _mm512_permutex2var_epi8(table3, vidx, table4);
__mmask64 msb = _mm512_movepi8_mask(vidx);
__m512i out = _mm512_mask_blend_epi8(msb, tmp1, tmp2);
_mm512_storeu_epi8(res_pos + pos, out);
pos += 64;
}
if (limit > limit64)
{
__mmask64 tail_mask = MASK64 >> (limit64 + 64 - limit);
__m512i vidx = _mm512_maskz_loadu_epi8(tail_mask, indexes_pos + pos);
__m512i tmp1 = _mm512_permutex2var_epi8(table1, vidx, table2);
__m512i tmp2 = _mm512_permutex2var_epi8(table3, vidx, table4);
__mmask64 msb = _mm512_movepi8_mask(vidx);
__m512i out = _mm512_mask_blend_epi8(msb, tmp1, tmp2);
_mm512_mask_storeu_epi8(res_pos + pos, tail_mask, out);
}
}
}
);
template <typename T>
template <typename Type>
ColumnPtr ColumnVector<T>::indexImpl(const PaddedPODArray<Type> & indexes, size_t limit) const
{
assert(limit <= indexes.size());
auto res = this->create(limit);
typename Self::Container & res_data = res->getData();
#if USE_MULTITARGET_CODE
if constexpr (sizeof(T) == 1 && sizeof(Type) == 1)
{
/// VBMI optimization only applicable for (U)Int8 types
if (isArchSupported(TargetArch::AVX512VBMI))
{
TargetSpecific::AVX512VBMI::vectorIndexImpl<Container, Type>(data, indexes, limit, res_data);
return res;
}
}
#endif
TargetSpecific::Default::vectorIndexImpl<Container, Type>(data, indexes, limit, res_data);
return res;
}
template <class TCol> template <class TCol>
concept is_col_vector = std::is_same_v<TCol, ColumnVector<typename TCol::ValueType>>; concept is_col_vector = std::is_same_v<TCol, ColumnVector<typename TCol::ValueType>>;

View File

@ -142,4 +142,10 @@ ColumnPtr permuteImpl(const Column & column, const IColumn::Permutation & perm,
template ColumnPtr Column::indexImpl<UInt16>(const PaddedPODArray<UInt16> & indexes, size_t limit) const; \ template ColumnPtr Column::indexImpl<UInt16>(const PaddedPODArray<UInt16> & indexes, size_t limit) const; \
template ColumnPtr Column::indexImpl<UInt32>(const PaddedPODArray<UInt32> & indexes, size_t limit) const; \ template ColumnPtr Column::indexImpl<UInt32>(const PaddedPODArray<UInt32> & indexes, size_t limit) const; \
template ColumnPtr Column::indexImpl<UInt64>(const PaddedPODArray<UInt64> & indexes, size_t limit) const; template ColumnPtr Column::indexImpl<UInt64>(const PaddedPODArray<UInt64> & indexes, size_t limit) const;
#define INSTANTIATE_INDEX_TEMPLATE_IMPL(ColumnTemplate) \
template ColumnPtr ColumnTemplate<UInt8>::indexImpl<UInt8>(const PaddedPODArray<UInt8> & indexes, size_t limit) const; \
template ColumnPtr ColumnTemplate<UInt16>::indexImpl<UInt16>(const PaddedPODArray<UInt16> & indexes, size_t limit) const; \
template ColumnPtr ColumnTemplate<UInt32>::indexImpl<UInt32>(const PaddedPODArray<UInt32> & indexes, size_t limit) const; \
template ColumnPtr ColumnTemplate<UInt64>::indexImpl<UInt64>(const PaddedPODArray<UInt64> & indexes, size_t limit) const;
} }

View File

@ -2,6 +2,7 @@
#include <base/arithmeticOverflow.h> #include <base/arithmeticOverflow.h>
#include <Core/Block.h> #include <Core/Block.h>
#include <Core/DecimalFunctions.h>
#include <Core/AccurateComparison.h> #include <Core/AccurateComparison.h>
#include <Core/callOnTypeIndex.h> #include <Core/callOnTypeIndex.h>
#include <DataTypes/DataTypesNumber.h> #include <DataTypes/DataTypesNumber.h>
@ -52,8 +53,8 @@ struct DecCompareInt
using TypeB = Type; using TypeB = Type;
}; };
template <typename A, typename B, template <typename, typename> typename Operation, bool _check_overflow = true, template <typename A, typename B, template <typename, typename> typename Operation>
bool _actual = is_decimal<A> || is_decimal<B>> requires is_decimal<A> || is_decimal<B>
class DecimalComparison class DecimalComparison
{ {
public: public:
@ -65,20 +66,17 @@ public:
using ArrayA = typename ColVecA::Container; using ArrayA = typename ColVecA::Container;
using ArrayB = typename ColVecB::Container; using ArrayB = typename ColVecB::Container;
static ColumnPtr apply(const ColumnWithTypeAndName & col_left, const ColumnWithTypeAndName & col_right) static ColumnPtr apply(const ColumnWithTypeAndName & col_left, const ColumnWithTypeAndName & col_right, bool check_overflow)
{ {
if constexpr (_actual) ColumnPtr c_res;
{ Shift shift = getScales<A, B>(col_left.type, col_right.type);
ColumnPtr c_res;
Shift shift = getScales<A, B>(col_left.type, col_right.type);
return applyWithScale(col_left.column, col_right.column, shift); if (check_overflow)
} return applyWithScale<true>(col_left.column, col_right.column, shift);
else return applyWithScale<false>(col_left.column, col_right.column, shift);
return nullptr;
} }
static bool compare(A a, B b, UInt32 scale_a, UInt32 scale_b) static bool compare(A a, B b, UInt32 scale_a, UInt32 scale_b, bool check_overflow)
{ {
static const UInt32 max_scale = DecimalUtils::max_precision<Decimal256>; static const UInt32 max_scale = DecimalUtils::max_precision<Decimal256>;
if (scale_a > max_scale || scale_b > max_scale) if (scale_a > max_scale || scale_b > max_scale)
@ -90,7 +88,9 @@ public:
if (scale_a > scale_b) if (scale_a > scale_b)
shift.b = static_cast<CompareInt>(DecimalUtils::scaleMultiplier<A>(scale_a - scale_b)); shift.b = static_cast<CompareInt>(DecimalUtils::scaleMultiplier<A>(scale_a - scale_b));
return applyWithScale(a, b, shift); if (check_overflow)
return applyWithScale<true>(a, b, shift);
return applyWithScale<false>(a, b, shift);
} }
private: private:
@ -104,14 +104,14 @@ private:
bool right() const { return b != 1; } bool right() const { return b != 1; }
}; };
template <typename T, typename U> template <bool check_overflow, typename T, typename U>
static auto applyWithScale(T a, U b, const Shift & shift) static auto applyWithScale(T a, U b, const Shift & shift)
{ {
if (shift.left()) if (shift.left())
return apply<true, false>(a, b, shift.a); return apply<check_overflow, true, false>(a, b, shift.a);
if (shift.right()) if (shift.right())
return apply<false, true>(a, b, shift.b); return apply<check_overflow, false, true>(a, b, shift.b);
return apply<false, false>(a, b, 1); return apply<check_overflow, false, false>(a, b, 1);
} }
template <typename T, typename U> template <typename T, typename U>
@ -125,8 +125,8 @@ private:
if (decimal0 && decimal1) if (decimal0 && decimal1)
{ {
auto result_type = DecimalUtils::binaryOpResult<false, false>(*decimal0, *decimal1); auto result_type = DecimalUtils::binaryOpResult<false, false>(*decimal0, *decimal1);
shift.a = static_cast<CompareInt>(result_type.scaleFactorFor(decimal0->getTrait(), false).value); shift.a = static_cast<CompareInt>(result_type.scaleFactorFor(DecimalUtils::DataTypeDecimalTrait<T>{decimal0->getPrecision(), decimal0->getScale()}, false).value);
shift.b = static_cast<CompareInt>(result_type.scaleFactorFor(decimal1->getTrait(), false).value); shift.b = static_cast<CompareInt>(result_type.scaleFactorFor(DecimalUtils::DataTypeDecimalTrait<U>{decimal1->getPrecision(), decimal1->getScale()}, false).value);
} }
else if (decimal0) else if (decimal0)
shift.b = static_cast<CompareInt>(decimal0->getScaleMultiplier().value); shift.b = static_cast<CompareInt>(decimal0->getScaleMultiplier().value);
@ -158,66 +158,63 @@ private:
return shift; return shift;
} }
template <bool scale_left, bool scale_right> template <bool check_overflow, bool scale_left, bool scale_right>
static ColumnPtr apply(const ColumnPtr & c0, const ColumnPtr & c1, CompareInt scale) static ColumnPtr apply(const ColumnPtr & c0, const ColumnPtr & c1, CompareInt scale)
{ {
auto c_res = ColumnUInt8::create(); auto c_res = ColumnUInt8::create();
if constexpr (_actual) bool c0_is_const = isColumnConst(*c0);
bool c1_is_const = isColumnConst(*c1);
if (c0_is_const && c1_is_const)
{ {
bool c0_is_const = isColumnConst(*c0); const ColumnConst & c0_const = checkAndGetColumnConst<ColVecA>(*c0);
bool c1_is_const = isColumnConst(*c1); const ColumnConst & c1_const = checkAndGetColumnConst<ColVecB>(*c1);
if (c0_is_const && c1_is_const) A a = c0_const.template getValue<A>();
B b = c1_const.template getValue<B>();
UInt8 res = apply<check_overflow, scale_left, scale_right>(a, b, scale);
return DataTypeUInt8().createColumnConst(c0->size(), toField(res));
}
ColumnUInt8::Container & vec_res = c_res->getData();
vec_res.resize(c0->size());
if (c0_is_const)
{
const ColumnConst & c0_const = checkAndGetColumnConst<ColVecA>(*c0);
A a = c0_const.template getValue<A>();
if (const ColVecB * c1_vec = checkAndGetColumn<ColVecB>(c1.get()))
constantVector<check_overflow, scale_left, scale_right>(a, c1_vec->getData(), vec_res, scale);
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong column in Decimal comparison");
}
else if (c1_is_const)
{
const ColumnConst & c1_const = checkAndGetColumnConst<ColVecB>(*c1);
B b = c1_const.template getValue<B>();
if (const ColVecA * c0_vec = checkAndGetColumn<ColVecA>(c0.get()))
vectorConstant<check_overflow, scale_left, scale_right>(c0_vec->getData(), b, vec_res, scale);
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong column in Decimal comparison");
}
else
{
if (const ColVecA * c0_vec = checkAndGetColumn<ColVecA>(c0.get()))
{ {
const ColumnConst & c0_const = checkAndGetColumnConst<ColVecA>(*c0);
const ColumnConst & c1_const = checkAndGetColumnConst<ColVecB>(*c1);
A a = c0_const.template getValue<A>();
B b = c1_const.template getValue<B>();
UInt8 res = apply<scale_left, scale_right>(a, b, scale);
return DataTypeUInt8().createColumnConst(c0->size(), toField(res));
}
ColumnUInt8::Container & vec_res = c_res->getData();
vec_res.resize(c0->size());
if (c0_is_const)
{
const ColumnConst & c0_const = checkAndGetColumnConst<ColVecA>(*c0);
A a = c0_const.template getValue<A>();
if (const ColVecB * c1_vec = checkAndGetColumn<ColVecB>(c1.get())) if (const ColVecB * c1_vec = checkAndGetColumn<ColVecB>(c1.get()))
constantVector<scale_left, scale_right>(a, c1_vec->getData(), vec_res, scale); vectorVector<check_overflow, scale_left, scale_right>(c0_vec->getData(), c1_vec->getData(), vec_res, scale);
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong column in Decimal comparison");
}
else if (c1_is_const)
{
const ColumnConst & c1_const = checkAndGetColumnConst<ColVecB>(*c1);
B b = c1_const.template getValue<B>();
if (const ColVecA * c0_vec = checkAndGetColumn<ColVecA>(c0.get()))
vectorConstant<scale_left, scale_right>(c0_vec->getData(), b, vec_res, scale);
else else
throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong column in Decimal comparison"); throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong column in Decimal comparison");
} }
else else
{ throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong column in Decimal comparison");
if (const ColVecA * c0_vec = checkAndGetColumn<ColVecA>(c0.get()))
{
if (const ColVecB * c1_vec = checkAndGetColumn<ColVecB>(c1.get()))
vectorVector<scale_left, scale_right>(c0_vec->getData(), c1_vec->getData(), vec_res, scale);
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong column in Decimal comparison");
}
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong column in Decimal comparison");
}
} }
return c_res; return c_res;
} }
template <bool scale_left, bool scale_right> template <bool check_overflow, bool scale_left, bool scale_right>
static NO_INLINE UInt8 apply(A a, B b, CompareInt scale [[maybe_unused]]) static NO_INLINE UInt8 apply(A a, B b, CompareInt scale [[maybe_unused]])
{ {
CompareInt x; CompareInt x;
@ -232,7 +229,7 @@ private:
else else
y = static_cast<CompareInt>(b); y = static_cast<CompareInt>(b);
if constexpr (_check_overflow) if constexpr (check_overflow)
{ {
bool overflow = false; bool overflow = false;
@ -264,9 +261,8 @@ private:
return Op::apply(x, y); return Op::apply(x, y);
} }
template <bool scale_left, bool scale_right> template <bool check_overflow, bool scale_left, bool scale_right>
static void NO_INLINE vectorVector(const ArrayA & a, const ArrayB & b, PaddedPODArray<UInt8> & c, static void NO_INLINE vectorVector(const ArrayA & a, const ArrayB & b, PaddedPODArray<UInt8> & c, CompareInt scale)
CompareInt scale)
{ {
size_t size = a.size(); size_t size = a.size();
const A * a_pos = a.data(); const A * a_pos = a.data();
@ -276,14 +272,14 @@ private:
while (a_pos < a_end) while (a_pos < a_end)
{ {
*c_pos = apply<scale_left, scale_right>(*a_pos, *b_pos, scale); *c_pos = apply<check_overflow, scale_left, scale_right>(*a_pos, *b_pos, scale);
++a_pos; ++a_pos;
++b_pos; ++b_pos;
++c_pos; ++c_pos;
} }
} }
template <bool scale_left, bool scale_right> template <bool check_overflow, bool scale_left, bool scale_right>
static void NO_INLINE vectorConstant(const ArrayA & a, B b, PaddedPODArray<UInt8> & c, CompareInt scale) static void NO_INLINE vectorConstant(const ArrayA & a, B b, PaddedPODArray<UInt8> & c, CompareInt scale)
{ {
size_t size = a.size(); size_t size = a.size();
@ -293,13 +289,13 @@ private:
while (a_pos < a_end) while (a_pos < a_end)
{ {
*c_pos = apply<scale_left, scale_right>(*a_pos, b, scale); *c_pos = apply<check_overflow, scale_left, scale_right>(*a_pos, b, scale);
++a_pos; ++a_pos;
++c_pos; ++c_pos;
} }
} }
template <bool scale_left, bool scale_right> template <bool check_overflow, bool scale_left, bool scale_right>
static void NO_INLINE constantVector(A a, const ArrayB & b, PaddedPODArray<UInt8> & c, CompareInt scale) static void NO_INLINE constantVector(A a, const ArrayB & b, PaddedPODArray<UInt8> & c, CompareInt scale)
{ {
size_t size = b.size(); size_t size = b.size();
@ -309,7 +305,7 @@ private:
while (b_pos < b_end) while (b_pos < b_end)
{ {
*c_pos = apply<scale_left, scale_right>(a, *b_pos, scale); *c_pos = apply<check_overflow, scale_left, scale_right>(a, *b_pos, scale);
++b_pos; ++b_pos;
++c_pos; ++c_pos;
} }

View File

@ -529,22 +529,25 @@ Field Field::restoreFromDump(std::string_view dump_)
template <typename T> template <typename T>
bool decimalEqual(T x, T y, UInt32 x_scale, UInt32 y_scale) bool decimalEqual(T x, T y, UInt32 x_scale, UInt32 y_scale)
{ {
bool check_overflow = true;
using Comparator = DecimalComparison<T, T, EqualsOp>; using Comparator = DecimalComparison<T, T, EqualsOp>;
return Comparator::compare(x, y, x_scale, y_scale); return Comparator::compare(x, y, x_scale, y_scale, check_overflow);
} }
template <typename T> template <typename T>
bool decimalLess(T x, T y, UInt32 x_scale, UInt32 y_scale) bool decimalLess(T x, T y, UInt32 x_scale, UInt32 y_scale)
{ {
bool check_overflow = true;
using Comparator = DecimalComparison<T, T, LessOp>; using Comparator = DecimalComparison<T, T, LessOp>;
return Comparator::compare(x, y, x_scale, y_scale); return Comparator::compare(x, y, x_scale, y_scale, check_overflow);
} }
template <typename T> template <typename T>
bool decimalLessOrEqual(T x, T y, UInt32 x_scale, UInt32 y_scale) bool decimalLessOrEqual(T x, T y, UInt32 x_scale, UInt32 y_scale)
{ {
bool check_overflow = true;
using Comparator = DecimalComparison<T, T, LessOrEqualsOp>; using Comparator = DecimalComparison<T, T, LessOrEqualsOp>;
return Comparator::compare(x, y, x_scale, y_scale); return Comparator::compare(x, y, x_scale, y_scale, check_overflow);
} }

View File

@ -863,6 +863,9 @@ template <> struct Field::EnumToType<Field::Types::AggregateFunctionState> { usi
template <> struct Field::EnumToType<Field::Types::CustomType> { using Type = CustomType; }; template <> struct Field::EnumToType<Field::Types::CustomType> { using Type = CustomType; };
template <> struct Field::EnumToType<Field::Types::Bool> { using Type = UInt64; }; template <> struct Field::EnumToType<Field::Types::Bool> { using Type = UInt64; };
/// Use it to prevent inclusion of magic_enum in headers, which is very expensive for the compiler
std::string_view fieldTypeToString(Field::Types::Which type);
constexpr bool isInt64OrUInt64FieldType(Field::Types::Which t) constexpr bool isInt64OrUInt64FieldType(Field::Types::Which t)
{ {
return t == Field::Types::Int64 return t == Field::Types::Int64
@ -886,7 +889,7 @@ auto & Field::safeGet()
if (target != which && if (target != which &&
!(which == Field::Types::Bool && (target == Field::Types::UInt64 || target == Field::Types::Int64)) && !(which == Field::Types::Bool && (target == Field::Types::UInt64 || target == Field::Types::Int64)) &&
!(isInt64OrUInt64FieldType(which) && isInt64OrUInt64FieldType(target))) !(isInt64OrUInt64FieldType(which) && isInt64OrUInt64FieldType(target)))
throw Exception(ErrorCodes::BAD_GET, "Bad get: has {}, requested {}", getTypeName(), target); throw Exception(ErrorCodes::BAD_GET, "Bad get: has {}, requested {}", getTypeName(), fieldTypeToString(target));
return get<T>(); return get<T>();
} }
@ -1002,8 +1005,6 @@ void readQuoted(DecimalField<T> & x, ReadBuffer & buf);
void writeFieldText(const Field & x, WriteBuffer & buf); void writeFieldText(const Field & x, WriteBuffer & buf);
String toString(const Field & x); String toString(const Field & x);
std::string_view fieldTypeToString(Field::Types::Which type);
} }
template <> template <>

View File

@ -87,6 +87,77 @@ static bool callOnBasicType(TypeIndex number, F && f)
return false; return false;
} }
template <typename T, bool _int, bool _float, bool _decimal, bool _datetime, typename F>
static bool callOnBasicTypeSecondArg(TypeIndex number, F && f)
{
if constexpr (_int)
{
switch (number)
{
case TypeIndex::UInt8: return f(TypePair<UInt8, T>());
case TypeIndex::UInt16: return f(TypePair<UInt16, T>());
case TypeIndex::UInt32: return f(TypePair<UInt32, T>());
case TypeIndex::UInt64: return f(TypePair<UInt64, T>());
case TypeIndex::UInt128: return f(TypePair<UInt128, T>());
case TypeIndex::UInt256: return f(TypePair<UInt256, T>());
case TypeIndex::Int8: return f(TypePair<Int8, T>());
case TypeIndex::Int16: return f(TypePair<Int16, T>());
case TypeIndex::Int32: return f(TypePair<Int32, T>());
case TypeIndex::Int64: return f(TypePair<Int64, T>());
case TypeIndex::Int128: return f(TypePair<Int128, T>());
case TypeIndex::Int256: return f(TypePair<Int256, T>());
case TypeIndex::Enum8: return f(TypePair<Int8, T>());
case TypeIndex::Enum16: return f(TypePair<Int16, T>());
default:
break;
}
}
if constexpr (_decimal)
{
switch (number)
{
case TypeIndex::Decimal32: return f(TypePair<Decimal32, T>());
case TypeIndex::Decimal64: return f(TypePair<Decimal64, T>());
case TypeIndex::Decimal128: return f(TypePair<Decimal128, T>());
case TypeIndex::Decimal256: return f(TypePair<Decimal256, T>());
default:
break;
}
}
if constexpr (_float)
{
switch (number)
{
case TypeIndex::BFloat16: return f(TypePair<BFloat16, T>());
case TypeIndex::Float32: return f(TypePair<Float32, T>());
case TypeIndex::Float64: return f(TypePair<Float64, T>());
default:
break;
}
}
if constexpr (_datetime)
{
switch (number)
{
case TypeIndex::Date: return f(TypePair<UInt16, T>());
case TypeIndex::Date32: return f(TypePair<Int32, T>());
case TypeIndex::DateTime: return f(TypePair<UInt32, T>());
case TypeIndex::DateTime64: return f(TypePair<DateTime64, T>());
default:
break;
}
}
return false;
}
/// Unroll template using TypeIndex /// Unroll template using TypeIndex
template <bool _int, bool _float, bool _decimal, bool _datetime, typename F> template <bool _int, bool _float, bool _decimal, bool _datetime, typename F>
static inline bool callOnBasicTypes(TypeIndex type_num1, TypeIndex type_num2, F && f) static inline bool callOnBasicTypes(TypeIndex type_num1, TypeIndex type_num2, F && f)

View File

@ -1,7 +1,8 @@
#include <type_traits>
#include <Core/DecimalFunctions.h>
#include <Core/Settings.h> #include <Core/Settings.h>
#include <DataTypes/DataTypeDecimalBase.h> #include <DataTypes/DataTypeDecimalBase.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <type_traits>
namespace DB namespace DB
{ {
@ -14,6 +15,12 @@ namespace ErrorCodes
{ {
} }
template <is_decimal T>
constexpr size_t DataTypeDecimalBase<T>::maxPrecision()
{
return DecimalUtils::max_precision<T>;
}
bool decimalCheckComparisonOverflow(ContextPtr context) bool decimalCheckComparisonOverflow(ContextPtr context)
{ {
return context->getSettingsRef()[Setting::decimal_check_overflow]; return context->getSettingsRef()[Setting::decimal_check_overflow];
@ -41,6 +48,18 @@ T DataTypeDecimalBase<T>::getScaleMultiplier(UInt32 scale_)
return DecimalUtils::scaleMultiplier<typename T::NativeType>(scale_); return DecimalUtils::scaleMultiplier<typename T::NativeType>(scale_);
} }
template <is_decimal T>
T DataTypeDecimalBase<T>::wholePart(T x) const
{
return DecimalUtils::getWholePart(x, scale);
}
template <is_decimal T>
T DataTypeDecimalBase<T>::fractionalPart(T x) const
{
return DecimalUtils::getFractionalPart(x, scale);
}
/// Explicit template instantiations. /// Explicit template instantiations.
template class DataTypeDecimalBase<Decimal32>; template class DataTypeDecimalBase<Decimal32>;

View File

@ -3,11 +3,10 @@
#include <cmath> #include <cmath>
#include <type_traits> #include <type_traits>
#include <Core/TypeId.h>
#include <Core/DecimalFunctions.h>
#include <Columns/ColumnDecimal.h> #include <Columns/ColumnDecimal.h>
#include <DataTypes/IDataType.h> #include <Core/TypeId.h>
#include <DataTypes/DataTypesNumber.h> #include <DataTypes/DataTypesNumber.h>
#include <DataTypes/IDataType.h>
#include <Interpreters/Context_fwd.h> #include <Interpreters/Context_fwd.h>
@ -64,7 +63,7 @@ public:
static constexpr bool is_parametric = true; static constexpr bool is_parametric = true;
static constexpr size_t maxPrecision() { return DecimalUtils::max_precision<T>; } static constexpr size_t maxPrecision();
DataTypeDecimalBase(UInt32 precision_, UInt32 scale_) DataTypeDecimalBase(UInt32 precision_, UInt32 scale_)
: precision(precision_), : precision(precision_),
@ -104,15 +103,8 @@ public:
UInt32 getScale() const { return scale; } UInt32 getScale() const { return scale; }
T getScaleMultiplier() const { return getScaleMultiplier(scale); } T getScaleMultiplier() const { return getScaleMultiplier(scale); }
T wholePart(T x) const T wholePart(T x) const;
{ T fractionalPart(T x) const;
return DecimalUtils::getWholePart(x, scale);
}
T fractionalPart(T x) const
{
return DecimalUtils::getFractionalPart(x, scale);
}
T maxWholeValue() const { return getScaleMultiplier(precision - scale) - T(1); } T maxWholeValue() const { return getScaleMultiplier(precision - scale) - T(1); }
@ -147,11 +139,6 @@ public:
static T getScaleMultiplier(UInt32 scale); static T getScaleMultiplier(UInt32 scale);
DecimalUtils::DataTypeDecimalTrait<T> getTrait() const
{
return {precision, scale};
}
protected: protected:
const UInt32 precision; const UInt32 precision;
const UInt32 scale; const UInt32 scale;
@ -167,50 +154,35 @@ inline const DataTypeDecimalBase<T> * checkDecimalBase(const IDataType & data_ty
return nullptr; return nullptr;
} }
template <bool is_multiply, bool is_division, typename T, typename U, template <typename> typename DecimalType> template <> constexpr size_t DataTypeDecimalBase<Decimal32>::maxPrecision() { return 9; };
inline auto decimalResultType(const DecimalType<T> & tx, const DecimalType<U> & ty) template <> constexpr size_t DataTypeDecimalBase<Decimal64>::maxPrecision() { return 18; };
{ template <> constexpr size_t DataTypeDecimalBase<DateTime64>::maxPrecision() { return 18; };
const auto result_trait = DecimalUtils::binaryOpResult<is_multiply, is_division>(tx, ty); template <> constexpr size_t DataTypeDecimalBase<Decimal128>::maxPrecision() { return 38; };
return DecimalType<typename decltype(result_trait)::FieldType>(result_trait.precision, result_trait.scale); template <> constexpr size_t DataTypeDecimalBase<Decimal256>::maxPrecision() { return 76; };
}
template <bool is_multiply, bool is_division, typename T, typename U, template <typename> typename DecimalType> extern template class DataTypeDecimalBase<Decimal32>;
inline DecimalType<T> decimalResultType(const DecimalType<T> & tx, const DataTypeNumber<U> & ty) extern template class DataTypeDecimalBase<Decimal64>;
{ extern template class DataTypeDecimalBase<DateTime64>;
const auto result_trait = DecimalUtils::binaryOpResult<is_multiply, is_division>(tx, ty); extern template class DataTypeDecimalBase<Decimal128>;
return DecimalType<typename decltype(result_trait)::FieldType>(result_trait.precision, result_trait.scale); extern template class DataTypeDecimalBase<Decimal256>;
}
template <bool is_multiply, bool is_division, typename T, typename U, template <typename> typename DecimalType>
inline DecimalType<U> decimalResultType(const DataTypeNumber<T> & tx, const DecimalType<U> & ty)
{
const auto result_trait = DecimalUtils::binaryOpResult<is_multiply, is_division>(tx, ty);
return DecimalType<typename decltype(result_trait)::FieldType>(result_trait.precision, result_trait.scale);
}
template <template <typename> typename DecimalType> template <template <typename> typename DecimalType>
inline DataTypePtr createDecimal(UInt64 precision_value, UInt64 scale_value) inline DataTypePtr createDecimal(UInt64 precision_value, UInt64 scale_value)
{ {
if (precision_value < DecimalUtils::min_precision || precision_value > DecimalUtils::max_precision<Decimal256>) if (precision_value < 1 || precision_value > DataTypeDecimalBase<Decimal256>::maxPrecision())
throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Wrong precision: it must be between {} and {}, got {}", throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Wrong precision: it must be between {} and {}, got {}",
DecimalUtils::min_precision, DecimalUtils::max_precision<Decimal256>, precision_value); 1, DataTypeDecimalBase<Decimal256>::maxPrecision(), precision_value);
if (scale_value > precision_value) if (scale_value > precision_value)
throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Negative scales and scales larger than precision are not supported"); throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Negative scales and scales larger than precision are not supported");
if (precision_value <= DecimalUtils::max_precision<Decimal32>) if (precision_value <= DataTypeDecimalBase<Decimal32>::maxPrecision())
return std::make_shared<DecimalType<Decimal32>>(precision_value, scale_value); return std::make_shared<DecimalType<Decimal32>>(precision_value, scale_value);
if (precision_value <= DecimalUtils::max_precision<Decimal64>) if (precision_value <= DataTypeDecimalBase<Decimal64>::maxPrecision())
return std::make_shared<DecimalType<Decimal64>>(precision_value, scale_value); return std::make_shared<DecimalType<Decimal64>>(precision_value, scale_value);
if (precision_value <= DecimalUtils::max_precision<Decimal128>) if (precision_value <= DataTypeDecimalBase<Decimal128>::maxPrecision())
return std::make_shared<DecimalType<Decimal128>>(precision_value, scale_value); return std::make_shared<DecimalType<Decimal128>>(precision_value, scale_value);
return std::make_shared<DecimalType<Decimal256>>(precision_value, scale_value); return std::make_shared<DecimalType<Decimal256>>(precision_value, scale_value);
} }
extern template class DataTypeDecimalBase<Decimal32>;
extern template class DataTypeDecimalBase<Decimal64>;
extern template class DataTypeDecimalBase<Decimal128>;
extern template class DataTypeDecimalBase<Decimal256>;
extern template class DataTypeDecimalBase<DateTime64>;
} }

View File

@ -762,8 +762,12 @@ void SerializationDynamic::serializeTextJSON(const IColumn & column, size_t row_
void SerializationDynamic::serializeTextJSONPretty(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings, size_t indent) const void SerializationDynamic::serializeTextJSONPretty(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings, size_t indent) const
{ {
const auto & dynamic_column = assert_cast<const ColumnDynamic &>(column); auto nested_serialize = [&settings, indent](const ISerialization & serialization, const IColumn & col, size_t row, WriteBuffer & buf)
dynamic_column.getVariantInfo().variant_type->getDefaultSerialization()->serializeTextJSONPretty(dynamic_column.getVariantColumn(), row_num, ostr, settings, indent); {
serialization.serializeTextJSONPretty(col, row, buf, settings, indent);
};
serializeTextImpl(column, row_num, ostr, settings, nested_serialize);
} }
void SerializationDynamic::deserializeTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const void SerializationDynamic::deserializeTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const

View File

@ -111,13 +111,13 @@ DataTypePtr convertMySQLDataType(MultiEnum<MySQLDataTypesSupport> type_support,
} }
else if (type_support.isSet(MySQLDataTypesSupport::DECIMAL) && (type_name == "numeric" || type_name == "decimal")) else if (type_support.isSet(MySQLDataTypesSupport::DECIMAL) && (type_name == "numeric" || type_name == "decimal"))
{ {
if (precision <= DecimalUtils::max_precision<Decimal32>) if (precision <= DataTypeDecimalBase<Decimal32>::maxPrecision())
res = std::make_shared<DataTypeDecimal<Decimal32>>(precision, scale); res = std::make_shared<DataTypeDecimal<Decimal32>>(precision, scale);
else if (precision <= DecimalUtils::max_precision<Decimal64>) else if (precision <= DataTypeDecimalBase<Decimal64>::maxPrecision())
res = std::make_shared<DataTypeDecimal<Decimal64>>(precision, scale); res = std::make_shared<DataTypeDecimal<Decimal64>>(precision, scale);
else if (precision <= DecimalUtils::max_precision<Decimal128>) else if (precision <= DataTypeDecimalBase<Decimal128>::maxPrecision())
res = std::make_shared<DataTypeDecimal<Decimal128>>(precision, scale); res = std::make_shared<DataTypeDecimal<Decimal128>>(precision, scale);
else if (precision <= DecimalUtils::max_precision<Decimal256>) else if (precision <= DataTypeDecimalBase<Decimal256>::maxPrecision())
res = std::make_shared<DataTypeDecimal<Decimal256>>(precision, scale); res = std::make_shared<DataTypeDecimal<Decimal256>>(precision, scale);
} }
else if (type_name == "point") else if (type_name == "point")

View File

@ -493,7 +493,7 @@ void buildConfigurationFromFunctionWithKeyValueArguments(
/// We assume that function will not take arguments and will return constant value like tcpPort or hostName /// We assume that function will not take arguments and will return constant value like tcpPort or hostName
/// Such functions will return column with size equal to input_rows_count. /// Such functions will return column with size equal to input_rows_count.
size_t input_rows_count = 1; size_t input_rows_count = 1;
auto result = function->execute({}, function->getResultType(), input_rows_count); auto result = function->execute({}, function->getResultType(), input_rows_count, /* dry_run = */ false);
Field value; Field value;
result->get(0, value); result->get(0, value);

View File

@ -1308,7 +1308,8 @@ namespace
{ {
if (decimal.value == 0) if (decimal.value == 0)
writeInt(0); writeInt(0);
else if (DecimalComparison<DecimalType, int, EqualsOp>::compare(decimal, 1, scale, 0)) else if (DecimalComparison<DecimalType, int, EqualsOp>::compare(
decimal, 1, scale, 0, /* check overflow */ true))
writeInt(1); writeInt(1);
else else
{ {

View File

@ -7,6 +7,7 @@ add_headers_and_sources(clickhouse_functions .)
# This allows less dependency and linker work (specially important when building many example executables) # This allows less dependency and linker work (specially important when building many example executables)
set(DBMS_FUNCTIONS set(DBMS_FUNCTIONS
IFunction.cpp # IFunctionOverloadResolver::getLambdaArgumentTypes, IExecutableFunction::execute... (Many AST visitors, analyzer passes, some storages...) IFunction.cpp # IFunctionOverloadResolver::getLambdaArgumentTypes, IExecutableFunction::execute... (Many AST visitors, analyzer passes, some storages...)
IFunctionAdaptors.cpp # FunctionToFunctionBaseAdaptor (Used by FunctionFactory.cpp)
FunctionDynamicAdaptor.cpp # IFunctionOverloadResolver::getLambdaArgumentTypes, IExecutableFunction::execute... (Many AST visitors, analyzer passes, some storages...) FunctionDynamicAdaptor.cpp # IFunctionOverloadResolver::getLambdaArgumentTypes, IExecutableFunction::execute... (Many AST visitors, analyzer passes, some storages...)
FunctionFactory.cpp # FunctionFactory::instance() (Many AST visitors, analyzer passes, some storages...) FunctionFactory.cpp # FunctionFactory::instance() (Many AST visitors, analyzer passes, some storages...)
FunctionHelpers.cpp # convertConstTupleToConstantElements, checkAndGetColumnConstStringOrFixedString, checkAndGetNestedArrayOffset ...) FunctionHelpers.cpp # convertConstTupleToConstantElements, checkAndGetColumnConstStringOrFixedString, checkAndGetNestedArrayOffset ...)

View File

@ -41,6 +41,7 @@
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h> #include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h> #include <Functions/IFunction.h>
#include <Functions/IFunctionAdaptors.h>
#include <Functions/IsOperation.h> #include <Functions/IsOperation.h>
#include <Functions/castTypeToEither.h> #include <Functions/castTypeToEither.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
@ -231,6 +232,27 @@ public:
namespace impl_ namespace impl_
{ {
template <bool is_multiply, bool is_division, typename T, typename U, template <typename> typename DecimalType>
inline auto decimalResultType(const DecimalType<T> & tx, const DecimalType<U> & ty)
{
const auto result_trait = DecimalUtils::binaryOpResult<is_multiply, is_division>(tx, ty);
return DecimalType<typename decltype(result_trait)::FieldType>(result_trait.precision, result_trait.scale);
}
template <bool is_multiply, bool is_division, typename T, typename U, template <typename> typename DecimalType>
inline DecimalType<T> decimalResultType(const DecimalType<T> & tx, const DataTypeNumber<U> & ty)
{
const auto result_trait = DecimalUtils::binaryOpResult<is_multiply, is_division>(tx, ty);
return DecimalType<typename decltype(result_trait)::FieldType>(result_trait.precision, result_trait.scale);
}
template <bool is_multiply, bool is_division, typename T, typename U, template <typename> typename DecimalType>
inline DecimalType<U> decimalResultType(const DataTypeNumber<T> & tx, const DecimalType<U> & ty)
{
const auto result_trait = DecimalUtils::binaryOpResult<is_multiply, is_division>(tx, ty);
return DecimalType<typename decltype(result_trait)::FieldType>(result_trait.precision, result_trait.scale);
}
/** Arithmetic operations: +, -, *, /, %, /** Arithmetic operations: +, -, *, /, %,
* intDiv (integer division) * intDiv (integer division)
* Bitwise operations: |, &, ^, ~. * Bitwise operations: |, &, ^, ~.
@ -1166,7 +1188,7 @@ class FunctionBinaryArithmetic : public IFunction
new_arguments[1].type = std::make_shared<DataTypeNumber<DataTypeInterval::FieldType>>(); new_arguments[1].type = std::make_shared<DataTypeNumber<DataTypeInterval::FieldType>>();
auto function = function_builder->build(new_arguments); auto function = function_builder->build(new_arguments);
return function->execute(new_arguments, result_type, input_rows_count); return function->execute(new_arguments, result_type, input_rows_count, /* dry_run = */ false);
} }
ColumnPtr executeDateTimeTupleOfIntervalsPlusMinus(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, ColumnPtr executeDateTimeTupleOfIntervalsPlusMinus(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type,
@ -1180,7 +1202,7 @@ class FunctionBinaryArithmetic : public IFunction
auto function = function_builder->build(new_arguments); auto function = function_builder->build(new_arguments);
return function->execute(new_arguments, result_type, input_rows_count); return function->execute(new_arguments, result_type, input_rows_count, /* dry_run = */ false);
} }
ColumnPtr executeIntervalTupleOfIntervalsPlusMinus(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, ColumnPtr executeIntervalTupleOfIntervalsPlusMinus(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type,
@ -1188,7 +1210,7 @@ class FunctionBinaryArithmetic : public IFunction
{ {
auto function = function_builder->build(arguments); auto function = function_builder->build(arguments);
return function->execute(arguments, result_type, input_rows_count); return function->execute(arguments, result_type, input_rows_count, /* dry_run = */ false);
} }
ColumnPtr executeArraysImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const ColumnPtr executeArraysImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const
@ -1323,7 +1345,7 @@ class FunctionBinaryArithmetic : public IFunction
auto function = function_builder->build(new_arguments); auto function = function_builder->build(new_arguments);
return function->execute(new_arguments, result_type, input_rows_count); return function->execute(new_arguments, result_type, input_rows_count, /* dry_run = */ false);
} }
template <typename T, typename ResultDataType> template <typename T, typename ResultDataType>
@ -2225,7 +2247,7 @@ ColumnPtr executeStringInteger(const ColumnsWithTypeAndName & arguments, const A
/// Special case when the function is plus, minus or multiply, both arguments are tuples. /// Special case when the function is plus, minus or multiply, both arguments are tuples.
if (auto function_builder = getFunctionForTupleArithmetic(arguments[0].type, arguments[1].type, context)) if (auto function_builder = getFunctionForTupleArithmetic(arguments[0].type, arguments[1].type, context))
{ {
return function_builder->build(arguments)->execute(arguments, result_type, input_rows_count); return function_builder->build(arguments)->execute(arguments, result_type, input_rows_count, /* dry_run = */ false);
} }
/// Special case when the function is multiply or divide, one of arguments is Tuple and another is Number. /// Special case when the function is multiply or divide, one of arguments is Tuple and another is Number.

View File

@ -1,4 +1,5 @@
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Functions/IFunctionAdaptors.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>

View File

@ -5,7 +5,6 @@
#include <Common/IFactoryWithAliases.h> #include <Common/IFactoryWithAliases.h>
#include <Common/FunctionDocumentation.h> #include <Common/FunctionDocumentation.h>
#include <Functions/IFunction.h> #include <Functions/IFunction.h>
#include <Functions/IFunctionAdaptors.h>
#include <functional> #include <functional>
#include <memory> #include <memory>

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include <Core/callOnTypeIndex.h> #include <Core/callOnTypeIndex.h>
#include <Core/DecimalFunctions.h>
#include <DataTypes/DataTypesNumber.h> #include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypesDecimal.h> #include <DataTypes/DataTypesDecimal.h>
#include <Columns/ColumnsNumber.h> #include <Columns/ColumnsNumber.h>

View File

@ -339,7 +339,7 @@ public:
/// Special case when the function is negate, argument is tuple. /// Special case when the function is negate, argument is tuple.
if (auto function_builder = getFunctionForTupleArithmetic(arguments[0].type, context)) if (auto function_builder = getFunctionForTupleArithmetic(arguments[0].type, context))
{ {
return function_builder->build(arguments)->execute(arguments, result_type, input_rows_count); return function_builder->build(arguments)->execute(arguments, result_type, input_rows_count, /* dry_run = */ false);
} }
ColumnPtr result_column; ColumnPtr result_column;

View File

@ -5,6 +5,7 @@
#include <Functions/FunctionHelpers.h> #include <Functions/FunctionHelpers.h>
#include <DataTypes/DataTypeDateTime64.h> #include <DataTypes/DataTypeDateTime64.h>
#include <DataTypes/DataTypesNumber.h> #include <DataTypes/DataTypesNumber.h>
#include <Core/DecimalFunctions.h>
#include <Columns/ColumnsNumber.h> #include <Columns/ColumnsNumber.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>

View File

@ -2,6 +2,7 @@
#if USE_ULID #if USE_ULID
#include <Core/DecimalFunctions.h>
#include <Columns/ColumnFixedString.h> #include <Columns/ColumnFixedString.h>
#include <Columns/ColumnString.h> #include <Columns/ColumnString.h>
#include <Columns/ColumnsDateTime.h> #include <Columns/ColumnsDateTime.h>

View File

@ -41,12 +41,6 @@
#include <limits> #include <limits>
#include <type_traits> #include <type_traits>
#if USE_EMBEDDED_COMPILER
# include <DataTypes/Native.h>
# include <llvm/IR/IRBuilder.h>
#endif
namespace DB namespace DB
{ {
@ -59,6 +53,68 @@ namespace ErrorCodes
extern const int BAD_ARGUMENTS; extern const int BAD_ARGUMENTS;
} }
template <bool _int, bool _float, bool _decimal, bool _datetime, typename F>
static inline bool callOnAtLeastOneDecimalType(TypeIndex type_num1, TypeIndex type_num2, F && f)
{
switch (type_num1)
{
case TypeIndex::DateTime64:
return callOnBasicType<DateTime64, _int, _float, _decimal, _datetime>(type_num2, std::forward<F>(f));
case TypeIndex::Decimal32:
return callOnBasicType<Decimal32, _int, _float, _decimal, _datetime>(type_num2, std::forward<F>(f));
case TypeIndex::Decimal64:
return callOnBasicType<Decimal64, _int, _float, _decimal, _datetime>(type_num2, std::forward<F>(f));
case TypeIndex::Decimal128:
return callOnBasicType<Decimal128, _int, _float, _decimal, _datetime>(type_num2, std::forward<F>(f));
case TypeIndex::Decimal256:
return callOnBasicType<Decimal256, _int, _float, _decimal, _datetime>(type_num2, std::forward<F>(f));
default:
break;
}
switch (type_num2)
{
case TypeIndex::DateTime64:
return callOnBasicTypeSecondArg<DateTime64, _int, _float, _decimal, _datetime>(type_num1, std::forward<F>(f));
case TypeIndex::Decimal32:
return callOnBasicTypeSecondArg<Decimal32, _int, _float, _decimal, _datetime>(type_num1, std::forward<F>(f));
case TypeIndex::Decimal64:
return callOnBasicTypeSecondArg<Decimal64, _int, _float, _decimal, _datetime>(type_num1, std::forward<F>(f));
case TypeIndex::Decimal128:
return callOnBasicTypeSecondArg<Decimal128, _int, _float, _decimal, _datetime>(type_num1, std::forward<F>(f));
case TypeIndex::Decimal256:
return callOnBasicTypeSecondArg<Decimal256, _int, _float, _decimal, _datetime>(type_num1, std::forward<F>(f));
default:
break;
}
return false;
}
template <template <typename, typename> class Operation, typename Name>
ColumnPtr executeDecimal(const ColumnWithTypeAndName & col_left, const ColumnWithTypeAndName & col_right, bool check_decimal_overflow)
{
TypeIndex left_number = col_left.type->getTypeId();
TypeIndex right_number = col_right.type->getTypeId();
ColumnPtr res;
auto call = [&](const auto & types) -> bool
{
using Types = std::decay_t<decltype(types)>;
using LeftDataType = typename Types::LeftType;
using RightDataType = typename Types::RightType;
return (res = DecimalComparison<LeftDataType, RightDataType, Operation>::apply(col_left, col_right, check_decimal_overflow))
!= nullptr;
};
if (!callOnAtLeastOneDecimalType<true, false, true, true>(left_number, right_number, call))
throw Exception(
ErrorCodes::LOGICAL_ERROR, "Wrong call for {} with {} and {}", Name::name, col_left.type->getName(), col_right.type->getName());
return res;
}
/** Comparison functions: ==, !=, <, >, <=, >=. /** Comparison functions: ==, !=, <, >, <=, >=.
* The comparison functions always return 0 or 1 (UInt8). * The comparison functions always return 0 or 1 (UInt8).
@ -574,62 +630,6 @@ struct GenericComparisonImpl
} }
}; };
#if USE_EMBEDDED_COMPILER
template <template <typename, typename> typename Op> struct CompileOp;
template <> struct CompileOp<EqualsOp>
{
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * x, llvm::Value * y, bool /*is_signed*/)
{
return x->getType()->isIntegerTy() ? b.CreateICmpEQ(x, y) : b.CreateFCmpOEQ(x, y); /// qNaNs always compare false
}
};
template <> struct CompileOp<NotEqualsOp>
{
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * x, llvm::Value * y, bool /*is_signed*/)
{
return x->getType()->isIntegerTy() ? b.CreateICmpNE(x, y) : b.CreateFCmpUNE(x, y);
}
};
template <> struct CompileOp<LessOp>
{
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * x, llvm::Value * y, bool is_signed)
{
return x->getType()->isIntegerTy() ? (is_signed ? b.CreateICmpSLT(x, y) : b.CreateICmpULT(x, y)) : b.CreateFCmpOLT(x, y);
}
};
template <> struct CompileOp<GreaterOp>
{
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * x, llvm::Value * y, bool is_signed)
{
return x->getType()->isIntegerTy() ? (is_signed ? b.CreateICmpSGT(x, y) : b.CreateICmpUGT(x, y)) : b.CreateFCmpOGT(x, y);
}
};
template <> struct CompileOp<LessOrEqualsOp>
{
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * x, llvm::Value * y, bool is_signed)
{
return x->getType()->isIntegerTy() ? (is_signed ? b.CreateICmpSLE(x, y) : b.CreateICmpULE(x, y)) : b.CreateFCmpOLE(x, y);
}
};
template <> struct CompileOp<GreaterOrEqualsOp>
{
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * x, llvm::Value * y, bool is_signed)
{
return x->getType()->isIntegerTy() ? (is_signed ? b.CreateICmpSGE(x, y) : b.CreateICmpUGE(x, y)) : b.CreateFCmpOGE(x, y);
}
};
#endif
struct NameEquals { static constexpr auto name = "equals"; }; struct NameEquals { static constexpr auto name = "equals"; };
struct NameNotEquals { static constexpr auto name = "notEquals"; }; struct NameNotEquals { static constexpr auto name = "notEquals"; };
struct NameLess { static constexpr auto name = "less"; }; struct NameLess { static constexpr auto name = "less"; };
@ -753,30 +753,6 @@ private:
return nullptr; return nullptr;
} }
ColumnPtr executeDecimal(const ColumnWithTypeAndName & col_left, const ColumnWithTypeAndName & col_right) const
{
TypeIndex left_number = col_left.type->getTypeId();
TypeIndex right_number = col_right.type->getTypeId();
ColumnPtr res;
auto call = [&](const auto & types) -> bool
{
using Types = std::decay_t<decltype(types)>;
using LeftDataType = typename Types::LeftType;
using RightDataType = typename Types::RightType;
if (check_decimal_overflow)
return (res = DecimalComparison<LeftDataType, RightDataType, Op, true>::apply(col_left, col_right)) != nullptr;
return (res = DecimalComparison<LeftDataType, RightDataType, Op, false>::apply(col_left, col_right)) != nullptr;
};
if (!callOnBasicTypes<true, false, true, true>(left_number, right_number, call))
throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong call for {} with {} and {}",
getName(), col_left.type->getName(), col_right.type->getName());
return res;
}
ColumnPtr executeString(const IColumn * c0, const IColumn * c1) const ColumnPtr executeString(const IColumn * c0, const IColumn * c1) const
{ {
const ColumnString * c0_string = checkAndGetColumn<ColumnString>(c0); const ColumnString * c0_string = checkAndGetColumn<ColumnString>(c0);
@ -1010,7 +986,7 @@ private:
convolution_columns[i].type = impl->getResultType(); convolution_columns[i].type = impl->getResultType();
/// Comparison of the elements. /// Comparison of the elements.
convolution_columns[i].column = impl->execute(tmp_columns, impl->getResultType(), input_rows_count); convolution_columns[i].column = impl->execute(tmp_columns, impl->getResultType(), input_rows_count, /* dry_run = */ false);
} }
if (tuple_size == 1) if (tuple_size == 1)
@ -1021,7 +997,7 @@ private:
/// Logical convolution. /// Logical convolution.
auto impl = func_convolution->build(convolution_columns); auto impl = func_convolution->build(convolution_columns);
return impl->execute(convolution_columns, impl->getResultType(), input_rows_count); return impl->execute(convolution_columns, impl->getResultType(), input_rows_count, /* dry_run = */ false);
} }
ColumnPtr executeTupleLessGreaterImpl( ColumnPtr executeTupleLessGreaterImpl(
@ -1053,18 +1029,18 @@ private:
{ {
auto impl_head = func_compare_head->build(tmp_columns); auto impl_head = func_compare_head->build(tmp_columns);
less_columns[i].type = impl_head->getResultType(); less_columns[i].type = impl_head->getResultType();
less_columns[i].column = impl_head->execute(tmp_columns, less_columns[i].type, input_rows_count); less_columns[i].column = impl_head->execute(tmp_columns, less_columns[i].type, input_rows_count, /* dry_run = */ false);
auto impl_equals = func_equals->build(tmp_columns); auto impl_equals = func_equals->build(tmp_columns);
equal_columns[i].type = impl_equals->getResultType(); equal_columns[i].type = impl_equals->getResultType();
equal_columns[i].column = impl_equals->execute(tmp_columns, equal_columns[i].type, input_rows_count); equal_columns[i].column = impl_equals->execute(tmp_columns, equal_columns[i].type, input_rows_count, /* dry_run = */ false);
} }
else else
{ {
auto impl_tail = func_compare_tail->build(tmp_columns); auto impl_tail = func_compare_tail->build(tmp_columns);
less_columns[i].type = impl_tail->getResultType(); less_columns[i].type = impl_tail->getResultType();
less_columns[i].column = impl_tail->execute(tmp_columns, less_columns[i].type, input_rows_count); less_columns[i].column = impl_tail->execute(tmp_columns, less_columns[i].type, input_rows_count, /* dry_run = */ false);
} }
} }
@ -1083,13 +1059,13 @@ private:
tmp_columns[1] = equal_columns[i]; tmp_columns[1] = equal_columns[i];
auto func_and_adaptor = func_and->build(tmp_columns); auto func_and_adaptor = func_and->build(tmp_columns);
tmp_columns[0].column = func_and_adaptor->execute(tmp_columns, func_and_adaptor->getResultType(), input_rows_count); tmp_columns[0].column = func_and_adaptor->execute(tmp_columns, func_and_adaptor->getResultType(), input_rows_count, /* dry_run = */ false);
tmp_columns[0].type = func_and_adaptor->getResultType(); tmp_columns[0].type = func_and_adaptor->getResultType();
tmp_columns[1] = less_columns[i]; tmp_columns[1] = less_columns[i];
auto func_or_adaptor = func_or->build(tmp_columns); auto func_or_adaptor = func_or->build(tmp_columns);
tmp_columns[0].column = func_or_adaptor->execute(tmp_columns, func_or_adaptor->getResultType(), input_rows_count); tmp_columns[0].column = func_or_adaptor->execute(tmp_columns, func_or_adaptor->getResultType(), input_rows_count, /* dry_run = */ false);
tmp_columns[tmp_columns.size() - 1].type = func_or_adaptor->getResultType(); tmp_columns[tmp_columns.size() - 1].type = func_or_adaptor->getResultType();
} }
@ -1334,7 +1310,8 @@ public:
DataTypePtr common_type = getLeastSupertype(DataTypes{left_type, right_type}); DataTypePtr common_type = getLeastSupertype(DataTypes{left_type, right_type});
ColumnPtr c0_converted = castColumn(col_with_type_and_name_left, common_type); ColumnPtr c0_converted = castColumn(col_with_type_and_name_left, common_type);
ColumnPtr c1_converted = castColumn(col_with_type_and_name_right, common_type); ColumnPtr c1_converted = castColumn(col_with_type_and_name_right, common_type);
return executeDecimal({c0_converted, common_type, "left"}, {c1_converted, common_type, "right"}); return executeDecimal<Op, Name>(
{c0_converted, common_type, "left"}, {c1_converted, common_type, "right"}, check_decimal_overflow);
} }
/// Check does another data type is comparable to Decimal, includes Int and Float. /// Check does another data type is comparable to Decimal, includes Int and Float.
@ -1357,7 +1334,7 @@ public:
= ColumnsWithTypeAndName{{c0_converted, converted_type, "left"}, {c1_converted, converted_type, "right"}}; = ColumnsWithTypeAndName{{c0_converted, converted_type, "left"}, {c1_converted, converted_type, "right"}};
return executeImpl(new_arguments, result_type, input_rows_count); return executeImpl(new_arguments, result_type, input_rows_count);
} }
return executeDecimal(col_with_type_and_name_left, col_with_type_and_name_right); return executeDecimal<Op, Name>(col_with_type_and_name_left, col_with_type_and_name_right, check_decimal_overflow);
} }
if (date_and_datetime) if (date_and_datetime)
{ {
@ -1367,7 +1344,8 @@ public:
if (!((res = executeNumLeftType<UInt32>(c0_converted.get(), c1_converted.get())) if (!((res = executeNumLeftType<UInt32>(c0_converted.get(), c1_converted.get()))
|| (res = executeNumLeftType<UInt64>(c0_converted.get(), c1_converted.get())) || (res = executeNumLeftType<UInt64>(c0_converted.get(), c1_converted.get()))
|| (res = executeNumLeftType<Int32>(c0_converted.get(), c1_converted.get())) || (res = executeNumLeftType<Int32>(c0_converted.get(), c1_converted.get()))
|| (res = executeDecimal({c0_converted, common_type, "left"}, {c1_converted, common_type, "right"})))) || (res = executeDecimal<Op, Name>(
{c0_converted, common_type, "left"}, {c1_converted, common_type, "right"}, check_decimal_overflow))))
throw Exception(ErrorCodes::LOGICAL_ERROR, "Date related common types can only be UInt32/UInt64/Int32/Decimal"); throw Exception(ErrorCodes::LOGICAL_ERROR, "Date related common types can only be UInt32/UInt64/Int32/Decimal");
return res; return res;
} }

View File

@ -3243,11 +3243,9 @@ private:
{ {
auto function_adaptor = std::make_unique<FunctionToOverloadResolverAdaptor>(function)->build({ColumnWithTypeAndName{nullptr, from_type, ""}}); auto function_adaptor = std::make_unique<FunctionToOverloadResolverAdaptor>(function)->build({ColumnWithTypeAndName{nullptr, from_type, ""}});
return [function_adaptor] return [function_adaptor](
(ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, const ColumnNullable *, size_t input_rows_count) ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, const ColumnNullable *, size_t input_rows_count)
{ { return function_adaptor->execute(arguments, result_type, input_rows_count, /* dry_run = */ false); };
return function_adaptor->execute(arguments, result_type, input_rows_count);
};
} }
static WrapperType createToNullableColumnWrapper() static WrapperType createToNullableColumnWrapper()

View File

@ -644,7 +644,7 @@ private:
}; };
auto rows = mask_column->size(); auto rows = mask_column->size();
result_column = if_func->build(if_args)->execute(if_args, result_type, rows); result_column = if_func->build(if_args)->execute(if_args, result_type, rows, /* dry_run = */ false);
} }

View File

@ -1,5 +1,6 @@
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Functions/FunctionsLogical.h> #include <Functions/FunctionsLogical.h>
#include <Functions/IFunctionAdaptors.h>
#include <Functions/logical.h> #include <Functions/logical.h>
#include <Columns/ColumnConst.h> #include <Columns/ColumnConst.h>

View File

@ -73,7 +73,7 @@ public:
auto op_build = op->build(arguments); auto op_build = op->build(arguments);
auto res_type = op_build->getResultType(); auto res_type = op_build->getResultType();
return op_build->execute(arguments, res_type, input_rows_count); return op_build->execute(arguments, res_type, input_rows_count, /* dry_run = */ false);
} }
private: private:

View File

@ -486,6 +486,13 @@ ColumnPtr IExecutableFunction::execute(
return executeWithoutSparseColumns(arguments, result_type, input_rows_count, dry_run); return executeWithoutSparseColumns(arguments, result_type, input_rows_count, dry_run);
} }
ColumnPtr IFunctionBase::execute(const DB::ColumnsWithTypeAndName& arguments, const DB::DataTypePtr& result_type,
size_t input_rows_count, bool dry_run) const
{
checkFunctionArgumentSizes(arguments, input_rows_count);
return prepare(arguments)->execute(arguments, result_type, input_rows_count, dry_run);
}
void IFunctionOverloadResolver::checkNumberOfArguments(size_t number_of_arguments) const void IFunctionOverloadResolver::checkNumberOfArguments(size_t number_of_arguments) const
{ {
if (isVariadic()) if (isVariadic())

View File

@ -7,7 +7,6 @@
#include <Core/Names.h> #include <Core/Names.h>
#include <Core/ValuesWithType.h> #include <Core/ValuesWithType.h>
#include <DataTypes/IDataType.h> #include <DataTypes/IDataType.h>
#include <Functions/FunctionHelpers.h>
#include <Common/Exception.h> #include <Common/Exception.h>
#include "config.h" #include "config.h"
@ -141,11 +140,7 @@ public:
const ColumnsWithTypeAndName & arguments, const ColumnsWithTypeAndName & arguments,
const DataTypePtr & result_type, const DataTypePtr & result_type,
size_t input_rows_count, size_t input_rows_count,
bool dry_run = false) const bool dry_run) const;
{
checkFunctionArgumentSizes(arguments, input_rows_count);
return prepare(arguments)->execute(arguments, result_type, input_rows_count, dry_run);
}
/// Get the main function name. /// Get the main function name.
virtual String getName() const = 0; virtual String getName() const = 0;

View File

@ -0,0 +1,19 @@
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunctionAdaptors.h>
namespace DB
{
ColumnPtr FunctionToExecutableFunctionAdaptor::executeImpl(const ColumnsWithTypeAndName& arguments,
const DataTypePtr& result_type, size_t input_rows_count) const
{
checkFunctionArgumentSizes(arguments, input_rows_count);
return function->executeImpl(arguments, result_type, input_rows_count);
}
ColumnPtr FunctionToExecutableFunctionAdaptor::executeDryRunImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const
{
checkFunctionArgumentSizes(arguments, input_rows_count);
return function->executeImplDryRun(arguments, result_type, input_rows_count);
}
}

View File

@ -16,17 +16,8 @@ public:
protected: protected:
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const final ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const final;
{ ColumnPtr executeDryRunImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const final;
checkFunctionArgumentSizes(arguments, input_rows_count);
return function->executeImpl(arguments, result_type, input_rows_count);
}
ColumnPtr executeDryRunImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const final
{
checkFunctionArgumentSizes(arguments, input_rows_count);
return function->executeImplDryRun(arguments, result_type, input_rows_count);
}
bool useDefaultImplementationForNulls() const final { return function->useDefaultImplementationForNulls(); } bool useDefaultImplementationForNulls() const final { return function->useDefaultImplementationForNulls(); }
bool useDefaultImplementationForNothing() const final { return function->useDefaultImplementationForNothing(); } bool useDefaultImplementationForNothing() const final { return function->useDefaultImplementationForNothing(); }

View File

@ -117,11 +117,11 @@ public:
} }
auto zipped auto zipped
= FunctionFactory::instance().get("arrayZip", context)->build(new_args)->execute(new_args, result_type, input_rows_count); = FunctionFactory::instance().get("arrayZip", context)->build(new_args)->execute(new_args, result_type, input_rows_count, /* dry_run = */ false);
ColumnsWithTypeAndName sort_arg({{zipped, std::make_shared<DataTypeArray>(result_type), "zipped"}}); ColumnsWithTypeAndName sort_arg({{zipped, std::make_shared<DataTypeArray>(result_type), "zipped"}});
auto sorted_tuple auto sorted_tuple
= FunctionFactory::instance().get(sort_function, context)->build(sort_arg)->execute(sort_arg, result_type, input_rows_count); = FunctionFactory::instance().get(sort_function, context)->build(sort_arg)->execute(sort_arg, result_type, input_rows_count, /* dry_run = */ false);
auto null_type = std::make_shared<DataTypeNullable>(std::make_shared<DataTypeInt8>()); auto null_type = std::make_shared<DataTypeNullable>(std::make_shared<DataTypeInt8>());
@ -140,7 +140,7 @@ public:
{null_type->createColumnConstWithDefaultValue(input_rows_count), null_type, "NULL"}, {null_type->createColumnConstWithDefaultValue(input_rows_count), null_type, "NULL"},
}); });
tuple_columns[i] = fun_array->build(null_array_arg)->execute(null_array_arg, arg_type, input_rows_count); tuple_columns[i] = fun_array->build(null_array_arg)->execute(null_array_arg, arg_type, input_rows_count, /* dry_run = */ false);
tuple_columns[i] = tuple_columns[i]->convertToFullColumnIfConst(); tuple_columns[i] = tuple_columns[i]->convertToFullColumnIfConst();
} }
else else
@ -151,7 +151,7 @@ public:
auto tuple_coulmn = FunctionFactory::instance() auto tuple_coulmn = FunctionFactory::instance()
.get("tupleElement", context) .get("tupleElement", context)
->build(untuple_args) ->build(untuple_args)
->execute(untuple_args, result_type, input_rows_count); ->execute(untuple_args, result_type, input_rows_count, /* dry_run = */ false);
auto out_tmp = ColumnArray::create(nested_types[i]->createColumn()); auto out_tmp = ColumnArray::create(nested_types[i]->createColumn());
@ -190,7 +190,7 @@ public:
slice_index.column = FunctionFactory::instance() slice_index.column = FunctionFactory::instance()
.get("indexOf", context) .get("indexOf", context)
->build(indexof_args) ->build(indexof_args)
->execute(indexof_args, result_type, input_rows_count); ->execute(indexof_args, result_type, input_rows_count, /* dry_run = */ false);
auto null_index_in_array = slice_index.column->get64(0); auto null_index_in_array = slice_index.column->get64(0);
if (null_index_in_array > 0) if (null_index_in_array > 0)
@ -218,15 +218,15 @@ public:
ColumnsWithTypeAndName slice_args_right( ColumnsWithTypeAndName slice_args_right(
{{ColumnWithTypeAndName(tuple_columns[i], arg_type, "array")}, slice_index}); {{ColumnWithTypeAndName(tuple_columns[i], arg_type, "array")}, slice_index});
ColumnWithTypeAndName arr_left{ ColumnWithTypeAndName arr_left{
fun_slice->build(slice_args_left)->execute(slice_args_left, arg_type, input_rows_count), arg_type, ""}; fun_slice->build(slice_args_left)->execute(slice_args_left, arg_type, input_rows_count, /* dry_run = */ false), arg_type, ""};
ColumnWithTypeAndName arr_right{ ColumnWithTypeAndName arr_right{
fun_slice->build(slice_args_right)->execute(slice_args_right, arg_type, input_rows_count), arg_type, ""}; fun_slice->build(slice_args_right)->execute(slice_args_right, arg_type, input_rows_count, /* dry_run = */ false), arg_type, ""};
ColumnsWithTypeAndName arr_cancat({arr_right, arr_left}); ColumnsWithTypeAndName arr_cancat({arr_right, arr_left});
auto out_tmp = FunctionFactory::instance() auto out_tmp = FunctionFactory::instance()
.get("arrayConcat", context) .get("arrayConcat", context)
->build(arr_cancat) ->build(arr_cancat)
->execute(arr_cancat, arg_type, input_rows_count); ->execute(arr_cancat, arg_type, input_rows_count, /* dry_run = */ false);
adjusted_columns[i] = std::move(out_tmp); adjusted_columns[i] = std::move(out_tmp);
} }
} }

View File

@ -23,7 +23,7 @@ public:
{DataTypeUInt8().createColumnConst(1, toField(UInt8(1))), std::make_shared<DataTypeUInt8>(), ""}, {DataTypeUInt8().createColumnConst(1, toField(UInt8(1))), std::make_shared<DataTypeUInt8>(), ""},
{DataTypeUInt8().createColumnConst(1, toField(UInt8(2))), std::make_shared<DataTypeUInt8>(), ""} {DataTypeUInt8().createColumnConst(1, toField(UInt8(2))), std::make_shared<DataTypeUInt8>(), ""}
}); });
auto if_res = FunctionFactory::instance().get("if", context)->build(if_columns)->execute(if_columns, std::make_shared<DataTypeUInt8>(), input_rows_count); auto if_res = FunctionFactory::instance().get("if", context)->build(if_columns)->execute(if_columns, std::make_shared<DataTypeUInt8>(), input_rows_count, /* dry_run = */ false);
auto result = if_res->getUInt(0); auto result = if_res->getUInt(0);
return (result == 1); return (result == 1);
} }

View File

@ -10,10 +10,9 @@
#include <Processors/Sources/SourceFromSingleChunk.h> #include <Processors/Sources/SourceFromSingleChunk.h>
#include <Formats/formatBlock.h> #include <Formats/formatBlock.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h> #include <Functions/FunctionHelpers.h>
#include <Functions/IFunctionAdaptors.h>
#include <Functions/UserDefined/ExternalUserDefinedExecutableFunctionsLoader.h> #include <Functions/UserDefined/ExternalUserDefinedExecutableFunctionsLoader.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <Interpreters/convertFieldToType.h> #include <Interpreters/convertFieldToType.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Interpreters/castColumn.h> #include <Interpreters/castColumn.h>

View File

@ -10,6 +10,7 @@
#include <DataTypes/DataTypeTuple.h> #include <DataTypes/DataTypeTuple.h>
#include <Functions/FunctionHelpers.h> #include <Functions/FunctionHelpers.h>
#include <Functions/IFunctionAdaptors.h>
#include <Functions/like.h> #include <Functions/like.h>
#include <Functions/array/arrayConcat.h> #include <Functions/array/arrayConcat.h>
#include <Functions/array/arrayFilter.h> #include <Functions/array/arrayFilter.h>

View File

@ -339,7 +339,7 @@ static ColumnPtr callFunctionNotEquals(ColumnWithTypeAndName first, ColumnWithTy
{ {
ColumnsWithTypeAndName args{first, second}; ColumnsWithTypeAndName args{first, second};
auto eq_func = FunctionFactory::instance().get("notEquals", context)->build(args); auto eq_func = FunctionFactory::instance().get("notEquals", context)->build(args);
return eq_func->execute(args, eq_func->getResultType(), args.front().column->size()); return eq_func->execute(args, eq_func->getResultType(), args.front().column->size(), /* dry_run = */ false);
} }
template <typename Mode> template <typename Mode>

View File

@ -113,7 +113,7 @@ public:
ColumnWithTypeAndName intersect_column; ColumnWithTypeAndName intersect_column;
intersect_column.type = intersect_array->getResultType(); intersect_column.type = intersect_array->getResultType();
intersect_column.column = intersect_array->execute(arguments, intersect_column.type, input_rows_count); intersect_column.column = intersect_array->execute(arguments, intersect_column.type, input_rows_count, /* dry_run = */ false);
const auto * intersect_column_type = checkAndGetDataType<DataTypeArray>(intersect_column.type.get()); const auto * intersect_column_type = checkAndGetDataType<DataTypeArray>(intersect_column.type.get());
if (!intersect_column_type) if (!intersect_column_type)

View File

@ -62,16 +62,17 @@ public:
for (size_t i = 0; i < num_rows; ++i) for (size_t i = 0; i < num_rows; ++i)
{ {
auto array_size = col_num->getInt(i); auto array_size = col_num->getInt(i);
auto element_size = col_value->byteSizeAt(i);
if (unlikely(array_size < 0)) if (unlikely(array_size < 0))
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Array size {} cannot be negative: while executing function {}", array_size, getName()); throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Array size {} cannot be negative: while executing function {}", array_size, getName());
Int64 estimated_size = 0; Int64 estimated_size = 0;
if (unlikely(common::mulOverflow(array_size, col_value->byteSize(), estimated_size))) if (unlikely(common::mulOverflow(array_size, element_size, estimated_size)))
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Array size {} with element size {} bytes is too large: while executing function {}", array_size, col_value->byteSize(), getName()); throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Array size {} with element size {} bytes is too large: while executing function {}", array_size, element_size, getName());
if (unlikely(estimated_size > max_array_size_in_columns_bytes)) if (unlikely(estimated_size > max_array_size_in_columns_bytes))
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Array size {} with element size {} bytes is too large: while executing function {}", array_size, col_value->byteSize(), getName()); throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Array size {} with element size {} bytes is too large: while executing function {}", array_size, element_size, getName());
offset += array_size; offset += array_size;

View File

@ -1,5 +1,6 @@
#include "arrayIndex.h" #include "arrayIndex.h"
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Functions/IFunctionAdaptors.h>
namespace DB namespace DB
{ {

View File

@ -93,12 +93,12 @@ public:
auto fun_array = FunctionFactory::instance().get("array", context); auto fun_array = FunctionFactory::instance().get("array", context);
src_array_col.column = fun_array->build(src_array_elems)->execute(src_array_elems, src_array_type, input_rows_count); src_array_col.column = fun_array->build(src_array_elems)->execute(src_array_elems, src_array_type, input_rows_count, /* dry_run = */ false);
dst_array_col.column = fun_array->build(dst_array_elems)->execute(dst_array_elems, dst_array_type, input_rows_count); dst_array_col.column = fun_array->build(dst_array_elems)->execute(dst_array_elems, dst_array_type, input_rows_count, /* dry_run = */ false);
/// Execute transform. /// Execute transform.
ColumnsWithTypeAndName transform_args{args.front(), src_array_col, dst_array_col, args.back()}; ColumnsWithTypeAndName transform_args{args.front(), src_array_col, dst_array_col, args.back()};
return FunctionFactory::instance().get("transform", context)->build(transform_args)->execute(transform_args, result_type, input_rows_count); return FunctionFactory::instance().get("transform", context)->build(transform_args)->execute(transform_args, result_type, input_rows_count, /* dry_run = */ false);
} }
private: private:

View File

@ -6,6 +6,7 @@
#include <Columns/IColumn.h> #include <Columns/IColumn.h>
#include <Common/DateLUT.h> #include <Common/DateLUT.h>
#include <Common/typeid_cast.h> #include <Common/typeid_cast.h>
#include <Core/DecimalFunctions.h>
#include <DataTypes/DataTypeDate.h> #include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDate32.h> #include <DataTypes/DataTypeDate32.h>
#include <DataTypes/DataTypeDateTime.h> #include <DataTypes/DataTypeDateTime.h>

View File

@ -133,11 +133,11 @@ public:
{ {
tmp_args[0] = filtered_args[i]; tmp_args[0] = filtered_args[i];
auto & cond = multi_if_args.emplace_back(ColumnWithTypeAndName{nullptr, std::make_shared<DataTypeUInt8>(), ""}); auto & cond = multi_if_args.emplace_back(ColumnWithTypeAndName{nullptr, std::make_shared<DataTypeUInt8>(), ""});
cond.column = is_not_null->build(tmp_args)->execute(tmp_args, cond.type, input_rows_count); cond.column = is_not_null->build(tmp_args)->execute(tmp_args, cond.type, input_rows_count, /* dry_run = */ false);
tmp_args[0] = filtered_args[i]; tmp_args[0] = filtered_args[i];
auto & val = multi_if_args.emplace_back(ColumnWithTypeAndName{nullptr, removeNullable(filtered_args[i].type), ""}); auto & val = multi_if_args.emplace_back(ColumnWithTypeAndName{nullptr, removeNullable(filtered_args[i].type), ""});
val.column = assume_not_null->build(tmp_args)->execute(tmp_args, val.type, input_rows_count); val.column = assume_not_null->build(tmp_args)->execute(tmp_args, val.type, input_rows_count, /* dry_run = */ false);
} }
} }
@ -152,7 +152,7 @@ public:
/// use function "if" instead, because it's implemented more efficient. /// use function "if" instead, because it's implemented more efficient.
/// TODO: make "multiIf" the same efficient. /// TODO: make "multiIf" the same efficient.
FunctionOverloadResolverPtr if_or_multi_if = multi_if_args.size() == 3 ? if_function : multi_if_function; FunctionOverloadResolverPtr if_or_multi_if = multi_if_args.size() == 3 ? if_function : multi_if_function;
ColumnPtr res = if_or_multi_if->build(multi_if_args)->execute(multi_if_args, result_type, input_rows_count); ColumnPtr res = if_or_multi_if->build(multi_if_args)->execute(multi_if_args, result_type, input_rows_count, /* dry_run = */ false);
/// if last argument is not nullable, result should be also not nullable /// if last argument is not nullable, result should be also not nullable
if (!multi_if_args.back().column->isNullable() && res->isNullable()) if (!multi_if_args.back().column->isNullable() && res->isNullable())

View File

@ -3,6 +3,7 @@
#include <DataTypes/DataTypeString.h> #include <DataTypes/DataTypeString.h>
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h> #include <Functions/FunctionHelpers.h>
#include <Functions/IFunctionAdaptors.h>
#include <Functions/GatherUtils/Algorithms.h> #include <Functions/GatherUtils/Algorithms.h>
#include <Functions/GatherUtils/Sinks.h> #include <Functions/GatherUtils/Sinks.h>
#include <Functions/GatherUtils/Sources.h> #include <Functions/GatherUtils/Sources.h>

View File

@ -151,10 +151,10 @@ public:
auto to_start_of_interval = FunctionFactory::instance().get("toStartOfInterval", context); auto to_start_of_interval = FunctionFactory::instance().get("toStartOfInterval", context);
if (arguments.size() == 2) if (arguments.size() == 2)
return to_start_of_interval->build(temp_columns)->execute(temp_columns, result_type, input_rows_count); return to_start_of_interval->build(temp_columns)->execute(temp_columns, result_type, input_rows_count, /* dry_run = */ false);
temp_columns[2] = arguments[2]; temp_columns[2] = arguments[2];
return to_start_of_interval->build(temp_columns)->execute(temp_columns, result_type, input_rows_count); return to_start_of_interval->build(temp_columns)->execute(temp_columns, result_type, input_rows_count, /* dry_run = */ false);
} }
bool hasInformationAboutMonotonicity() const override bool hasInformationAboutMonotonicity() const override

View File

@ -1,6 +1,7 @@
#include <DataTypes/DataTypeString.h> #include <DataTypes/DataTypeString.h>
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Functions/FunctionStringOrArrayToT.h> #include <Functions/FunctionStringOrArrayToT.h>
#include <Functions/IFunctionAdaptors.h>
#include <Functions/EmptyImpl.h> #include <Functions/EmptyImpl.h>
#include <Columns/ColumnObject.h> #include <Columns/ColumnObject.h>

View File

@ -5,6 +5,7 @@
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h> #include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h> #include <Functions/IFunction.h>
#include <Functions/IFunctionAdaptors.h>
#include <IO/WriteBufferFromVector.h> #include <IO/WriteBufferFromVector.h>
#include <IO/WriteHelpers.h> #include <IO/WriteHelpers.h>
#include <Processors/Formats/IOutputFormat.h> #include <Processors/Formats/IOutputFormat.h>

View File

@ -7,6 +7,7 @@ namespace DB
using FunctionGreater = FunctionComparison<GreaterOp, NameGreater>; using FunctionGreater = FunctionComparison<GreaterOp, NameGreater>;
using FunctionEquals = FunctionComparison<EqualsOp, NameEquals>; using FunctionEquals = FunctionComparison<EqualsOp, NameEquals>;
extern template class FunctionComparison<EqualsOp, NameEquals>;
REGISTER_FUNCTION(Greater) REGISTER_FUNCTION(Greater)
{ {

View File

@ -2,13 +2,14 @@
#include <Functions/FunctionsComparison.h> #include <Functions/FunctionsComparison.h>
#include <Functions/FunctionsLogical.h> #include <Functions/FunctionsLogical.h>
namespace DB namespace DB
{ {
using FunctionGreaterOrEquals = FunctionComparison<GreaterOrEqualsOp, NameGreaterOrEquals>; using FunctionGreaterOrEquals = FunctionComparison<GreaterOrEqualsOp, NameGreaterOrEquals>;
using FunctionGreater = FunctionComparison<GreaterOp, NameGreater>; using FunctionGreater = FunctionComparison<GreaterOp, NameGreater>;
extern template class FunctionComparison<GreaterOp, NameGreater>;
using FunctionEquals = FunctionComparison<EqualsOp, NameEquals>; using FunctionEquals = FunctionComparison<EqualsOp, NameEquals>;
extern template class FunctionComparison<EqualsOp, NameEquals>;
REGISTER_FUNCTION(GreaterOrEquals) REGISTER_FUNCTION(GreaterOrEquals)
{ {

View File

@ -26,6 +26,7 @@
#include <Functions/FunctionIfBase.h> #include <Functions/FunctionIfBase.h>
#include <Functions/GatherUtils/Algorithms.h> #include <Functions/GatherUtils/Algorithms.h>
#include <Functions/IFunction.h> #include <Functions/IFunction.h>
#include <Functions/IFunctionAdaptors.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Interpreters/castColumn.h> #include <Interpreters/castColumn.h>
#include <Common/assert_cast.h> #include <Common/assert_cast.h>

View File

@ -45,7 +45,7 @@ public:
{ {
ColumnsWithTypeAndName is_finite_columns{arguments[0]}; ColumnsWithTypeAndName is_finite_columns{arguments[0]};
auto is_finite = FunctionFactory::instance().get("isFinite", context)->build(is_finite_columns); auto is_finite = FunctionFactory::instance().get("isFinite", context)->build(is_finite_columns);
auto res = is_finite->execute(is_finite_columns, is_finite->getResultType(), input_rows_count); auto res = is_finite->execute(is_finite_columns, is_finite->getResultType(), input_rows_count, /* dry_run = */ false);
ColumnsWithTypeAndName if_columns ColumnsWithTypeAndName if_columns
{ {
@ -55,7 +55,7 @@ public:
}; };
auto func_if = FunctionFactory::instance().get("if", context)->build(if_columns); auto func_if = FunctionFactory::instance().get("if", context)->build(if_columns);
return func_if->execute(if_columns, result_type, input_rows_count); return func_if->execute(if_columns, result_type, input_rows_count, /* dry_run = */ false);
} }
private: private:

View File

@ -66,11 +66,11 @@ public:
auto is_not_null = FunctionFactory::instance().get("isNotNull", context)->build(columns); auto is_not_null = FunctionFactory::instance().get("isNotNull", context)->build(columns);
auto is_not_null_type = std::make_shared<DataTypeUInt8>(); auto is_not_null_type = std::make_shared<DataTypeUInt8>();
auto is_not_null_res = is_not_null->execute(columns, is_not_null_type, input_rows_count); auto is_not_null_res = is_not_null->execute(columns, is_not_null_type, input_rows_count, /* dry_run = */ false);
auto assume_not_null = FunctionFactory::instance().get("assumeNotNull", context)->build(columns); auto assume_not_null = FunctionFactory::instance().get("assumeNotNull", context)->build(columns);
auto assume_not_null_type = removeNullable(arguments[0].type); auto assume_not_null_type = removeNullable(arguments[0].type);
auto assume_nut_null_res = assume_not_null->execute(columns, assume_not_null_type, input_rows_count); auto assume_nut_null_res = assume_not_null->execute(columns, assume_not_null_type, input_rows_count, /* dry_run = */ false);
ColumnsWithTypeAndName if_columns ColumnsWithTypeAndName if_columns
{ {
@ -80,7 +80,7 @@ public:
}; };
auto func_if = FunctionFactory::instance().get("if", context)->build(if_columns); auto func_if = FunctionFactory::instance().get("if", context)->build(if_columns);
return func_if->execute(if_columns, result_type, input_rows_count); return func_if->execute(if_columns, result_type, input_rows_count, /* dry_run = */ false);
} }
private: private:

View File

@ -7,6 +7,7 @@
#include <Columns/ColumnDecimal.h> #include <Columns/ColumnDecimal.h>
#include <Columns/ColumnConst.h> #include <Columns/ColumnConst.h>
#include <Common/intExp.h> #include <Common/intExp.h>
#include <Core/DecimalFunctions.h>
namespace DB namespace DB

View File

@ -8,6 +8,7 @@ namespace DB
using FunctionLess = FunctionComparison<LessOp, NameLess>; using FunctionLess = FunctionComparison<LessOp, NameLess>;
using FunctionEquals = FunctionComparison<EqualsOp, NameEquals>; using FunctionEquals = FunctionComparison<EqualsOp, NameEquals>;
extern template class FunctionComparison<EqualsOp, NameEquals>;
REGISTER_FUNCTION(Less) REGISTER_FUNCTION(Less)
{ {

View File

@ -8,7 +8,9 @@ namespace DB
using FunctionLessOrEquals = FunctionComparison<LessOrEqualsOp, NameLessOrEquals>; using FunctionLessOrEquals = FunctionComparison<LessOrEqualsOp, NameLessOrEquals>;
using FunctionLess = FunctionComparison<LessOp, NameLess>; using FunctionLess = FunctionComparison<LessOp, NameLess>;
extern template class FunctionComparison<LessOp, NameLess>;
using FunctionEquals = FunctionComparison<EqualsOp, NameEquals>; using FunctionEquals = FunctionComparison<EqualsOp, NameEquals>;
extern template class FunctionComparison<EqualsOp, NameEquals>;
REGISTER_FUNCTION(LessOrEquals) REGISTER_FUNCTION(LessOrEquals)
{ {

View File

@ -10,6 +10,7 @@
#include <Columns/ColumnDecimal.h> #include <Columns/ColumnDecimal.h>
#include <Columns/ColumnsDateTime.h> #include <Columns/ColumnsDateTime.h>
#include <Columns/ColumnsNumber.h> #include <Columns/ColumnsNumber.h>
#include <Core/DecimalFunctions.h>
#include <Interpreters/castColumn.h> #include <Interpreters/castColumn.h>
#include <Common/DateLUT.h> #include <Common/DateLUT.h>

View File

@ -133,13 +133,13 @@ public:
const DataTypePtr & value_array_type = std::make_shared<DataTypeArray>(value_type); const DataTypePtr & value_array_type = std::make_shared<DataTypeArray>(value_type);
/// key_array = array(args[0], args[2]...) /// key_array = array(args[0], args[2]...)
ColumnPtr key_array = function_array->build(key_args)->execute(key_args, key_array_type, input_rows_count); ColumnPtr key_array = function_array->build(key_args)->execute(key_args, key_array_type, input_rows_count, /* dry_run = */ false);
/// value_array = array(args[1], args[3]...) /// value_array = array(args[1], args[3]...)
ColumnPtr value_array = function_array->build(value_args)->execute(value_args, value_array_type, input_rows_count); ColumnPtr value_array = function_array->build(value_args)->execute(value_args, value_array_type, input_rows_count, /* dry_run = */ false);
/// result = mapFromArrays(key_array, value_array) /// result = mapFromArrays(key_array, value_array)
ColumnsWithTypeAndName map_args{{key_array, key_array_type, ""}, {value_array, value_array_type, ""}}; ColumnsWithTypeAndName map_args{{key_array, key_array_type, ""}, {value_array, value_array_type, ""}};
return function_map_from_arrays->build(map_args)->execute(map_args, result_type, input_rows_count); return function_map_from_arrays->build(map_args)->execute(map_args, result_type, input_rows_count, /* dry_run = */ false);
} }
private: private:

View File

@ -70,7 +70,7 @@ public:
}; };
auto date_name_func = function_resolver->build(temporary_columns); auto date_name_func = function_resolver->build(temporary_columns);
return date_name_func->execute(temporary_columns, result_type, input_rows_count); return date_name_func->execute(temporary_columns, result_type, input_rows_count, /* dry_run = */ false);
} }
private: private:

View File

@ -1,5 +1,6 @@
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Functions/FunctionIfBase.h> #include <Functions/FunctionIfBase.h>
#include <Functions/IFunctionAdaptors.h>
#include <Columns/ColumnNullable.h> #include <Columns/ColumnNullable.h>
#include <Columns/ColumnConst.h> #include <Columns/ColumnConst.h>
#include <Columns/ColumnsNumber.h> #include <Columns/ColumnsNumber.h>

View File

@ -1,6 +1,7 @@
#include "FunctionsMultiStringSearch.h" #include <Functions/FunctionsMultiStringSearch.h>
#include "FunctionFactory.h" #include <Functions/FunctionFactory.h>
#include "MultiMatchAnyImpl.h" #include <Functions/MultiMatchAnyImpl.h>
#include <Functions/IFunctionAdaptors.h>
namespace DB namespace DB

View File

@ -49,7 +49,7 @@ public:
/// nullIf(col1, col2) == if(col1 = col2, NULL, col1) /// nullIf(col1, col2) == if(col1 = col2, NULL, col1)
auto equals_func = FunctionFactory::instance().get("equals", context)->build(arguments); auto equals_func = FunctionFactory::instance().get("equals", context)->build(arguments);
auto eq_res = equals_func->execute(arguments, equals_func->getResultType(), input_rows_count); auto eq_res = equals_func->execute(arguments, equals_func->getResultType(), input_rows_count, /* dry_run = */ false);
ColumnsWithTypeAndName if_columns ColumnsWithTypeAndName if_columns
{ {
@ -59,7 +59,7 @@ public:
}; };
auto func_if = FunctionFactory::instance().get("if", context)->build(if_columns); auto func_if = FunctionFactory::instance().get("if", context)->build(if_columns);
auto if_res = func_if->execute(if_columns, result_type, input_rows_count); auto if_res = func_if->execute(if_columns, result_type, input_rows_count, /* dry_run = */ false);
return makeNullable(if_res); return makeNullable(if_res);
} }

View File

@ -242,7 +242,7 @@ public:
} }
} }
auto res = function_concat->build(concat_args)->execute(concat_args, std::make_shared<DataTypeString>(), input_rows_count); auto res = function_concat->build(concat_args)->execute(concat_args, std::make_shared<DataTypeString>(), input_rows_count, /* dry_run = */ false);
return res; return res;
} }

View File

@ -3,6 +3,7 @@
#include <Columns/ColumnArray.h> #include <Columns/ColumnArray.h>
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h> #include <Functions/FunctionHelpers.h>
#include <Functions/IFunctionAdaptors.h>
#include <base/map.h> #include <base/map.h>
#include "reverse.h" #include "reverse.h"

View File

@ -3,6 +3,7 @@
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h> #include <Functions/FunctionHelpers.h>
#include <Functions/FunctionTokens.h> #include <Functions/FunctionTokens.h>
#include <Functions/IFunctionAdaptors.h>
#include <Functions/Regexps.h> #include <Functions/Regexps.h>
#include <Common/StringUtils.h> #include <Common/StringUtils.h>
#include <base/map.h> #include <base/map.h>

View File

@ -55,7 +55,7 @@ namespace
}; };
auto func_cast = createInternalCast(arguments[0], result_type, CastType::nonAccurate, {}); auto func_cast = createInternalCast(arguments[0], result_type, CastType::nonAccurate, {});
return func_cast->execute(cast_args, result_type, arguments[0].column->size()); return func_cast->execute(cast_args, result_type, arguments[0].column->size(), /* dry_run = */ false);
} }
}; };
} }

View File

@ -12,6 +12,7 @@
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h> #include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h> #include <Functions/IFunction.h>
#include <Functions/IFunctionAdaptors.h>
#include <Interpreters/castColumn.h> #include <Interpreters/castColumn.h>
#include <Interpreters/convertFieldToType.h> #include <Interpreters/convertFieldToType.h>
#include <Common/HashTable/HashMap.h> #include <Common/HashTable/HashMap.h>
@ -213,7 +214,7 @@ namespace
auto impl = std::make_shared<FunctionToOverloadResolverAdaptor>(std::make_shared<FunctionTransform>())->build(args); auto impl = std::make_shared<FunctionToOverloadResolverAdaptor>(std::make_shared<FunctionTransform>())->build(args);
return impl->execute(args, result_type, input_rows_count); return impl->execute(args, result_type, input_rows_count, /* dry_run = */ false);
} }
void executeAnything(const IColumn * in, IColumn & column_result, const ColumnPtr default_non_const, const IColumn & in_casted, size_t input_rows_count) const void executeAnything(const IColumn * in, IColumn & column_result, const ColumnPtr default_non_const, const IColumn & in_casted, size_t input_rows_count) const

View File

@ -119,7 +119,7 @@ public:
ColumnWithTypeAndName column; ColumnWithTypeAndName column;
column.type = elem_compare->getResultType(); column.type = elem_compare->getResultType();
column.column = elem_compare->execute({left, right}, column.type, input_rows_count); column.column = elem_compare->execute({left, right}, column.type, input_rows_count, /* dry_run = */ false);
if (i == 0) if (i == 0)
{ {
@ -129,7 +129,7 @@ public:
{ {
auto plus_elem = plus->build({res, column}); auto plus_elem = plus->build({res, column});
auto res_type = plus_elem->getResultType(); auto res_type = plus_elem->getResultType();
res.column = plus_elem->execute({res, column}, res_type, input_rows_count); res.column = plus_elem->execute({res, column}, res_type, input_rows_count, /* dry_run = */ false);
res.type = res_type; res.type = res_type;
} }
} }

View File

@ -4,6 +4,7 @@
#include <DataTypes/DataTypeString.h> #include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypeTuple.h> #include <DataTypes/DataTypeTuple.h>
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
namespace DB namespace DB
{ {

View File

@ -136,7 +136,7 @@ public:
ColumnWithTypeAndName left{left_elements[i], left_types[i], {}}; ColumnWithTypeAndName left{left_elements[i], left_types[i], {}};
ColumnWithTypeAndName right{right_elements[i], right_types[i], {}}; ColumnWithTypeAndName right{right_elements[i], right_types[i], {}};
auto elem_func = func->build(ColumnsWithTypeAndName{left, right}); auto elem_func = func->build(ColumnsWithTypeAndName{left, right});
columns[i] = elem_func->execute({left, right}, elem_func->getResultType(), input_rows_count) columns[i] = elem_func->execute({left, right}, elem_func->getResultType(), input_rows_count, /* dry_run = */ false)
->convertToFullColumnIfConst(); ->convertToFullColumnIfConst();
} }
@ -221,7 +221,7 @@ public:
{ {
ColumnWithTypeAndName cur{cur_elements[i], cur_types[i], {}}; ColumnWithTypeAndName cur{cur_elements[i], cur_types[i], {}};
auto elem_negate = negate->build(ColumnsWithTypeAndName{cur}); auto elem_negate = negate->build(ColumnsWithTypeAndName{cur});
columns[i] = elem_negate->execute({cur}, elem_negate->getResultType(), input_rows_count) columns[i] = elem_negate->execute({cur}, elem_negate->getResultType(), input_rows_count, /* dry_run = */ false)
->convertToFullColumnIfConst(); ->convertToFullColumnIfConst();
} }
@ -295,7 +295,7 @@ public:
{ {
ColumnWithTypeAndName cur{cur_elements[i], cur_types[i], {}}; ColumnWithTypeAndName cur{cur_elements[i], cur_types[i], {}};
auto elem_func = func->build(ColumnsWithTypeAndName{cur, p_column}); auto elem_func = func->build(ColumnsWithTypeAndName{cur, p_column});
columns[i] = elem_func->execute({cur, p_column}, elem_func->getResultType(), input_rows_count) columns[i] = elem_func->execute({cur, p_column}, elem_func->getResultType(), input_rows_count, /* dry_run = */ false)
->convertToFullColumnIfConst(); ->convertToFullColumnIfConst();
} }
@ -413,7 +413,7 @@ public:
ColumnWithTypeAndName column; ColumnWithTypeAndName column;
column.type = elem_multiply->getResultType(); column.type = elem_multiply->getResultType();
column.column = elem_multiply->execute({left, right}, column.type, input_rows_count); column.column = elem_multiply->execute({left, right}, column.type, input_rows_count, /* dry_run = */ false);
if (i == 0) if (i == 0)
{ {
@ -423,7 +423,7 @@ public:
{ {
auto plus_elem = plus->build({res, column}); auto plus_elem = plus->build({res, column});
auto res_type = plus_elem->getResultType(); auto res_type = plus_elem->getResultType();
res.column = plus_elem->execute({res, column}, res_type, input_rows_count); res.column = plus_elem->execute({res, column}, res_type, input_rows_count, /* dry_run = */ false);
res.type = res_type; res.type = res_type;
} }
} }
@ -510,7 +510,7 @@ public:
ColumnWithTypeAndName column{cur_elements[i], cur_types[i], {}}; ColumnWithTypeAndName column{cur_elements[i], cur_types[i], {}};
auto elem_plus = plus->build(ColumnsWithTypeAndName{i == 0 ? arguments[0] : res, column}); auto elem_plus = plus->build(ColumnsWithTypeAndName{i == 0 ? arguments[0] : res, column});
auto res_type = elem_plus->getResultType(); auto res_type = elem_plus->getResultType();
res.column = elem_plus->execute({i == 0 ? arguments[0] : res, column}, res_type, input_rows_count); res.column = elem_plus->execute({i == 0 ? arguments[0] : res, column}, res_type, input_rows_count, /* dry_run = */ false);
res.type = res_type; res.type = res_type;
} }
@ -665,14 +665,14 @@ public:
{ {
auto minus = FunctionFactory::instance().get("minus", context); auto minus = FunctionFactory::instance().get("minus", context);
auto elem_minus = minus->build({left, arguments[1]}); auto elem_minus = minus->build({left, arguments[1]});
last_column = elem_minus->execute({left, arguments[1]}, arguments[1].type, input_rows_count) last_column = elem_minus->execute({left, arguments[1]}, arguments[1].type, input_rows_count, /* dry_run = */ false)
->convertToFullColumnIfConst(); ->convertToFullColumnIfConst();
} }
else else
{ {
auto plus = FunctionFactory::instance().get("plus", context); auto plus = FunctionFactory::instance().get("plus", context);
auto elem_plus = plus->build({left, arguments[1]}); auto elem_plus = plus->build({left, arguments[1]});
last_column = elem_plus->execute({left, arguments[1]}, arguments[1].type, input_rows_count) last_column = elem_plus->execute({left, arguments[1]}, arguments[1].type, input_rows_count, /* dry_run = */ false)
->convertToFullColumnIfConst(); ->convertToFullColumnIfConst();
} }
} }
@ -682,7 +682,7 @@ public:
{ {
auto negate = FunctionFactory::instance().get("negate", context); auto negate = FunctionFactory::instance().get("negate", context);
auto elem_negate = negate->build({arguments[1]}); auto elem_negate = negate->build({arguments[1]});
last_column = elem_negate->execute({arguments[1]}, arguments[1].type, input_rows_count); last_column = elem_negate->execute({arguments[1]}, arguments[1].type, input_rows_count, /* dry_run = */ false);
} }
else else
{ {
@ -783,7 +783,7 @@ public:
ColumnWithTypeAndName column; ColumnWithTypeAndName column;
column.type = elem_abs->getResultType(); column.type = elem_abs->getResultType();
column.column = elem_abs->execute({cur}, column.type, input_rows_count); column.column = elem_abs->execute({cur}, column.type, input_rows_count, /* dry_run = */ false);
if (i == 0) if (i == 0)
{ {
@ -793,7 +793,7 @@ public:
{ {
auto plus_elem = plus->build({res, column}); auto plus_elem = plus->build({res, column});
auto res_type = plus_elem->getResultType(); auto res_type = plus_elem->getResultType();
res.column = plus_elem->execute({res, column}, res_type, input_rows_count); res.column = plus_elem->execute({res, column}, res_type, input_rows_count, /* dry_run = */ false);
res.type = res_type; res.type = res_type;
} }
} }
@ -885,7 +885,7 @@ public:
ColumnWithTypeAndName column; ColumnWithTypeAndName column;
column.type = elem_multiply->getResultType(); column.type = elem_multiply->getResultType();
column.column = elem_multiply->execute({cur, cur}, column.type, input_rows_count); column.column = elem_multiply->execute({cur, cur}, column.type, input_rows_count, /* dry_run = */ false);
if (i == 0) if (i == 0)
{ {
@ -895,7 +895,7 @@ public:
{ {
auto plus_elem = plus->build({res, column}); auto plus_elem = plus->build({res, column});
auto res_type = plus_elem->getResultType(); auto res_type = plus_elem->getResultType();
res.column = plus_elem->execute({res, column}, res_type, input_rows_count); res.column = plus_elem->execute({res, column}, res_type, input_rows_count, /* dry_run = */ false);
res.type = res_type; res.type = res_type;
} }
} }
@ -949,7 +949,7 @@ public:
auto sqrt = FunctionFactory::instance().get("sqrt", context); auto sqrt = FunctionFactory::instance().get("sqrt", context);
auto sqrt_elem = sqrt->build({squared_res}); auto sqrt_elem = sqrt->build({squared_res});
return sqrt_elem->execute({squared_res}, sqrt_elem->getResultType(), input_rows_count); return sqrt_elem->execute({squared_res}, sqrt_elem->getResultType(), input_rows_count, /* dry_run = */ false);
} }
}; };
using FunctionL2Norm = FunctionLNorm<L2Label>; using FunctionL2Norm = FunctionLNorm<L2Label>;
@ -1036,7 +1036,7 @@ public:
ColumnWithTypeAndName column; ColumnWithTypeAndName column;
column.type = elem_abs->getResultType(); column.type = elem_abs->getResultType();
column.column = elem_abs->execute({cur}, column.type, input_rows_count); column.column = elem_abs->execute({cur}, column.type, input_rows_count, /* dry_run = */ false);
if (i == 0) if (i == 0)
{ {
@ -1046,7 +1046,7 @@ public:
{ {
auto max_elem = max->build({res, column}); auto max_elem = max->build({res, column});
auto res_type = max_elem->getResultType(); auto res_type = max_elem->getResultType();
res.column = max_elem->execute({res, column}, res_type, input_rows_count); res.column = max_elem->execute({res, column}, res_type, input_rows_count, /* dry_run = */ false);
res.type = res_type; res.type = res_type;
} }
} }
@ -1163,14 +1163,14 @@ public:
{ {
ColumnWithTypeAndName cur{cur_elements[i], cur_types[i], {}}; ColumnWithTypeAndName cur{cur_elements[i], cur_types[i], {}};
auto elem_abs = abs->build(ColumnsWithTypeAndName{cur}); auto elem_abs = abs->build(ColumnsWithTypeAndName{cur});
cur.column = elem_abs->execute({cur}, elem_abs->getResultType(), input_rows_count); cur.column = elem_abs->execute({cur}, elem_abs->getResultType(), input_rows_count, /* dry_run = */ false);
cur.type = elem_abs->getResultType(); cur.type = elem_abs->getResultType();
auto elem_pow = pow->build(ColumnsWithTypeAndName{cur, p_column}); auto elem_pow = pow->build(ColumnsWithTypeAndName{cur, p_column});
ColumnWithTypeAndName column; ColumnWithTypeAndName column;
column.type = elem_pow->getResultType(); column.type = elem_pow->getResultType();
column.column = elem_pow->execute({cur, p_column}, column.type, input_rows_count); column.column = elem_pow->execute({cur, p_column}, column.type, input_rows_count, /* dry_run = */ false);
if (i == 0) if (i == 0)
{ {
@ -1180,7 +1180,7 @@ public:
{ {
auto plus_elem = plus->build({res, column}); auto plus_elem = plus->build({res, column});
auto res_type = plus_elem->getResultType(); auto res_type = plus_elem->getResultType();
res.column = plus_elem->execute({res, column}, res_type, input_rows_count); res.column = plus_elem->execute({res, column}, res_type, input_rows_count, /* dry_run = */ false);
res.type = res_type; res.type = res_type;
} }
} }
@ -1188,7 +1188,7 @@ public:
ColumnWithTypeAndName inv_p_column{DataTypeFloat64().createColumnConst(input_rows_count, 1 / p), ColumnWithTypeAndName inv_p_column{DataTypeFloat64().createColumnConst(input_rows_count, 1 / p),
std::make_shared<DataTypeFloat64>(), {}}; std::make_shared<DataTypeFloat64>(), {}};
auto pow_elem = pow->build({res, inv_p_column}); auto pow_elem = pow->build({res, inv_p_column});
return pow_elem->execute({res, inv_p_column}, pow_elem->getResultType(), input_rows_count); return pow_elem->execute({res, inv_p_column}, pow_elem->getResultType(), input_rows_count, /* dry_run = */ false);
} }
}; };
using FunctionLpNorm = FunctionLNorm<LpLabel>; using FunctionLpNorm = FunctionLNorm<LpLabel>;
@ -1247,12 +1247,12 @@ public:
if constexpr (FuncLabel::name[0] == 'p') if constexpr (FuncLabel::name[0] == 'p')
{ {
auto func_elem = func->build({minus_res, arguments[2]}); auto func_elem = func->build({minus_res, arguments[2]});
return func_elem->execute({minus_res, arguments[2]}, func_elem->getResultType(), input_rows_count); return func_elem->execute({minus_res, arguments[2]}, func_elem->getResultType(), input_rows_count, /* dry_run = */ false);
} }
else else
{ {
auto func_elem = func->build({minus_res}); auto func_elem = func->build({minus_res});
return func_elem->execute({minus_res}, func_elem->getResultType(), input_rows_count); return func_elem->execute({minus_res}, func_elem->getResultType(), input_rows_count, /* dry_run = */ false);
} }
} }
}; };
@ -1394,16 +1394,16 @@ public:
ColumnWithTypeAndName multiply_result; ColumnWithTypeAndName multiply_result;
multiply_result.type = multiply_elem->getResultType(); multiply_result.type = multiply_elem->getResultType();
multiply_result.column = multiply_elem->execute({first_norm, second_norm}, multiply_result.column = multiply_elem->execute({first_norm, second_norm},
multiply_result.type, input_rows_count); multiply_result.type, input_rows_count, /* dry_run = */ false);
auto divide_elem = divide->build({dot_result, multiply_result}); auto divide_elem = divide->build({dot_result, multiply_result});
ColumnWithTypeAndName divide_result; ColumnWithTypeAndName divide_result;
divide_result.type = divide_elem->getResultType(); divide_result.type = divide_elem->getResultType();
divide_result.column = divide_elem->execute({dot_result, multiply_result}, divide_result.column = divide_elem->execute({dot_result, multiply_result},
divide_result.type, input_rows_count); divide_result.type, input_rows_count, /* dry_run = */ false);
auto minus_elem = minus->build({one, divide_result}); auto minus_elem = minus->build({one, divide_result});
return minus_elem->execute({one, divide_result}, minus_elem->getResultType(), input_rows_count); return minus_elem->execute({one, divide_result}, minus_elem->getResultType(), input_rows_count, /* dry_run = */ false);
} }
}; };

View File

@ -1,25 +1,13 @@
#include <Interpreters/InterpreterFactory.h>
#include <Interpreters/Access/InterpreterCheckGrantQuery.h> #include <Interpreters/Access/InterpreterCheckGrantQuery.h>
#include <Interpreters/executeQuery.h>
#include <Parsers/Access/ASTCheckGrantQuery.h>
#include <Parsers/Access/ASTRolesOrUsersSet.h>
#include <Access/AccessControl.h>
#include <Access/ContextAccess.h> #include <Access/ContextAccess.h>
#include <Access/Role.h>
#include <Access/RolesOrUsersSet.h>
#include <Access/User.h>
#include <Interpreters/Context.h>
#include <Interpreters/removeOnClusterClauseIfNeeded.h>
#include <Interpreters/QueryLog.h>
#include <Interpreters/executeDDLQueryOnCluster.h>
#include <boost/range/algorithm/copy.hpp>
#include <boost/range/algorithm/set_algorithm.hpp>
#include <boost/range/algorithm_ext/erase.hpp>
#include <Interpreters/DatabaseCatalog.h>
#include <Columns/ColumnsNumber.h> #include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypesNumber.h> #include <DataTypes/DataTypesNumber.h>
#include <Interpreters/Context.h>
#include <Interpreters/InterpreterFactory.h>
#include <Parsers/Access/ASTCheckGrantQuery.h>
#include <Processors/Sources/SourceFromSingleChunk.h> #include <Processors/Sources/SourceFromSingleChunk.h>
#include "Storages/IStorage.h"
namespace DB namespace DB
{ {
@ -27,21 +15,19 @@ namespace DB
BlockIO InterpreterCheckGrantQuery::execute() BlockIO InterpreterCheckGrantQuery::execute()
{ {
auto & query = query_ptr->as<ASTCheckGrantQuery &>(); auto & query = query_ptr->as<ASTCheckGrantQuery &>();
query.access_rights_elements.eraseNonGrantable();
auto current_user_access = getContext()->getAccess();
/// Collect access rights elements which will be checked. /// Collect access rights elements which will be checked.
AccessRightsElements & elements_to_check_grant = query.access_rights_elements; AccessRightsElements & elements_to_check_grant = query.access_rights_elements;
/// Replacing empty database with the default. This step must be done before replication to avoid privilege escalation.
String current_database = getContext()->getCurrentDatabase(); String current_database = getContext()->getCurrentDatabase();
elements_to_check_grant.replaceEmptyDatabase(current_database); elements_to_check_grant.replaceEmptyDatabase(current_database);
query.access_rights_elements.replaceEmptyDatabase(current_database);
bool user_is_granted = current_user_access->isGranted(elements_to_check_grant); auto current_user_access = getContext()->getAccess();
bool is_granted = current_user_access->isGranted(elements_to_check_grant);
BlockIO res; BlockIO res;
res.pipeline = QueryPipeline( res.pipeline = QueryPipeline(
std::make_shared<SourceFromSingleChunk>(Block{{ColumnUInt8::create(1, user_is_granted), std::make_shared<DataTypeUInt8>(), "result"}})); std::make_shared<SourceFromSingleChunk>(Block{{ColumnUInt8::create(1, is_granted), std::make_shared<DataTypeUInt8>(), "result"}}));
return res; return res;
} }

View File

@ -1,6 +1,5 @@
#pragma once #pragma once
#include <Core/UUID.h>
#include <Interpreters/IInterpreter.h> #include <Interpreters/IInterpreter.h>
#include <Parsers/IAST_fwd.h> #include <Parsers/IAST_fwd.h>
@ -8,10 +7,6 @@
namespace DB namespace DB
{ {
class ASTCheckGrantQuery;
struct User;
struct Role;
class InterpreterCheckGrantQuery : public IInterpreter, WithMutableContext class InterpreterCheckGrantQuery : public IInterpreter, WithMutableContext
{ {
public: public:

View File

@ -418,7 +418,7 @@ BlockIO InterpreterGrantQuery::execute()
auto & query = updated_query->as<ASTGrantQuery &>(); auto & query = updated_query->as<ASTGrantQuery &>();
query.replaceCurrentUserTag(getContext()->getUserName()); query.replaceCurrentUserTag(getContext()->getUserName());
query.access_rights_elements.eraseNonGrantable(); query.access_rights_elements.eraseNotGrantable();
if (!query.access_rights_elements.sameOptions()) if (!query.access_rights_elements.sameOptions())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Elements of an ASTGrantQuery are expected to have the same options"); throw Exception(ErrorCodes::LOGICAL_ERROR, "Elements of an ASTGrantQuery are expected to have the same options");

View File

@ -5,6 +5,7 @@
#include <Columns/ColumnMap.h> #include <Columns/ColumnMap.h>
#include <DataTypes/DataTypesNumber.h> #include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Functions/IFunctionAdaptors.h>
#include <Functions/array/length.h> #include <Functions/array/length.h>
#include <Functions/array/arrayResize.h> #include <Functions/array/arrayResize.h>
#include <Functions/array/emptyArrayToSingle.h> #include <Functions/array/emptyArrayToSingle.h>
@ -166,7 +167,7 @@ ArrayJoinResultIterator::ArrayJoinResultIterator(const ArrayJoinAction * array_j
ColumnWithTypeAndName array_col = convertArrayJoinColumn(src_col); ColumnWithTypeAndName array_col = convertArrayJoinColumn(src_col);
ColumnsWithTypeAndName tmp_block{array_col}; //, {{}, uint64, {}}}; ColumnsWithTypeAndName tmp_block{array_col}; //, {{}, uint64, {}}};
auto len_col = function_length->build(tmp_block)->execute(tmp_block, uint64, rows); auto len_col = function_length->build(tmp_block)->execute(tmp_block, uint64, rows, /* dry_run = */ false);
updateMaxLength(*max_length, *len_col); updateMaxLength(*max_length, *len_col);
} }
@ -177,7 +178,7 @@ ArrayJoinResultIterator::ArrayJoinResultIterator(const ArrayJoinAction * array_j
ColumnWithTypeAndName array_col = convertArrayJoinColumn(src_col); ColumnWithTypeAndName array_col = convertArrayJoinColumn(src_col);
ColumnsWithTypeAndName tmp_block{array_col, column_of_max_length}; ColumnsWithTypeAndName tmp_block{array_col, column_of_max_length};
array_col.column = function_array_resize->build(tmp_block)->execute(tmp_block, array_col.type, rows); array_col.column = function_array_resize->build(tmp_block)->execute(tmp_block, array_col.type, rows, /* dry_run = */ false);
src_col = std::move(array_col); src_col = std::move(array_col);
any_array_map_ptr = src_col.column->convertToFullColumnIfConst(); any_array_map_ptr = src_col.column->convertToFullColumnIfConst();
@ -194,7 +195,7 @@ ArrayJoinResultIterator::ArrayJoinResultIterator(const ArrayJoinAction * array_j
const auto & src_col = block.getByName(name); const auto & src_col = block.getByName(name);
ColumnWithTypeAndName array_col = convertArrayJoinColumn(src_col); ColumnWithTypeAndName array_col = convertArrayJoinColumn(src_col);
ColumnsWithTypeAndName tmp_block{array_col}; ColumnsWithTypeAndName tmp_block{array_col};
non_empty_array_columns[name] = function_builder->build(tmp_block)->execute(tmp_block, array_col.type, array_col.column->size()); non_empty_array_columns[name] = function_builder->build(tmp_block)->execute(tmp_block, array_col.type, array_col.column->size(), /* dry_run = */ false);
} }
any_array_map_ptr = non_empty_array_columns.begin()->second->convertToFullColumnIfConst(); any_array_map_ptr = non_empty_array_columns.begin()->second->convertToFullColumnIfConst();

View File

@ -266,7 +266,7 @@ public:
{ {
const auto & type = function.getArgumentTypes().at(0); const auto & type = function.getArgumentTypes().at(0);
ColumnsWithTypeAndName args{{type->createColumnConst(1, value), type, "x" }}; ColumnsWithTypeAndName args{{type->createColumnConst(1, value), type, "x" }};
auto col = function.execute(args, function.getResultType(), 1); auto col = function.execute(args, function.getResultType(), 1, /* dry_run = */ false);
col->get(0, value); col->get(0, value);
} }

View File

@ -1,10 +1,11 @@
#include <Interpreters/RowRefs.h> #include <Interpreters/RowRefs.h>
#include <Common/RadixSort.h> #include <Columns/ColumnDecimal.h>
#include <Columns/IColumn.h> #include <Columns/IColumn.h>
#include <DataTypes/IDataType.h>
#include <Core/Joins.h> #include <Core/Joins.h>
#include <DataTypes/IDataType.h>
#include <base/types.h> #include <base/types.h>
#include <Common/RadixSort.h>
namespace DB namespace DB

View File

@ -7,7 +7,6 @@
#include <optional> #include <optional>
#include <variant> #include <variant>
#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnVector.h> #include <Columns/ColumnVector.h>
#include <Columns/IColumn.h> #include <Columns/IColumn.h>
#include <Core/Joins.h> #include <Core/Joins.h>

View File

@ -34,8 +34,8 @@ static ColumnPtr castColumn(CastType cast_type, const ColumnWithTypeAndName & ar
FunctionBasePtr func_cast = cache ? cache->getOrSet(cast_type, from_name, to_name, std::move(get_cast_func)) : get_cast_func(); FunctionBasePtr func_cast = cache ? cache->getOrSet(cast_type, from_name, to_name, std::move(get_cast_func)) : get_cast_func();
if (cast_type == CastType::accurateOrNull) if (cast_type == CastType::accurateOrNull)
return func_cast->execute(arguments, makeNullable(type), arg.column->size()); return func_cast->execute(arguments, makeNullable(type), arg.column->size(), /* dry_run = */ false);
return func_cast->execute(arguments, type, arg.column->size()); return func_cast->execute(arguments, type, arg.column->size(), /* dry_run = */ false);
} }
ColumnPtr castColumn(const ColumnWithTypeAndName & arg, const DataTypePtr & type, InternalCastFunctionCache * cache) ColumnPtr castColumn(const ColumnWithTypeAndName & arg, const DataTypePtr & type, InternalCastFunctionCache * cache)

View File

@ -1,98 +1,11 @@
#include <Parsers/Access/ASTCheckGrantQuery.h> #include <Parsers/Access/ASTCheckGrantQuery.h>
#include <Parsers/Access/ASTRolesOrUsersSet.h>
#include <Common/quoteString.h>
#include <IO/Operators.h> #include <IO/Operators.h>
namespace DB namespace DB
{ {
namespace
{
void formatColumnNames(const Strings & columns, const IAST::FormatSettings & settings)
{
settings.ostr << "(";
bool need_comma = false;
for (const auto & column : columns)
{
if (std::exchange(need_comma, true))
settings.ostr << ", ";
settings.ostr << backQuoteIfNeed(column);
}
settings.ostr << ")";
}
void formatONClause(const AccessRightsElement & element, const IAST::FormatSettings & settings)
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "ON " << (settings.hilite ? IAST::hilite_none : "");
if (element.isGlobalWithParameter())
{
if (element.anyParameter())
settings.ostr << "*";
else
settings.ostr << backQuoteIfNeed(element.parameter);
}
else if (element.anyDatabase())
{
settings.ostr << "*.*";
}
else
{
if (!element.database.empty())
settings.ostr << backQuoteIfNeed(element.database) << ".";
if (element.anyDatabase())
settings.ostr << "*";
else
settings.ostr << backQuoteIfNeed(element.table);
}
}
void formatElementsWithoutOptions(const AccessRightsElements & elements, const IAST::FormatSettings & settings)
{
bool no_output = true;
for (size_t i = 0; i != elements.size(); ++i)
{
const auto & element = elements[i];
auto keywords = element.access_flags.toKeywords();
if (keywords.empty() || (!element.anyColumn() && element.columns.empty()))
continue;
for (const auto & keyword : keywords)
{
if (!std::exchange(no_output, false))
settings.ostr << ", ";
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << keyword << (settings.hilite ? IAST::hilite_none : "");
if (!element.anyColumn())
formatColumnNames(element.columns, settings);
}
bool next_element_on_same_db_and_table = false;
if (i != elements.size() - 1)
{
const auto & next_element = elements[i + 1];
if (element.sameDatabaseAndTableAndParameter(next_element))
{
next_element_on_same_db_and_table = true;
}
}
if (!next_element_on_same_db_and_table)
{
settings.ostr << " ";
formatONClause(element, settings);
}
}
if (no_output)
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "USAGE ON " << (settings.hilite ? IAST::hilite_none : "") << "*.*";
}
}
String ASTCheckGrantQuery::getID(char) const String ASTCheckGrantQuery::getID(char) const
{ {
return "CheckGrantQuery"; return "CheckGrantQuery";
@ -113,8 +26,7 @@ void ASTCheckGrantQuery::formatImpl(const FormatSettings & settings, FormatState
<< (settings.hilite ? IAST::hilite_none : ""); << (settings.hilite ? IAST::hilite_none : "");
settings.ostr << " "; settings.ostr << " ";
access_rights_elements.formatElementsWithoutOptions(settings.ostr, settings.hilite);
formatElementsWithoutOptions(access_rights_elements, settings);
} }

View File

@ -2,13 +2,10 @@
#include <Parsers/IAST.h> #include <Parsers/IAST.h>
#include <Access/Common/AccessRightsElement.h> #include <Access/Common/AccessRightsElement.h>
#include <Parsers/ASTQueryWithOnCluster.h>
namespace DB namespace DB
{ {
class ASTRolesOrUsersSet;
/** Parses queries like /** Parses queries like
* CHECK GRANT access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} * CHECK GRANT access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*}
@ -24,4 +21,5 @@ public:
void replaceEmptyDatabase(const String & current_database); void replaceEmptyDatabase(const String & current_database);
QueryKind getQueryKind() const override { return QueryKind::Check; } QueryKind getQueryKind() const override { return QueryKind::Check; }
}; };
} }

View File

@ -1,6 +1,5 @@
#include <Parsers/Access/ASTGrantQuery.h> #include <Parsers/Access/ASTGrantQuery.h>
#include <Parsers/Access/ASTRolesOrUsersSet.h> #include <Parsers/Access/ASTRolesOrUsersSet.h>
#include <Common/quoteString.h>
#include <IO/Operators.h> #include <IO/Operators.h>
@ -13,52 +12,10 @@ namespace ErrorCodes
namespace namespace
{ {
void formatElementsWithoutOptions(const AccessRightsElements & elements, const IAST::FormatSettings & settings)
{
bool no_output = true;
for (size_t i = 0; i != elements.size(); ++i)
{
const auto & element = elements[i];
auto keywords = element.access_flags.toKeywords();
if (keywords.empty() || (!element.anyColumn() && element.columns.empty()))
continue;
for (const auto & keyword : keywords)
{
if (!std::exchange(no_output, false))
settings.ostr << ", ";
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << keyword << (settings.hilite ? IAST::hilite_none : "");
if (!element.anyColumn())
element.formatColumnNames(settings.ostr);
}
bool next_element_on_same_db_and_table = false;
if (i != elements.size() - 1)
{
const auto & next_element = elements[i + 1];
if (element.sameDatabaseAndTableAndParameter(next_element))
{
next_element_on_same_db_and_table = true;
}
}
if (!next_element_on_same_db_and_table)
{
settings.ostr << " ";
element.formatONClause(settings.ostr, settings.hilite);
}
}
if (no_output)
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "USAGE ON " << (settings.hilite ? IAST::hilite_none : "") << "*.*";
}
void formatCurrentGrantsElements(const AccessRightsElements & elements, const IAST::FormatSettings & settings) void formatCurrentGrantsElements(const AccessRightsElements & elements, const IAST::FormatSettings & settings)
{ {
settings.ostr << "("; settings.ostr << "(";
formatElementsWithoutOptions(elements, settings); elements.formatElementsWithoutOptions(settings.ostr, settings.hilite);
settings.ostr << ")"; settings.ostr << ")";
} }
} }
@ -122,7 +79,7 @@ void ASTGrantQuery::formatImpl(const FormatSettings & settings, FormatState &, F
} }
else else
{ {
formatElementsWithoutOptions(access_rights_elements, settings); access_rights_elements.formatElementsWithoutOptions(settings.ostr, settings.hilite);
} }
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << (is_revoke ? " FROM " : " TO ") settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << (is_revoke ? " FROM " : " TO ")

View File

@ -1,219 +1,24 @@
#include <Parsers/ASTIdentifier_fwd.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/Access/ASTCheckGrantQuery.h>
#include <Parsers/Access/ASTRolesOrUsersSet.h>
#include <Parsers/Access/ParserCheckGrantQuery.h> #include <Parsers/Access/ParserCheckGrantQuery.h>
#include <Parsers/Access/ParserRolesOrUsersSet.h>
#include <Parsers/ExpressionElementParsers.h> #include <Access/Common/AccessRightsElement.h>
#include <Parsers/ExpressionListParsers.h> #include <Parsers/Access/ASTCheckGrantQuery.h>
#include <Parsers/parseDatabaseAndTableName.h> #include <Parsers/Access/parseAccessRightsElements.h>
#include <boost/algorithm/string/predicate.hpp> #include <Parsers/CommonParsers.h>
#include <boost/range/algorithm_ext/erase.hpp>
namespace DB namespace DB
{ {
namespace ErrorCodes
{
extern const int INVALID_GRANT;
}
namespace
{
bool parseAccessFlags(IParser::Pos & pos, Expected & expected, AccessFlags & access_flags)
{
static constexpr auto is_one_of_access_type_words = [](IParser::Pos & pos_)
{
if (pos_->type != TokenType::BareWord)
return false;
std::string_view word{pos_->begin, pos_->size()};
return !(boost::iequals(word, toStringView(Keyword::ON)) || boost::iequals(word, toStringView(Keyword::TO)) || boost::iequals(word, toStringView(Keyword::FROM)));
};
expected.add(pos, "access type");
return IParserBase::wrapParseImpl(pos, [&]
{
if (!is_one_of_access_type_words(pos))
return false;
String str;
do
{
if (!str.empty())
str += " ";
str += std::string_view(pos->begin, pos->size());
++pos;
}
while (is_one_of_access_type_words(pos));
try
{
access_flags = AccessFlags{str};
}
catch (...)
{
return false;
}
return true;
});
}
bool parseColumnNames(IParser::Pos & pos, Expected & expected, Strings & columns)
{
return IParserBase::wrapParseImpl(pos, [&]
{
if (!ParserToken{TokenType::OpeningRoundBracket}.ignore(pos, expected))
return false;
ASTPtr ast;
if (!ParserList{std::make_unique<ParserIdentifier>(), std::make_unique<ParserToken>(TokenType::Comma), false}.parse(pos, ast, expected))
return false;
Strings res_columns;
for (const auto & child : ast->children)
res_columns.emplace_back(getIdentifierName(child));
if (!ParserToken{TokenType::ClosingRoundBracket}.ignore(pos, expected))
return false;
columns = std::move(res_columns);
return true;
});
}
bool parseAccessFlagsWithColumns(IParser::Pos & pos, Expected & expected,
std::vector<std::pair<AccessFlags, Strings>> & access_and_columns)
{
std::vector<std::pair<AccessFlags, Strings>> res;
auto parse_access_and_columns = [&]
{
AccessFlags access_flags;
if (!parseAccessFlags(pos, expected, access_flags))
return false;
Strings columns;
parseColumnNames(pos, expected, columns);
res.emplace_back(access_flags, std::move(columns));
return true;
};
if (!ParserList::parseUtil(pos, expected, parse_access_and_columns, false))
return false;
access_and_columns = std::move(res);
return true;
}
bool parseElements(IParser::Pos & pos, Expected & expected, AccessRightsElements & elements)
{
return IParserBase::wrapParseImpl(pos, [&]
{
AccessRightsElements res_elements;
auto parse_around_on = [&]
{
std::vector<std::pair<AccessFlags, Strings>> access_and_columns;
if (!parseAccessFlagsWithColumns(pos, expected, access_and_columns))
return false;
String database_name, table_name, parameter;
size_t is_global_with_parameter = 0;
for (const auto & elem : access_and_columns)
{
if (elem.first.isGlobalWithParameter())
++is_global_with_parameter;
}
if (!ParserKeyword{Keyword::ON}.ignore(pos, expected))
return false;
bool wildcard = false;
bool default_database = false;
if (is_global_with_parameter && is_global_with_parameter == access_and_columns.size())
{
ASTPtr parameter_ast;
if (!ParserToken{TokenType::Asterisk}.ignore(pos, expected))
{
if (ParserIdentifier{}.parse(pos, parameter_ast, expected))
parameter = getIdentifierName(parameter_ast);
else
return false;
}
if (ParserToken{TokenType::Asterisk}.ignore(pos, expected))
wildcard = true;
}
else if (!parseDatabaseAndTableNameOrAsterisks(pos, expected, database_name, table_name, wildcard, default_database))
return false;
for (auto & [access_flags, columns] : access_and_columns)
{
AccessRightsElement element;
element.access_flags = access_flags;
element.columns = std::move(columns);
element.database = database_name;
element.table = table_name;
element.parameter = parameter;
element.wildcard = wildcard;
element.default_database = default_database;
res_elements.emplace_back(std::move(element));
}
return true;
};
if (!ParserList::parseUtil(pos, expected, parse_around_on, false))
return false;
elements = std::move(res_elements);
return true;
});
}
void throwIfNotGrantable(AccessRightsElements & elements)
{
std::erase_if(elements, [](AccessRightsElement & element)
{
if (element.empty())
return true;
auto old_flags = element.access_flags;
element.eraseNonGrantable();
if (!element.empty())
return false;
if (!element.anyColumn())
throw Exception(ErrorCodes::INVALID_GRANT, "{} cannot check grant on the column level", old_flags.toString());
else if (!element.anyTable())
throw Exception(ErrorCodes::INVALID_GRANT, "{} cannot check grant on the table level", old_flags.toString());
else if (!element.anyDatabase())
throw Exception(ErrorCodes::INVALID_GRANT, "{} cannot check grant on the database level", old_flags.toString());
else if (!element.anyParameter())
throw Exception(ErrorCodes::INVALID_GRANT, "{} cannot check grant on the global with parameter level", old_flags.toString());
else
throw Exception(ErrorCodes::INVALID_GRANT, "{} cannot check grant", old_flags.toString());
});
}
}
bool ParserCheckGrantQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) bool ParserCheckGrantQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{ {
if (!ParserKeyword{Keyword::CHECK_GRANT}.ignore(pos, expected)) if (!ParserKeyword{Keyword::CHECK_GRANT}.ignore(pos, expected))
return false; return false;
AccessRightsElements elements; AccessRightsElements elements;
if (!parseAccessRightsElementsWithoutOptions(pos, expected, elements))
if (!parseElements(pos, expected, elements))
return false; return false;
throwIfNotGrantable(elements); elements.throwIfNotGrantable();
auto query = std::make_shared<ASTCheckGrantQuery>(); auto query = std::make_shared<ASTCheckGrantQuery>();
node = query; node = query;

View File

@ -1,190 +1,29 @@
#include <Parsers/ASTIdentifier_fwd.h> #include <Parsers/Access/ParserGrantQuery.h>
#include <Parsers/ASTLiteral.h>
#include <Access/Common/AccessRightsElement.h>
#include <Parsers/ASTQueryWithOnCluster.h>
#include <Parsers/Access/ASTGrantQuery.h> #include <Parsers/Access/ASTGrantQuery.h>
#include <Parsers/Access/ASTRolesOrUsersSet.h> #include <Parsers/Access/ASTRolesOrUsersSet.h>
#include <Parsers/Access/ParserGrantQuery.h>
#include <Parsers/Access/ParserRolesOrUsersSet.h> #include <Parsers/Access/ParserRolesOrUsersSet.h>
#include <Parsers/ExpressionElementParsers.h> #include <Parsers/Access/parseAccessRightsElements.h>
#include <Parsers/ExpressionListParsers.h> #include <Parsers/CommonParsers.h>
#include <Parsers/parseDatabaseAndTableName.h> #include <Parsers/parseDatabaseAndTableName.h>
#include <boost/algorithm/string/predicate.hpp>
#include <boost/range/algorithm_ext/erase.hpp>
namespace DB namespace DB
{ {
namespace ErrorCodes namespace ErrorCodes
{ {
extern const int INVALID_GRANT;
extern const int SYNTAX_ERROR; extern const int SYNTAX_ERROR;
} }
namespace namespace
{ {
bool parseAccessFlags(IParser::Pos & pos, Expected & expected, AccessFlags & access_flags)
{
static constexpr auto is_one_of_access_type_words = [](IParser::Pos & pos_)
{
if (pos_->type != TokenType::BareWord)
return false;
std::string_view word{pos_->begin, pos_->size()};
return !(boost::iequals(word, toStringView(Keyword::ON)) || boost::iequals(word, toStringView(Keyword::TO)) || boost::iequals(word, toStringView(Keyword::FROM)));
};
expected.add(pos, "access type");
return IParserBase::wrapParseImpl(pos, [&]
{
if (!is_one_of_access_type_words(pos))
return false;
String str;
do
{
if (!str.empty())
str += " ";
str += std::string_view(pos->begin, pos->size());
++pos;
}
while (is_one_of_access_type_words(pos));
try
{
access_flags = AccessFlags{str};
}
catch (...)
{
return false;
}
return true;
});
}
bool parseColumnNames(IParser::Pos & pos, Expected & expected, Strings & columns)
{
return IParserBase::wrapParseImpl(pos, [&]
{
if (!ParserToken{TokenType::OpeningRoundBracket}.ignore(pos, expected))
return false;
ASTPtr ast;
if (!ParserList{std::make_unique<ParserIdentifier>(), std::make_unique<ParserToken>(TokenType::Comma), false}.parse(pos, ast, expected))
return false;
Strings res_columns;
for (const auto & child : ast->children)
res_columns.emplace_back(getIdentifierName(child));
if (!ParserToken{TokenType::ClosingRoundBracket}.ignore(pos, expected))
return false;
columns = std::move(res_columns);
return true;
});
}
bool parseAccessFlagsWithColumns(IParser::Pos & pos, Expected & expected,
std::vector<std::pair<AccessFlags, Strings>> & access_and_columns)
{
std::vector<std::pair<AccessFlags, Strings>> res;
auto parse_access_and_columns = [&]
{
AccessFlags access_flags;
if (!parseAccessFlags(pos, expected, access_flags))
return false;
Strings columns;
parseColumnNames(pos, expected, columns);
res.emplace_back(access_flags, std::move(columns));
return true;
};
if (!ParserList::parseUtil(pos, expected, parse_access_and_columns, false))
return false;
access_and_columns = std::move(res);
return true;
}
bool parseElementsWithoutOptions(IParser::Pos & pos, Expected & expected, AccessRightsElements & elements)
{
return IParserBase::wrapParseImpl(pos, [&]
{
AccessRightsElements res_elements;
auto parse_around_on = [&]
{
std::vector<std::pair<AccessFlags, Strings>> access_and_columns;
if (!parseAccessFlagsWithColumns(pos, expected, access_and_columns))
return false;
String database_name, table_name, parameter;
size_t is_global_with_parameter = 0;
for (const auto & elem : access_and_columns)
{
if (elem.first.isGlobalWithParameter())
++is_global_with_parameter;
}
if (!ParserKeyword{Keyword::ON}.ignore(pos, expected))
return false;
bool wildcard = false;
bool default_database = false;
if (is_global_with_parameter && is_global_with_parameter == access_and_columns.size())
{
ASTPtr parameter_ast;
if (!ParserToken{TokenType::Asterisk}.ignore(pos, expected))
{
if (ParserIdentifier{}.parse(pos, parameter_ast, expected))
parameter = getIdentifierName(parameter_ast);
else
return false;
}
if (ParserToken{TokenType::Asterisk}.ignore(pos, expected))
wildcard = true;
}
else if (!parseDatabaseAndTableNameOrAsterisks(pos, expected, database_name, table_name, wildcard, default_database))
return false;
for (auto & [access_flags, columns] : access_and_columns)
{
if (wildcard && !columns.empty())
return false;
AccessRightsElement element;
element.access_flags = access_flags;
element.columns = std::move(columns);
element.database = database_name;
element.table = table_name;
element.parameter = parameter;
element.wildcard = wildcard;
element.default_database = default_database;
res_elements.emplace_back(std::move(element));
}
return true;
};
if (!ParserList::parseUtil(pos, expected, parse_around_on, false))
return false;
elements = std::move(res_elements);
return true;
});
}
bool parseCurrentGrants(IParser::Pos & pos, Expected & expected, AccessRightsElements & elements) bool parseCurrentGrants(IParser::Pos & pos, Expected & expected, AccessRightsElements & elements)
{ {
if (ParserToken(TokenType::OpeningRoundBracket).ignore(pos, expected)) if (ParserToken(TokenType::OpeningRoundBracket).ignore(pos, expected))
{ {
if (!parseElementsWithoutOptions(pos, expected, elements)) if (!parseAccessRightsElementsWithoutOptions(pos, expected, elements))
return false; return false;
if (!ParserToken(TokenType::ClosingRoundBracket).ignore(pos, expected)) if (!ParserToken(TokenType::ClosingRoundBracket).ignore(pos, expected))
@ -214,30 +53,6 @@ namespace
return true; return true;
} }
void throwIfNotGrantable(AccessRightsElements & elements)
{
std::erase_if(elements, [](AccessRightsElement & element)
{
if (element.empty())
return true;
auto old_flags = element.access_flags;
element.eraseNonGrantable();
if (!element.empty())
return false;
if (!element.anyColumn())
throw Exception(ErrorCodes::INVALID_GRANT, "{} cannot be granted on the column level", old_flags.toString());
if (!element.anyTable())
throw Exception(ErrorCodes::INVALID_GRANT, "{} cannot be granted on the table level", old_flags.toString());
if (!element.anyDatabase())
throw Exception(ErrorCodes::INVALID_GRANT, "{} cannot be granted on the database level", old_flags.toString());
if (!element.anyParameter())
throw Exception(ErrorCodes::INVALID_GRANT, "{} cannot be granted on the global with parameter level", old_flags.toString());
throw Exception(ErrorCodes::INVALID_GRANT, "{} cannot be granted", old_flags.toString());
});
}
bool parseRoles(IParser::Pos & pos, Expected & expected, bool is_revoke, bool id_mode, std::shared_ptr<ASTRolesOrUsersSet> & roles) bool parseRoles(IParser::Pos & pos, Expected & expected, bool is_revoke, bool id_mode, std::shared_ptr<ASTRolesOrUsersSet> & roles)
{ {
@ -323,7 +138,7 @@ bool ParserGrantQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
} }
else else
{ {
if (!parseElementsWithoutOptions(pos, expected, elements) && !parseRoles(pos, expected, is_revoke, attach_mode, roles)) if (!parseAccessRightsElementsWithoutOptions(pos, expected, elements) && !parseRoles(pos, expected, is_revoke, attach_mode, roles))
return false; return false;
} }
@ -373,13 +188,8 @@ bool ParserGrantQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
replace_access = true; replace_access = true;
} }
if (!is_revoke) if (!is_revoke && !attach_mode)
{ elements.throwIfNotGrantable();
if (attach_mode)
elements.eraseNonGrantable();
else
throwIfNotGrantable(elements);
}
auto query = std::make_shared<ASTGrantQuery>(); auto query = std::make_shared<ASTGrantQuery>();
node = query; node = query;

View File

@ -0,0 +1,178 @@
#include <Parsers/Access/parseAccessRightsElements.h>
#include <Access/Common/AccessRightsElement.h>
#include <Parsers/ASTIdentifier_fwd.h>
#include <Parsers/CommonParsers.h>
#include <Parsers/ExpressionListParsers.h>
#include <Parsers/IAST.h>
#include <Parsers/IParserBase.h>
#include <Parsers/parseDatabaseAndTableName.h>
#include <boost/algorithm/string/predicate.hpp>
namespace DB
{
namespace
{
bool parseColumnNames(IParser::Pos & pos, Expected & expected, Strings & columns)
{
return IParserBase::wrapParseImpl(pos, [&]
{
if (!ParserToken{TokenType::OpeningRoundBracket}.ignore(pos, expected))
return false;
ASTPtr ast;
if (!ParserList{std::make_unique<ParserIdentifier>(), std::make_unique<ParserToken>(TokenType::Comma), false}.parse(pos, ast, expected))
return false;
Strings res_columns;
for (const auto & child : ast->children)
res_columns.emplace_back(getIdentifierName(child));
if (!ParserToken{TokenType::ClosingRoundBracket}.ignore(pos, expected))
return false;
columns = std::move(res_columns);
return true;
});
}
}
bool parseAccessFlags(IParser::Pos & pos, Expected & expected, AccessFlags & access_flags)
{
static constexpr auto is_one_of_access_type_words = [](IParser::Pos & pos_)
{
if (pos_->type != TokenType::BareWord)
return false;
std::string_view word{pos_->begin, pos_->size()};
return !(boost::iequals(word, toStringView(Keyword::ON)) || boost::iequals(word, toStringView(Keyword::TO)) || boost::iequals(word, toStringView(Keyword::FROM)));
};
expected.add(pos, "access type");
return IParserBase::wrapParseImpl(pos, [&]
{
if (!is_one_of_access_type_words(pos))
return false;
String str;
do
{
if (!str.empty())
str += " ";
str += std::string_view(pos->begin, pos->size());
++pos;
}
while (is_one_of_access_type_words(pos));
try
{
access_flags = AccessFlags{str};
}
catch (...)
{
return false;
}
return true;
});
}
bool parseAccessFlagsWithColumns(IParser::Pos & pos, Expected & expected,
std::vector<std::pair<AccessFlags, Strings>> & access_and_columns)
{
std::vector<std::pair<AccessFlags, Strings>> res;
auto parse_access_and_columns = [&]
{
AccessFlags access_flags;
if (!parseAccessFlags(pos, expected, access_flags))
return false;
Strings columns;
parseColumnNames(pos, expected, columns);
res.emplace_back(access_flags, std::move(columns));
return true;
};
if (!ParserList::parseUtil(pos, expected, parse_access_and_columns, false))
return false;
access_and_columns = std::move(res);
return true;
}
bool parseAccessRightsElementsWithoutOptions(IParser::Pos & pos, Expected & expected, AccessRightsElements & elements)
{
return IParserBase::wrapParseImpl(pos, [&]
{
AccessRightsElements res_elements;
auto parse_around_on = [&]
{
std::vector<std::pair<AccessFlags, Strings>> access_and_columns;
if (!parseAccessFlagsWithColumns(pos, expected, access_and_columns))
return false;
String database_name, table_name, parameter;
size_t is_global_with_parameter = 0;
for (const auto & elem : access_and_columns)
{
if (elem.first.isGlobalWithParameter())
++is_global_with_parameter;
}
if (!ParserKeyword{Keyword::ON}.ignore(pos, expected))
return false;
bool wildcard = false;
bool default_database = false;
if (is_global_with_parameter && is_global_with_parameter == access_and_columns.size())
{
ASTPtr parameter_ast;
if (!ParserToken{TokenType::Asterisk}.ignore(pos, expected))
{
if (ParserIdentifier{}.parse(pos, parameter_ast, expected))
parameter = getIdentifierName(parameter_ast);
else
return false;
}
if (ParserToken{TokenType::Asterisk}.ignore(pos, expected))
wildcard = true;
}
else if (!parseDatabaseAndTableNameOrAsterisks(pos, expected, database_name, table_name, wildcard, default_database))
return false;
for (auto & [access_flags, columns] : access_and_columns)
{
if (wildcard && !columns.empty())
return false;
AccessRightsElement element;
element.access_flags = access_flags;
element.columns = std::move(columns);
element.database = database_name;
element.table = table_name;
element.parameter = parameter;
element.wildcard = wildcard;
element.default_database = default_database;
res_elements.emplace_back(std::move(element));
}
return true;
};
if (!ParserList::parseUtil(pos, expected, parse_around_on, false))
return false;
elements = std::move(res_elements);
return true;
});
}
}

View File

@ -0,0 +1,24 @@
#pragma once
#include <Core/Types.h>
#include <Parsers/IParser.h>
namespace DB
{
class AccessFlags;
class AccessRightsElements;
/// Parses a list of privileges, for example "SELECT, INSERT".
bool parseAccessFlags(IParser::Pos & pos, Expected & expected, AccessFlags & access_flags);
/// Parses a list of privileges which can be written with lists of columns.
/// For example "SELECT(a), INSERT(b, c), DROP".
bool parseAccessFlagsWithColumns(IParser::Pos & pos, Expected & expected,
std::vector<std::pair<AccessFlags, Strings>> & access_and_columns);
/// Parses a list of privileges with columns and tables or databases or wildcards,
/// For examples, "SELECT(a), INSERT(b,c) ON mydb.mytable, DROP ON mydb.*"
bool parseAccessRightsElementsWithoutOptions(IParser::Pos & pos, Expected & expected, AccessRightsElements & elements);
}

View File

@ -11,11 +11,13 @@
#include <DataTypes/DataTypeNullable.h> #include <DataTypes/DataTypeNullable.h>
#include <Columns/getLeastSuperColumn.h> #include <Columns/getLeastSuperColumn.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnSet.h> #include <Columns/ColumnSet.h>
#include <IO/WriteBufferFromString.h> #include <IO/WriteBufferFromString.h>
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Functions/IFunctionAdaptors.h>
#include <Functions/indexHint.h> #include <Functions/indexHint.h>
#include <Storages/StorageDummy.h> #include <Storages/StorageDummy.h>

View File

@ -2,6 +2,7 @@
#if USE_ARROW || USE_PARQUET #if USE_ARROW || USE_PARQUET
#include <Core/DecimalFunctions.h>
#include <Columns/ColumnFixedString.h> #include <Columns/ColumnFixedString.h>
#include <Columns/ColumnNullable.h> #include <Columns/ColumnNullable.h>
#include <Columns/ColumnString.h> #include <Columns/ColumnString.h>

View File

@ -7,6 +7,7 @@
#include <Common/JSONBuilder.h> #include <Common/JSONBuilder.h>
#include <DataTypes/DataTypeFactory.h> #include <DataTypes/DataTypeFactory.h>
#include <DataTypes/DataTypeLowCardinality.h> #include <DataTypes/DataTypeLowCardinality.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h> #include <DataTypes/DataTypesNumber.h>
#include <Functions/IFunction.h> #include <Functions/IFunction.h>
#include <stack> #include <stack>

View File

@ -2350,7 +2350,7 @@ struct WindowFunctionLagLeadInFrame final : public StatelessWindowFunction
} }
}; };
return func_cast->execute(arguments, argument_types[0], columns[idx[2]]->size()); return func_cast->execute(arguments, argument_types[0], columns[idx[2]]->size(), /* dry_run = */ false);
} }
static DataTypePtr createResultType(const DataTypes & argument_types_, const std::string & name_) static DataTypePtr createResultType(const DataTypes & argument_types_, const std::string & name_)

View File

@ -17,6 +17,7 @@
#include <Functions/indexHint.h> #include <Functions/indexHint.h>
#include <Functions/CastOverloadResolver.h> #include <Functions/CastOverloadResolver.h>
#include <Functions/IFunction.h> #include <Functions/IFunction.h>
#include <Functions/IFunctionAdaptors.h>
#include <Functions/IFunctionDateOrDateTime.h> #include <Functions/IFunctionDateOrDateTime.h>
#include <Functions/geometryConverters.h> #include <Functions/geometryConverters.h>
#include <Common/FieldVisitorToString.h> #include <Common/FieldVisitorToString.h>
@ -905,7 +906,7 @@ static Field applyFunctionForField(
{ arg_type->createColumnConst(1, arg_value), arg_type, "x" }, { arg_type->createColumnConst(1, arg_value), arg_type, "x" },
}; };
auto col = func->execute(columns, func->getResultType(), 1); auto col = func->execute(columns, func->getResultType(), 1, /* dry_run = */ false);
return (*col)[0]; return (*col)[0];
} }
@ -939,7 +940,7 @@ static FieldRef applyFunction(const FunctionBasePtr & func, const DataTypePtr &
/// When cache is missed, we calculate the whole column where the field comes from. This will avoid repeated calculation. /// When cache is missed, we calculate the whole column where the field comes from. This will avoid repeated calculation.
ColumnsWithTypeAndName args{(*columns)[field.column_idx]}; ColumnsWithTypeAndName args{(*columns)[field.column_idx]};
field.columns->emplace_back(ColumnWithTypeAndName {nullptr, func->getResultType(), result_name}); field.columns->emplace_back(ColumnWithTypeAndName {nullptr, func->getResultType(), result_name});
(*columns)[result_idx].column = func->execute(args, (*columns)[result_idx].type, columns->front().column->size()); (*columns)[result_idx].column = func->execute(args, (*columns)[result_idx].type, columns->front().column->size(), /* dry_run = */ false);
} }
return {field.columns, field.row_idx, result_idx}; return {field.columns, field.row_idx, result_idx};
@ -1012,7 +1013,7 @@ bool applyFunctionChainToColumn(
return false; return false;
result_column = castColumnAccurate({result_column, result_type, ""}, argument_type); result_column = castColumnAccurate({result_column, result_type, ""}, argument_type);
result_column = func->execute({{result_column, argument_type, ""}}, func->getResultType(), result_column->size()); result_column = func->execute({{result_column, argument_type, ""}}, func->getResultType(), result_column->size(), /* dry_run = */ false);
result_type = func->getResultType(); result_type = func->getResultType();
// Transforming nullable columns to the nested ones, in case no nulls found // Transforming nullable columns to the nested ones, in case no nulls found

Some files were not shown because too many files have changed in this diff Show More