Cache cast function in set during execution (#55712)

* Cache cast function in set during execution

Signed-off-by: Duc Canh Le <duccanh.le@ahrefs.com>

* minor fix for performance test

Signed-off-by: Duc Canh Le <duccanh.le@ahrefs.com>

* Update src/Interpreters/castColumn.cpp

Co-authored-by: Nikita Taranov <nickita.taranov@gmail.com>

* improvement

Signed-off-by: Duc Canh Le <duccanh.le@ahrefs.com>

* fix use-after-free

Signed-off-by: Duc Canh Le <duccanh.le@ahrefs.com>

---------

Signed-off-by: Duc Canh Le <duccanh.le@ahrefs.com>
Co-authored-by: Nikita Taranov <nickita.taranov@gmail.com>
This commit is contained in:
Duc Canh Le 2023-10-23 19:31:44 +08:00 committed by GitHub
parent 3d8875a342
commit 5923e1b116
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 69 additions and 19 deletions

View File

@ -324,11 +324,11 @@ ColumnPtr Set::execute(const ColumnsWithTypeAndName & columns, bool negative) co
if (!transform_null_in && data_types[i]->canBeInsideNullable()) if (!transform_null_in && data_types[i]->canBeInsideNullable())
{ {
result = castColumnAccurateOrNull(column_to_cast, data_types[i]); result = castColumnAccurateOrNull(column_to_cast, data_types[i], cast_cache.get());
} }
else else
{ {
result = castColumnAccurate(column_to_cast, data_types[i]); result = castColumnAccurate(column_to_cast, data_types[i], cast_cache.get());
} }
materialized_columns.emplace_back() = result; materialized_columns.emplace_back() = result;

View File

@ -9,6 +9,7 @@
#include <Storages/MergeTree/BoolMask.h> #include <Storages/MergeTree/BoolMask.h>
#include <Common/SharedMutex.h> #include <Common/SharedMutex.h>
#include <Interpreters/castColumn.h>
namespace DB namespace DB
@ -33,9 +34,9 @@ public:
/// This is needed for subsequent use for index. /// This is needed for subsequent use for index.
Set(const SizeLimits & limits_, size_t max_elements_to_fill_, bool transform_null_in_) Set(const SizeLimits & limits_, size_t max_elements_to_fill_, bool transform_null_in_)
: log(&Poco::Logger::get("Set")), : log(&Poco::Logger::get("Set")),
limits(limits_), max_elements_to_fill(max_elements_to_fill_), transform_null_in(transform_null_in_) limits(limits_), max_elements_to_fill(max_elements_to_fill_), transform_null_in(transform_null_in_),
{ cast_cache(std::make_unique<InternalCastFunctionCache>())
} {}
/** Set can be created either from AST or from a stream of data (subquery result). /** Set can be created either from AST or from a stream of data (subquery result).
*/ */
@ -142,6 +143,10 @@ private:
*/ */
mutable SharedMutex rwlock; mutable SharedMutex rwlock;
/// A cache for cast functions (if any) to avoid rebuilding cast functions
/// for every call to `execute`
mutable std::unique_ptr<InternalCastFunctionCache> cast_cache;
template <typename Method> template <typename Method>
void insertFromBlockImpl( void insertFromBlockImpl(
Method & method, Method & method,

View File

@ -7,24 +7,29 @@ namespace DB
{ {
template <CastType cast_type = CastType::nonAccurate> template <CastType cast_type = CastType::nonAccurate>
static ColumnPtr castColumn(const ColumnWithTypeAndName & arg, const DataTypePtr & type) static ColumnPtr castColumn(const ColumnWithTypeAndName & arg, const DataTypePtr & type, InternalCastFunctionCache * cache = nullptr)
{ {
if (arg.type->equals(*type) && cast_type != CastType::accurateOrNull) if (arg.type->equals(*type) && cast_type != CastType::accurateOrNull)
return arg.column; return arg.column;
const auto from_name = arg.type->getName();
const auto to_name = type->getName();
ColumnsWithTypeAndName arguments ColumnsWithTypeAndName arguments
{ {
arg, arg,
{ {
DataTypeString().createColumnConst(arg.column->size(), type->getName()), DataTypeString().createColumnConst(arg.column->size(), to_name),
std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>(),
"" ""
} }
}; };
auto get_cast_func = [&arguments]
{
FunctionOverloadResolverPtr func_builder_cast = CastInternalOverloadResolver<cast_type>::createImpl();
return func_builder_cast->build(arguments);
};
FunctionOverloadResolverPtr func_builder_cast = CastInternalOverloadResolver<cast_type>::createImpl(); FunctionBasePtr func_cast = cache ? cache->getOrSet(cast_type, from_name, to_name, std::move(get_cast_func)) : get_cast_func();
auto func_cast = func_builder_cast->build(arguments);
if constexpr (cast_type == CastType::accurateOrNull) if constexpr (cast_type == CastType::accurateOrNull)
{ {
@ -36,19 +41,19 @@ static ColumnPtr castColumn(const ColumnWithTypeAndName & arg, const DataTypePtr
} }
} }
ColumnPtr castColumn(const ColumnWithTypeAndName & arg, const DataTypePtr & type) ColumnPtr castColumn(const ColumnWithTypeAndName & arg, const DataTypePtr & type, InternalCastFunctionCache * cache)
{ {
return castColumn<CastType::nonAccurate>(arg, type); return castColumn<CastType::nonAccurate>(arg, type, cache);
} }
ColumnPtr castColumnAccurate(const ColumnWithTypeAndName & arg, const DataTypePtr & type) ColumnPtr castColumnAccurate(const ColumnWithTypeAndName & arg, const DataTypePtr & type, InternalCastFunctionCache * cache)
{ {
return castColumn<CastType::accurate>(arg, type); return castColumn<CastType::accurate>(arg, type, cache);
} }
ColumnPtr castColumnAccurateOrNull(const ColumnWithTypeAndName & arg, const DataTypePtr & type) ColumnPtr castColumnAccurateOrNull(const ColumnWithTypeAndName & arg, const DataTypePtr & type, InternalCastFunctionCache * cache)
{ {
return castColumn<CastType::accurateOrNull>(arg, type); return castColumn<CastType::accurateOrNull>(arg, type, cache);
} }
} }

View File

@ -1,12 +1,34 @@
#pragma once #pragma once
#include <tuple>
#include <Core/ColumnWithTypeAndName.h> #include <Core/ColumnWithTypeAndName.h>
#include <Functions/FunctionsConversion.h>
namespace DB namespace DB
{ {
ColumnPtr castColumn(const ColumnWithTypeAndName & arg, const DataTypePtr & type); struct InternalCastFunctionCache
ColumnPtr castColumnAccurate(const ColumnWithTypeAndName & arg, const DataTypePtr & type); {
ColumnPtr castColumnAccurateOrNull(const ColumnWithTypeAndName & arg, const DataTypePtr & type); private:
/// Maps <cast_type, from_type, to_type> -> cast functions
/// Doesn't own key, never refer to key after inserted
std::map<std::tuple<CastType, String, String>, FunctionBasePtr> impl;
mutable std::mutex mutex;
public:
template<typename Getter>
FunctionBasePtr getOrSet(CastType cast_type, const String & from, const String & to, Getter && getter)
{
std::lock_guard lock{mutex};
auto key = std::forward_as_tuple(cast_type, from, to);
auto it = impl.find(key);
if (it == impl.end())
it = impl.emplace(key, getter()).first;
return it->second;
}
};
ColumnPtr castColumn(const ColumnWithTypeAndName & arg, const DataTypePtr & type, InternalCastFunctionCache * cache = nullptr);
ColumnPtr castColumnAccurate(const ColumnWithTypeAndName & arg, const DataTypePtr & type, InternalCastFunctionCache * cache = nullptr);
ColumnPtr castColumnAccurateOrNull(const ColumnWithTypeAndName & arg, const DataTypePtr & type, InternalCastFunctionCache * cache = nullptr);
} }

View File

@ -0,0 +1,18 @@
<test>
<!-- high cardinality -->
<create_query>
CREATE TABLE iso_3166_1_alpha_2
(
`c` Enum8('LI' = -128, 'LT' = -127, 'LU' = -126, 'MO' = -125, 'MK' = -124, 'MG' = -123, 'MW' = -122, 'MY' = -121, 'MV' = -120, 'ML' = -119, 'MT' = -118, 'MH' = -117, 'MQ' = -116, 'MR' = -115, 'MU' = -114, 'YT' = -113, 'MX' = -112, 'FM' = -111, 'MD' = -110, 'MC' = -109, 'MN' = -108, 'ME' = -107, 'MS' = -106, 'MA' = -105, 'MZ' = -104, 'MM' = -103, 'NA' = -102, 'NR' = -101, 'NP' = -100, 'NL' = -99, 'NC' = -98, 'NZ' = -97, 'NI' = -96, 'NE' = -95, 'NG' = -94, 'NU' = -93, 'NF' = -92, 'MP' = -91, 'NO' = -90, 'OM' = -89, 'PK' = -88, 'PW' = -87, 'PS' = -86, 'PA' = -85, 'PG' = -84, 'PY' = -83, 'PE' = -82, 'PH' = -81, 'PN' = -80, 'PL' = -79, 'PT' = -78, 'PR' = -77, 'QA' = -76, 'RE' = -75, 'RO' = -74, 'RU' = -73, 'RW' = -72, 'BL' = -71, 'SH' = -70, 'KN' = -69, 'LC' = -68, 'MF' = -67, 'PM' = -66, 'VC' = -65, 'WS' = -64, 'SM' = -63, 'ST' = -62, 'SA' = -61, 'SN' = -60, 'RS' = -59, 'SC' = -58, 'SL' = -57, 'SG' = -56, 'SX' = -55, 'SK' = -54, 'SI' = -53, 'SB' = -52, 'SO' = -51, 'ZA' = -50, 'GS' = -49, 'SS' = -48, 'ES' = -47, 'LK' = -46, 'SD' = -45, 'SR' = -44, 'SJ' = -43, 'SZ' = -42, 'SE' = -41, 'CH' = -40, 'SY' = -39, 'TW' = -38, 'TJ' = -37, 'TZ' = -36, 'TH' = -35, 'TL' = -34, 'TG' = -33, 'TK' = -32, 'TO' = -31, 'TT' = -30, 'TN' = -29, 'TR' = -28, 'TM' = -27, 'TC' = -26, 'TV' = -25, 'UG' = -24, 'UA' = -23, 'AE' = -22, 'GB' = -21, 'UM' = -20, 'US' = -19, 'UY' = -18, 'UZ' = -17, 'VU' = -16, 'VE' = -15, 'VN' = -14, 'VG' = -13, 'VI' = -12, 'WF' = -11, 'EH' = -10, 'YE' = -9, 'ZM' = -8, 'ZW' = -7, 'OTHER' = 0, 'AF' = 1, 'AX' = 2, 'AL' = 3, 'DZ' = 4, 'AS' = 5, 'AD' = 6, 'AO' = 7, 'AI' = 8, 'AQ' = 9, 'AG' = 10, 'AR' = 11, 'AM' = 12, 'AW' = 13, 'AU' = 14, 'AT' = 15, 'AZ' = 16, 'BS' = 17, 'BH' = 18, 'BD' = 19, 'BB' = 20, 'BY' = 21, 'BE' = 22, 'BZ' = 23, 'BJ' = 24, 'BM' = 25, 'BT' = 26, 'BO' = 27, 'BQ' = 28, 'BA' = 29, 'BW' = 30, 'BV' = 31, 'BR' = 32, 'IO' = 33, 'BN' = 34, 'BG' = 35, 'BF' = 36, 'BI' = 37, 'CV' = 38, 'KH' = 39, 'CM' = 40, 'CA' = 41, 'KY' = 42, 'CF' = 43, 'TD' = 44, 'CL' = 45, 'CN' = 46, 'CX' = 47, 'CC' = 48, 'CO' = 49, 'KM' = 50, 'CD' = 51, 'CG' = 52, 'CK' = 53, 'CR' = 54, 'CI' = 55, 'HR' = 56, 'CU' = 57, 'CW' = 58, 'CY' = 59, 'CZ' = 60, 'DK' = 61, 'DJ' = 62, 'DM' = 63, 'DO' = 64, 'EC' = 65, 'EG' = 66, 'SV' = 67, 'GQ' = 68, 'ER' = 69, 'EE' = 70, 'ET' = 71, 'FK' = 72, 'FO' = 73, 'FJ' = 74, 'FI' = 75, 'FR' = 76, 'GF' = 77, 'PF' = 78, 'TF' = 79, 'GA' = 80, 'GM' = 81, 'GE' = 82, 'DE' = 83, 'GH' = 84, 'GI' = 85, 'GR' = 86, 'GL' = 87, 'GD' = 88, 'GP' = 89, 'GU' = 90, 'GT' = 91, 'GG' = 92, 'GN' = 93, 'GW' = 94, 'GY' = 95, 'HT' = 96, 'HM' = 97, 'VA' = 98, 'HN' = 99, 'HK' = 100, 'HU' = 101, 'IS' = 102, 'IN' = 103, 'ID' = 104, 'IR' = 105, 'IQ' = 106, 'IE' = 107, 'IM' = 108, 'IL' = 109, 'IT' = 110, 'JM' = 111, 'JP' = 112, 'JE' = 113, 'JO' = 114, 'KZ' = 115, 'KE' = 116, 'KI' = 117, 'KP' = 118, 'KR' = 119, 'KW' = 120, 'KG' = 121, 'LA' = 122, 'LV' = 123, 'LB' = 124, 'LS' = 125, 'LR' = 126, 'LY' = 127)
)
ENGINE = MergeTree
ORDER BY tuple()
SETTINGS index_granularity = 8192
</create_query>
<fill_query>INSERT INTO iso_3166_1_alpha_2 SELECT (rand(number) % 256) - 128 FROM numbers(200000000)</fill_query>
<fill_query>OPTIMIZE TABLE iso_3166_1_alpha_2 FINAL</fill_query>
<query>SELECT count() FROM iso_3166_1_alpha_2 WHERE c NOT IN ('CU', 'BN', 'VI', 'US', 'AQ', 'AG', 'AR', 'AM', 'AW', 'AU', 'AT', 'AZ', 'BS', 'BH', 'BD', 'BB', 'BY', 'BE') FORMAT Null SETTINGS max_threads = 1</query>
<drop_query>DROP TABLE IF EXISTS iso_3166_1_alpha_2</drop_query>
</test>