mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-22 07:31:57 +00:00
Merge branch 'master' of github.com:ClickHouse/ClickHouse into filter-large-translation-units
This commit is contained in:
commit
c2127b05f6
2
contrib/arrow
vendored
2
contrib/arrow
vendored
@ -1 +1 @@
|
||||
Subproject commit 1d93838f69a802639ca144ea5704a98e2481810d
|
||||
Subproject commit ba5c67934e8274d649befcffab56731632dc5253
|
@ -109,7 +109,6 @@ set (ORC_CXX_HAS_CSTDINT 1)
|
||||
set (ORC_CXX_HAS_THREAD_LOCAL 1)
|
||||
|
||||
include(orc_check.cmake)
|
||||
configure_file("${ORC_INCLUDE_DIR}/orc/orc-config.hh.in" "${ORC_BUILD_INCLUDE_DIR}/orc/orc-config.hh")
|
||||
configure_file("${ORC_SOURCE_SRC_DIR}/Adaptor.hh.in" "${ORC_BUILD_INCLUDE_DIR}/Adaptor.hh")
|
||||
|
||||
|
||||
@ -198,7 +197,9 @@ target_link_libraries(_orc PRIVATE
|
||||
ch_contrib::snappy
|
||||
ch_contrib::zlib
|
||||
ch_contrib::zstd)
|
||||
target_include_directories(_orc SYSTEM BEFORE PUBLIC ${ORC_INCLUDE_DIR})
|
||||
target_include_directories(_orc SYSTEM BEFORE PUBLIC
|
||||
${ORC_INCLUDE_DIR}
|
||||
"${ClickHouse_SOURCE_DIR}/contrib/arrow-cmake/cpp/src/orc/c++/include")
|
||||
target_include_directories(_orc SYSTEM BEFORE PUBLIC ${ORC_BUILD_INCLUDE_DIR})
|
||||
target_include_directories(_orc SYSTEM PRIVATE
|
||||
${ORC_SOURCE_SRC_DIR}
|
||||
@ -212,8 +213,6 @@ target_include_directories(_orc SYSTEM PRIVATE
|
||||
|
||||
set(LIBRARY_DIR "${ClickHouse_SOURCE_DIR}/contrib/arrow/cpp/src/arrow")
|
||||
|
||||
configure_file("${LIBRARY_DIR}/util/config.h.cmake" "${CMAKE_CURRENT_BINARY_DIR}/cpp/src/arrow/util/config.h")
|
||||
|
||||
# arrow/cpp/src/arrow/CMakeLists.txt (ARROW_SRCS + ARROW_COMPUTE + ARROW_IPC)
|
||||
set(ARROW_SRCS
|
||||
"${LIBRARY_DIR}/array/array_base.cc"
|
||||
@ -230,6 +229,8 @@ set(ARROW_SRCS
|
||||
"${LIBRARY_DIR}/array/builder_nested.cc"
|
||||
"${LIBRARY_DIR}/array/builder_primitive.cc"
|
||||
"${LIBRARY_DIR}/array/builder_union.cc"
|
||||
"${LIBRARY_DIR}/array/builder_run_end.cc"
|
||||
"${LIBRARY_DIR}/array/array_run_end.cc"
|
||||
"${LIBRARY_DIR}/array/concatenate.cc"
|
||||
"${LIBRARY_DIR}/array/data.cc"
|
||||
"${LIBRARY_DIR}/array/diff.cc"
|
||||
@ -309,9 +310,12 @@ set(ARROW_SRCS
|
||||
"${LIBRARY_DIR}/util/debug.cc"
|
||||
"${LIBRARY_DIR}/util/tracing.cc"
|
||||
"${LIBRARY_DIR}/util/atfork_internal.cc"
|
||||
"${LIBRARY_DIR}/util/crc32.cc"
|
||||
"${LIBRARY_DIR}/util/hashing.cc"
|
||||
"${LIBRARY_DIR}/util/ree_util.cc"
|
||||
"${LIBRARY_DIR}/util/union_util.cc"
|
||||
"${LIBRARY_DIR}/vendored/base64.cpp"
|
||||
"${LIBRARY_DIR}/vendored/datetime/tz.cpp"
|
||||
|
||||
"${LIBRARY_DIR}/vendored/musl/strptime.c"
|
||||
"${LIBRARY_DIR}/vendored/uriparser/UriCommon.c"
|
||||
"${LIBRARY_DIR}/vendored/uriparser/UriCompare.c"
|
||||
@ -328,39 +332,20 @@ set(ARROW_SRCS
|
||||
"${LIBRARY_DIR}/vendored/uriparser/UriRecompose.c"
|
||||
"${LIBRARY_DIR}/vendored/uriparser/UriResolve.c"
|
||||
"${LIBRARY_DIR}/vendored/uriparser/UriShorten.c"
|
||||
"${LIBRARY_DIR}/vendored/double-conversion/bignum.cc"
|
||||
"${LIBRARY_DIR}/vendored/double-conversion/bignum-dtoa.cc"
|
||||
"${LIBRARY_DIR}/vendored/double-conversion/cached-powers.cc"
|
||||
"${LIBRARY_DIR}/vendored/double-conversion/double-to-string.cc"
|
||||
"${LIBRARY_DIR}/vendored/double-conversion/fast-dtoa.cc"
|
||||
"${LIBRARY_DIR}/vendored/double-conversion/fixed-dtoa.cc"
|
||||
"${LIBRARY_DIR}/vendored/double-conversion/string-to-double.cc"
|
||||
"${LIBRARY_DIR}/vendored/double-conversion/strtod.cc"
|
||||
|
||||
"${LIBRARY_DIR}/compute/api_aggregate.cc"
|
||||
"${LIBRARY_DIR}/compute/api_scalar.cc"
|
||||
"${LIBRARY_DIR}/compute/api_vector.cc"
|
||||
"${LIBRARY_DIR}/compute/cast.cc"
|
||||
"${LIBRARY_DIR}/compute/exec.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/accumulation_queue.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/accumulation_queue.h"
|
||||
"${LIBRARY_DIR}/compute/exec/aggregate.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/aggregate_node.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/asof_join_node.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/bloom_filter.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/exec_plan.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/expression.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/filter_node.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/hash_join.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/hash_join_dict.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/hash_join_node.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/key_hash.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/key_map.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/map_node.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/options.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/order_by_impl.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/partition_util.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/project_node.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/query_context.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/sink_node.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/source_node.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/swiss_join.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/task_util.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/tpch_node.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/union_node.cc"
|
||||
"${LIBRARY_DIR}/compute/exec/util.cc"
|
||||
"${LIBRARY_DIR}/compute/function.cc"
|
||||
"${LIBRARY_DIR}/compute/function_internal.cc"
|
||||
"${LIBRARY_DIR}/compute/kernel.cc"
|
||||
@ -403,8 +388,13 @@ set(ARROW_SRCS
|
||||
"${LIBRARY_DIR}/compute/kernels/vector_select_k.cc"
|
||||
"${LIBRARY_DIR}/compute/kernels/vector_selection.cc"
|
||||
"${LIBRARY_DIR}/compute/kernels/vector_sort.cc"
|
||||
"${LIBRARY_DIR}/compute/kernels/vector_selection_internal.cc"
|
||||
"${LIBRARY_DIR}/compute/kernels/vector_selection_filter_internal.cc"
|
||||
"${LIBRARY_DIR}/compute/kernels/vector_selection_take_internal.cc"
|
||||
"${LIBRARY_DIR}/compute/light_array.cc"
|
||||
"${LIBRARY_DIR}/compute/registry.cc"
|
||||
"${LIBRARY_DIR}/compute/expression.cc"
|
||||
"${LIBRARY_DIR}/compute/ordering.cc"
|
||||
"${LIBRARY_DIR}/compute/row/compare_internal.cc"
|
||||
"${LIBRARY_DIR}/compute/row/encode_internal.cc"
|
||||
"${LIBRARY_DIR}/compute/row/grouper.cc"
|
||||
@ -459,7 +449,7 @@ target_link_libraries(_arrow PUBLIC _orc)
|
||||
add_dependencies(_arrow protoc)
|
||||
|
||||
target_include_directories(_arrow SYSTEM BEFORE PUBLIC ${ARROW_SRC_DIR})
|
||||
target_include_directories(_arrow SYSTEM BEFORE PUBLIC "${CMAKE_CURRENT_BINARY_DIR}/cpp/src")
|
||||
target_include_directories(_arrow SYSTEM BEFORE PUBLIC "${ClickHouse_SOURCE_DIR}/contrib/arrow-cmake/cpp/src")
|
||||
|
||||
target_include_directories(_arrow SYSTEM PRIVATE ${ARROW_SRC_DIR})
|
||||
target_include_directories(_arrow SYSTEM PRIVATE ${HDFS_INCLUDE_DIR})
|
||||
@ -488,10 +478,10 @@ set(PARQUET_SRCS
|
||||
"${LIBRARY_DIR}/exception.cc"
|
||||
"${LIBRARY_DIR}/file_reader.cc"
|
||||
"${LIBRARY_DIR}/file_writer.cc"
|
||||
"${LIBRARY_DIR}/page_index.cc"
|
||||
"${LIBRARY_DIR}/level_conversion.cc"
|
||||
"${LIBRARY_DIR}/level_comparison.cc"
|
||||
"${LIBRARY_DIR}/metadata.cc"
|
||||
"${LIBRARY_DIR}/murmur3.cc"
|
||||
"${LIBRARY_DIR}/platform.cc"
|
||||
"${LIBRARY_DIR}/printer.cc"
|
||||
"${LIBRARY_DIR}/properties.cc"
|
||||
@ -500,6 +490,8 @@ set(PARQUET_SRCS
|
||||
"${LIBRARY_DIR}/stream_reader.cc"
|
||||
"${LIBRARY_DIR}/stream_writer.cc"
|
||||
"${LIBRARY_DIR}/types.cc"
|
||||
"${LIBRARY_DIR}/bloom_filter_reader.cc"
|
||||
"${LIBRARY_DIR}/xxhasher.cc"
|
||||
|
||||
"${GEN_LIBRARY_DIR}/parquet_constants.cpp"
|
||||
"${GEN_LIBRARY_DIR}/parquet_types.cpp"
|
||||
|
@ -1 +0,0 @@
|
||||
../../../thrift/build/cmake/config.h.in
|
61
contrib/arrow-cmake/cpp/src/arrow/util/config.h
Normal file
61
contrib/arrow-cmake/cpp/src/arrow/util/config.h
Normal file
@ -0,0 +1,61 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#define ARROW_VERSION_MAJOR 11
|
||||
#define ARROW_VERSION_MINOR 0
|
||||
#define ARROW_VERSION_PATCH 0
|
||||
#define ARROW_VERSION ((ARROW_VERSION_MAJOR * 1000) + ARROW_VERSION_MINOR) * 1000 + ARROW_VERSION_PATCH
|
||||
|
||||
#define ARROW_VERSION_STRING "11.0.0"
|
||||
|
||||
#define ARROW_SO_VERSION "1100"
|
||||
#define ARROW_FULL_SO_VERSION "1100.0.0"
|
||||
|
||||
#define ARROW_CXX_COMPILER_ID "Clang"
|
||||
#define ARROW_CXX_COMPILER_VERSION "ClickHouse"
|
||||
#define ARROW_CXX_COMPILER_FLAGS ""
|
||||
|
||||
#define ARROW_BUILD_TYPE ""
|
||||
|
||||
#define ARROW_GIT_ID ""
|
||||
#define ARROW_GIT_DESCRIPTION ""
|
||||
|
||||
#define ARROW_PACKAGE_KIND ""
|
||||
|
||||
/* #undef ARROW_COMPUTE */
|
||||
/* #undef ARROW_CSV */
|
||||
/* #undef ARROW_CUDA */
|
||||
/* #undef ARROW_DATASET */
|
||||
/* #undef ARROW_FILESYSTEM */
|
||||
/* #undef ARROW_FLIGHT */
|
||||
/* #undef ARROW_FLIGHT_SQL */
|
||||
/* #undef ARROW_IPC */
|
||||
/* #undef ARROW_JEMALLOC */
|
||||
/* #undef ARROW_JEMALLOC_VENDORED */
|
||||
/* #undef ARROW_JSON */
|
||||
/* #undef ARROW_ORC */
|
||||
/* #undef ARROW_PARQUET */
|
||||
/* #undef ARROW_SUBSTRAIT */
|
||||
|
||||
/* #undef ARROW_GCS */
|
||||
/* #undef ARROW_S3 */
|
||||
/* #undef ARROW_USE_NATIVE_INT128 */
|
||||
/* #undef ARROW_WITH_MUSL */
|
||||
/* #undef ARROW_WITH_OPENTELEMETRY */
|
||||
/* #undef ARROW_WITH_UCX */
|
||||
|
||||
/* #undef GRPCPP_PP_INCLUDE */
|
@ -0,0 +1,38 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ORC_CONFIG_HH
|
||||
#define ORC_CONFIG_HH
|
||||
|
||||
#define ORC_VERSION ""
|
||||
|
||||
#define ORC_CXX_HAS_CSTDINT
|
||||
|
||||
#ifdef ORC_CXX_HAS_CSTDINT
|
||||
#include <cstdint>
|
||||
#else
|
||||
#include <stdint.h>
|
||||
#endif
|
||||
|
||||
// Following MACROS should be keeped for backward compatibility.
|
||||
#define ORC_NOEXCEPT noexcept
|
||||
#define ORC_NULLPTR nullptr
|
||||
#define ORC_OVERRIDE override
|
||||
#define ORC_UNIQUE_PTR std::unique_ptr
|
||||
|
||||
#endif
|
@ -7,12 +7,6 @@ endif()
|
||||
|
||||
set(LIB_SOURCE_DIR "${ClickHouse_SOURCE_DIR}/contrib/libssh")
|
||||
set(LIB_BINARY_DIR "${ClickHouse_BINARY_DIR}/contrib/libssh")
|
||||
# Specify search path for CMake modules to be loaded by include()
|
||||
# and find_package()
|
||||
list(APPEND CMAKE_MODULE_PATH "${LIB_SOURCE_DIR}/cmake/Modules")
|
||||
|
||||
include(DefineCMakeDefaults)
|
||||
include(DefineCompilerFlags)
|
||||
|
||||
project(libssh VERSION 0.9.7 LANGUAGES C)
|
||||
|
||||
@ -29,12 +23,6 @@ set(APPLICATION_NAME ${PROJECT_NAME})
|
||||
set(LIBRARY_VERSION "4.8.7")
|
||||
set(LIBRARY_SOVERSION "4")
|
||||
|
||||
# where to look first for cmake modules, before ${CMAKE_ROOT}/Modules/ is checked
|
||||
|
||||
# add definitions
|
||||
|
||||
include(DefinePlatformDefaults)
|
||||
|
||||
# Copy library files to a lib sub-directory
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${LIB_BINARY_DIR}/lib")
|
||||
|
||||
|
@ -1,20 +1,8 @@
|
||||
set(LIBSSH_LINK_LIBRARIES
|
||||
${LIBSSH_REQUIRED_LIBRARIES}
|
||||
)
|
||||
|
||||
|
||||
set(LIBSSH_LINK_LIBRARIES
|
||||
${LIBSSH_LINK_LIBRARIES}
|
||||
OpenSSL::Crypto
|
||||
)
|
||||
|
||||
if (MINGW AND Threads_FOUND)
|
||||
set(LIBSSH_LINK_LIBRARIES
|
||||
${LIBSSH_LINK_LIBRARIES}
|
||||
Threads::Threads
|
||||
)
|
||||
endif()
|
||||
|
||||
set(libssh_SRCS
|
||||
${LIB_SOURCE_DIR}/src/agent.c
|
||||
${LIB_SOURCE_DIR}/src/auth.c
|
||||
@ -66,30 +54,11 @@ set(libssh_SRCS
|
||||
${LIB_SOURCE_DIR}/src/pki_ed25519_common.c
|
||||
)
|
||||
|
||||
if (DEFAULT_C_NO_DEPRECATION_FLAGS)
|
||||
set_source_files_properties(known_hosts.c
|
||||
PROPERTIES
|
||||
COMPILE_FLAGS ${DEFAULT_C_NO_DEPRECATION_FLAGS})
|
||||
endif()
|
||||
|
||||
if (CMAKE_USE_PTHREADS_INIT)
|
||||
set(libssh_SRCS
|
||||
${libssh_SRCS}
|
||||
${LIB_SOURCE_DIR}/src/threads/noop.c
|
||||
${LIB_SOURCE_DIR}/src/threads/pthread.c
|
||||
)
|
||||
elseif (CMAKE_USE_WIN32_THREADS_INIT)
|
||||
set(libssh_SRCS
|
||||
${libssh_SRCS}
|
||||
${LIB_SOURCE_DIR}/src/threads/noop.c
|
||||
${LIB_SOURCE_DIR}/src/threads/winlocks.c
|
||||
)
|
||||
else()
|
||||
set(libssh_SRCS
|
||||
${libssh_SRCS}
|
||||
${LIB_SOURCE_DIR}/src/threads/noop.c
|
||||
)
|
||||
endif()
|
||||
set(libssh_SRCS
|
||||
${libssh_SRCS}
|
||||
${LIB_SOURCE_DIR}/src/threads/noop.c
|
||||
${LIB_SOURCE_DIR}/src/threads/pthread.c
|
||||
)
|
||||
|
||||
# LIBCRYPT specific
|
||||
set(libssh_SRCS
|
||||
@ -127,14 +96,3 @@ target_compile_options(_ssh
|
||||
PRIVATE
|
||||
${DEFAULT_C_COMPILE_FLAGS}
|
||||
-D_GNU_SOURCE)
|
||||
|
||||
|
||||
set_target_properties(_ssh
|
||||
PROPERTIES
|
||||
VERSION
|
||||
${LIBRARY_VERSION}
|
||||
SOVERSION
|
||||
${LIBRARY_SOVERSION}
|
||||
DEFINE_SYMBOL
|
||||
LIBSSH_EXPORTS
|
||||
)
|
||||
|
@ -74,8 +74,6 @@ function configure()
|
||||
|
||||
randomize_config_boolean_value use_compression zookeeper
|
||||
|
||||
randomize_config_boolean_value allow_experimental_block_number_column merge_tree_settings
|
||||
|
||||
# for clickhouse-server (via service)
|
||||
echo "ASAN_OPTIONS='malloc_context_size=10 verbosity=1 allocator_release_to_os_interval_ms=10000'" >> /etc/environment
|
||||
# for clickhouse-client
|
||||
|
@ -18,7 +18,15 @@ function, table engine, database, etc. In the examples below the parameter list
|
||||
linked to for each type.
|
||||
|
||||
Parameters set in a named collection can be overridden in SQL, this is shown in the examples
|
||||
below.
|
||||
below. This ability can be limited using `[NOT] OVERRIDABLE` keywords and XML attributes
|
||||
and/or the configuration option `allow_named_collection_override_by_default`.
|
||||
|
||||
:::warning
|
||||
If override is allowed, it may be possible for users without administrative access to
|
||||
figure out the credentials that you are trying to hide.
|
||||
If you are using named collections with that purpose, you should disable
|
||||
`allow_named_collection_override_by_default` (which is enabled by default).
|
||||
:::
|
||||
|
||||
## Storing named collections in the system database
|
||||
|
||||
@ -26,11 +34,17 @@ below.
|
||||
|
||||
```sql
|
||||
CREATE NAMED COLLECTION name AS
|
||||
key_1 = 'value',
|
||||
key_2 = 'value2',
|
||||
key_1 = 'value' OVERRIDABLE,
|
||||
key_2 = 'value2' NOT OVERRIDABLE,
|
||||
url = 'https://connection.url/'
|
||||
```
|
||||
|
||||
In the above example:
|
||||
|
||||
* `key_1` can always be overridden.
|
||||
* `key_2` can never be overridden.
|
||||
* `url` can be overridden or not depending on the value of `allow_named_collection_override_by_default`.
|
||||
|
||||
### Permissions to create named collections with DDL
|
||||
|
||||
To manage named collections with DDL a user must have the `named_control_collection` privilege. This can be assigned by adding a file to `/etc/clickhouse-server/users.d/`. The example gives the user `default` both the `access_management` and `named_collection_control` privileges:
|
||||
@ -61,25 +75,37 @@ In the above example the `password_sha256_hex` value is the hexadecimal represen
|
||||
<clickhouse>
|
||||
<named_collections>
|
||||
<name>
|
||||
<key_1>value</key_1>
|
||||
<key_2>value_2</key_2>
|
||||
<key_1 overridable="true">value</key_1>
|
||||
<key_2 overridable="false">value_2</key_2>
|
||||
<url>https://connection.url/</url>
|
||||
</name>
|
||||
</named_collections>
|
||||
</clickhouse>
|
||||
```
|
||||
|
||||
In the above example:
|
||||
|
||||
* `key_1` can always be overridden.
|
||||
* `key_2` can never be overridden.
|
||||
* `url` can be overridden or not depending on the value of `allow_named_collection_override_by_default`.
|
||||
|
||||
## Modifying named collections
|
||||
|
||||
Named collections that are created with DDL queries can be altered or dropped with DDL. Named collections created with XML files can be managed by editing or deleting the corresponding XML.
|
||||
|
||||
### Alter a DDL named collection
|
||||
|
||||
Change or add the keys `key1` and `key3` of the collection `collection2`:
|
||||
Change or add the keys `key1` and `key3` of the collection `collection2`
|
||||
(this will not change the value of the `overridable` flag for those keys):
|
||||
```sql
|
||||
ALTER NAMED COLLECTION collection2 SET key1=4, key3='value3'
|
||||
```
|
||||
|
||||
Change or add the key `key1` and allow it to be always overridden:
|
||||
```sql
|
||||
ALTER NAMED COLLECTION collection2 SET key1=4 OVERRIDABLE
|
||||
```
|
||||
|
||||
Remove the key `key2` from `collection2`:
|
||||
```sql
|
||||
ALTER NAMED COLLECTION collection2 DELETE key2
|
||||
@ -90,6 +116,13 @@ Change or add the key `key1` and delete the key `key3` of the collection `collec
|
||||
ALTER NAMED COLLECTION collection2 SET key1=4, DELETE key3
|
||||
```
|
||||
|
||||
To force a key to use the default settings for the `overridable` flag, you have to
|
||||
remove and re-add the key.
|
||||
```sql
|
||||
ALTER NAMED COLLECTION collection2 DELETE key1;
|
||||
ALTER NAMED COLLECTION collection2 SET key1=4;
|
||||
```
|
||||
|
||||
### Drop the DDL named collection `collection2`:
|
||||
```sql
|
||||
DROP NAMED COLLECTION collection2
|
||||
|
@ -1769,7 +1769,7 @@ Example of settings:
|
||||
<password>qwerty123</password>
|
||||
<keyspase>database_name</keyspase>
|
||||
<column_family>table_name</column_family>
|
||||
<allow_filering>1</allow_filering>
|
||||
<allow_filtering>1</allow_filtering>
|
||||
<partition_key_prefix>1</partition_key_prefix>
|
||||
<consistency>One</consistency>
|
||||
<where>"SomeColumn" = 42</where>
|
||||
@ -1787,7 +1787,7 @@ Setting fields:
|
||||
- `password` – Password of the Cassandra user.
|
||||
- `keyspace` – Name of the keyspace (database).
|
||||
- `column_family` – Name of the column family (table).
|
||||
- `allow_filering` – Flag to allow or not potentially expensive conditions on clustering key columns. Default value is 1.
|
||||
- `allow_filtering` – Flag to allow or not potentially expensive conditions on clustering key columns. Default value is 1.
|
||||
- `partition_key_prefix` – Number of partition key columns in primary key of the Cassandra table. Required for compose key dictionaries. Order of key columns in the dictionary definition must be the same as in Cassandra. Default value is 1 (the first key column is a partition key and other key columns are clustering key).
|
||||
- `consistency` – Consistency level. Possible values: `One`, `Two`, `Three`, `All`, `EachQuorum`, `Quorum`, `LocalQuorum`, `LocalOne`, `Serial`, `LocalSerial`. Default value is `One`.
|
||||
- `where` – Optional selection criteria.
|
||||
|
@ -12,9 +12,9 @@ This query intends to modify already existing named collections.
|
||||
```sql
|
||||
ALTER NAMED COLLECTION [IF EXISTS] name [ON CLUSTER cluster]
|
||||
[ SET
|
||||
key_name1 = 'some value',
|
||||
key_name2 = 'some value',
|
||||
key_name3 = 'some value',
|
||||
key_name1 = 'some value' [[NOT] OVERRIDABLE],
|
||||
key_name2 = 'some value' [[NOT] OVERRIDABLE],
|
||||
key_name3 = 'some value' [[NOT] OVERRIDABLE],
|
||||
... ] |
|
||||
[ DELETE key_name4, key_name5, ... ]
|
||||
```
|
||||
@ -22,9 +22,9 @@ key_name3 = 'some value',
|
||||
**Example**
|
||||
|
||||
```sql
|
||||
CREATE NAMED COLLECTION foobar AS a = '1', b = '2';
|
||||
CREATE NAMED COLLECTION foobar AS a = '1' NOT OVERRIDABLE, b = '2';
|
||||
|
||||
ALTER NAMED COLLECTION foobar SET a = '2', c = '3';
|
||||
ALTER NAMED COLLECTION foobar SET a = '2' OVERRIDABLE, c = '3';
|
||||
|
||||
ALTER NAMED COLLECTION foobar DELETE b;
|
||||
```
|
||||
|
@ -11,16 +11,16 @@ Creates a new named collection.
|
||||
|
||||
```sql
|
||||
CREATE NAMED COLLECTION [IF NOT EXISTS] name [ON CLUSTER cluster] AS
|
||||
key_name1 = 'some value',
|
||||
key_name2 = 'some value',
|
||||
key_name3 = 'some value',
|
||||
key_name1 = 'some value' [[NOT] OVERRIDABLE],
|
||||
key_name2 = 'some value' [[NOT] OVERRIDABLE],
|
||||
key_name3 = 'some value' [[NOT] OVERRIDABLE],
|
||||
...
|
||||
```
|
||||
|
||||
**Example**
|
||||
|
||||
```sql
|
||||
CREATE NAMED COLLECTION foobar AS a = '1', b = '2';
|
||||
CREATE NAMED COLLECTION foobar AS a = '1', b = '2' OVERRIDABLE;
|
||||
```
|
||||
|
||||
**Related statements**
|
||||
|
@ -723,7 +723,7 @@ SOURCE(REDIS(
|
||||
<password>qwerty123</password>
|
||||
<keyspase>database_name</keyspase>
|
||||
<column_family>table_name</column_family>
|
||||
<allow_filering>1</allow_filering>
|
||||
<allow_filtering>1</allow_filtering>
|
||||
<partition_key_prefix>1</partition_key_prefix>
|
||||
<consistency>One</consistency>
|
||||
<where>"SomeColumn" = 42</where>
|
||||
@ -741,7 +741,7 @@ SOURCE(REDIS(
|
||||
- `password` – пароль для соединения с Cassandra.
|
||||
- `keyspace` – имя keyspace (база данных).
|
||||
- `column_family` – имя семейства столбцов (таблица).
|
||||
- `allow_filering` – флаг, разрешающий или не разрешающий потенциально дорогостоящие условия на кластеризации ключевых столбцов. Значение по умолчанию: 1.
|
||||
- `allow_filtering` – флаг, разрешающий или не разрешающий потенциально дорогостоящие условия на кластеризации ключевых столбцов. Значение по умолчанию: 1.
|
||||
- `partition_key_prefix` – количество партиций ключевых столбцов в первичном ключе таблицы Cassandra.
|
||||
Необходимо для составления ключей словаря. Порядок ключевых столбцов в определении словаря должен быть таким же, как в Cassandra.
|
||||
Значение по умолчанию: 1 (первый ключевой столбец - это ключ партицирования, остальные ключевые столбцы - ключи кластеризации).
|
||||
|
@ -425,7 +425,7 @@ void Client::connect()
|
||||
if (hosts_and_ports.empty())
|
||||
{
|
||||
String host = config().getString("host", "localhost");
|
||||
UInt16 port = ConnectionParameters::getPortFromConfig(config());
|
||||
UInt16 port = ConnectionParameters::getPortFromConfig(config(), host);
|
||||
hosts_and_ports.emplace_back(HostAndPort{host, port});
|
||||
}
|
||||
|
||||
|
@ -424,7 +424,7 @@ void LocalServer::setupUsers()
|
||||
|
||||
void LocalServer::connect()
|
||||
{
|
||||
connection_parameters = ConnectionParameters(config());
|
||||
connection_parameters = ConnectionParameters(config(), "localhost");
|
||||
connection = LocalConnection::createConnection(
|
||||
connection_parameters, global_context, need_render_progress, need_render_profile_events, server_display_name);
|
||||
}
|
||||
|
@ -325,7 +325,7 @@
|
||||
Query can upscale to desired number of threads during execution if more threads become available.
|
||||
-->
|
||||
<concurrent_threads_soft_limit_num>0</concurrent_threads_soft_limit_num>
|
||||
<concurrent_threads_soft_limit_ratio_to_cores>0</concurrent_threads_soft_limit_ratio_to_cores>
|
||||
<concurrent_threads_soft_limit_ratio_to_cores>2</concurrent_threads_soft_limit_ratio_to_cores>
|
||||
|
||||
<!-- Maximum number of concurrent queries. -->
|
||||
<max_concurrent_queries>1000</max_concurrent_queries>
|
||||
|
@ -44,7 +44,7 @@ private:
|
||||
krb5_ccache defcache = nullptr;
|
||||
krb5_get_init_creds_opt * options = nullptr;
|
||||
// Credentials structure including ticket, session key, and lifetime info.
|
||||
krb5_creds my_creds;
|
||||
krb5_creds my_creds {};
|
||||
krb5_keytab keytab = nullptr;
|
||||
krb5_principal defcache_princ = nullptr;
|
||||
String fmtError(krb5_error_code code) const;
|
||||
|
@ -12,7 +12,7 @@
|
||||
#include <Common/Config/ConfigReloader.h>
|
||||
#include <Common/StringUtils/StringUtils.h>
|
||||
#include <Common/quoteString.h>
|
||||
#include <Common/TransformEndianness.hpp>
|
||||
#include <Common/transformEndianness.h>
|
||||
#include <Core/Settings.h>
|
||||
#include <Interpreters/executeQuery.h>
|
||||
#include <Parsers/Access/ASTGrantQuery.h>
|
||||
|
@ -1,7 +1,18 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionAnalysisOfVariance.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
|
||||
#include <IO/VarInt.h>
|
||||
|
||||
#include <array>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <Columns/ColumnNullable.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/Moments.h>
|
||||
#include <Common/NaNUtils.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
@ -13,6 +24,82 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
using AggregateFunctionAnalysisOfVarianceData = AnalysisOfVarianceMoments<Float64>;
|
||||
|
||||
|
||||
/// One way analysis of variance
|
||||
/// Provides a statistical test of whether two or more population means are equal (null hypothesis)
|
||||
/// Has an assumption that subjects from group i have normal distribution.
|
||||
/// Accepts two arguments - a value and a group number which this value belongs to.
|
||||
/// Groups are enumerated starting from 0 and there should be at least two groups to perform a test
|
||||
/// Moreover there should be at least one group with the number of observations greater than one.
|
||||
class AggregateFunctionAnalysisOfVariance final : public IAggregateFunctionDataHelper<AggregateFunctionAnalysisOfVarianceData, AggregateFunctionAnalysisOfVariance>
|
||||
{
|
||||
public:
|
||||
explicit AggregateFunctionAnalysisOfVariance(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper(arguments, params, createResultType())
|
||||
{}
|
||||
|
||||
DataTypePtr createResultType() const
|
||||
{
|
||||
DataTypes types {std::make_shared<DataTypeNumber<Float64>>(), std::make_shared<DataTypeNumber<Float64>>() };
|
||||
Strings names {"f_statistic", "p_value"};
|
||||
return std::make_shared<DataTypeTuple>(
|
||||
std::move(types),
|
||||
std::move(names)
|
||||
);
|
||||
}
|
||||
|
||||
String getName() const override { return "analysisOfVariance"; }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
data(place).add(columns[0]->getFloat64(row_num), columns[1]->getUInt(row_num));
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
data(place).merge(data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
data(place).write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
data(place).read(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
auto f_stat = data(place).getFStatistic();
|
||||
|
||||
auto & column_tuple = assert_cast<ColumnTuple &>(to);
|
||||
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
|
||||
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
|
||||
|
||||
if (unlikely(!std::isfinite(f_stat) || f_stat < 0))
|
||||
{
|
||||
column_stat.getData().push_back(std::numeric_limits<Float64>::quiet_NaN());
|
||||
column_value.getData().push_back(std::numeric_limits<Float64>::quiet_NaN());
|
||||
return;
|
||||
}
|
||||
|
||||
auto p_value = data(place).getPValue(f_stat);
|
||||
|
||||
/// Because p-value is a probability.
|
||||
p_value = std::min(1.0, std::max(0.0, p_value));
|
||||
|
||||
column_stat.getData().push_back(f_stat);
|
||||
column_value.getData().push_back(p_value);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionAnalysisOfVariance(const std::string & name, const DataTypes & arguments, const Array & parameters, const Settings *)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
|
@ -1,97 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <IO/VarInt.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
#include <array>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <Columns/ColumnNullable.h>
|
||||
#include <Columns/ColumnsCommon.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/Moments.h>
|
||||
#include "Common/NaNUtils.h"
|
||||
#include <Common/assert_cast.h>
|
||||
#include <Core/Types.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
using AggregateFunctionAnalysisOfVarianceData = AnalysisOfVarianceMoments<Float64>;
|
||||
|
||||
|
||||
/// One way analysis of variance
|
||||
/// Provides a statistical test of whether two or more population means are equal (null hypothesis)
|
||||
/// Has an assumption that subjects from group i have normal distribution.
|
||||
/// Accepts two arguments - a value and a group number which this value belongs to.
|
||||
/// Groups are enumerated starting from 0 and there should be at least two groups to perform a test
|
||||
/// Moreover there should be at least one group with the number of observations greater than one.
|
||||
class AggregateFunctionAnalysisOfVariance final : public IAggregateFunctionDataHelper<AggregateFunctionAnalysisOfVarianceData, AggregateFunctionAnalysisOfVariance>
|
||||
{
|
||||
public:
|
||||
explicit AggregateFunctionAnalysisOfVariance(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper(arguments, params, createResultType())
|
||||
{}
|
||||
|
||||
DataTypePtr createResultType() const
|
||||
{
|
||||
DataTypes types {std::make_shared<DataTypeNumber<Float64>>(), std::make_shared<DataTypeNumber<Float64>>() };
|
||||
Strings names {"f_statistic", "p_value"};
|
||||
return std::make_shared<DataTypeTuple>(
|
||||
std::move(types),
|
||||
std::move(names)
|
||||
);
|
||||
}
|
||||
|
||||
String getName() const override { return "analysisOfVariance"; }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
data(place).add(columns[0]->getFloat64(row_num), columns[1]->getUInt(row_num));
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
data(place).merge(data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
data(place).write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
data(place).read(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
auto f_stat = data(place).getFStatistic();
|
||||
|
||||
auto & column_tuple = assert_cast<ColumnTuple &>(to);
|
||||
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
|
||||
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
|
||||
|
||||
if (unlikely(!std::isfinite(f_stat) || f_stat < 0))
|
||||
{
|
||||
column_stat.getData().push_back(std::numeric_limits<Float64>::quiet_NaN());
|
||||
column_value.getData().push_back(std::numeric_limits<Float64>::quiet_NaN());
|
||||
return;
|
||||
}
|
||||
|
||||
auto p_value = data(place).getPValue(f_stat);
|
||||
|
||||
/// Because p-value is a probability.
|
||||
p_value = std::min(1.0, std::max(0.0, p_value));
|
||||
|
||||
column_stat.getData().push_back(f_stat);
|
||||
column_value.getData().push_back(p_value);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
@ -1,12 +1,14 @@
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <AggregateFunctions/AggregateFunctionAvg.h>
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionAvgWeighted.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
@ -16,13 +18,93 @@ namespace ErrorCodes
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
using AvgWeightedFieldType = std::conditional_t<DecimalOrExtendedInt<T>,
|
||||
Float64, // no way to do UInt128 * UInt128, better cast to Float64
|
||||
NearestFieldType<T>>;
|
||||
|
||||
template <typename T, typename U>
|
||||
using MaxFieldType = std::conditional_t<(sizeof(AvgWeightedFieldType<T>) > sizeof(AvgWeightedFieldType<U>)),
|
||||
AvgWeightedFieldType<T>, AvgWeightedFieldType<U>>;
|
||||
|
||||
template <typename Value, typename Weight>
|
||||
class AggregateFunctionAvgWeighted final :
|
||||
public AggregateFunctionAvgBase<
|
||||
MaxFieldType<Value, Weight>, AvgWeightedFieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>
|
||||
{
|
||||
public:
|
||||
using Base = AggregateFunctionAvgBase<
|
||||
MaxFieldType<Value, Weight>, AvgWeightedFieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>;
|
||||
using Base::Base;
|
||||
|
||||
using Numerator = typename Base::Numerator;
|
||||
using Denominator = typename Base::Denominator;
|
||||
using Fraction = typename Base::Fraction;
|
||||
|
||||
void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
const auto & weights = static_cast<const ColumnVector<Weight> &>(*columns[1]);
|
||||
|
||||
this->data(place).numerator += static_cast<Numerator>(
|
||||
static_cast<const ColumnVector<Value> &>(*columns[0]).getData()[row_num])
|
||||
* static_cast<Numerator>(weights.getData()[row_num]);
|
||||
|
||||
this->data(place).denominator += static_cast<Denominator>(weights.getData()[row_num]);
|
||||
}
|
||||
|
||||
String getName() const override { return "avgWeighted"; }
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
bool isCompilable() const override
|
||||
{
|
||||
bool can_be_compiled = Base::isCompilable();
|
||||
can_be_compiled &= canBeNativeType<Weight>();
|
||||
|
||||
return can_be_compiled;
|
||||
}
|
||||
|
||||
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * numerator_type = toNativeType<Numerator>(b);
|
||||
auto * numerator_ptr = aggregate_data_ptr;
|
||||
auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr);
|
||||
|
||||
auto numerator_data_type = toNativeDataType<Numerator>();
|
||||
auto * argument = nativeCast(b, arguments[0], numerator_data_type);
|
||||
auto * weight = nativeCast(b, arguments[1], numerator_data_type);
|
||||
|
||||
llvm::Value * value_weight_multiplication = argument->getType()->isIntegerTy() ? b.CreateMul(argument, weight) : b.CreateFMul(argument, weight);
|
||||
auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_value, value_weight_multiplication) : b.CreateFAdd(numerator_value, value_weight_multiplication);
|
||||
b.CreateStore(numerator_result_value, numerator_ptr);
|
||||
|
||||
auto * denominator_type = toNativeType<Denominator>(b);
|
||||
|
||||
static constexpr size_t denominator_offset = offsetof(Fraction, denominator);
|
||||
auto * denominator_ptr = b.CreateConstInBoundsGEP1_64(b.getInt8Ty(), aggregate_data_ptr, denominator_offset);
|
||||
|
||||
auto * weight_cast_to_denominator = nativeCast(b, arguments[1], toNativeDataType<Denominator>());
|
||||
|
||||
auto * denominator_value = b.CreateLoad(denominator_type, denominator_ptr);
|
||||
auto * denominator_value_updated = denominator_type->isIntegerTy() ? b.CreateAdd(denominator_value, weight_cast_to_denominator) : b.CreateFAdd(denominator_value, weight_cast_to_denominator);
|
||||
|
||||
b.CreateStore(denominator_value_updated, denominator_ptr);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
};
|
||||
|
||||
bool allowTypes(const DataTypePtr& left, const DataTypePtr& right) noexcept
|
||||
{
|
||||
const WhichDataType l_dt(left), r_dt(right);
|
||||
|
||||
constexpr auto allow = [](WhichDataType t)
|
||||
{
|
||||
return t.isInt() || t.isUInt() || t.isFloat() || t.isDecimal();
|
||||
return t.isInt() || t.isUInt() || t.isFloat();
|
||||
};
|
||||
|
||||
return allow(l_dt) && allow(r_dt);
|
||||
@ -33,7 +115,6 @@ bool allowTypes(const DataTypePtr& left, const DataTypePtr& right) noexcept
|
||||
{ \
|
||||
LINE(Int8); LINE(Int16); LINE(Int32); LINE(Int64); LINE(Int128); LINE(Int256); \
|
||||
LINE(UInt8); LINE(UInt16); LINE(UInt32); LINE(UInt64); LINE(UInt128); LINE(UInt256); \
|
||||
LINE(Decimal32); LINE(Decimal64); LINE(Decimal128); LINE(Decimal256); \
|
||||
LINE(Float32); LINE(Float64); \
|
||||
default: return nullptr; \
|
||||
}
|
||||
@ -75,31 +156,14 @@ createAggregateFunctionAvgWeighted(const std::string & name, const DataTypes & a
|
||||
"Types {} and {} are non-conforming as arguments for aggregate function {}",
|
||||
data_type->getName(), data_type_weight->getName(), name);
|
||||
|
||||
AggregateFunctionPtr ptr;
|
||||
|
||||
const bool left_decimal = isDecimal(data_type);
|
||||
const bool right_decimal = isDecimal(data_type_weight);
|
||||
|
||||
/// We multiply value by weight, so actual scale of numerator is <scale of value> + <scale of weight>
|
||||
if (left_decimal && right_decimal)
|
||||
ptr.reset(create(*data_type, *data_type_weight,
|
||||
argument_types,
|
||||
getDecimalScale(*data_type) + getDecimalScale(*data_type_weight), getDecimalScale(*data_type_weight)));
|
||||
else if (left_decimal)
|
||||
ptr.reset(create(*data_type, *data_type_weight, argument_types,
|
||||
getDecimalScale(*data_type)));
|
||||
else if (right_decimal)
|
||||
ptr.reset(create(*data_type, *data_type_weight, argument_types,
|
||||
getDecimalScale(*data_type_weight), getDecimalScale(*data_type_weight)));
|
||||
else
|
||||
ptr.reset(create(*data_type, *data_type_weight, argument_types));
|
||||
|
||||
return ptr;
|
||||
return AggregateFunctionPtr(create(*data_type, *data_type_weight, argument_types));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionAvgWeighted(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("avgWeighted", createAggregateFunctionAvgWeighted, AggregateFunctionFactory::CaseSensitive);
|
||||
factory.registerFunction("avgWeighted", createAggregateFunctionAvgWeighted);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,90 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include <AggregateFunctions/AggregateFunctionAvg.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
template <typename T>
|
||||
using AvgWeightedFieldType = std::conditional_t<is_decimal<T>,
|
||||
std::conditional_t<std::is_same_v<T, Decimal256>, Decimal256, Decimal128>,
|
||||
std::conditional_t<DecimalOrExtendedInt<T>,
|
||||
Float64, // no way to do UInt128 * UInt128, better cast to Float64
|
||||
NearestFieldType<T>>>;
|
||||
|
||||
template <typename T, typename U>
|
||||
using MaxFieldType = std::conditional_t<(sizeof(AvgWeightedFieldType<T>) > sizeof(AvgWeightedFieldType<U>)),
|
||||
AvgWeightedFieldType<T>, AvgWeightedFieldType<U>>;
|
||||
|
||||
template <typename Value, typename Weight>
|
||||
class AggregateFunctionAvgWeighted final :
|
||||
public AggregateFunctionAvgBase<
|
||||
MaxFieldType<Value, Weight>, AvgWeightedFieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>
|
||||
{
|
||||
public:
|
||||
using Base = AggregateFunctionAvgBase<
|
||||
MaxFieldType<Value, Weight>, AvgWeightedFieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>;
|
||||
using Base::Base;
|
||||
|
||||
using Numerator = typename Base::Numerator;
|
||||
using Denominator = typename Base::Denominator;
|
||||
using Fraction = typename Base::Fraction;
|
||||
|
||||
void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
const auto& weights = static_cast<const ColumnVectorOrDecimal<Weight> &>(*columns[1]);
|
||||
|
||||
this->data(place).numerator += static_cast<Numerator>(
|
||||
static_cast<const ColumnVectorOrDecimal<Value> &>(*columns[0]).getData()[row_num]) *
|
||||
static_cast<Numerator>(weights.getData()[row_num]);
|
||||
|
||||
this->data(place).denominator += static_cast<Denominator>(weights.getData()[row_num]);
|
||||
}
|
||||
|
||||
String getName() const override { return "avgWeighted"; }
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
bool isCompilable() const override
|
||||
{
|
||||
bool can_be_compiled = Base::isCompilable();
|
||||
can_be_compiled &= canBeNativeType<Weight>();
|
||||
|
||||
return can_be_compiled;
|
||||
}
|
||||
|
||||
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * numerator_type = toNativeType<Numerator>(b);
|
||||
auto * numerator_ptr = aggregate_data_ptr;
|
||||
auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr);
|
||||
|
||||
auto numerator_data_type = toNativeDataType<Numerator>();
|
||||
auto * argument = nativeCast(b, arguments[0], numerator_data_type);
|
||||
auto * weight = nativeCast(b, arguments[1], numerator_data_type);
|
||||
|
||||
llvm::Value * value_weight_multiplication = argument->getType()->isIntegerTy() ? b.CreateMul(argument, weight) : b.CreateFMul(argument, weight);
|
||||
auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_value, value_weight_multiplication) : b.CreateFAdd(numerator_value, value_weight_multiplication);
|
||||
b.CreateStore(numerator_result_value, numerator_ptr);
|
||||
|
||||
auto * denominator_type = toNativeType<Denominator>(b);
|
||||
|
||||
static constexpr size_t denominator_offset = offsetof(Fraction, denominator);
|
||||
auto * denominator_ptr = b.CreateConstInBoundsGEP1_64(b.getInt8Ty(), aggregate_data_ptr, denominator_offset);
|
||||
|
||||
auto * weight_cast_to_denominator = nativeCast(b, arguments[1], toNativeDataType<Denominator>());
|
||||
|
||||
auto * denominator_value = b.CreateLoad(denominator_type, denominator_ptr);
|
||||
auto * denominator_value_updated = denominator_type->isIntegerTy() ? b.CreateAdd(denominator_value, weight_cast_to_denominator) : b.CreateFAdd(denominator_value, weight_cast_to_denominator);
|
||||
|
||||
b.CreateStore(denominator_value_updated, denominator_ptr);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
};
|
||||
}
|
@ -1,11 +1,27 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionBitwise.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
#include "config.h"
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
# include <llvm/IR/IRBuilder.h>
|
||||
# include <DataTypes/Native.h>
|
||||
#endif
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
@ -16,6 +32,179 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
struct AggregateFunctionGroupBitOrData
|
||||
{
|
||||
T value = 0;
|
||||
static const char * name() { return "groupBitOr"; }
|
||||
void update(T x) { value |= x; }
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
static void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * value_ptr)
|
||||
{
|
||||
auto type = toNativeType<T>(builder);
|
||||
builder.CreateStore(llvm::Constant::getNullValue(type), value_ptr);
|
||||
}
|
||||
|
||||
static llvm::Value* compileUpdate(llvm::IRBuilderBase & builder, llvm::Value * lhs, llvm::Value * rhs)
|
||||
{
|
||||
return builder.CreateOr(lhs, rhs);
|
||||
}
|
||||
|
||||
#endif
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AggregateFunctionGroupBitAndData
|
||||
{
|
||||
T value = -1; /// Two's complement arithmetic, sign extension.
|
||||
static const char * name() { return "groupBitAnd"; }
|
||||
void update(T x) { value &= x; }
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
static void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * value_ptr)
|
||||
{
|
||||
auto type = toNativeType<T>(builder);
|
||||
builder.CreateStore(llvm::ConstantInt::get(type, -1), value_ptr);
|
||||
}
|
||||
|
||||
static llvm::Value* compileUpdate(llvm::IRBuilderBase & builder, llvm::Value * lhs, llvm::Value * rhs)
|
||||
{
|
||||
return builder.CreateAnd(lhs, rhs);
|
||||
}
|
||||
|
||||
#endif
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AggregateFunctionGroupBitXorData
|
||||
{
|
||||
T value = 0;
|
||||
static const char * name() { return "groupBitXor"; }
|
||||
void update(T x) { value ^= x; }
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
static void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * value_ptr)
|
||||
{
|
||||
auto type = toNativeType<T>(builder);
|
||||
builder.CreateStore(llvm::Constant::getNullValue(type), value_ptr);
|
||||
}
|
||||
|
||||
static llvm::Value* compileUpdate(llvm::IRBuilderBase & builder, llvm::Value * lhs, llvm::Value * rhs)
|
||||
{
|
||||
return builder.CreateXor(lhs, rhs);
|
||||
}
|
||||
|
||||
#endif
|
||||
};
|
||||
|
||||
|
||||
/// Counts bitwise operation on numbers.
|
||||
template <typename T, typename Data>
|
||||
class AggregateFunctionBitwise final : public IAggregateFunctionDataHelper<Data, AggregateFunctionBitwise<T, Data>>
|
||||
{
|
||||
public:
|
||||
explicit AggregateFunctionBitwise(const DataTypePtr & type)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitwise<T, Data>>({type}, {}, createResultType())
|
||||
{}
|
||||
|
||||
String getName() const override { return Data::name(); }
|
||||
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
return std::make_shared<DataTypeNumber<T>>();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
this->data(place).update(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).update(this->data(rhs).value);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
writeBinary(this->data(place).value, buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
readBinary(this->data(place).value, buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
assert_cast<ColumnVector<T> &>(to).getData().push_back(this->data(place).value);
|
||||
}
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
bool isCompilable() const override
|
||||
{
|
||||
auto return_type = this->getResultType();
|
||||
return canBeNativeType(*return_type);
|
||||
}
|
||||
|
||||
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
|
||||
{
|
||||
auto * value_ptr = aggregate_data_ptr;
|
||||
Data::compileCreate(builder, value_ptr);
|
||||
}
|
||||
|
||||
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, this->getResultType());
|
||||
|
||||
auto * value_ptr = aggregate_data_ptr;
|
||||
auto * value = b.CreateLoad(return_type, value_ptr);
|
||||
|
||||
auto * result_value = Data::compileUpdate(builder, value, arguments[0].value);
|
||||
|
||||
b.CreateStore(result_value, value_ptr);
|
||||
}
|
||||
|
||||
void compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, this->getResultType());
|
||||
|
||||
auto * value_dst_ptr = aggregate_data_dst_ptr;
|
||||
auto * value_dst = b.CreateLoad(return_type, value_dst_ptr);
|
||||
|
||||
auto * value_src_ptr = aggregate_data_src_ptr;
|
||||
auto * value_src = b.CreateLoad(return_type, value_src_ptr);
|
||||
|
||||
auto * result_value = Data::compileUpdate(builder, value_dst, value_src);
|
||||
|
||||
b.CreateStore(result_value, value_dst_ptr);
|
||||
}
|
||||
|
||||
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, this->getResultType());
|
||||
auto * value_ptr = aggregate_data_ptr;
|
||||
|
||||
return b.CreateLoad(return_type, value_ptr);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
};
|
||||
|
||||
|
||||
template <template <typename> class Data>
|
||||
AggregateFunctionPtr createAggregateFunctionBitwise(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
|
@ -1,197 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
#include "config.h"
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
# include <llvm/IR/IRBuilder.h>
|
||||
# include <DataTypes/Native.h>
|
||||
#endif
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct AggregateFunctionGroupBitOrData
|
||||
{
|
||||
T value = 0;
|
||||
static const char * name() { return "groupBitOr"; }
|
||||
void update(T x) { value |= x; }
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
static void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * value_ptr)
|
||||
{
|
||||
auto type = toNativeType<T>(builder);
|
||||
builder.CreateStore(llvm::Constant::getNullValue(type), value_ptr);
|
||||
}
|
||||
|
||||
static llvm::Value* compileUpdate(llvm::IRBuilderBase & builder, llvm::Value * lhs, llvm::Value * rhs)
|
||||
{
|
||||
return builder.CreateOr(lhs, rhs);
|
||||
}
|
||||
|
||||
#endif
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AggregateFunctionGroupBitAndData
|
||||
{
|
||||
T value = -1; /// Two's complement arithmetic, sign extension.
|
||||
static const char * name() { return "groupBitAnd"; }
|
||||
void update(T x) { value &= x; }
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
static void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * value_ptr)
|
||||
{
|
||||
auto type = toNativeType<T>(builder);
|
||||
builder.CreateStore(llvm::ConstantInt::get(type, -1), value_ptr);
|
||||
}
|
||||
|
||||
static llvm::Value* compileUpdate(llvm::IRBuilderBase & builder, llvm::Value * lhs, llvm::Value * rhs)
|
||||
{
|
||||
return builder.CreateAnd(lhs, rhs);
|
||||
}
|
||||
|
||||
#endif
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AggregateFunctionGroupBitXorData
|
||||
{
|
||||
T value = 0;
|
||||
static const char * name() { return "groupBitXor"; }
|
||||
void update(T x) { value ^= x; }
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
static void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * value_ptr)
|
||||
{
|
||||
auto type = toNativeType<T>(builder);
|
||||
builder.CreateStore(llvm::Constant::getNullValue(type), value_ptr);
|
||||
}
|
||||
|
||||
static llvm::Value* compileUpdate(llvm::IRBuilderBase & builder, llvm::Value * lhs, llvm::Value * rhs)
|
||||
{
|
||||
return builder.CreateXor(lhs, rhs);
|
||||
}
|
||||
|
||||
#endif
|
||||
};
|
||||
|
||||
|
||||
/// Counts bitwise operation on numbers.
|
||||
template <typename T, typename Data>
|
||||
class AggregateFunctionBitwise final : public IAggregateFunctionDataHelper<Data, AggregateFunctionBitwise<T, Data>>
|
||||
{
|
||||
public:
|
||||
explicit AggregateFunctionBitwise(const DataTypePtr & type)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitwise<T, Data>>({type}, {}, createResultType())
|
||||
{}
|
||||
|
||||
String getName() const override { return Data::name(); }
|
||||
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
return std::make_shared<DataTypeNumber<T>>();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
this->data(place).update(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).update(this->data(rhs).value);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
writeBinary(this->data(place).value, buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
readBinary(this->data(place).value, buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
assert_cast<ColumnVector<T> &>(to).getData().push_back(this->data(place).value);
|
||||
}
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
bool isCompilable() const override
|
||||
{
|
||||
auto return_type = this->getResultType();
|
||||
return canBeNativeType(*return_type);
|
||||
}
|
||||
|
||||
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
|
||||
{
|
||||
auto * value_ptr = aggregate_data_ptr;
|
||||
Data::compileCreate(builder, value_ptr);
|
||||
}
|
||||
|
||||
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, this->getResultType());
|
||||
|
||||
auto * value_ptr = aggregate_data_ptr;
|
||||
auto * value = b.CreateLoad(return_type, value_ptr);
|
||||
|
||||
auto * result_value = Data::compileUpdate(builder, value, arguments[0].value);
|
||||
|
||||
b.CreateStore(result_value, value_ptr);
|
||||
}
|
||||
|
||||
void compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, this->getResultType());
|
||||
|
||||
auto * value_dst_ptr = aggregate_data_dst_ptr;
|
||||
auto * value_dst = b.CreateLoad(return_type, value_dst_ptr);
|
||||
|
||||
auto * value_src_ptr = aggregate_data_src_ptr;
|
||||
auto * value_src = b.CreateLoad(return_type, value_src_ptr);
|
||||
|
||||
auto * result_value = Data::compileUpdate(builder, value_dst, value_src);
|
||||
|
||||
b.CreateStore(result_value, value_dst_ptr);
|
||||
}
|
||||
|
||||
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, this->getResultType());
|
||||
auto * value_ptr = aggregate_data_ptr;
|
||||
|
||||
return b.CreateLoad(return_type, value_ptr);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
};
|
||||
|
||||
|
||||
}
|
@ -1,7 +1,14 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionBoundingRatio.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <Common/transformEndianness.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -10,11 +17,169 @@ struct Settings;
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
/** Tracks the leftmost and rightmost (x, y) data points.
|
||||
*/
|
||||
struct AggregateFunctionBoundingRatioData
|
||||
{
|
||||
struct Point
|
||||
{
|
||||
Float64 x;
|
||||
Float64 y;
|
||||
};
|
||||
|
||||
bool empty = true;
|
||||
Point left;
|
||||
Point right;
|
||||
|
||||
void add(Float64 x, Float64 y)
|
||||
{
|
||||
Point point{x, y};
|
||||
|
||||
if (empty)
|
||||
{
|
||||
left = point;
|
||||
right = point;
|
||||
empty = false;
|
||||
}
|
||||
else if (point.x < left.x)
|
||||
{
|
||||
left = point;
|
||||
}
|
||||
else if (point.x > right.x)
|
||||
{
|
||||
right = point;
|
||||
}
|
||||
}
|
||||
|
||||
void merge(const AggregateFunctionBoundingRatioData & other)
|
||||
{
|
||||
if (empty)
|
||||
{
|
||||
*this = other;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (other.left.x < left.x)
|
||||
left = other.left;
|
||||
if (other.right.x > right.x)
|
||||
right = other.right;
|
||||
}
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const;
|
||||
void deserialize(ReadBuffer & buf);
|
||||
};
|
||||
|
||||
template <std::endian endian>
|
||||
inline void transformEndianness(AggregateFunctionBoundingRatioData::Point & p)
|
||||
{
|
||||
DB::transformEndianness<endian>(p.x);
|
||||
DB::transformEndianness<endian>(p.y);
|
||||
}
|
||||
|
||||
void AggregateFunctionBoundingRatioData::serialize(WriteBuffer & buf) const
|
||||
{
|
||||
writeBinaryLittleEndian(empty, buf);
|
||||
|
||||
if (!empty)
|
||||
{
|
||||
writeBinaryLittleEndian(left, buf);
|
||||
writeBinaryLittleEndian(right, buf);
|
||||
}
|
||||
}
|
||||
|
||||
void AggregateFunctionBoundingRatioData::deserialize(ReadBuffer & buf)
|
||||
{
|
||||
readBinaryLittleEndian(empty, buf);
|
||||
|
||||
if (!empty)
|
||||
{
|
||||
readBinaryLittleEndian(left, buf);
|
||||
readBinaryLittleEndian(right, buf);
|
||||
}
|
||||
}
|
||||
|
||||
inline void writeBinary(const AggregateFunctionBoundingRatioData::Point & p, WriteBuffer & buf)
|
||||
{
|
||||
writePODBinary(p, buf);
|
||||
}
|
||||
|
||||
inline void readBinary(AggregateFunctionBoundingRatioData::Point & p, ReadBuffer & buf)
|
||||
{
|
||||
readPODBinary(p, buf);
|
||||
}
|
||||
|
||||
|
||||
class AggregateFunctionBoundingRatio final : public IAggregateFunctionDataHelper<AggregateFunctionBoundingRatioData, AggregateFunctionBoundingRatio>
|
||||
{
|
||||
private:
|
||||
/** Calculates the slope of a line between leftmost and rightmost data points.
|
||||
* (y2 - y1) / (x2 - x1)
|
||||
*/
|
||||
static Float64 NO_SANITIZE_UNDEFINED getBoundingRatio(const AggregateFunctionBoundingRatioData & data)
|
||||
{
|
||||
if (data.empty)
|
||||
return std::numeric_limits<Float64>::quiet_NaN();
|
||||
|
||||
return (data.right.y - data.left.y) / (data.right.x - data.left.x);
|
||||
}
|
||||
|
||||
public:
|
||||
String getName() const override
|
||||
{
|
||||
return "boundingRatio";
|
||||
}
|
||||
|
||||
explicit AggregateFunctionBoundingRatio(const DataTypes & arguments)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionBoundingRatioData, AggregateFunctionBoundingRatio>(arguments, {}, std::make_shared<DataTypeFloat64>())
|
||||
{
|
||||
const auto * x_arg = arguments.at(0).get();
|
||||
const auto * y_arg = arguments.at(1).get();
|
||||
|
||||
if (!x_arg->isValueRepresentedByNumber() || !y_arg->isValueRepresentedByNumber())
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS,
|
||||
"Illegal types of arguments of aggregate function {}, must have number representation.",
|
||||
getName());
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override
|
||||
{
|
||||
/// NOTE Slightly inefficient.
|
||||
const auto x = columns[0]->getFloat64(row_num);
|
||||
const auto y = columns[1]->getFloat64(row_num);
|
||||
data(place).add(x, y);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
data(place).merge(data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
assert_cast<ColumnFloat64 &>(to).getData().push_back(getBoundingRatio(data(place)));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionRate(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
|
@ -1,177 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
/** Tracks the leftmost and rightmost (x, y) data points.
|
||||
*/
|
||||
struct AggregateFunctionBoundingRatioData
|
||||
{
|
||||
struct Point
|
||||
{
|
||||
Float64 x;
|
||||
Float64 y;
|
||||
};
|
||||
|
||||
bool empty = true;
|
||||
Point left;
|
||||
Point right;
|
||||
|
||||
void add(Float64 x, Float64 y)
|
||||
{
|
||||
Point point{x, y};
|
||||
|
||||
if (empty)
|
||||
{
|
||||
left = point;
|
||||
right = point;
|
||||
empty = false;
|
||||
}
|
||||
else if (point.x < left.x)
|
||||
{
|
||||
left = point;
|
||||
}
|
||||
else if (point.x > right.x)
|
||||
{
|
||||
right = point;
|
||||
}
|
||||
}
|
||||
|
||||
void merge(const AggregateFunctionBoundingRatioData & other)
|
||||
{
|
||||
if (empty)
|
||||
{
|
||||
*this = other;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (other.left.x < left.x)
|
||||
left = other.left;
|
||||
if (other.right.x > right.x)
|
||||
right = other.right;
|
||||
}
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const;
|
||||
void deserialize(ReadBuffer & buf);
|
||||
};
|
||||
|
||||
template <std::endian endian>
|
||||
inline void transformEndianness(AggregateFunctionBoundingRatioData::Point & p)
|
||||
{
|
||||
transformEndianness<endian>(p.x);
|
||||
transformEndianness<endian>(p.y);
|
||||
}
|
||||
|
||||
void AggregateFunctionBoundingRatioData::serialize(WriteBuffer & buf) const
|
||||
{
|
||||
writeBinaryLittleEndian(empty, buf);
|
||||
|
||||
if (!empty)
|
||||
{
|
||||
writeBinaryLittleEndian(left, buf);
|
||||
writeBinaryLittleEndian(right, buf);
|
||||
}
|
||||
}
|
||||
|
||||
void AggregateFunctionBoundingRatioData::deserialize(ReadBuffer & buf)
|
||||
{
|
||||
readBinaryLittleEndian(empty, buf);
|
||||
|
||||
if (!empty)
|
||||
{
|
||||
readBinaryLittleEndian(left, buf);
|
||||
readBinaryLittleEndian(right, buf);
|
||||
}
|
||||
}
|
||||
|
||||
inline void writeBinary(const AggregateFunctionBoundingRatioData::Point & p, WriteBuffer & buf)
|
||||
{
|
||||
writePODBinary(p, buf);
|
||||
}
|
||||
|
||||
inline void readBinary(AggregateFunctionBoundingRatioData::Point & p, ReadBuffer & buf)
|
||||
{
|
||||
readPODBinary(p, buf);
|
||||
}
|
||||
|
||||
|
||||
class AggregateFunctionBoundingRatio final : public IAggregateFunctionDataHelper<AggregateFunctionBoundingRatioData, AggregateFunctionBoundingRatio>
|
||||
{
|
||||
private:
|
||||
/** Calculates the slope of a line between leftmost and rightmost data points.
|
||||
* (y2 - y1) / (x2 - x1)
|
||||
*/
|
||||
static Float64 NO_SANITIZE_UNDEFINED getBoundingRatio(const AggregateFunctionBoundingRatioData & data)
|
||||
{
|
||||
if (data.empty)
|
||||
return std::numeric_limits<Float64>::quiet_NaN();
|
||||
|
||||
return (data.right.y - data.left.y) / (data.right.x - data.left.x);
|
||||
}
|
||||
|
||||
public:
|
||||
String getName() const override
|
||||
{
|
||||
return "boundingRatio";
|
||||
}
|
||||
|
||||
explicit AggregateFunctionBoundingRatio(const DataTypes & arguments)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionBoundingRatioData, AggregateFunctionBoundingRatio>(arguments, {}, std::make_shared<DataTypeFloat64>())
|
||||
{
|
||||
const auto * x_arg = arguments.at(0).get();
|
||||
const auto * y_arg = arguments.at(1).get();
|
||||
|
||||
if (!x_arg->isValueRepresentedByNumber() || !y_arg->isValueRepresentedByNumber())
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS,
|
||||
"Illegal types of arguments of aggregate function {}, must have number representation.",
|
||||
getName());
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override
|
||||
{
|
||||
/// NOTE Slightly inefficient.
|
||||
const auto x = columns[0]->getFloat64(row_num);
|
||||
const auto y = columns[1]->getFloat64(row_num);
|
||||
data(place).add(x, y);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
data(place).merge(data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
assert_cast<ColumnFloat64 &>(to).getData().push_back(getBoundingRatio(data(place)));
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -1,9 +1,15 @@
|
||||
#include <AggregateFunctions/AggregateFunctionDeltaSum.h>
|
||||
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -18,6 +24,113 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
struct AggregationFunctionDeltaSumData
|
||||
{
|
||||
T sum = 0;
|
||||
T last = 0;
|
||||
T first = 0;
|
||||
bool seen = false;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class AggregationFunctionDeltaSum final
|
||||
: public IAggregateFunctionDataHelper<AggregationFunctionDeltaSumData<T>, AggregationFunctionDeltaSum<T>>
|
||||
{
|
||||
public:
|
||||
AggregationFunctionDeltaSum(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<AggregationFunctionDeltaSumData<T>, AggregationFunctionDeltaSum<T>>{arguments, params, createResultType()}
|
||||
{}
|
||||
|
||||
AggregationFunctionDeltaSum()
|
||||
: IAggregateFunctionDataHelper<AggregationFunctionDeltaSumData<T>, AggregationFunctionDeltaSum<T>>{}
|
||||
{}
|
||||
|
||||
String getName() const override { return "deltaSum"; }
|
||||
|
||||
static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<T>>(); }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void NO_SANITIZE_UNDEFINED ALWAYS_INLINE add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
auto value = assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];
|
||||
|
||||
if ((this->data(place).last < value) && this->data(place).seen)
|
||||
{
|
||||
this->data(place).sum += (value - this->data(place).last);
|
||||
}
|
||||
|
||||
this->data(place).last = value;
|
||||
|
||||
if (!this->data(place).seen)
|
||||
{
|
||||
this->data(place).first = value;
|
||||
this->data(place).seen = true;
|
||||
}
|
||||
}
|
||||
|
||||
void NO_SANITIZE_UNDEFINED merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
auto place_data = &this->data(place);
|
||||
auto rhs_data = &this->data(rhs);
|
||||
|
||||
if ((place_data->last < rhs_data->first) && place_data->seen && rhs_data->seen)
|
||||
{
|
||||
// If the lhs last number seen is less than the first number the rhs saw, the lhs is before
|
||||
// the rhs, for example [0, 2] [4, 7]. So we want to add the deltasums, but also add the
|
||||
// difference between lhs last number and rhs first number (the 2 and 4). Then we want to
|
||||
// take last value from the rhs, so first and last become 0 and 7.
|
||||
|
||||
place_data->sum += rhs_data->sum + (rhs_data->first - place_data->last);
|
||||
place_data->last = rhs_data->last;
|
||||
}
|
||||
else if ((rhs_data->first < place_data->last && rhs_data->seen && place_data->seen))
|
||||
{
|
||||
// In the opposite scenario, the lhs comes after the rhs, e.g. [4, 6] [1, 2]. Since we
|
||||
// assume the input interval states are sorted by time, we assume this is a counter
|
||||
// reset, and therefore do *not* add the difference between our first value and the
|
||||
// rhs last value.
|
||||
|
||||
place_data->sum += rhs_data->sum;
|
||||
place_data->last = rhs_data->last;
|
||||
}
|
||||
else if (rhs_data->seen && !place_data->seen)
|
||||
{
|
||||
// If we're here then the lhs is an empty state and the rhs does have some state, so
|
||||
// we'll just take that state.
|
||||
|
||||
place_data->first = rhs_data->first;
|
||||
place_data->last = rhs_data->last;
|
||||
place_data->sum = rhs_data->sum;
|
||||
place_data->seen = rhs_data->seen;
|
||||
}
|
||||
|
||||
// Otherwise lhs either has data or is uninitialized, so we don't need to modify its values.
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
writeBinaryLittleEndian(this->data(place).sum, buf);
|
||||
writeBinaryLittleEndian(this->data(place).first, buf);
|
||||
writeBinaryLittleEndian(this->data(place).last, buf);
|
||||
writeBinaryLittleEndian(this->data(place).seen, buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
readBinaryLittleEndian(this->data(place).sum, buf);
|
||||
readBinaryLittleEndian(this->data(place).first, buf);
|
||||
readBinaryLittleEndian(this->data(place).last, buf);
|
||||
readBinaryLittleEndian(this->data(place).seen, buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
assert_cast<ColumnVector<T> &>(to).getData().push_back(this->data(place).sum);
|
||||
}
|
||||
};
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionDeltaSum(
|
||||
const String & name,
|
||||
const DataTypes & arguments,
|
||||
|
@ -1,126 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <DataTypes/DataTypesDecimal.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
template <typename T>
|
||||
struct AggregationFunctionDeltaSumData
|
||||
{
|
||||
T sum = 0;
|
||||
T last = 0;
|
||||
T first = 0;
|
||||
bool seen = false;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class AggregationFunctionDeltaSum final
|
||||
: public IAggregateFunctionDataHelper<AggregationFunctionDeltaSumData<T>, AggregationFunctionDeltaSum<T>>
|
||||
{
|
||||
public:
|
||||
AggregationFunctionDeltaSum(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<AggregationFunctionDeltaSumData<T>, AggregationFunctionDeltaSum<T>>{arguments, params, createResultType()}
|
||||
{}
|
||||
|
||||
AggregationFunctionDeltaSum()
|
||||
: IAggregateFunctionDataHelper<AggregationFunctionDeltaSumData<T>, AggregationFunctionDeltaSum<T>>{}
|
||||
{}
|
||||
|
||||
String getName() const override { return "deltaSum"; }
|
||||
|
||||
static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<T>>(); }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void NO_SANITIZE_UNDEFINED ALWAYS_INLINE add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
auto value = assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];
|
||||
|
||||
if ((this->data(place).last < value) && this->data(place).seen)
|
||||
{
|
||||
this->data(place).sum += (value - this->data(place).last);
|
||||
}
|
||||
|
||||
this->data(place).last = value;
|
||||
|
||||
if (!this->data(place).seen)
|
||||
{
|
||||
this->data(place).first = value;
|
||||
this->data(place).seen = true;
|
||||
}
|
||||
}
|
||||
|
||||
void NO_SANITIZE_UNDEFINED ALWAYS_INLINE merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
auto place_data = &this->data(place);
|
||||
auto rhs_data = &this->data(rhs);
|
||||
|
||||
if ((place_data->last < rhs_data->first) && place_data->seen && rhs_data->seen)
|
||||
{
|
||||
// If the lhs last number seen is less than the first number the rhs saw, the lhs is before
|
||||
// the rhs, for example [0, 2] [4, 7]. So we want to add the deltasums, but also add the
|
||||
// difference between lhs last number and rhs first number (the 2 and 4). Then we want to
|
||||
// take last value from the rhs, so first and last become 0 and 7.
|
||||
|
||||
place_data->sum += rhs_data->sum + (rhs_data->first - place_data->last);
|
||||
place_data->last = rhs_data->last;
|
||||
}
|
||||
else if ((rhs_data->first < place_data->last && rhs_data->seen && place_data->seen))
|
||||
{
|
||||
// In the opposite scenario, the lhs comes after the rhs, e.g. [4, 6] [1, 2]. Since we
|
||||
// assume the input interval states are sorted by time, we assume this is a counter
|
||||
// reset, and therefore do *not* add the difference between our first value and the
|
||||
// rhs last value.
|
||||
|
||||
place_data->sum += rhs_data->sum;
|
||||
place_data->last = rhs_data->last;
|
||||
}
|
||||
else if (rhs_data->seen && !place_data->seen)
|
||||
{
|
||||
// If we're here then the lhs is an empty state and the rhs does have some state, so
|
||||
// we'll just take that state.
|
||||
|
||||
place_data->first = rhs_data->first;
|
||||
place_data->last = rhs_data->last;
|
||||
place_data->sum = rhs_data->sum;
|
||||
place_data->seen = rhs_data->seen;
|
||||
}
|
||||
|
||||
// Otherwise lhs either has data or is uninitialized, so we don't need to modify its values.
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
writeBinaryLittleEndian(this->data(place).sum, buf);
|
||||
writeBinaryLittleEndian(this->data(place).first, buf);
|
||||
writeBinaryLittleEndian(this->data(place).last, buf);
|
||||
writeBinaryLittleEndian(this->data(place).seen, buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
readBinaryLittleEndian(this->data(place).sum, buf);
|
||||
readBinaryLittleEndian(this->data(place).first, buf);
|
||||
readBinaryLittleEndian(this->data(place).last, buf);
|
||||
readBinaryLittleEndian(this->data(place).seen, buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
assert_cast<ColumnVector<T> &>(to).getData().push_back(this->data(place).sum);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -1,22 +1,181 @@
|
||||
#include <AggregateFunctions/AggregateFunctionDeltaSumTimestamp.h>
|
||||
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename ValueType, typename TimestampType>
|
||||
struct AggregationFunctionDeltaSumTimestampData
|
||||
{
|
||||
ValueType sum = 0;
|
||||
ValueType first = 0;
|
||||
ValueType last = 0;
|
||||
TimestampType first_ts = 0;
|
||||
TimestampType last_ts = 0;
|
||||
bool seen = false;
|
||||
};
|
||||
|
||||
template <typename ValueType, typename TimestampType>
|
||||
class AggregationFunctionDeltaSumTimestamp final
|
||||
: public IAggregateFunctionDataHelper<
|
||||
AggregationFunctionDeltaSumTimestampData<ValueType, TimestampType>,
|
||||
AggregationFunctionDeltaSumTimestamp<ValueType, TimestampType>
|
||||
>
|
||||
{
|
||||
public:
|
||||
AggregationFunctionDeltaSumTimestamp(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<
|
||||
AggregationFunctionDeltaSumTimestampData<ValueType, TimestampType>,
|
||||
AggregationFunctionDeltaSumTimestamp<ValueType, TimestampType>
|
||||
>{arguments, params, createResultType()}
|
||||
{}
|
||||
|
||||
AggregationFunctionDeltaSumTimestamp()
|
||||
: IAggregateFunctionDataHelper<
|
||||
AggregationFunctionDeltaSumTimestampData<ValueType, TimestampType>,
|
||||
AggregationFunctionDeltaSumTimestamp<ValueType, TimestampType>
|
||||
>{}
|
||||
{}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
String getName() const override { return "deltaSumTimestamp"; }
|
||||
|
||||
static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<ValueType>>(); }
|
||||
|
||||
void NO_SANITIZE_UNDEFINED ALWAYS_INLINE add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
auto value = assert_cast<const ColumnVector<ValueType> &>(*columns[0]).getData()[row_num];
|
||||
auto ts = assert_cast<const ColumnVector<TimestampType> &>(*columns[1]).getData()[row_num];
|
||||
|
||||
auto & data = this->data(place);
|
||||
|
||||
if ((data.last < value) && data.seen)
|
||||
{
|
||||
data.sum += (value - data.last);
|
||||
}
|
||||
|
||||
data.last = value;
|
||||
data.last_ts = ts;
|
||||
|
||||
if (!data.seen)
|
||||
{
|
||||
data.first = value;
|
||||
data.seen = true;
|
||||
data.first_ts = ts;
|
||||
}
|
||||
}
|
||||
|
||||
// before returns true if lhs is before rhs or false if it is not or can't be determined
|
||||
bool ALWAYS_INLINE before(
|
||||
const AggregationFunctionDeltaSumTimestampData<ValueType, TimestampType> & lhs,
|
||||
const AggregationFunctionDeltaSumTimestampData<ValueType, TimestampType> & rhs) const
|
||||
{
|
||||
if (lhs.last_ts < rhs.first_ts)
|
||||
return true;
|
||||
if (lhs.last_ts == rhs.first_ts && (lhs.last_ts < rhs.last_ts || lhs.first_ts < rhs.first_ts))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
void NO_SANITIZE_UNDEFINED merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
auto & place_data = this->data(place);
|
||||
auto & rhs_data = this->data(rhs);
|
||||
|
||||
if (!place_data.seen && rhs_data.seen)
|
||||
{
|
||||
place_data.sum = rhs_data.sum;
|
||||
place_data.seen = true;
|
||||
place_data.first = rhs_data.first;
|
||||
place_data.first_ts = rhs_data.first_ts;
|
||||
place_data.last = rhs_data.last;
|
||||
place_data.last_ts = rhs_data.last_ts;
|
||||
}
|
||||
else if (place_data.seen && !rhs_data.seen)
|
||||
{
|
||||
return;
|
||||
}
|
||||
else if (before(place_data, rhs_data))
|
||||
{
|
||||
// This state came before the rhs state
|
||||
|
||||
if (rhs_data.first > place_data.last)
|
||||
place_data.sum += (rhs_data.first - place_data.last);
|
||||
place_data.sum += rhs_data.sum;
|
||||
place_data.last = rhs_data.last;
|
||||
place_data.last_ts = rhs_data.last_ts;
|
||||
}
|
||||
else if (before(rhs_data, place_data))
|
||||
{
|
||||
// This state came after the rhs state
|
||||
|
||||
if (place_data.first > rhs_data.last)
|
||||
place_data.sum += (place_data.first - rhs_data.last);
|
||||
place_data.sum += rhs_data.sum;
|
||||
place_data.first = rhs_data.first;
|
||||
place_data.first_ts = rhs_data.first_ts;
|
||||
}
|
||||
else
|
||||
{
|
||||
// If none of those conditions matched, it means both states we are merging have all
|
||||
// same timestamps. We have to pick either the smaller or larger value so that the
|
||||
// result is deterministic.
|
||||
|
||||
if (place_data.first < rhs_data.first)
|
||||
{
|
||||
place_data.first = rhs_data.first;
|
||||
place_data.last = rhs_data.last;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
const auto & data = this->data(place);
|
||||
writeBinaryLittleEndian(data.sum, buf);
|
||||
writeBinaryLittleEndian(data.first, buf);
|
||||
writeBinaryLittleEndian(data.first_ts, buf);
|
||||
writeBinaryLittleEndian(data.last, buf);
|
||||
writeBinaryLittleEndian(data.last_ts, buf);
|
||||
writeBinaryLittleEndian(data.seen, buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
auto & data = this->data(place);
|
||||
readBinaryLittleEndian(data.sum, buf);
|
||||
readBinaryLittleEndian(data.first, buf);
|
||||
readBinaryLittleEndian(data.first_ts, buf);
|
||||
readBinaryLittleEndian(data.last, buf);
|
||||
readBinaryLittleEndian(data.last_ts, buf);
|
||||
readBinaryLittleEndian(data.seen, buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
assert_cast<ColumnVector<ValueType> &>(to).getData().push_back(this->data(place).sum);
|
||||
}
|
||||
};
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionDeltaSumTimestamp(
|
||||
const String & name,
|
||||
const DataTypes & arguments,
|
||||
@ -24,10 +183,7 @@ AggregateFunctionPtr createAggregateFunctionDeltaSumTimestamp(
|
||||
const Settings *)
|
||||
{
|
||||
assertNoParameters(name, params);
|
||||
|
||||
if (arguments.size() != 2)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
|
||||
"Incorrect number of arguments for aggregate function {}", name);
|
||||
assertBinary(name, arguments);
|
||||
|
||||
if (!isInteger(arguments[0]) && !isFloat(arguments[0]) && !isDate(arguments[0]) && !isDateTime(arguments[0]))
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}, "
|
||||
|
@ -1,171 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <DataTypes/DataTypesDecimal.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
template <typename ValueType, typename TimestampType>
|
||||
struct AggregationFunctionDeltaSumTimestampData
|
||||
{
|
||||
ValueType sum = 0;
|
||||
ValueType first = 0;
|
||||
ValueType last = 0;
|
||||
TimestampType first_ts = 0;
|
||||
TimestampType last_ts = 0;
|
||||
bool seen = false;
|
||||
};
|
||||
|
||||
template <typename ValueType, typename TimestampType>
|
||||
class AggregationFunctionDeltaSumTimestamp final
|
||||
: public IAggregateFunctionDataHelper<
|
||||
AggregationFunctionDeltaSumTimestampData<ValueType, TimestampType>,
|
||||
AggregationFunctionDeltaSumTimestamp<ValueType, TimestampType>
|
||||
>
|
||||
{
|
||||
public:
|
||||
AggregationFunctionDeltaSumTimestamp(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<
|
||||
AggregationFunctionDeltaSumTimestampData<ValueType, TimestampType>,
|
||||
AggregationFunctionDeltaSumTimestamp<ValueType, TimestampType>
|
||||
>{arguments, params, createResultType()}
|
||||
{}
|
||||
|
||||
AggregationFunctionDeltaSumTimestamp()
|
||||
: IAggregateFunctionDataHelper<
|
||||
AggregationFunctionDeltaSumTimestampData<ValueType, TimestampType>,
|
||||
AggregationFunctionDeltaSumTimestamp<ValueType, TimestampType>
|
||||
>{}
|
||||
{}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
String getName() const override { return "deltaSumTimestamp"; }
|
||||
|
||||
static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<ValueType>>(); }
|
||||
|
||||
void NO_SANITIZE_UNDEFINED ALWAYS_INLINE add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
auto value = assert_cast<const ColumnVector<ValueType> &>(*columns[0]).getData()[row_num];
|
||||
auto ts = assert_cast<const ColumnVector<TimestampType> &>(*columns[1]).getData()[row_num];
|
||||
|
||||
if ((this->data(place).last < value) && this->data(place).seen)
|
||||
{
|
||||
this->data(place).sum += (value - this->data(place).last);
|
||||
}
|
||||
|
||||
this->data(place).last = value;
|
||||
this->data(place).last_ts = ts;
|
||||
|
||||
if (!this->data(place).seen)
|
||||
{
|
||||
this->data(place).first = value;
|
||||
this->data(place).seen = true;
|
||||
this->data(place).first_ts = ts;
|
||||
}
|
||||
}
|
||||
|
||||
// before returns true if lhs is before rhs or false if it is not or can't be determined
|
||||
bool ALWAYS_INLINE before (
|
||||
const AggregationFunctionDeltaSumTimestampData<ValueType, TimestampType> * lhs,
|
||||
const AggregationFunctionDeltaSumTimestampData<ValueType, TimestampType> * rhs
|
||||
) const
|
||||
{
|
||||
if (lhs->last_ts < rhs->first_ts)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
if (lhs->last_ts == rhs->first_ts && (lhs->last_ts < rhs->last_ts || lhs->first_ts < rhs->first_ts))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void NO_SANITIZE_UNDEFINED ALWAYS_INLINE merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
auto place_data = &this->data(place);
|
||||
auto rhs_data = &this->data(rhs);
|
||||
|
||||
if (!place_data->seen && rhs_data->seen)
|
||||
{
|
||||
place_data->sum = rhs_data->sum;
|
||||
place_data->seen = true;
|
||||
place_data->first = rhs_data->first;
|
||||
place_data->first_ts = rhs_data->first_ts;
|
||||
place_data->last = rhs_data->last;
|
||||
place_data->last_ts = rhs_data->last_ts;
|
||||
}
|
||||
else if (place_data->seen && !rhs_data->seen)
|
||||
return;
|
||||
else if (before(place_data, rhs_data))
|
||||
{
|
||||
// This state came before the rhs state
|
||||
|
||||
if (rhs_data->first > place_data->last)
|
||||
place_data->sum += (rhs_data->first - place_data->last);
|
||||
place_data->sum += rhs_data->sum;
|
||||
place_data->last = rhs_data->last;
|
||||
place_data->last_ts = rhs_data->last_ts;
|
||||
}
|
||||
else if (before(rhs_data, place_data))
|
||||
{
|
||||
// This state came after the rhs state
|
||||
|
||||
if (place_data->first > rhs_data->last)
|
||||
place_data->sum += (place_data->first - rhs_data->last);
|
||||
place_data->sum += rhs_data->sum;
|
||||
place_data->first = rhs_data->first;
|
||||
place_data->first_ts = rhs_data->first_ts;
|
||||
}
|
||||
else
|
||||
{
|
||||
// If none of those conditions matched, it means both states we are merging have all
|
||||
// same timestamps. We have to pick either the smaller or larger value so that the
|
||||
// result is deterministic.
|
||||
|
||||
if (place_data->first < rhs_data->first)
|
||||
{
|
||||
place_data->first = rhs_data->first;
|
||||
place_data->last = rhs_data->last;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
writeBinaryLittleEndian(this->data(place).sum, buf);
|
||||
writeBinaryLittleEndian(this->data(place).first, buf);
|
||||
writeBinaryLittleEndian(this->data(place).first_ts, buf);
|
||||
writeBinaryLittleEndian(this->data(place).last, buf);
|
||||
writeBinaryLittleEndian(this->data(place).last_ts, buf);
|
||||
writeBinaryLittleEndian(this->data(place).seen, buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
readBinaryLittleEndian(this->data(place).sum, buf);
|
||||
readBinaryLittleEndian(this->data(place).first, buf);
|
||||
readBinaryLittleEndian(this->data(place).first_ts, buf);
|
||||
readBinaryLittleEndian(this->data(place).last, buf);
|
||||
readBinaryLittleEndian(this->data(place).last_ts, buf);
|
||||
readBinaryLittleEndian(this->data(place).seen, buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
assert_cast<ColumnVector<ValueType> &>(to).getData().push_back(this->data(place).sum);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -1,8 +1,18 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionEntropy.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
|
||||
#include <Common/HashTable/HashMap.h>
|
||||
#include <Common/NaNUtils.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/UniqVariadicHash.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -16,6 +26,133 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
/** Calculates Shannon Entropy, using HashMap and computing empirical distribution function.
|
||||
* Entropy is measured in bits (base-2 logarithm is used).
|
||||
*/
|
||||
template <typename Value>
|
||||
struct EntropyData
|
||||
{
|
||||
using Weight = UInt64;
|
||||
|
||||
using HashingMap = HashMapWithStackMemory<Value, Weight, HashCRC32<Value>, 4>;
|
||||
|
||||
/// For the case of pre-hashed values.
|
||||
using TrivialMap = HashMapWithStackMemory<Value, Weight, UInt128TrivialHash, 4>;
|
||||
|
||||
using Map = std::conditional_t<std::is_same_v<UInt128, Value>, TrivialMap, HashingMap>;
|
||||
|
||||
Map map;
|
||||
|
||||
void add(const Value & x)
|
||||
{
|
||||
if (!isNaN(x))
|
||||
++map[x];
|
||||
}
|
||||
|
||||
void add(const Value & x, const Weight & weight)
|
||||
{
|
||||
if (!isNaN(x))
|
||||
map[x] += weight;
|
||||
}
|
||||
|
||||
void merge(const EntropyData & rhs)
|
||||
{
|
||||
for (const auto & pair : rhs.map)
|
||||
map[pair.getKey()] += pair.getMapped();
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
map.write(buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
typename Map::Reader reader(buf);
|
||||
while (reader.next())
|
||||
{
|
||||
const auto & pair = reader.get();
|
||||
map[pair.first] = pair.second;
|
||||
}
|
||||
}
|
||||
|
||||
Float64 get() const
|
||||
{
|
||||
UInt64 total_value = 0;
|
||||
for (const auto & pair : map)
|
||||
total_value += pair.getMapped();
|
||||
|
||||
Float64 shannon_entropy = 0;
|
||||
for (const auto & pair : map)
|
||||
{
|
||||
Float64 frequency = Float64(pair.getMapped()) / total_value;
|
||||
shannon_entropy -= frequency * log2(frequency);
|
||||
}
|
||||
|
||||
return shannon_entropy;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename Value>
|
||||
class AggregateFunctionEntropy final : public IAggregateFunctionDataHelper<EntropyData<Value>, AggregateFunctionEntropy<Value>>
|
||||
{
|
||||
private:
|
||||
size_t num_args;
|
||||
|
||||
public:
|
||||
explicit AggregateFunctionEntropy(const DataTypes & argument_types_)
|
||||
: IAggregateFunctionDataHelper<EntropyData<Value>, AggregateFunctionEntropy<Value>>(argument_types_, {}, createResultType())
|
||||
, num_args(argument_types_.size())
|
||||
{
|
||||
}
|
||||
|
||||
String getName() const override { return "entropy"; }
|
||||
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
return std::make_shared<DataTypeNumber<Float64>>();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
if constexpr (!std::is_same_v<UInt128, Value>)
|
||||
{
|
||||
/// Here we manage only with numerical types
|
||||
const auto & column = assert_cast<const ColumnVector <Value> &>(*columns[0]);
|
||||
this->data(place).add(column.getData()[row_num]);
|
||||
}
|
||||
else
|
||||
{
|
||||
this->data(place).add(UniqVariadicHash<true, false>::apply(num_args, columns, row_num));
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(const_cast<AggregateDataPtr>(place)).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
auto & column = assert_cast<ColumnVector<Float64> &>(to);
|
||||
column.getData().push_back(this->data(place).get());
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionEntropy(
|
||||
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
|
@ -1,145 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <Common/HashTable/HashMap.h>
|
||||
#include <Common/NaNUtils.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/UniqVariadicHash.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
/** Calculates Shannon Entropy, using HashMap and computing empirical distribution function.
|
||||
* Entropy is measured in bits (base-2 logarithm is used).
|
||||
*/
|
||||
template <typename Value>
|
||||
struct EntropyData
|
||||
{
|
||||
using Weight = UInt64;
|
||||
|
||||
using HashingMap = HashMapWithStackMemory<Value, Weight, HashCRC32<Value>, 4>;
|
||||
|
||||
/// For the case of pre-hashed values.
|
||||
using TrivialMap = HashMapWithStackMemory<Value, Weight, UInt128TrivialHash, 4>;
|
||||
|
||||
using Map = std::conditional_t<std::is_same_v<UInt128, Value>, TrivialMap, HashingMap>;
|
||||
|
||||
Map map;
|
||||
|
||||
void add(const Value & x)
|
||||
{
|
||||
if (!isNaN(x))
|
||||
++map[x];
|
||||
}
|
||||
|
||||
void add(const Value & x, const Weight & weight)
|
||||
{
|
||||
if (!isNaN(x))
|
||||
map[x] += weight;
|
||||
}
|
||||
|
||||
void merge(const EntropyData & rhs)
|
||||
{
|
||||
for (const auto & pair : rhs.map)
|
||||
map[pair.getKey()] += pair.getMapped();
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
map.write(buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
typename Map::Reader reader(buf);
|
||||
while (reader.next())
|
||||
{
|
||||
const auto & pair = reader.get();
|
||||
map[pair.first] = pair.second;
|
||||
}
|
||||
}
|
||||
|
||||
Float64 get() const
|
||||
{
|
||||
UInt64 total_value = 0;
|
||||
for (const auto & pair : map)
|
||||
total_value += pair.getMapped();
|
||||
|
||||
Float64 shannon_entropy = 0;
|
||||
for (const auto & pair : map)
|
||||
{
|
||||
Float64 frequency = Float64(pair.getMapped()) / total_value;
|
||||
shannon_entropy -= frequency * log2(frequency);
|
||||
}
|
||||
|
||||
return shannon_entropy;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename Value>
|
||||
class AggregateFunctionEntropy final : public IAggregateFunctionDataHelper<EntropyData<Value>, AggregateFunctionEntropy<Value>>
|
||||
{
|
||||
private:
|
||||
size_t num_args;
|
||||
|
||||
public:
|
||||
explicit AggregateFunctionEntropy(const DataTypes & argument_types_)
|
||||
: IAggregateFunctionDataHelper<EntropyData<Value>, AggregateFunctionEntropy<Value>>(argument_types_, {}, createResultType())
|
||||
, num_args(argument_types_.size())
|
||||
{
|
||||
}
|
||||
|
||||
String getName() const override { return "entropy"; }
|
||||
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
return std::make_shared<DataTypeNumber<Float64>>();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
if constexpr (!std::is_same_v<UInt128, Value>)
|
||||
{
|
||||
/// Here we manage only with numerical types
|
||||
const auto & column = assert_cast<const ColumnVector <Value> &>(*columns[0]);
|
||||
this->data(place).add(column.getData()[row_num]);
|
||||
}
|
||||
else
|
||||
{
|
||||
this->data(place).add(UniqVariadicHash<true, false>::apply(num_args, columns, row_num));
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(const_cast<AggregateDataPtr>(place)).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
auto & column = assert_cast<ColumnVector<Float64> &>(to);
|
||||
column.getData().push_back(this->data(place).get());
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -1,12 +1,32 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionGroupArray.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <DataTypes/DataTypeDate.h>
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Core/ServerSettings.h>
|
||||
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <IO/WriteBufferFromString.h>
|
||||
#include <IO/Operators.h>
|
||||
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
|
||||
#include <Common/ArenaAllocator.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#define AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ELEMENT_SIZE 0xFFFFFF
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -16,11 +36,670 @@ namespace ErrorCodes
|
||||
{
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
extern const int TOO_LARGE_ARRAY_SIZE;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
enum class Sampler
|
||||
{
|
||||
NONE,
|
||||
RNG,
|
||||
};
|
||||
|
||||
template <bool Thas_limit, bool Tlast, Sampler Tsampler>
|
||||
struct GroupArrayTrait
|
||||
{
|
||||
static constexpr bool has_limit = Thas_limit;
|
||||
static constexpr bool last = Tlast;
|
||||
static constexpr Sampler sampler = Tsampler;
|
||||
};
|
||||
|
||||
template <typename Trait>
|
||||
constexpr const char * getNameByTrait()
|
||||
{
|
||||
if (Trait::last)
|
||||
return "groupArrayLast";
|
||||
if (Trait::sampler == Sampler::NONE)
|
||||
return "groupArray";
|
||||
else if (Trait::sampler == Sampler::RNG)
|
||||
return "groupArraySample";
|
||||
|
||||
UNREACHABLE();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct GroupArraySamplerData
|
||||
{
|
||||
/// For easy serialization.
|
||||
static_assert(std::has_unique_object_representations_v<T> || std::is_floating_point_v<T>);
|
||||
|
||||
// Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena
|
||||
using Allocator = MixedAlignedArenaAllocator<alignof(T), 4096>;
|
||||
using Array = PODArray<T, 32, Allocator>;
|
||||
|
||||
Array value;
|
||||
size_t total_values = 0;
|
||||
pcg32_fast rng;
|
||||
|
||||
UInt64 genRandom(size_t lim)
|
||||
{
|
||||
chassert(lim != 0);
|
||||
|
||||
/// With a large number of values, we will generate random numbers several times slower.
|
||||
if (lim <= static_cast<UInt64>(rng.max()))
|
||||
return rng() % lim;
|
||||
else
|
||||
return (static_cast<UInt64>(rng()) * (static_cast<UInt64>(rng.max()) + 1ULL) + static_cast<UInt64>(rng())) % lim;
|
||||
}
|
||||
|
||||
void randomShuffle()
|
||||
{
|
||||
size_t size = value.size();
|
||||
chassert(size < std::numeric_limits<size_t>::max());
|
||||
|
||||
for (size_t i = 1; i < size; ++i)
|
||||
{
|
||||
size_t j = genRandom(i + 1);
|
||||
std::swap(value[i], value[j]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// A particular case is an implementation for numeric types.
|
||||
template <typename T, bool has_sampler>
|
||||
struct GroupArrayNumericData;
|
||||
|
||||
template <typename T>
|
||||
struct GroupArrayNumericData<T, false>
|
||||
{
|
||||
/// For easy serialization.
|
||||
static_assert(std::has_unique_object_representations_v<T> || std::is_floating_point_v<T>);
|
||||
|
||||
// Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena
|
||||
using Allocator = MixedAlignedArenaAllocator<alignof(T), 4096>;
|
||||
using Array = PODArray<T, 32, Allocator>;
|
||||
|
||||
// For groupArrayLast()
|
||||
size_t total_values = 0;
|
||||
Array value;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GroupArrayNumericData<T, true> : public GroupArraySamplerData<T>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T, typename Trait>
|
||||
class GroupArrayNumericImpl final
|
||||
: public IAggregateFunctionDataHelper<GroupArrayNumericData<T, Trait::sampler != Sampler::NONE>, GroupArrayNumericImpl<T, Trait>>
|
||||
{
|
||||
using Data = GroupArrayNumericData<T, Trait::sampler != Sampler::NONE>;
|
||||
static constexpr bool limit_num_elems = Trait::has_limit;
|
||||
UInt64 max_elems;
|
||||
std::optional<UInt64> seed;
|
||||
|
||||
public:
|
||||
explicit GroupArrayNumericImpl(
|
||||
const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_, std::optional<UInt64> seed_)
|
||||
: IAggregateFunctionDataHelper<GroupArrayNumericData<T, Trait::sampler != Sampler::NONE>, GroupArrayNumericImpl<T, Trait>>(
|
||||
{data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))
|
||||
, max_elems(max_elems_)
|
||||
, seed(seed_)
|
||||
{
|
||||
}
|
||||
|
||||
String getName() const override { return getNameByTrait<Trait>(); }
|
||||
|
||||
void insertWithSampler(Data & a, const T & v, Arena * arena) const
|
||||
{
|
||||
++a.total_values;
|
||||
if (a.value.size() < max_elems)
|
||||
a.value.push_back(v, arena);
|
||||
else
|
||||
{
|
||||
UInt64 rnd = a.genRandom(a.total_values);
|
||||
if (rnd < max_elems)
|
||||
a.value[rnd] = v;
|
||||
}
|
||||
}
|
||||
|
||||
void create(AggregateDataPtr __restrict place) const override /// NOLINT
|
||||
{
|
||||
[[maybe_unused]] auto a = new (place) Data;
|
||||
if constexpr (Trait::sampler == Sampler::RNG)
|
||||
a->rng.seed(seed.value_or(thread_local_rng()));
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
const auto & row_value = assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];
|
||||
auto & cur_elems = this->data(place);
|
||||
|
||||
++cur_elems.total_values;
|
||||
|
||||
if constexpr (Trait::sampler == Sampler::NONE)
|
||||
{
|
||||
if (limit_num_elems && cur_elems.value.size() >= max_elems)
|
||||
{
|
||||
if constexpr (Trait::last)
|
||||
cur_elems.value[(cur_elems.total_values - 1) % max_elems] = row_value;
|
||||
return;
|
||||
}
|
||||
|
||||
cur_elems.value.push_back(row_value, arena);
|
||||
}
|
||||
|
||||
if constexpr (Trait::sampler == Sampler::RNG)
|
||||
{
|
||||
if (cur_elems.value.size() < max_elems)
|
||||
cur_elems.value.push_back(row_value, arena);
|
||||
else
|
||||
{
|
||||
UInt64 rnd = cur_elems.genRandom(cur_elems.total_values);
|
||||
if (rnd < max_elems)
|
||||
cur_elems.value[rnd] = row_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
auto & cur_elems = this->data(place);
|
||||
auto & rhs_elems = this->data(rhs);
|
||||
|
||||
if (rhs_elems.value.empty())
|
||||
return;
|
||||
|
||||
if constexpr (Trait::last)
|
||||
mergeNoSamplerLast(cur_elems, rhs_elems, arena);
|
||||
else if constexpr (Trait::sampler == Sampler::NONE)
|
||||
mergeNoSampler(cur_elems, rhs_elems, arena);
|
||||
else if constexpr (Trait::sampler == Sampler::RNG)
|
||||
mergeWithRNGSampler(cur_elems, rhs_elems, arena);
|
||||
}
|
||||
|
||||
void mergeNoSamplerLast(Data & cur_elems, const Data & rhs_elems, Arena * arena) const
|
||||
{
|
||||
UInt64 new_elements = std::min(static_cast<size_t>(max_elems), cur_elems.value.size() + rhs_elems.value.size());
|
||||
cur_elems.value.resize_exact(new_elements, arena);
|
||||
for (auto & value : rhs_elems.value)
|
||||
{
|
||||
cur_elems.value[cur_elems.total_values % max_elems] = value;
|
||||
++cur_elems.total_values;
|
||||
}
|
||||
chassert(rhs_elems.total_values >= rhs_elems.value.size());
|
||||
cur_elems.total_values += rhs_elems.total_values - rhs_elems.value.size();
|
||||
}
|
||||
|
||||
void mergeNoSampler(Data & cur_elems, const Data & rhs_elems, Arena * arena) const
|
||||
{
|
||||
if (!limit_num_elems)
|
||||
{
|
||||
if (rhs_elems.value.size())
|
||||
cur_elems.value.insertByOffsets(rhs_elems.value, 0, rhs_elems.value.size(), arena);
|
||||
}
|
||||
else
|
||||
{
|
||||
UInt64 elems_to_insert = std::min(static_cast<size_t>(max_elems) - cur_elems.value.size(), rhs_elems.value.size());
|
||||
if (elems_to_insert)
|
||||
cur_elems.value.insertByOffsets(rhs_elems.value, 0, elems_to_insert, arena);
|
||||
}
|
||||
}
|
||||
|
||||
void mergeWithRNGSampler(Data & cur_elems, const Data & rhs_elems, Arena * arena) const
|
||||
{
|
||||
if (rhs_elems.total_values <= max_elems)
|
||||
{
|
||||
for (size_t i = 0; i < rhs_elems.value.size(); ++i)
|
||||
insertWithSampler(cur_elems, rhs_elems.value[i], arena);
|
||||
}
|
||||
else if (cur_elems.total_values <= max_elems)
|
||||
{
|
||||
decltype(cur_elems.value) from;
|
||||
from.swap(cur_elems.value, arena);
|
||||
cur_elems.value.assign(rhs_elems.value.begin(), rhs_elems.value.end(), arena);
|
||||
cur_elems.total_values = rhs_elems.total_values;
|
||||
for (size_t i = 0; i < from.size(); ++i)
|
||||
insertWithSampler(cur_elems, from[i], arena);
|
||||
}
|
||||
else
|
||||
{
|
||||
cur_elems.randomShuffle();
|
||||
cur_elems.total_values += rhs_elems.total_values;
|
||||
for (size_t i = 0; i < max_elems; ++i)
|
||||
{
|
||||
UInt64 rnd = cur_elems.genRandom(cur_elems.total_values);
|
||||
if (rnd < rhs_elems.total_values)
|
||||
cur_elems.value[i] = rhs_elems.value[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void checkArraySize(size_t elems, size_t max_elems)
|
||||
{
|
||||
if (unlikely(elems > max_elems))
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
|
||||
"Too large array size {} (maximum: {})", elems, max_elems);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
const auto & value = this->data(place).value;
|
||||
const UInt64 size = value.size();
|
||||
checkArraySize(size, max_elems);
|
||||
writeVarUInt(size, buf);
|
||||
for (const auto & element : value)
|
||||
writeBinaryLittleEndian(element, buf);
|
||||
|
||||
if constexpr (Trait::last)
|
||||
writeBinaryLittleEndian(this->data(place).total_values, buf);
|
||||
|
||||
if constexpr (Trait::sampler == Sampler::RNG)
|
||||
{
|
||||
writeBinaryLittleEndian(this->data(place).total_values, buf);
|
||||
WriteBufferFromOwnString rng_buf;
|
||||
rng_buf << this->data(place).rng;
|
||||
writeStringBinary(rng_buf.str(), buf);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
size_t size = 0;
|
||||
readVarUInt(size, buf);
|
||||
checkArraySize(size, max_elems);
|
||||
|
||||
auto & value = this->data(place).value;
|
||||
|
||||
value.resize_exact(size, arena);
|
||||
for (auto & element : value)
|
||||
readBinaryLittleEndian(element, buf);
|
||||
|
||||
if constexpr (Trait::last)
|
||||
readBinaryLittleEndian(this->data(place).total_values, buf);
|
||||
|
||||
if constexpr (Trait::sampler == Sampler::RNG)
|
||||
{
|
||||
readBinaryLittleEndian(this->data(place).total_values, buf);
|
||||
std::string rng_string;
|
||||
readStringBinary(rng_string, buf);
|
||||
ReadBufferFromString rng_buf(rng_string);
|
||||
rng_buf >> this->data(place).rng;
|
||||
}
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
const auto & value = this->data(place).value;
|
||||
size_t size = value.size();
|
||||
|
||||
ColumnArray & arr_to = assert_cast<ColumnArray &>(to);
|
||||
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
|
||||
|
||||
offsets_to.push_back(offsets_to.back() + size);
|
||||
|
||||
if (size)
|
||||
{
|
||||
typename ColumnVector<T>::Container & data_to = assert_cast<ColumnVector<T> &>(arr_to.getData()).getData();
|
||||
data_to.insert(this->data(place).value.begin(), this->data(place).value.end());
|
||||
}
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return true; }
|
||||
};
|
||||
|
||||
|
||||
/// General case
|
||||
|
||||
|
||||
/// Nodes used to implement a linked list for storage of groupArray states
|
||||
|
||||
template <typename Node>
|
||||
struct GroupArrayNodeBase
|
||||
{
|
||||
UInt64 size; // size of payload
|
||||
|
||||
/// Returns pointer to actual payload
|
||||
char * data() { return reinterpret_cast<char *>(this) + sizeof(Node); }
|
||||
|
||||
const char * data() const { return reinterpret_cast<const char *>(this) + sizeof(Node); }
|
||||
|
||||
/// Clones existing node (does not modify next field)
|
||||
Node * clone(Arena * arena) const
|
||||
{
|
||||
return reinterpret_cast<Node *>(
|
||||
const_cast<char *>(arena->alignedInsert(reinterpret_cast<const char *>(this), sizeof(Node) + size, alignof(Node))));
|
||||
}
|
||||
|
||||
static void checkElementSize(size_t size, size_t max_size)
|
||||
{
|
||||
if (unlikely(size > max_size))
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
|
||||
"Too large array element size {} (maximum: {})", size, max_size);
|
||||
}
|
||||
|
||||
/// Write node to buffer
|
||||
void write(WriteBuffer & buf) const
|
||||
{
|
||||
checkElementSize(size, AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ELEMENT_SIZE);
|
||||
writeVarUInt(size, buf);
|
||||
buf.write(data(), size);
|
||||
}
|
||||
|
||||
/// Reads and allocates node from ReadBuffer's data (doesn't set next)
|
||||
static Node * read(ReadBuffer & buf, Arena * arena)
|
||||
{
|
||||
UInt64 size;
|
||||
readVarUInt(size, buf);
|
||||
checkElementSize(size, AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ELEMENT_SIZE);
|
||||
|
||||
Node * node = reinterpret_cast<Node *>(arena->alignedAlloc(sizeof(Node) + size, alignof(Node)));
|
||||
node->size = size;
|
||||
buf.readStrict(node->data(), size);
|
||||
return node;
|
||||
}
|
||||
};
|
||||
|
||||
struct GroupArrayNodeString : public GroupArrayNodeBase<GroupArrayNodeString>
|
||||
{
|
||||
using Node = GroupArrayNodeString;
|
||||
|
||||
/// Create node from string
|
||||
static Node * allocate(const IColumn & column, size_t row_num, Arena * arena)
|
||||
{
|
||||
StringRef string = assert_cast<const ColumnString &>(column).getDataAt(row_num);
|
||||
|
||||
Node * node = reinterpret_cast<Node *>(arena->alignedAlloc(sizeof(Node) + string.size, alignof(Node)));
|
||||
node->size = string.size;
|
||||
memcpy(node->data(), string.data, string.size);
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
void insertInto(IColumn & column)
|
||||
{
|
||||
assert_cast<ColumnString &>(column).insertData(data(), size);
|
||||
}
|
||||
};
|
||||
|
||||
struct GroupArrayNodeGeneral : public GroupArrayNodeBase<GroupArrayNodeGeneral>
|
||||
{
|
||||
using Node = GroupArrayNodeGeneral;
|
||||
|
||||
static Node * allocate(const IColumn & column, size_t row_num, Arena * arena)
|
||||
{
|
||||
const char * begin = arena->alignedAlloc(sizeof(Node), alignof(Node));
|
||||
StringRef value = column.serializeValueIntoArena(row_num, *arena, begin);
|
||||
|
||||
Node * node = reinterpret_cast<Node *>(const_cast<char *>(begin));
|
||||
node->size = value.size;
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
void insertInto(IColumn & column) { column.deserializeAndInsertFromArena(data()); }
|
||||
};
|
||||
|
||||
template <typename Node, bool has_sampler>
|
||||
struct GroupArrayGeneralData;
|
||||
|
||||
template <typename Node>
|
||||
struct GroupArrayGeneralData<Node, false>
|
||||
{
|
||||
// Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena
|
||||
using Allocator = MixedAlignedArenaAllocator<alignof(Node *), 4096>;
|
||||
using Array = PODArray<Node *, 32, Allocator>;
|
||||
|
||||
// For groupArrayLast()
|
||||
size_t total_values = 0;
|
||||
Array value;
|
||||
};
|
||||
|
||||
template <typename Node>
|
||||
struct GroupArrayGeneralData<Node, true> : public GroupArraySamplerData<Node *>
|
||||
{
|
||||
};
|
||||
|
||||
/// Implementation of groupArray for String or any ComplexObject via Array
|
||||
template <typename Node, typename Trait>
|
||||
class GroupArrayGeneralImpl final
|
||||
: public IAggregateFunctionDataHelper<GroupArrayGeneralData<Node, Trait::sampler != Sampler::NONE>, GroupArrayGeneralImpl<Node, Trait>>
|
||||
{
|
||||
static constexpr bool limit_num_elems = Trait::has_limit;
|
||||
using Data = GroupArrayGeneralData<Node, Trait::sampler != Sampler::NONE>;
|
||||
static Data & data(AggregateDataPtr __restrict place) { return *reinterpret_cast<Data *>(place); }
|
||||
static const Data & data(ConstAggregateDataPtr __restrict place) { return *reinterpret_cast<const Data *>(place); }
|
||||
|
||||
DataTypePtr & data_type;
|
||||
UInt64 max_elems;
|
||||
std::optional<UInt64> seed;
|
||||
|
||||
public:
|
||||
GroupArrayGeneralImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_, std::optional<UInt64> seed_)
|
||||
: IAggregateFunctionDataHelper<GroupArrayGeneralData<Node, Trait::sampler != Sampler::NONE>, GroupArrayGeneralImpl<Node, Trait>>(
|
||||
{data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))
|
||||
, data_type(this->argument_types[0])
|
||||
, max_elems(max_elems_)
|
||||
, seed(seed_)
|
||||
{
|
||||
}
|
||||
|
||||
String getName() const override { return getNameByTrait<Trait>(); }
|
||||
|
||||
void insertWithSampler(Data & a, const Node * v, Arena * arena) const
|
||||
{
|
||||
++a.total_values;
|
||||
if (a.value.size() < max_elems)
|
||||
a.value.push_back(v->clone(arena), arena);
|
||||
else
|
||||
{
|
||||
UInt64 rnd = a.genRandom(a.total_values);
|
||||
if (rnd < max_elems)
|
||||
a.value[rnd] = v->clone(arena);
|
||||
}
|
||||
}
|
||||
|
||||
void create(AggregateDataPtr __restrict place) const override /// NOLINT
|
||||
{
|
||||
[[maybe_unused]] auto a = new (place) Data;
|
||||
if constexpr (Trait::sampler == Sampler::RNG)
|
||||
a->rng.seed(seed.value_or(thread_local_rng()));
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
auto & cur_elems = data(place);
|
||||
|
||||
++cur_elems.total_values;
|
||||
|
||||
if constexpr (Trait::sampler == Sampler::NONE)
|
||||
{
|
||||
if (limit_num_elems && cur_elems.value.size() >= max_elems)
|
||||
{
|
||||
if (Trait::last)
|
||||
{
|
||||
Node * node = Node::allocate(*columns[0], row_num, arena);
|
||||
cur_elems.value[(cur_elems.total_values - 1) % max_elems] = node;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
Node * node = Node::allocate(*columns[0], row_num, arena);
|
||||
cur_elems.value.push_back(node, arena);
|
||||
}
|
||||
|
||||
if constexpr (Trait::sampler == Sampler::RNG)
|
||||
{
|
||||
if (cur_elems.value.size() < max_elems)
|
||||
cur_elems.value.push_back(Node::allocate(*columns[0], row_num, arena), arena);
|
||||
else
|
||||
{
|
||||
UInt64 rnd = cur_elems.genRandom(cur_elems.total_values);
|
||||
if (rnd < max_elems)
|
||||
cur_elems.value[rnd] = Node::allocate(*columns[0], row_num, arena);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
auto & cur_elems = data(place);
|
||||
auto & rhs_elems = data(rhs);
|
||||
|
||||
if (rhs_elems.value.empty())
|
||||
return;
|
||||
|
||||
if constexpr (Trait::last)
|
||||
mergeNoSamplerLast(cur_elems, rhs_elems, arena);
|
||||
else if constexpr (Trait::sampler == Sampler::NONE)
|
||||
mergeNoSampler(cur_elems, rhs_elems, arena);
|
||||
else if constexpr (Trait::sampler == Sampler::RNG)
|
||||
mergeWithRNGSampler(cur_elems, rhs_elems, arena);
|
||||
}
|
||||
|
||||
void ALWAYS_INLINE mergeNoSamplerLast(Data & cur_elems, const Data & rhs_elems, Arena * arena) const
|
||||
{
|
||||
UInt64 new_elements = std::min(static_cast<size_t>(max_elems), cur_elems.value.size() + rhs_elems.value.size());
|
||||
cur_elems.value.resize_exact(new_elements, arena);
|
||||
for (auto & value : rhs_elems.value)
|
||||
{
|
||||
cur_elems.value[cur_elems.total_values % max_elems] = value->clone(arena);
|
||||
++cur_elems.total_values;
|
||||
}
|
||||
chassert(rhs_elems.total_values >= rhs_elems.value.size());
|
||||
cur_elems.total_values += rhs_elems.total_values - rhs_elems.value.size();
|
||||
}
|
||||
|
||||
void ALWAYS_INLINE mergeNoSampler(Data & cur_elems, const Data & rhs_elems, Arena * arena) const
|
||||
{
|
||||
UInt64 new_elems;
|
||||
if (limit_num_elems)
|
||||
{
|
||||
if (cur_elems.value.size() >= max_elems)
|
||||
return;
|
||||
new_elems = std::min(rhs_elems.value.size(), static_cast<size_t>(max_elems) - cur_elems.value.size());
|
||||
}
|
||||
else
|
||||
new_elems = rhs_elems.value.size();
|
||||
|
||||
for (UInt64 i = 0; i < new_elems; ++i)
|
||||
cur_elems.value.push_back(rhs_elems.value[i]->clone(arena), arena);
|
||||
}
|
||||
|
||||
void ALWAYS_INLINE mergeWithRNGSampler(Data & cur_elems, const Data & rhs_elems, Arena * arena) const
|
||||
{
|
||||
if (rhs_elems.total_values <= max_elems)
|
||||
{
|
||||
for (size_t i = 0; i < rhs_elems.value.size(); ++i)
|
||||
insertWithSampler(cur_elems, rhs_elems.value[i], arena);
|
||||
}
|
||||
else if (cur_elems.total_values <= max_elems)
|
||||
{
|
||||
decltype(cur_elems.value) from;
|
||||
from.swap(cur_elems.value, arena);
|
||||
for (auto & node : rhs_elems.value)
|
||||
cur_elems.value.push_back(node->clone(arena), arena);
|
||||
cur_elems.total_values = rhs_elems.total_values;
|
||||
for (size_t i = 0; i < from.size(); ++i)
|
||||
insertWithSampler(cur_elems, from[i], arena);
|
||||
}
|
||||
else
|
||||
{
|
||||
cur_elems.randomShuffle();
|
||||
cur_elems.total_values += rhs_elems.total_values;
|
||||
for (size_t i = 0; i < max_elems; ++i)
|
||||
{
|
||||
UInt64 rnd = cur_elems.genRandom(cur_elems.total_values);
|
||||
if (rnd < rhs_elems.total_values)
|
||||
cur_elems.value[i] = rhs_elems.value[i]->clone(arena);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void checkArraySize(size_t elems, size_t max_elems)
|
||||
{
|
||||
if (unlikely(elems > max_elems))
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
|
||||
"Too large array size {} (maximum: {})", elems, max_elems);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
UInt64 elems = data(place).value.size();
|
||||
checkArraySize(elems, max_elems);
|
||||
writeVarUInt(elems, buf);
|
||||
|
||||
auto & value = data(place).value;
|
||||
for (auto & node : value)
|
||||
node->write(buf);
|
||||
|
||||
if constexpr (Trait::last)
|
||||
writeBinaryLittleEndian(data(place).total_values, buf);
|
||||
|
||||
if constexpr (Trait::sampler == Sampler::RNG)
|
||||
{
|
||||
writeBinaryLittleEndian(data(place).total_values, buf);
|
||||
WriteBufferFromOwnString rng_buf;
|
||||
rng_buf << data(place).rng;
|
||||
writeStringBinary(rng_buf.str(), buf);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
UInt64 elems;
|
||||
readVarUInt(elems, buf);
|
||||
|
||||
if (unlikely(elems == 0))
|
||||
return;
|
||||
|
||||
checkArraySize(elems, max_elems);
|
||||
|
||||
auto & value = data(place).value;
|
||||
|
||||
value.resize_exact(elems, arena);
|
||||
for (UInt64 i = 0; i < elems; ++i)
|
||||
value[i] = Node::read(buf, arena);
|
||||
|
||||
if constexpr (Trait::last)
|
||||
readBinaryLittleEndian(data(place).total_values, buf);
|
||||
|
||||
if constexpr (Trait::sampler == Sampler::RNG)
|
||||
{
|
||||
readBinaryLittleEndian(data(place).total_values, buf);
|
||||
std::string rng_string;
|
||||
readStringBinary(rng_string, buf);
|
||||
ReadBufferFromString rng_buf(rng_string);
|
||||
rng_buf >> data(place).rng;
|
||||
}
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
auto & column_array = assert_cast<ColumnArray &>(to);
|
||||
|
||||
auto & offsets = column_array.getOffsets();
|
||||
offsets.push_back(offsets.back() + data(place).value.size());
|
||||
|
||||
auto & column_data = column_array.getData();
|
||||
|
||||
if (std::is_same_v<Node, GroupArrayNodeString>)
|
||||
{
|
||||
auto & string_offsets = assert_cast<ColumnString &>(column_data).getOffsets();
|
||||
string_offsets.reserve(string_offsets.size() + data(place).value.size());
|
||||
}
|
||||
|
||||
auto & value = data(place).value;
|
||||
for (auto & node : value)
|
||||
node->insertInto(column_data);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return true; }
|
||||
};
|
||||
|
||||
|
||||
template <template <typename, typename> class AggregateFunctionTemplate, typename Data, typename ... TArgs>
|
||||
IAggregateFunction * createWithNumericOrTimeType(const IDataType & argument_type, TArgs && ... args)
|
||||
{
|
||||
@ -87,10 +766,10 @@ AggregateFunctionPtr createAggregateFunctionGroupArray(
|
||||
{
|
||||
if (Tlast)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "groupArrayLast make sense only with max_elems (groupArrayLast(max_elems)())");
|
||||
return createAggregateFunctionGroupArrayImpl<GroupArrayTrait</* Thas_limit= */ false, Tlast, /* Tsampler= */ Sampler::NONE>>(argument_types[0], parameters, max_elems);
|
||||
return createAggregateFunctionGroupArrayImpl<GroupArrayTrait</* Thas_limit= */ false, Tlast, /* Tsampler= */ Sampler::NONE>>(argument_types[0], parameters, max_elems, std::nullopt);
|
||||
}
|
||||
else
|
||||
return createAggregateFunctionGroupArrayImpl<GroupArrayTrait</* Thas_limit= */ true, Tlast, /* Tsampler= */ Sampler::NONE>>(argument_types[0], parameters, max_elems);
|
||||
return createAggregateFunctionGroupArrayImpl<GroupArrayTrait</* Thas_limit= */ true, Tlast, /* Tsampler= */ Sampler::NONE>>(argument_types[0], parameters, max_elems, std::nullopt);
|
||||
}
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionGroupArraySample(
|
||||
@ -117,11 +796,9 @@ AggregateFunctionPtr createAggregateFunctionGroupArraySample(
|
||||
|
||||
UInt64 max_elems = get_parameter(0);
|
||||
|
||||
UInt64 seed;
|
||||
std::optional<UInt64> seed;
|
||||
if (parameters.size() >= 2)
|
||||
seed = get_parameter(1);
|
||||
else
|
||||
seed = thread_local_rng();
|
||||
|
||||
return createAggregateFunctionGroupArrayImpl<GroupArrayTrait</* Thas_limit= */ true, /* Tlast= */ false, /* Tsampler= */ Sampler::RNG>>(argument_types[0], parameters, max_elems, seed);
|
||||
}
|
||||
|
@ -1,690 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <IO/WriteBufferFromString.h>
|
||||
#include <IO/Operators.h>
|
||||
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
|
||||
#include <Common/ArenaAllocator.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#define AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ELEMENT_SIZE 0xFFFFFF
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int TOO_LARGE_ARRAY_SIZE;
|
||||
}
|
||||
|
||||
enum class Sampler
|
||||
{
|
||||
NONE,
|
||||
RNG,
|
||||
};
|
||||
|
||||
template <bool Thas_limit, bool Tlast, Sampler Tsampler>
|
||||
struct GroupArrayTrait
|
||||
{
|
||||
static constexpr bool has_limit = Thas_limit;
|
||||
static constexpr bool last = Tlast;
|
||||
static constexpr Sampler sampler = Tsampler;
|
||||
};
|
||||
|
||||
template <typename Trait>
|
||||
static constexpr const char * getNameByTrait()
|
||||
{
|
||||
if (Trait::last)
|
||||
return "groupArrayLast";
|
||||
if (Trait::sampler == Sampler::NONE)
|
||||
return "groupArray";
|
||||
else if (Trait::sampler == Sampler::RNG)
|
||||
return "groupArraySample";
|
||||
|
||||
UNREACHABLE();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct GroupArraySamplerData
|
||||
{
|
||||
/// For easy serialization.
|
||||
static_assert(std::has_unique_object_representations_v<T> || std::is_floating_point_v<T>);
|
||||
|
||||
// Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena
|
||||
using Allocator = MixedAlignedArenaAllocator<alignof(T), 4096>;
|
||||
using Array = PODArray<T, 32, Allocator>;
|
||||
|
||||
Array value;
|
||||
size_t total_values = 0;
|
||||
pcg32_fast rng;
|
||||
|
||||
UInt64 genRandom(size_t lim)
|
||||
{
|
||||
/// With a large number of values, we will generate random numbers several times slower.
|
||||
if (lim <= static_cast<UInt64>(rng.max()))
|
||||
return static_cast<UInt32>(rng()) % static_cast<UInt32>(lim);
|
||||
else
|
||||
return (static_cast<UInt64>(rng()) * (static_cast<UInt64>(rng.max()) + 1ULL) + static_cast<UInt64>(rng())) % lim;
|
||||
}
|
||||
|
||||
void randomShuffle()
|
||||
{
|
||||
for (size_t i = 1; i < value.size(); ++i)
|
||||
{
|
||||
size_t j = genRandom(i + 1);
|
||||
std::swap(value[i], value[j]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// A particular case is an implementation for numeric types.
|
||||
template <typename T, bool has_sampler>
|
||||
struct GroupArrayNumericData;
|
||||
|
||||
template <typename T>
|
||||
struct GroupArrayNumericData<T, false>
|
||||
{
|
||||
/// For easy serialization.
|
||||
static_assert(std::has_unique_object_representations_v<T> || std::is_floating_point_v<T>);
|
||||
|
||||
// Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena
|
||||
using Allocator = MixedAlignedArenaAllocator<alignof(T), 4096>;
|
||||
using Array = PODArray<T, 32, Allocator>;
|
||||
|
||||
// For groupArrayLast()
|
||||
size_t total_values = 0;
|
||||
Array value;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GroupArrayNumericData<T, true> : public GroupArraySamplerData<T>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T, typename Trait>
|
||||
class GroupArrayNumericImpl final
|
||||
: public IAggregateFunctionDataHelper<GroupArrayNumericData<T, Trait::sampler != Sampler::NONE>, GroupArrayNumericImpl<T, Trait>>
|
||||
{
|
||||
using Data = GroupArrayNumericData<T, Trait::sampler != Sampler::NONE>;
|
||||
static constexpr bool limit_num_elems = Trait::has_limit;
|
||||
UInt64 max_elems;
|
||||
UInt64 seed;
|
||||
|
||||
public:
|
||||
explicit GroupArrayNumericImpl(
|
||||
const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_, UInt64 seed_ = 123456)
|
||||
: IAggregateFunctionDataHelper<GroupArrayNumericData<T, Trait::sampler != Sampler::NONE>, GroupArrayNumericImpl<T, Trait>>(
|
||||
{data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))
|
||||
, max_elems(max_elems_)
|
||||
, seed(seed_)
|
||||
{
|
||||
}
|
||||
|
||||
String getName() const override { return getNameByTrait<Trait>(); }
|
||||
|
||||
void insertWithSampler(Data & a, const T & v, Arena * arena) const
|
||||
{
|
||||
++a.total_values;
|
||||
if (a.value.size() < max_elems)
|
||||
a.value.push_back(v, arena);
|
||||
else
|
||||
{
|
||||
UInt64 rnd = a.genRandom(a.total_values);
|
||||
if (rnd < max_elems)
|
||||
a.value[rnd] = v;
|
||||
}
|
||||
}
|
||||
|
||||
void create(AggregateDataPtr __restrict place) const override /// NOLINT
|
||||
{
|
||||
[[maybe_unused]] auto a = new (place) Data;
|
||||
if constexpr (Trait::sampler == Sampler::RNG)
|
||||
a->rng.seed(seed);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
const auto & row_value = assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];
|
||||
auto & cur_elems = this->data(place);
|
||||
|
||||
++cur_elems.total_values;
|
||||
|
||||
if constexpr (Trait::sampler == Sampler::NONE)
|
||||
{
|
||||
if (limit_num_elems && cur_elems.value.size() >= max_elems)
|
||||
{
|
||||
if constexpr (Trait::last)
|
||||
cur_elems.value[(cur_elems.total_values - 1) % max_elems] = row_value;
|
||||
return;
|
||||
}
|
||||
|
||||
cur_elems.value.push_back(row_value, arena);
|
||||
}
|
||||
|
||||
if constexpr (Trait::sampler == Sampler::RNG)
|
||||
{
|
||||
if (cur_elems.value.size() < max_elems)
|
||||
cur_elems.value.push_back(row_value, arena);
|
||||
else
|
||||
{
|
||||
UInt64 rnd = cur_elems.genRandom(cur_elems.total_values);
|
||||
if (rnd < max_elems)
|
||||
cur_elems.value[rnd] = row_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
auto & cur_elems = this->data(place);
|
||||
auto & rhs_elems = this->data(rhs);
|
||||
|
||||
if (rhs_elems.value.empty())
|
||||
return;
|
||||
|
||||
if constexpr (Trait::last)
|
||||
mergeNoSamplerLast(cur_elems, rhs_elems, arena);
|
||||
else if constexpr (Trait::sampler == Sampler::NONE)
|
||||
mergeNoSampler(cur_elems, rhs_elems, arena);
|
||||
else if constexpr (Trait::sampler == Sampler::RNG)
|
||||
mergeWithRNGSampler(cur_elems, rhs_elems, arena);
|
||||
}
|
||||
|
||||
void mergeNoSamplerLast(Data & cur_elems, const Data & rhs_elems, Arena * arena) const
|
||||
{
|
||||
UInt64 new_elements = std::min(static_cast<size_t>(max_elems), cur_elems.value.size() + rhs_elems.value.size());
|
||||
cur_elems.value.resize_exact(new_elements, arena);
|
||||
for (auto & value : rhs_elems.value)
|
||||
{
|
||||
cur_elems.value[cur_elems.total_values % max_elems] = value;
|
||||
++cur_elems.total_values;
|
||||
}
|
||||
assert(rhs_elems.total_values >= rhs_elems.value.size());
|
||||
cur_elems.total_values += rhs_elems.total_values - rhs_elems.value.size();
|
||||
}
|
||||
|
||||
void mergeNoSampler(Data & cur_elems, const Data & rhs_elems, Arena * arena) const
|
||||
{
|
||||
if (!limit_num_elems)
|
||||
{
|
||||
if (rhs_elems.value.size())
|
||||
cur_elems.value.insertByOffsets(rhs_elems.value, 0, rhs_elems.value.size(), arena);
|
||||
}
|
||||
else
|
||||
{
|
||||
UInt64 elems_to_insert = std::min(static_cast<size_t>(max_elems) - cur_elems.value.size(), rhs_elems.value.size());
|
||||
if (elems_to_insert)
|
||||
cur_elems.value.insertByOffsets(rhs_elems.value, 0, elems_to_insert, arena);
|
||||
}
|
||||
}
|
||||
|
||||
void mergeWithRNGSampler(Data & cur_elems, const Data & rhs_elems, Arena * arena) const
|
||||
{
|
||||
if (rhs_elems.total_values <= max_elems)
|
||||
{
|
||||
for (size_t i = 0; i < rhs_elems.value.size(); ++i)
|
||||
insertWithSampler(cur_elems, rhs_elems.value[i], arena);
|
||||
}
|
||||
else if (cur_elems.total_values <= max_elems)
|
||||
{
|
||||
decltype(cur_elems.value) from;
|
||||
from.swap(cur_elems.value, arena);
|
||||
cur_elems.value.assign(rhs_elems.value.begin(), rhs_elems.value.end(), arena);
|
||||
cur_elems.total_values = rhs_elems.total_values;
|
||||
for (size_t i = 0; i < from.size(); ++i)
|
||||
insertWithSampler(cur_elems, from[i], arena);
|
||||
}
|
||||
else
|
||||
{
|
||||
cur_elems.randomShuffle();
|
||||
cur_elems.total_values += rhs_elems.total_values;
|
||||
for (size_t i = 0; i < max_elems; ++i)
|
||||
{
|
||||
UInt64 rnd = cur_elems.genRandom(cur_elems.total_values);
|
||||
if (rnd < rhs_elems.total_values)
|
||||
cur_elems.value[i] = rhs_elems.value[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void checkArraySize(size_t elems, size_t max_elems)
|
||||
{
|
||||
if (unlikely(elems > max_elems))
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
|
||||
"Too large array size {} (maximum: {})", elems, max_elems);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
const auto & value = this->data(place).value;
|
||||
const UInt64 size = value.size();
|
||||
checkArraySize(size, max_elems);
|
||||
writeVarUInt(size, buf);
|
||||
for (const auto & element : value)
|
||||
writeBinaryLittleEndian(element, buf);
|
||||
|
||||
if constexpr (Trait::last)
|
||||
writeBinaryLittleEndian(this->data(place).total_values, buf);
|
||||
|
||||
if constexpr (Trait::sampler == Sampler::RNG)
|
||||
{
|
||||
writeBinaryLittleEndian(this->data(place).total_values, buf);
|
||||
WriteBufferFromOwnString rng_buf;
|
||||
rng_buf << this->data(place).rng;
|
||||
writeStringBinary(rng_buf.str(), buf);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
size_t size = 0;
|
||||
readVarUInt(size, buf);
|
||||
checkArraySize(size, max_elems);
|
||||
|
||||
auto & value = this->data(place).value;
|
||||
|
||||
value.resize_exact(size, arena);
|
||||
for (auto & element : value)
|
||||
readBinaryLittleEndian(element, buf);
|
||||
|
||||
if constexpr (Trait::last)
|
||||
readBinaryLittleEndian(this->data(place).total_values, buf);
|
||||
|
||||
if constexpr (Trait::sampler == Sampler::RNG)
|
||||
{
|
||||
readBinaryLittleEndian(this->data(place).total_values, buf);
|
||||
std::string rng_string;
|
||||
readStringBinary(rng_string, buf);
|
||||
ReadBufferFromString rng_buf(rng_string);
|
||||
rng_buf >> this->data(place).rng;
|
||||
}
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
const auto & value = this->data(place).value;
|
||||
size_t size = value.size();
|
||||
|
||||
ColumnArray & arr_to = assert_cast<ColumnArray &>(to);
|
||||
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
|
||||
|
||||
offsets_to.push_back(offsets_to.back() + size);
|
||||
|
||||
if (size)
|
||||
{
|
||||
typename ColumnVector<T>::Container & data_to = assert_cast<ColumnVector<T> &>(arr_to.getData()).getData();
|
||||
data_to.insert(this->data(place).value.begin(), this->data(place).value.end());
|
||||
}
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return true; }
|
||||
};
|
||||
|
||||
|
||||
/// General case
|
||||
|
||||
|
||||
/// Nodes used to implement a linked list for storage of groupArray states
|
||||
|
||||
template <typename Node>
|
||||
struct GroupArrayNodeBase
|
||||
{
|
||||
UInt64 size; // size of payload
|
||||
|
||||
/// Returns pointer to actual payload
|
||||
char * data() { return reinterpret_cast<char *>(this) + sizeof(Node); }
|
||||
|
||||
const char * data() const { return reinterpret_cast<const char *>(this) + sizeof(Node); }
|
||||
|
||||
/// Clones existing node (does not modify next field)
|
||||
Node * clone(Arena * arena) const
|
||||
{
|
||||
return reinterpret_cast<Node *>(
|
||||
const_cast<char *>(arena->alignedInsert(reinterpret_cast<const char *>(this), sizeof(Node) + size, alignof(Node))));
|
||||
}
|
||||
|
||||
static void checkElementSize(size_t size, size_t max_size)
|
||||
{
|
||||
if (unlikely(size > max_size))
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
|
||||
"Too large array element size {} (maximum: {})", size, max_size);
|
||||
}
|
||||
|
||||
/// Write node to buffer
|
||||
void write(WriteBuffer & buf) const
|
||||
{
|
||||
checkElementSize(size, AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ELEMENT_SIZE);
|
||||
writeVarUInt(size, buf);
|
||||
buf.write(data(), size);
|
||||
}
|
||||
|
||||
/// Reads and allocates node from ReadBuffer's data (doesn't set next)
|
||||
static Node * read(ReadBuffer & buf, Arena * arena)
|
||||
{
|
||||
UInt64 size;
|
||||
readVarUInt(size, buf);
|
||||
checkElementSize(size, AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ELEMENT_SIZE);
|
||||
|
||||
Node * node = reinterpret_cast<Node *>(arena->alignedAlloc(sizeof(Node) + size, alignof(Node)));
|
||||
node->size = size;
|
||||
buf.readStrict(node->data(), size);
|
||||
return node;
|
||||
}
|
||||
};
|
||||
|
||||
struct GroupArrayNodeString : public GroupArrayNodeBase<GroupArrayNodeString>
|
||||
{
|
||||
using Node = GroupArrayNodeString;
|
||||
|
||||
/// Create node from string
|
||||
static Node * allocate(const IColumn & column, size_t row_num, Arena * arena)
|
||||
{
|
||||
StringRef string = assert_cast<const ColumnString &>(column).getDataAt(row_num);
|
||||
|
||||
Node * node = reinterpret_cast<Node *>(arena->alignedAlloc(sizeof(Node) + string.size, alignof(Node)));
|
||||
node->size = string.size;
|
||||
memcpy(node->data(), string.data, string.size);
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
void insertInto(IColumn & column)
|
||||
{
|
||||
assert_cast<ColumnString &>(column).insertData(data(), size);
|
||||
}
|
||||
};
|
||||
|
||||
struct GroupArrayNodeGeneral : public GroupArrayNodeBase<GroupArrayNodeGeneral>
|
||||
{
|
||||
using Node = GroupArrayNodeGeneral;
|
||||
|
||||
static Node * allocate(const IColumn & column, size_t row_num, Arena * arena)
|
||||
{
|
||||
const char * begin = arena->alignedAlloc(sizeof(Node), alignof(Node));
|
||||
StringRef value = column.serializeValueIntoArena(row_num, *arena, begin);
|
||||
|
||||
Node * node = reinterpret_cast<Node *>(const_cast<char *>(begin));
|
||||
node->size = value.size;
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
void insertInto(IColumn & column) { column.deserializeAndInsertFromArena(data()); }
|
||||
};
|
||||
|
||||
template <typename Node, bool has_sampler>
|
||||
struct GroupArrayGeneralData;
|
||||
|
||||
template <typename Node>
|
||||
struct GroupArrayGeneralData<Node, false>
|
||||
{
|
||||
// Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena
|
||||
using Allocator = MixedAlignedArenaAllocator<alignof(Node *), 4096>;
|
||||
using Array = PODArray<Node *, 32, Allocator>;
|
||||
|
||||
// For groupArrayLast()
|
||||
size_t total_values = 0;
|
||||
Array value;
|
||||
};
|
||||
|
||||
template <typename Node>
|
||||
struct GroupArrayGeneralData<Node, true> : public GroupArraySamplerData<Node *>
|
||||
{
|
||||
};
|
||||
|
||||
/// Implementation of groupArray for String or any ComplexObject via Array
|
||||
template <typename Node, typename Trait>
|
||||
class GroupArrayGeneralImpl final
|
||||
: public IAggregateFunctionDataHelper<GroupArrayGeneralData<Node, Trait::sampler != Sampler::NONE>, GroupArrayGeneralImpl<Node, Trait>>
|
||||
{
|
||||
static constexpr bool limit_num_elems = Trait::has_limit;
|
||||
using Data = GroupArrayGeneralData<Node, Trait::sampler != Sampler::NONE>;
|
||||
static Data & data(AggregateDataPtr __restrict place) { return *reinterpret_cast<Data *>(place); }
|
||||
static const Data & data(ConstAggregateDataPtr __restrict place) { return *reinterpret_cast<const Data *>(place); }
|
||||
|
||||
DataTypePtr & data_type;
|
||||
UInt64 max_elems;
|
||||
UInt64 seed;
|
||||
|
||||
public:
|
||||
GroupArrayGeneralImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_, UInt64 seed_ = 123456)
|
||||
: IAggregateFunctionDataHelper<GroupArrayGeneralData<Node, Trait::sampler != Sampler::NONE>, GroupArrayGeneralImpl<Node, Trait>>(
|
||||
{data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))
|
||||
, data_type(this->argument_types[0])
|
||||
, max_elems(max_elems_)
|
||||
, seed(seed_)
|
||||
{
|
||||
}
|
||||
|
||||
String getName() const override { return getNameByTrait<Trait>(); }
|
||||
|
||||
void insertWithSampler(Data & a, const Node * v, Arena * arena) const
|
||||
{
|
||||
++a.total_values;
|
||||
if (a.value.size() < max_elems)
|
||||
a.value.push_back(v->clone(arena), arena);
|
||||
else
|
||||
{
|
||||
UInt64 rnd = a.genRandom(a.total_values);
|
||||
if (rnd < max_elems)
|
||||
a.value[rnd] = v->clone(arena);
|
||||
}
|
||||
}
|
||||
|
||||
void create(AggregateDataPtr __restrict place) const override /// NOLINT
|
||||
{
|
||||
[[maybe_unused]] auto a = new (place) Data;
|
||||
if constexpr (Trait::sampler == Sampler::RNG)
|
||||
a->rng.seed(seed);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
auto & cur_elems = data(place);
|
||||
|
||||
++cur_elems.total_values;
|
||||
|
||||
if constexpr (Trait::sampler == Sampler::NONE)
|
||||
{
|
||||
if (limit_num_elems && cur_elems.value.size() >= max_elems)
|
||||
{
|
||||
if (Trait::last)
|
||||
{
|
||||
Node * node = Node::allocate(*columns[0], row_num, arena);
|
||||
cur_elems.value[(cur_elems.total_values - 1) % max_elems] = node;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
Node * node = Node::allocate(*columns[0], row_num, arena);
|
||||
cur_elems.value.push_back(node, arena);
|
||||
}
|
||||
|
||||
if constexpr (Trait::sampler == Sampler::RNG)
|
||||
{
|
||||
if (cur_elems.value.size() < max_elems)
|
||||
cur_elems.value.push_back(Node::allocate(*columns[0], row_num, arena), arena);
|
||||
else
|
||||
{
|
||||
UInt64 rnd = cur_elems.genRandom(cur_elems.total_values);
|
||||
if (rnd < max_elems)
|
||||
cur_elems.value[rnd] = Node::allocate(*columns[0], row_num, arena);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
auto & cur_elems = data(place);
|
||||
auto & rhs_elems = data(rhs);
|
||||
|
||||
if (rhs_elems.value.empty())
|
||||
return;
|
||||
|
||||
if constexpr (Trait::last)
|
||||
mergeNoSamplerLast(cur_elems, rhs_elems, arena);
|
||||
else if constexpr (Trait::sampler == Sampler::NONE)
|
||||
mergeNoSampler(cur_elems, rhs_elems, arena);
|
||||
else if constexpr (Trait::sampler == Sampler::RNG)
|
||||
mergeWithRNGSampler(cur_elems, rhs_elems, arena);
|
||||
}
|
||||
|
||||
void ALWAYS_INLINE mergeNoSamplerLast(Data & cur_elems, const Data & rhs_elems, Arena * arena) const
|
||||
{
|
||||
UInt64 new_elements = std::min(static_cast<size_t>(max_elems), cur_elems.value.size() + rhs_elems.value.size());
|
||||
cur_elems.value.resize_exact(new_elements, arena);
|
||||
for (auto & value : rhs_elems.value)
|
||||
{
|
||||
cur_elems.value[cur_elems.total_values % max_elems] = value->clone(arena);
|
||||
++cur_elems.total_values;
|
||||
}
|
||||
assert(rhs_elems.total_values >= rhs_elems.value.size());
|
||||
cur_elems.total_values += rhs_elems.total_values - rhs_elems.value.size();
|
||||
}
|
||||
|
||||
void ALWAYS_INLINE mergeNoSampler(Data & cur_elems, const Data & rhs_elems, Arena * arena) const
|
||||
{
|
||||
UInt64 new_elems;
|
||||
if (limit_num_elems)
|
||||
{
|
||||
if (cur_elems.value.size() >= max_elems)
|
||||
return;
|
||||
new_elems = std::min(rhs_elems.value.size(), static_cast<size_t>(max_elems) - cur_elems.value.size());
|
||||
}
|
||||
else
|
||||
new_elems = rhs_elems.value.size();
|
||||
|
||||
for (UInt64 i = 0; i < new_elems; ++i)
|
||||
cur_elems.value.push_back(rhs_elems.value[i]->clone(arena), arena);
|
||||
}
|
||||
|
||||
void ALWAYS_INLINE mergeWithRNGSampler(Data & cur_elems, const Data & rhs_elems, Arena * arena) const
|
||||
{
|
||||
if (rhs_elems.total_values <= max_elems)
|
||||
{
|
||||
for (size_t i = 0; i < rhs_elems.value.size(); ++i)
|
||||
insertWithSampler(cur_elems, rhs_elems.value[i], arena);
|
||||
}
|
||||
else if (cur_elems.total_values <= max_elems)
|
||||
{
|
||||
decltype(cur_elems.value) from;
|
||||
from.swap(cur_elems.value, arena);
|
||||
for (auto & node : rhs_elems.value)
|
||||
cur_elems.value.push_back(node->clone(arena), arena);
|
||||
cur_elems.total_values = rhs_elems.total_values;
|
||||
for (size_t i = 0; i < from.size(); ++i)
|
||||
insertWithSampler(cur_elems, from[i], arena);
|
||||
}
|
||||
else
|
||||
{
|
||||
cur_elems.randomShuffle();
|
||||
cur_elems.total_values += rhs_elems.total_values;
|
||||
for (size_t i = 0; i < max_elems; ++i)
|
||||
{
|
||||
UInt64 rnd = cur_elems.genRandom(cur_elems.total_values);
|
||||
if (rnd < rhs_elems.total_values)
|
||||
cur_elems.value[i] = rhs_elems.value[i]->clone(arena);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void checkArraySize(size_t elems, size_t max_elems)
|
||||
{
|
||||
if (unlikely(elems > max_elems))
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
|
||||
"Too large array size {} (maximum: {})", elems, max_elems);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
UInt64 elems = data(place).value.size();
|
||||
checkArraySize(elems, max_elems);
|
||||
writeVarUInt(elems, buf);
|
||||
|
||||
auto & value = data(place).value;
|
||||
for (auto & node : value)
|
||||
node->write(buf);
|
||||
|
||||
if constexpr (Trait::last)
|
||||
writeBinaryLittleEndian(data(place).total_values, buf);
|
||||
|
||||
if constexpr (Trait::sampler == Sampler::RNG)
|
||||
{
|
||||
writeBinaryLittleEndian(data(place).total_values, buf);
|
||||
WriteBufferFromOwnString rng_buf;
|
||||
rng_buf << data(place).rng;
|
||||
writeStringBinary(rng_buf.str(), buf);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
UInt64 elems;
|
||||
readVarUInt(elems, buf);
|
||||
|
||||
if (unlikely(elems == 0))
|
||||
return;
|
||||
|
||||
checkArraySize(elems, max_elems);
|
||||
|
||||
auto & value = data(place).value;
|
||||
|
||||
value.resize_exact(elems, arena);
|
||||
for (UInt64 i = 0; i < elems; ++i)
|
||||
value[i] = Node::read(buf, arena);
|
||||
|
||||
if constexpr (Trait::last)
|
||||
readBinaryLittleEndian(data(place).total_values, buf);
|
||||
|
||||
if constexpr (Trait::sampler == Sampler::RNG)
|
||||
{
|
||||
readBinaryLittleEndian(data(place).total_values, buf);
|
||||
std::string rng_string;
|
||||
readStringBinary(rng_string, buf);
|
||||
ReadBufferFromString rng_buf(rng_string);
|
||||
rng_buf >> data(place).rng;
|
||||
}
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
auto & column_array = assert_cast<ColumnArray &>(to);
|
||||
|
||||
auto & offsets = column_array.getOffsets();
|
||||
offsets.push_back(offsets.back() + data(place).value.size());
|
||||
|
||||
auto & column_data = column_array.getData();
|
||||
|
||||
if (std::is_same_v<Node, GroupArrayNodeString>)
|
||||
{
|
||||
auto & string_offsets = assert_cast<ColumnString &>(column_data).getOffsets();
|
||||
string_offsets.reserve(string_offsets.size() + data(place).value.size());
|
||||
}
|
||||
|
||||
auto & value = data(place).value;
|
||||
for (auto & node : value)
|
||||
node->insertInto(column_data);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return true; }
|
||||
};
|
||||
|
||||
#undef AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ELEMENT_SIZE
|
||||
|
||||
}
|
@ -1,21 +1,218 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionGroupArrayInsertAt.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
|
||||
#include <Columns/ColumnArray.h>
|
||||
|
||||
#include <Common/FieldVisitorToString.h>
|
||||
#include <Common/FieldVisitorConvertToNumber.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <Interpreters/convertFieldToType.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
#define AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE 0xFFFFFF
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int TOO_LARGE_ARRAY_SIZE;
|
||||
extern const int CANNOT_CONVERT_TYPE;
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
/** Aggregate function, that takes two arguments: value and position,
|
||||
* and as a result, builds an array with values are located at corresponding positions.
|
||||
*
|
||||
* If more than one value was inserted to single position, the any value (first in case of single thread) is stored.
|
||||
* If no values was inserted to some position, then default value will be substituted.
|
||||
*
|
||||
* Aggregate function also accept optional parameters:
|
||||
* - default value to substitute;
|
||||
* - length to resize result arrays (if you want to have results of same length for all aggregation keys);
|
||||
*
|
||||
* If you want to pass length, default value should be also given.
|
||||
*/
|
||||
|
||||
|
||||
/// Generic case (inefficient).
|
||||
struct AggregateFunctionGroupArrayInsertAtDataGeneric
|
||||
{
|
||||
Array value; /// TODO Add MemoryTracker
|
||||
};
|
||||
|
||||
|
||||
class AggregateFunctionGroupArrayInsertAtGeneric final
|
||||
: public IAggregateFunctionDataHelper<AggregateFunctionGroupArrayInsertAtDataGeneric, AggregateFunctionGroupArrayInsertAtGeneric>
|
||||
{
|
||||
private:
|
||||
DataTypePtr type;
|
||||
SerializationPtr serialization;
|
||||
Field default_value;
|
||||
UInt64 length_to_resize = 0; /// zero means - do not do resizing.
|
||||
|
||||
public:
|
||||
AggregateFunctionGroupArrayInsertAtGeneric(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionGroupArrayInsertAtDataGeneric, AggregateFunctionGroupArrayInsertAtGeneric>(arguments, params, std::make_shared<DataTypeArray>(arguments[0]))
|
||||
, type(argument_types[0])
|
||||
, serialization(type->getDefaultSerialization())
|
||||
{
|
||||
if (!params.empty())
|
||||
{
|
||||
if (params.size() > 2)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires at most two parameters.", getName());
|
||||
|
||||
default_value = params[0];
|
||||
|
||||
if (params.size() == 2)
|
||||
{
|
||||
length_to_resize = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[1]);
|
||||
if (length_to_resize > AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE)
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
|
||||
"Too large array size (maximum: {})", AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE);
|
||||
}
|
||||
}
|
||||
|
||||
if (!isUInt(arguments[1]))
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument of aggregate function {} must be unsigned integer.", getName());
|
||||
|
||||
if (default_value.isNull())
|
||||
default_value = type->getDefault();
|
||||
else
|
||||
{
|
||||
Field converted = convertFieldToType(default_value, *type);
|
||||
if (converted.isNull())
|
||||
throw Exception(ErrorCodes::CANNOT_CONVERT_TYPE, "Cannot convert parameter of aggregate function {} ({}) "
|
||||
"to type {} to be used as default value in array",
|
||||
getName(), applyVisitor(FieldVisitorToString(), default_value), type->getName());
|
||||
|
||||
default_value = converted;
|
||||
}
|
||||
}
|
||||
|
||||
String getName() const override { return "groupArrayInsertAt"; }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
/// TODO Do positions need to be 1-based for this function?
|
||||
size_t position = columns[1]->getUInt(row_num);
|
||||
|
||||
/// If position is larger than size to which array will be cut - simply ignore value.
|
||||
if (length_to_resize && position >= length_to_resize)
|
||||
return;
|
||||
|
||||
if (position >= AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE)
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large array size: "
|
||||
"position argument ({}) is greater or equals to limit ({})",
|
||||
position, AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE);
|
||||
|
||||
Array & arr = data(place).value;
|
||||
|
||||
if (arr.size() <= position)
|
||||
arr.resize(position + 1);
|
||||
else if (!arr[position].isNull())
|
||||
return; /// Element was already inserted to the specified position.
|
||||
|
||||
columns[0]->get(row_num, arr[position]);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
Array & arr_lhs = data(place).value;
|
||||
const Array & arr_rhs = data(rhs).value;
|
||||
|
||||
if (arr_lhs.size() < arr_rhs.size())
|
||||
arr_lhs.resize(arr_rhs.size());
|
||||
|
||||
for (size_t i = 0, size = arr_rhs.size(); i < size; ++i)
|
||||
if (arr_lhs[i].isNull() && !arr_rhs[i].isNull())
|
||||
arr_lhs[i] = arr_rhs[i];
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
const Array & arr = data(place).value;
|
||||
size_t size = arr.size();
|
||||
writeVarUInt(size, buf);
|
||||
|
||||
for (const Field & elem : arr)
|
||||
{
|
||||
if (elem.isNull())
|
||||
{
|
||||
writeBinary(UInt8(1), buf);
|
||||
}
|
||||
else
|
||||
{
|
||||
writeBinary(UInt8(0), buf);
|
||||
serialization->serializeBinary(elem, buf, {});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
size_t size = 0;
|
||||
readVarUInt(size, buf);
|
||||
|
||||
if (size > AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE)
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
|
||||
"Too large array size (maximum: {})", AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE);
|
||||
|
||||
Array & arr = data(place).value;
|
||||
|
||||
arr.resize(size);
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
UInt8 is_null = 0;
|
||||
readBinary(is_null, buf);
|
||||
if (!is_null)
|
||||
serialization->deserializeBinary(arr[i], buf, {});
|
||||
}
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
ColumnArray & to_array = assert_cast<ColumnArray &>(to);
|
||||
IColumn & to_data = to_array.getData();
|
||||
ColumnArray::Offsets & to_offsets = to_array.getOffsets();
|
||||
|
||||
const Array & arr = data(place).value;
|
||||
|
||||
for (const Field & elem : arr)
|
||||
{
|
||||
if (!elem.isNull())
|
||||
to_data.insert(elem);
|
||||
else
|
||||
to_data.insert(default_value);
|
||||
}
|
||||
|
||||
size_t result_array_size = length_to_resize ? length_to_resize : arr.size();
|
||||
|
||||
/// Pad array if need.
|
||||
for (size_t i = arr.size(); i < result_array_size; ++i)
|
||||
to_data.insert(default_value);
|
||||
|
||||
to_offsets.push_back(to_offsets.back() + result_array_size);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionGroupArrayInsertAt(
|
||||
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
|
@ -1,215 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
|
||||
#include <Common/FieldVisitorToString.h>
|
||||
#include <Common/FieldVisitorConvertToNumber.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <Interpreters/convertFieldToType.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
#define AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE 0xFFFFFF
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int TOO_LARGE_ARRAY_SIZE;
|
||||
extern const int CANNOT_CONVERT_TYPE;
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
}
|
||||
|
||||
|
||||
/** Aggregate function, that takes two arguments: value and position,
|
||||
* and as a result, builds an array with values are located at corresponding positions.
|
||||
*
|
||||
* If more than one value was inserted to single position, the any value (first in case of single thread) is stored.
|
||||
* If no values was inserted to some position, then default value will be substituted.
|
||||
*
|
||||
* Aggregate function also accept optional parameters:
|
||||
* - default value to substitute;
|
||||
* - length to resize result arrays (if you want to have results of same length for all aggregation keys);
|
||||
*
|
||||
* If you want to pass length, default value should be also given.
|
||||
*/
|
||||
|
||||
|
||||
/// Generic case (inefficient).
|
||||
struct AggregateFunctionGroupArrayInsertAtDataGeneric
|
||||
{
|
||||
Array value; /// TODO Add MemoryTracker
|
||||
};
|
||||
|
||||
|
||||
class AggregateFunctionGroupArrayInsertAtGeneric final
|
||||
: public IAggregateFunctionDataHelper<AggregateFunctionGroupArrayInsertAtDataGeneric, AggregateFunctionGroupArrayInsertAtGeneric>
|
||||
{
|
||||
private:
|
||||
DataTypePtr type;
|
||||
SerializationPtr serialization;
|
||||
Field default_value;
|
||||
UInt64 length_to_resize = 0; /// zero means - do not do resizing.
|
||||
|
||||
public:
|
||||
AggregateFunctionGroupArrayInsertAtGeneric(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionGroupArrayInsertAtDataGeneric, AggregateFunctionGroupArrayInsertAtGeneric>(arguments, params, std::make_shared<DataTypeArray>(arguments[0]))
|
||||
, type(argument_types[0])
|
||||
, serialization(type->getDefaultSerialization())
|
||||
{
|
||||
if (!params.empty())
|
||||
{
|
||||
if (params.size() > 2)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires at most two parameters.", getName());
|
||||
|
||||
default_value = params[0];
|
||||
|
||||
if (params.size() == 2)
|
||||
{
|
||||
length_to_resize = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[1]);
|
||||
if (length_to_resize > AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE)
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
|
||||
"Too large array size (maximum: {})", AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE);
|
||||
}
|
||||
}
|
||||
|
||||
if (!isUInt(arguments[1]))
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument of aggregate function {} must be unsigned integer.", getName());
|
||||
|
||||
if (default_value.isNull())
|
||||
default_value = type->getDefault();
|
||||
else
|
||||
{
|
||||
Field converted = convertFieldToType(default_value, *type);
|
||||
if (converted.isNull())
|
||||
throw Exception(ErrorCodes::CANNOT_CONVERT_TYPE, "Cannot convert parameter of aggregate function {} ({}) "
|
||||
"to type {} to be used as default value in array",
|
||||
getName(), applyVisitor(FieldVisitorToString(), default_value), type->getName());
|
||||
|
||||
default_value = converted;
|
||||
}
|
||||
}
|
||||
|
||||
String getName() const override { return "groupArrayInsertAt"; }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
/// TODO Do positions need to be 1-based for this function?
|
||||
size_t position = columns[1]->getUInt(row_num);
|
||||
|
||||
/// If position is larger than size to which array will be cut - simply ignore value.
|
||||
if (length_to_resize && position >= length_to_resize)
|
||||
return;
|
||||
|
||||
if (position >= AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE)
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large array size: "
|
||||
"position argument ({}) is greater or equals to limit ({})",
|
||||
position, AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE);
|
||||
|
||||
Array & arr = data(place).value;
|
||||
|
||||
if (arr.size() <= position)
|
||||
arr.resize(position + 1);
|
||||
else if (!arr[position].isNull())
|
||||
return; /// Element was already inserted to the specified position.
|
||||
|
||||
columns[0]->get(row_num, arr[position]);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
Array & arr_lhs = data(place).value;
|
||||
const Array & arr_rhs = data(rhs).value;
|
||||
|
||||
if (arr_lhs.size() < arr_rhs.size())
|
||||
arr_lhs.resize(arr_rhs.size());
|
||||
|
||||
for (size_t i = 0, size = arr_rhs.size(); i < size; ++i)
|
||||
if (arr_lhs[i].isNull() && !arr_rhs[i].isNull())
|
||||
arr_lhs[i] = arr_rhs[i];
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
const Array & arr = data(place).value;
|
||||
size_t size = arr.size();
|
||||
writeVarUInt(size, buf);
|
||||
|
||||
for (const Field & elem : arr)
|
||||
{
|
||||
if (elem.isNull())
|
||||
{
|
||||
writeBinary(UInt8(1), buf);
|
||||
}
|
||||
else
|
||||
{
|
||||
writeBinary(UInt8(0), buf);
|
||||
serialization->serializeBinary(elem, buf, {});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
size_t size = 0;
|
||||
readVarUInt(size, buf);
|
||||
|
||||
if (size > AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE)
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
|
||||
"Too large array size (maximum: {})", AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE);
|
||||
|
||||
Array & arr = data(place).value;
|
||||
|
||||
arr.resize(size);
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
UInt8 is_null = 0;
|
||||
readBinary(is_null, buf);
|
||||
if (!is_null)
|
||||
serialization->deserializeBinary(arr[i], buf, {});
|
||||
}
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
ColumnArray & to_array = assert_cast<ColumnArray &>(to);
|
||||
IColumn & to_data = to_array.getData();
|
||||
ColumnArray::Offsets & to_offsets = to_array.getOffsets();
|
||||
|
||||
const Array & arr = data(place).value;
|
||||
|
||||
for (const Field & elem : arr)
|
||||
{
|
||||
if (!elem.isNull())
|
||||
to_data.insert(elem);
|
||||
else
|
||||
to_data.insert(default_value);
|
||||
}
|
||||
|
||||
size_t result_array_size = length_to_resize ? length_to_resize : arr.size();
|
||||
|
||||
/// Pad array if need.
|
||||
for (size_t i = arr.size(); i < result_array_size; ++i)
|
||||
to_data.insert(default_value);
|
||||
|
||||
to_offsets.push_back(to_offsets.back() + result_array_size);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
#undef AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE
|
||||
|
||||
}
|
@ -2,8 +2,14 @@
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <DataTypes/DataTypeAggregateFunction.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <Columns/ColumnAggregateFunction.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
// TODO include this last because of a broken roaring header. See the comment inside.
|
||||
#include <AggregateFunctions/AggregateFunctionGroupBitmap.h>
|
||||
#include <AggregateFunctions/AggregateFunctionGroupBitmapData.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -17,77 +23,255 @@ namespace ErrorCodes
|
||||
|
||||
namespace
|
||||
{
|
||||
template <template <typename, typename> class AggregateFunctionTemplate, template <typename> typename Data, typename... TArgs>
|
||||
IAggregateFunction * createWithIntegerType(const IDataType & argument_type, TArgs &&... args)
|
||||
|
||||
/// Counts bitmap operation on numbers.
|
||||
template <typename T, typename Data>
|
||||
class AggregateFunctionBitmap final : public IAggregateFunctionDataHelper<Data, AggregateFunctionBitmap<T, Data>>
|
||||
{
|
||||
public:
|
||||
explicit AggregateFunctionBitmap(const DataTypePtr & type)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitmap<T, Data>>({type}, {}, createResultType())
|
||||
{
|
||||
WhichDataType which(argument_type);
|
||||
if (which.idx == TypeIndex::UInt8) return new AggregateFunctionTemplate<UInt8, Data<UInt8>>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::UInt16) return new AggregateFunctionTemplate<UInt16, Data<UInt16>>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::UInt32) return new AggregateFunctionTemplate<UInt32, Data<UInt32>>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::UInt64) return new AggregateFunctionTemplate<UInt64, Data<UInt64>>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::Int8) return new AggregateFunctionTemplate<Int8, Data<Int8>>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::Int16) return new AggregateFunctionTemplate<Int16, Data<Int16>>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::Int32) return new AggregateFunctionTemplate<Int32, Data<Int32>>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::Int64) return new AggregateFunctionTemplate<Int64, Data<Int64>>(std::forward<TArgs>(args)...);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <template <typename> typename Data>
|
||||
AggregateFunctionPtr createAggregateFunctionBitmap(
|
||||
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
String getName() const override { return Data::name(); }
|
||||
|
||||
static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<T>>(); }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
assertUnary(name, argument_types);
|
||||
|
||||
if (!argument_types[0]->canBeUsedInBitOperations())
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"The type {} of argument for aggregate function {} "
|
||||
"is illegal, because it cannot be used in Bitmap operations",
|
||||
argument_types[0]->getName(), name);
|
||||
|
||||
AggregateFunctionPtr res(createWithIntegerType<AggregateFunctionBitmap, Data>(*argument_types[0], argument_types[0]));
|
||||
|
||||
if (!res)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}",
|
||||
argument_types[0]->getName(), name);
|
||||
|
||||
return res;
|
||||
this->data(place).roaring_bitmap_with_small_set.add(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
|
||||
}
|
||||
|
||||
// Additional aggregate functions to manipulate bitmaps.
|
||||
template <template <typename, typename> typename AggregateFunctionTemplate>
|
||||
AggregateFunctionPtr createAggregateFunctionBitmapL2(
|
||||
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
assertUnary(name, argument_types);
|
||||
|
||||
DataTypePtr argument_type_ptr = argument_types[0];
|
||||
WhichDataType which(*argument_type_ptr);
|
||||
if (which.idx != TypeIndex::AggregateFunction)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}",
|
||||
argument_types[0]->getName(), name);
|
||||
|
||||
/// groupBitmap needs to know about the data type that was used to create bitmaps.
|
||||
/// We need to look inside the type of its argument to obtain it.
|
||||
const DataTypeAggregateFunction & datatype_aggfunc = dynamic_cast<const DataTypeAggregateFunction &>(*argument_type_ptr);
|
||||
AggregateFunctionPtr aggfunc = datatype_aggfunc.getFunction();
|
||||
|
||||
if (aggfunc->getName() != AggregateFunctionGroupBitmapData<UInt8>::name())
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}",
|
||||
argument_types[0]->getName(), name);
|
||||
|
||||
DataTypePtr nested_argument_type_ptr = aggfunc->getArgumentTypes()[0];
|
||||
|
||||
AggregateFunctionPtr res(createWithIntegerType<AggregateFunctionTemplate, AggregateFunctionGroupBitmapData>(
|
||||
*nested_argument_type_ptr, argument_type_ptr));
|
||||
|
||||
if (!res)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}",
|
||||
argument_types[0]->getName(), name);
|
||||
|
||||
return res;
|
||||
this->data(place).roaring_bitmap_with_small_set.merge(this->data(rhs).roaring_bitmap_with_small_set);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).roaring_bitmap_with_small_set.write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).roaring_bitmap_with_small_set.read(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
assert_cast<ColumnVector<T> &>(to).getData().push_back(
|
||||
static_cast<T>(this->data(place).roaring_bitmap_with_small_set.size()));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// This aggregate function takes the states of AggregateFunctionBitmap as its argument.
|
||||
template <typename T, typename Data, typename Policy>
|
||||
class AggregateFunctionBitmapL2 final : public IAggregateFunctionDataHelper<Data, AggregateFunctionBitmapL2<T, Data, Policy>>
|
||||
{
|
||||
private:
|
||||
static constexpr size_t STATE_VERSION_1_MIN_REVISION = 54455;
|
||||
public:
|
||||
explicit AggregateFunctionBitmapL2(const DataTypePtr & type)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitmapL2<T, Data, Policy>>({type}, {}, createResultType())
|
||||
{
|
||||
}
|
||||
|
||||
String getName() const override { return Policy::name; }
|
||||
|
||||
static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<T>>(); }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
DataTypePtr getStateType() const override
|
||||
{
|
||||
return this->argument_types.at(0);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
Data & data_lhs = this->data(place);
|
||||
const Data & data_rhs = this->data(assert_cast<const ColumnAggregateFunction &>(*columns[0]).getData()[row_num]);
|
||||
if (!data_lhs.init)
|
||||
{
|
||||
data_lhs.init = true;
|
||||
data_lhs.roaring_bitmap_with_small_set.merge(data_rhs.roaring_bitmap_with_small_set);
|
||||
}
|
||||
else
|
||||
{
|
||||
Policy::apply(data_lhs, data_rhs);
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
Data & data_lhs = this->data(place);
|
||||
const Data & data_rhs = this->data(rhs);
|
||||
|
||||
if (!data_rhs.init)
|
||||
return;
|
||||
|
||||
if (!data_lhs.init)
|
||||
{
|
||||
data_lhs.init = true;
|
||||
data_lhs.roaring_bitmap_with_small_set.merge(data_rhs.roaring_bitmap_with_small_set);
|
||||
}
|
||||
else
|
||||
{
|
||||
Policy::apply(data_lhs, data_rhs);
|
||||
}
|
||||
}
|
||||
|
||||
bool isVersioned() const override { return true; }
|
||||
|
||||
size_t getDefaultVersion() const override { return 1; }
|
||||
|
||||
size_t getVersionFromRevision(size_t revision) const override
|
||||
{
|
||||
if (revision >= STATE_VERSION_1_MIN_REVISION)
|
||||
return 1;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> version) const override
|
||||
{
|
||||
if (!version)
|
||||
version = getDefaultVersion();
|
||||
|
||||
if (*version >= 1)
|
||||
DB::writeBoolText(this->data(place).init, buf);
|
||||
|
||||
this->data(place).roaring_bitmap_with_small_set.write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> version, Arena *) const override
|
||||
{
|
||||
if (!version)
|
||||
version = getDefaultVersion();
|
||||
|
||||
if (*version >= 1)
|
||||
DB::readBoolText(this->data(place).init, buf);
|
||||
this->data(place).roaring_bitmap_with_small_set.read(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
assert_cast<ColumnVector<T> &>(to).getData().push_back(
|
||||
static_cast<T>(this->data(place).roaring_bitmap_with_small_set.size()));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename Data>
|
||||
class BitmapAndPolicy
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = "groupBitmapAnd";
|
||||
static void apply(Data & lhs, const Data & rhs) { lhs.roaring_bitmap_with_small_set.rb_and(rhs.roaring_bitmap_with_small_set); }
|
||||
};
|
||||
|
||||
template <typename Data>
|
||||
class BitmapOrPolicy
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = "groupBitmapOr";
|
||||
static void apply(Data & lhs, const Data & rhs) { lhs.roaring_bitmap_with_small_set.rb_or(rhs.roaring_bitmap_with_small_set); }
|
||||
};
|
||||
|
||||
template <typename Data>
|
||||
class BitmapXorPolicy
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = "groupBitmapXor";
|
||||
static void apply(Data & lhs, const Data & rhs) { lhs.roaring_bitmap_with_small_set.rb_xor(rhs.roaring_bitmap_with_small_set); }
|
||||
};
|
||||
|
||||
template <typename T, typename Data>
|
||||
using AggregateFunctionBitmapL2And = AggregateFunctionBitmapL2<T, Data, BitmapAndPolicy<Data>>;
|
||||
|
||||
template <typename T, typename Data>
|
||||
using AggregateFunctionBitmapL2Or = AggregateFunctionBitmapL2<T, Data, BitmapOrPolicy<Data>>;
|
||||
|
||||
template <typename T, typename Data>
|
||||
using AggregateFunctionBitmapL2Xor = AggregateFunctionBitmapL2<T, Data, BitmapXorPolicy<Data>>;
|
||||
|
||||
|
||||
template <template <typename, typename> class AggregateFunctionTemplate, template <typename> typename Data, typename... TArgs>
|
||||
IAggregateFunction * createWithIntegerType(const IDataType & argument_type, TArgs &&... args)
|
||||
{
|
||||
WhichDataType which(argument_type);
|
||||
if (which.idx == TypeIndex::UInt8) return new AggregateFunctionTemplate<UInt8, Data<UInt8>>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::UInt16) return new AggregateFunctionTemplate<UInt16, Data<UInt16>>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::UInt32) return new AggregateFunctionTemplate<UInt32, Data<UInt32>>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::UInt64) return new AggregateFunctionTemplate<UInt64, Data<UInt64>>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::Int8) return new AggregateFunctionTemplate<Int8, Data<Int8>>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::Int16) return new AggregateFunctionTemplate<Int16, Data<Int16>>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::Int32) return new AggregateFunctionTemplate<Int32, Data<Int32>>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::Int64) return new AggregateFunctionTemplate<Int64, Data<Int64>>(std::forward<TArgs>(args)...);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <template <typename> typename Data>
|
||||
AggregateFunctionPtr createAggregateFunctionBitmap(
|
||||
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
assertUnary(name, argument_types);
|
||||
|
||||
if (!argument_types[0]->canBeUsedInBitOperations())
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"The type {} of argument for aggregate function {} "
|
||||
"is illegal, because it cannot be used in Bitmap operations",
|
||||
argument_types[0]->getName(), name);
|
||||
|
||||
AggregateFunctionPtr res(createWithIntegerType<AggregateFunctionBitmap, Data>(*argument_types[0], argument_types[0]));
|
||||
|
||||
if (!res)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}",
|
||||
argument_types[0]->getName(), name);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// Additional aggregate functions to manipulate bitmaps.
|
||||
template <template <typename, typename> typename AggregateFunctionTemplate>
|
||||
AggregateFunctionPtr createAggregateFunctionBitmapL2(
|
||||
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
assertUnary(name, argument_types);
|
||||
|
||||
DataTypePtr argument_type_ptr = argument_types[0];
|
||||
WhichDataType which(*argument_type_ptr);
|
||||
if (which.idx != TypeIndex::AggregateFunction)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}",
|
||||
argument_types[0]->getName(), name);
|
||||
|
||||
/// groupBitmap needs to know about the data type that was used to create bitmaps.
|
||||
/// We need to look inside the type of its argument to obtain it.
|
||||
const DataTypeAggregateFunction & datatype_aggfunc = dynamic_cast<const DataTypeAggregateFunction &>(*argument_type_ptr);
|
||||
AggregateFunctionPtr aggfunc = datatype_aggfunc.getFunction();
|
||||
|
||||
if (aggfunc->getName() != AggregateFunctionGroupBitmapData<UInt8>::name())
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}",
|
||||
argument_types[0]->getName(), name);
|
||||
|
||||
DataTypePtr nested_argument_type_ptr = aggfunc->getArgumentTypes()[0];
|
||||
|
||||
AggregateFunctionPtr res(createWithIntegerType<AggregateFunctionTemplate, AggregateFunctionGroupBitmapData>(
|
||||
*nested_argument_type_ptr, argument_type_ptr));
|
||||
|
||||
if (!res)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}",
|
||||
argument_types[0]->getName(), name);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionsBitmap(AggregateFunctionFactory & factory)
|
||||
|
@ -1,191 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <Columns/ColumnAggregateFunction.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
// TODO include this last because of a broken roaring header. See the comment inside.
|
||||
#include <AggregateFunctions/AggregateFunctionGroupBitmapData.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
/// Counts bitmap operation on numbers.
|
||||
template <typename T, typename Data>
|
||||
class AggregateFunctionBitmap final : public IAggregateFunctionDataHelper<Data, AggregateFunctionBitmap<T, Data>>
|
||||
{
|
||||
public:
|
||||
explicit AggregateFunctionBitmap(const DataTypePtr & type)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitmap<T, Data>>({type}, {}, createResultType())
|
||||
{
|
||||
}
|
||||
|
||||
String getName() const override { return Data::name(); }
|
||||
|
||||
static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<T>>(); }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
this->data(place).roaring_bitmap_with_small_set.add(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).roaring_bitmap_with_small_set.merge(this->data(rhs).roaring_bitmap_with_small_set);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).roaring_bitmap_with_small_set.write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).roaring_bitmap_with_small_set.read(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
assert_cast<ColumnVector<T> &>(to).getData().push_back(
|
||||
static_cast<T>(this->data(place).roaring_bitmap_with_small_set.size()));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// This aggregate function takes the states of AggregateFunctionBitmap as its argument.
|
||||
template <typename T, typename Data, typename Policy>
|
||||
class AggregateFunctionBitmapL2 final : public IAggregateFunctionDataHelper<Data, AggregateFunctionBitmapL2<T, Data, Policy>>
|
||||
{
|
||||
private:
|
||||
static constexpr size_t STATE_VERSION_1_MIN_REVISION = 54455;
|
||||
public:
|
||||
explicit AggregateFunctionBitmapL2(const DataTypePtr & type)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitmapL2<T, Data, Policy>>({type}, {}, createResultType())
|
||||
{
|
||||
}
|
||||
|
||||
String getName() const override { return Policy::name; }
|
||||
|
||||
static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<T>>(); }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
DataTypePtr getStateType() const override
|
||||
{
|
||||
return this->argument_types.at(0);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
Data & data_lhs = this->data(place);
|
||||
const Data & data_rhs = this->data(assert_cast<const ColumnAggregateFunction &>(*columns[0]).getData()[row_num]);
|
||||
if (!data_lhs.init)
|
||||
{
|
||||
data_lhs.init = true;
|
||||
data_lhs.roaring_bitmap_with_small_set.merge(data_rhs.roaring_bitmap_with_small_set);
|
||||
}
|
||||
else
|
||||
{
|
||||
Policy::apply(data_lhs, data_rhs);
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
Data & data_lhs = this->data(place);
|
||||
const Data & data_rhs = this->data(rhs);
|
||||
|
||||
if (!data_rhs.init)
|
||||
return;
|
||||
|
||||
if (!data_lhs.init)
|
||||
{
|
||||
data_lhs.init = true;
|
||||
data_lhs.roaring_bitmap_with_small_set.merge(data_rhs.roaring_bitmap_with_small_set);
|
||||
}
|
||||
else
|
||||
{
|
||||
Policy::apply(data_lhs, data_rhs);
|
||||
}
|
||||
}
|
||||
|
||||
bool isVersioned() const override { return true; }
|
||||
|
||||
size_t getDefaultVersion() const override { return 1; }
|
||||
|
||||
size_t getVersionFromRevision(size_t revision) const override
|
||||
{
|
||||
if (revision >= STATE_VERSION_1_MIN_REVISION)
|
||||
return 1;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> version) const override
|
||||
{
|
||||
if (!version)
|
||||
version = getDefaultVersion();
|
||||
|
||||
if (*version >= 1)
|
||||
DB::writeBoolText(this->data(place).init, buf);
|
||||
|
||||
this->data(place).roaring_bitmap_with_small_set.write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> version, Arena *) const override
|
||||
{
|
||||
if (!version)
|
||||
version = getDefaultVersion();
|
||||
|
||||
if (*version >= 1)
|
||||
DB::readBoolText(this->data(place).init, buf);
|
||||
this->data(place).roaring_bitmap_with_small_set.read(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
assert_cast<ColumnVector<T> &>(to).getData().push_back(
|
||||
static_cast<T>(this->data(place).roaring_bitmap_with_small_set.size()));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename Data>
|
||||
class BitmapAndPolicy
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = "groupBitmapAnd";
|
||||
static void apply(Data & lhs, const Data & rhs) { lhs.roaring_bitmap_with_small_set.rb_and(rhs.roaring_bitmap_with_small_set); }
|
||||
};
|
||||
|
||||
template <typename Data>
|
||||
class BitmapOrPolicy
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = "groupBitmapOr";
|
||||
static void apply(Data & lhs, const Data & rhs) { lhs.roaring_bitmap_with_small_set.rb_or(rhs.roaring_bitmap_with_small_set); }
|
||||
};
|
||||
|
||||
template <typename Data>
|
||||
class BitmapXorPolicy
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = "groupBitmapXor";
|
||||
static void apply(Data & lhs, const Data & rhs) { lhs.roaring_bitmap_with_small_set.rb_xor(rhs.roaring_bitmap_with_small_set); }
|
||||
};
|
||||
|
||||
template <typename T, typename Data>
|
||||
using AggregateFunctionBitmapL2And = AggregateFunctionBitmapL2<T, Data, BitmapAndPolicy<Data>>;
|
||||
|
||||
template <typename T, typename Data>
|
||||
using AggregateFunctionBitmapL2Or = AggregateFunctionBitmapL2<T, Data, BitmapOrPolicy<Data>>;
|
||||
|
||||
template <typename T, typename Data>
|
||||
using AggregateFunctionBitmapL2Xor = AggregateFunctionBitmapL2<T, Data, BitmapXorPolicy<Data>>;
|
||||
|
||||
}
|
@ -1,14 +1,31 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionGroupUniqArray.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <DataTypes/DataTypeDate.h>
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
#include <DataTypes/DataTypeIPv4andIPv6.h>
|
||||
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/ReadHelpersArena.h>
|
||||
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
|
||||
#include <Columns/ColumnArray.h>
|
||||
|
||||
#include <Common/HashTable/HashSet.h>
|
||||
#include <Common/HashTable/HashTableKeyHolder.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/KeyHolderHelpers.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
@ -21,6 +38,211 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
struct AggregateFunctionGroupUniqArrayData
|
||||
{
|
||||
/// When creating, the hash table must be small.
|
||||
using Set = HashSetWithStackMemory<T, DefaultHash<T>, 4>;
|
||||
|
||||
Set value;
|
||||
};
|
||||
|
||||
|
||||
/// Puts all values to the hash set. Returns an array of unique values. Implemented for numeric types.
|
||||
template <typename T, typename LimitNumElems>
|
||||
class AggregateFunctionGroupUniqArray
|
||||
: public IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayData<T>, AggregateFunctionGroupUniqArray<T, LimitNumElems>>
|
||||
{
|
||||
static constexpr bool limit_num_elems = LimitNumElems::value;
|
||||
UInt64 max_elems;
|
||||
|
||||
private:
|
||||
using State = AggregateFunctionGroupUniqArrayData<T>;
|
||||
|
||||
public:
|
||||
AggregateFunctionGroupUniqArray(const DataTypePtr & argument_type, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayData<T>,
|
||||
AggregateFunctionGroupUniqArray<T, LimitNumElems>>({argument_type}, parameters_, std::make_shared<DataTypeArray>(argument_type)),
|
||||
max_elems(max_elems_) {}
|
||||
|
||||
AggregateFunctionGroupUniqArray(const DataTypePtr & argument_type, const Array & parameters_, const DataTypePtr & result_type_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayData<T>,
|
||||
AggregateFunctionGroupUniqArray<T, LimitNumElems>>({argument_type}, parameters_, result_type_),
|
||||
max_elems(max_elems_) {}
|
||||
|
||||
|
||||
String getName() const override { return "groupUniqArray"; }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
if (limit_num_elems && this->data(place).value.size() >= max_elems)
|
||||
return;
|
||||
this->data(place).value.insert(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
if (!limit_num_elems)
|
||||
this->data(place).value.merge(this->data(rhs).value);
|
||||
else
|
||||
{
|
||||
auto & cur_set = this->data(place).value;
|
||||
auto & rhs_set = this->data(rhs).value;
|
||||
|
||||
for (auto & rhs_elem : rhs_set)
|
||||
{
|
||||
if (cur_set.size() >= max_elems)
|
||||
return;
|
||||
cur_set.insert(rhs_elem.getValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
auto & set = this->data(place).value;
|
||||
size_t size = set.size();
|
||||
writeVarUInt(size, buf);
|
||||
for (const auto & elem : set)
|
||||
writeBinaryLittleEndian(elem.key, buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).value.read(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
ColumnArray & arr_to = assert_cast<ColumnArray &>(to);
|
||||
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
|
||||
|
||||
const typename State::Set & set = this->data(place).value;
|
||||
size_t size = set.size();
|
||||
|
||||
offsets_to.push_back(offsets_to.back() + size);
|
||||
|
||||
typename ColumnVector<T>::Container & data_to = assert_cast<ColumnVector<T> &>(arr_to.getData()).getData();
|
||||
size_t old_size = data_to.size();
|
||||
data_to.resize(old_size + size);
|
||||
|
||||
size_t i = 0;
|
||||
for (auto it = set.begin(); it != set.end(); ++it, ++i)
|
||||
data_to[old_size + i] = it->getValue();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Generic implementation, it uses serialized representation as object descriptor.
|
||||
struct AggregateFunctionGroupUniqArrayGenericData
|
||||
{
|
||||
static constexpr size_t INITIAL_SIZE_DEGREE = 3; /// adjustable
|
||||
|
||||
using Set = HashSetWithSavedHashWithStackMemory<StringRef, StringRefHash,
|
||||
INITIAL_SIZE_DEGREE>;
|
||||
|
||||
Set value;
|
||||
};
|
||||
|
||||
template <bool is_plain_column>
|
||||
static void deserializeAndInsertImpl(StringRef str, IColumn & data_to);
|
||||
|
||||
/** Template parameter with true value should be used for columns that store their elements in memory continuously.
|
||||
* For such columns groupUniqArray() can be implemented more efficiently (especially for small numeric arrays).
|
||||
*/
|
||||
template <bool is_plain_column = false, typename LimitNumElems = std::false_type>
|
||||
class AggregateFunctionGroupUniqArrayGeneric
|
||||
: public IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayGenericData,
|
||||
AggregateFunctionGroupUniqArrayGeneric<is_plain_column, LimitNumElems>>
|
||||
{
|
||||
DataTypePtr & input_data_type;
|
||||
|
||||
static constexpr bool limit_num_elems = LimitNumElems::value;
|
||||
UInt64 max_elems;
|
||||
|
||||
using State = AggregateFunctionGroupUniqArrayGenericData;
|
||||
|
||||
public:
|
||||
AggregateFunctionGroupUniqArrayGeneric(const DataTypePtr & input_data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayGenericData, AggregateFunctionGroupUniqArrayGeneric<is_plain_column, LimitNumElems>>({input_data_type_}, parameters_, std::make_shared<DataTypeArray>(input_data_type_))
|
||||
, input_data_type(this->argument_types[0])
|
||||
, max_elems(max_elems_) {}
|
||||
|
||||
String getName() const override { return "groupUniqArray"; }
|
||||
|
||||
bool allocatesMemoryInArena() const override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
auto & set = this->data(place).value;
|
||||
writeVarUInt(set.size(), buf);
|
||||
|
||||
for (const auto & elem : set)
|
||||
{
|
||||
writeStringBinary(elem.getValue(), buf);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
auto & set = this->data(place).value;
|
||||
size_t size;
|
||||
readVarUInt(size, buf);
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
set.insert(readStringBinaryInto(*arena, buf));
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
auto & set = this->data(place).value;
|
||||
if (limit_num_elems && set.size() >= max_elems)
|
||||
return;
|
||||
|
||||
bool inserted;
|
||||
State::Set::LookupResult it;
|
||||
auto key_holder = getKeyHolder<is_plain_column>(*columns[0], row_num, *arena);
|
||||
set.emplace(key_holder, it, inserted);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
auto & cur_set = this->data(place).value;
|
||||
auto & rhs_set = this->data(rhs).value;
|
||||
|
||||
bool inserted;
|
||||
State::Set::LookupResult it;
|
||||
for (auto & rhs_elem : rhs_set)
|
||||
{
|
||||
if (limit_num_elems && cur_set.size() >= max_elems)
|
||||
return;
|
||||
|
||||
// We have to copy the keys to our arena.
|
||||
chassert(arena != nullptr);
|
||||
cur_set.emplace(ArenaKeyHolder{rhs_elem.getValue(), *arena}, it, inserted);
|
||||
}
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
ColumnArray & arr_to = assert_cast<ColumnArray &>(to);
|
||||
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
|
||||
IColumn & data_to = arr_to.getData();
|
||||
|
||||
auto & set = this->data(place).value;
|
||||
offsets_to.push_back(offsets_to.back() + set.size());
|
||||
|
||||
for (auto & elem : set)
|
||||
deserializeAndInsert<is_plain_column>(elem.getValue(), data_to);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Substitute return type for Date and DateTime
|
||||
template <typename HasLimit>
|
||||
class AggregateFunctionGroupUniqArrayDate : public AggregateFunctionGroupUniqArray<DataTypeDate::FieldType, HasLimit>
|
||||
|
@ -1,236 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/ReadHelpersArena.h>
|
||||
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
|
||||
#include <Columns/ColumnArray.h>
|
||||
|
||||
#include <Common/HashTable/HashSet.h>
|
||||
#include <Common/HashTable/HashTableKeyHolder.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/KeyHolderHelpers.h>
|
||||
|
||||
#define AGGREGATE_FUNCTION_GROUP_ARRAY_UNIQ_MAX_SIZE 0xFFFFFF
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct AggregateFunctionGroupUniqArrayData
|
||||
{
|
||||
/// When creating, the hash table must be small.
|
||||
using Set = HashSetWithStackMemory<T, DefaultHash<T>, 4>;
|
||||
|
||||
Set value;
|
||||
};
|
||||
|
||||
|
||||
/// Puts all values to the hash set. Returns an array of unique values. Implemented for numeric types.
|
||||
template <typename T, typename LimitNumElems>
|
||||
class AggregateFunctionGroupUniqArray
|
||||
: public IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayData<T>, AggregateFunctionGroupUniqArray<T, LimitNumElems>>
|
||||
{
|
||||
static constexpr bool limit_num_elems = LimitNumElems::value;
|
||||
UInt64 max_elems;
|
||||
|
||||
private:
|
||||
using State = AggregateFunctionGroupUniqArrayData<T>;
|
||||
|
||||
public:
|
||||
AggregateFunctionGroupUniqArray(const DataTypePtr & argument_type, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayData<T>,
|
||||
AggregateFunctionGroupUniqArray<T, LimitNumElems>>({argument_type}, parameters_, std::make_shared<DataTypeArray>(argument_type)),
|
||||
max_elems(max_elems_) {}
|
||||
|
||||
AggregateFunctionGroupUniqArray(const DataTypePtr & argument_type, const Array & parameters_, const DataTypePtr & result_type_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayData<T>,
|
||||
AggregateFunctionGroupUniqArray<T, LimitNumElems>>({argument_type}, parameters_, result_type_),
|
||||
max_elems(max_elems_) {}
|
||||
|
||||
|
||||
String getName() const override { return "groupUniqArray"; }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
if (limit_num_elems && this->data(place).value.size() >= max_elems)
|
||||
return;
|
||||
this->data(place).value.insert(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
if (!limit_num_elems)
|
||||
this->data(place).value.merge(this->data(rhs).value);
|
||||
else
|
||||
{
|
||||
auto & cur_set = this->data(place).value;
|
||||
auto & rhs_set = this->data(rhs).value;
|
||||
|
||||
for (auto & rhs_elem : rhs_set)
|
||||
{
|
||||
if (cur_set.size() >= max_elems)
|
||||
return;
|
||||
cur_set.insert(rhs_elem.getValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
auto & set = this->data(place).value;
|
||||
size_t size = set.size();
|
||||
writeVarUInt(size, buf);
|
||||
for (const auto & elem : set)
|
||||
writeBinaryLittleEndian(elem.key, buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).value.read(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
ColumnArray & arr_to = assert_cast<ColumnArray &>(to);
|
||||
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
|
||||
|
||||
const typename State::Set & set = this->data(place).value;
|
||||
size_t size = set.size();
|
||||
|
||||
offsets_to.push_back(offsets_to.back() + size);
|
||||
|
||||
typename ColumnVector<T>::Container & data_to = assert_cast<ColumnVector<T> &>(arr_to.getData()).getData();
|
||||
size_t old_size = data_to.size();
|
||||
data_to.resize(old_size + size);
|
||||
|
||||
size_t i = 0;
|
||||
for (auto it = set.begin(); it != set.end(); ++it, ++i)
|
||||
data_to[old_size + i] = it->getValue();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Generic implementation, it uses serialized representation as object descriptor.
|
||||
struct AggregateFunctionGroupUniqArrayGenericData
|
||||
{
|
||||
static constexpr size_t INITIAL_SIZE_DEGREE = 3; /// adjustable
|
||||
|
||||
using Set = HashSetWithSavedHashWithStackMemory<StringRef, StringRefHash,
|
||||
INITIAL_SIZE_DEGREE>;
|
||||
|
||||
Set value;
|
||||
};
|
||||
|
||||
template <bool is_plain_column>
|
||||
static void deserializeAndInsertImpl(StringRef str, IColumn & data_to);
|
||||
|
||||
/** Template parameter with true value should be used for columns that store their elements in memory continuously.
|
||||
* For such columns groupUniqArray() can be implemented more efficiently (especially for small numeric arrays).
|
||||
*/
|
||||
template <bool is_plain_column = false, typename LimitNumElems = std::false_type>
|
||||
class AggregateFunctionGroupUniqArrayGeneric
|
||||
: public IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayGenericData,
|
||||
AggregateFunctionGroupUniqArrayGeneric<is_plain_column, LimitNumElems>>
|
||||
{
|
||||
DataTypePtr & input_data_type;
|
||||
|
||||
static constexpr bool limit_num_elems = LimitNumElems::value;
|
||||
UInt64 max_elems;
|
||||
|
||||
using State = AggregateFunctionGroupUniqArrayGenericData;
|
||||
|
||||
public:
|
||||
AggregateFunctionGroupUniqArrayGeneric(const DataTypePtr & input_data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayGenericData, AggregateFunctionGroupUniqArrayGeneric<is_plain_column, LimitNumElems>>({input_data_type_}, parameters_, std::make_shared<DataTypeArray>(input_data_type_))
|
||||
, input_data_type(this->argument_types[0])
|
||||
, max_elems(max_elems_) {}
|
||||
|
||||
String getName() const override { return "groupUniqArray"; }
|
||||
|
||||
bool allocatesMemoryInArena() const override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
auto & set = this->data(place).value;
|
||||
writeVarUInt(set.size(), buf);
|
||||
|
||||
for (const auto & elem : set)
|
||||
{
|
||||
writeStringBinary(elem.getValue(), buf);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
auto & set = this->data(place).value;
|
||||
size_t size;
|
||||
readVarUInt(size, buf);
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
set.insert(readStringBinaryInto(*arena, buf));
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
auto & set = this->data(place).value;
|
||||
if (limit_num_elems && set.size() >= max_elems)
|
||||
return;
|
||||
|
||||
bool inserted;
|
||||
State::Set::LookupResult it;
|
||||
auto key_holder = getKeyHolder<is_plain_column>(*columns[0], row_num, *arena);
|
||||
set.emplace(key_holder, it, inserted);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
auto & cur_set = this->data(place).value;
|
||||
auto & rhs_set = this->data(rhs).value;
|
||||
|
||||
bool inserted;
|
||||
State::Set::LookupResult it;
|
||||
for (auto & rhs_elem : rhs_set)
|
||||
{
|
||||
if (limit_num_elems && cur_set.size() >= max_elems)
|
||||
return;
|
||||
|
||||
// We have to copy the keys to our arena.
|
||||
assert(arena != nullptr);
|
||||
cur_set.emplace(ArenaKeyHolder{rhs_elem.getValue(), *arena}, it, inserted);
|
||||
}
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
ColumnArray & arr_to = assert_cast<ColumnArray &>(to);
|
||||
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
|
||||
IColumn & data_to = arr_to.getData();
|
||||
|
||||
auto & set = this->data(place).value;
|
||||
offsets_to.push_back(offsets_to.back() + set.size());
|
||||
|
||||
for (auto & elem : set)
|
||||
deserializeAndInsert<is_plain_column>(elem.getValue(), data_to);
|
||||
}
|
||||
};
|
||||
|
||||
#undef AGGREGATE_FUNCTION_GROUP_ARRAY_UNIQ_MAX_SIZE
|
||||
|
||||
}
|
@ -1,9 +1,31 @@
|
||||
#include <AggregateFunctions/AggregateFunctionHistogram.h>
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <Common/FieldVisitorConvertToNumber.h>
|
||||
|
||||
#include <Common/NaNUtils.h>
|
||||
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
|
||||
#include <IO/WriteBuffer.h>
|
||||
#include <IO/ReadBuffer.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/VarInt.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
#include <queue>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -16,12 +38,357 @@ namespace ErrorCodes
|
||||
extern const int BAD_ARGUMENTS;
|
||||
extern const int UNSUPPORTED_PARAMETER;
|
||||
extern const int PARAMETER_OUT_OF_BOUND;
|
||||
extern const int TOO_LARGE_ARRAY_SIZE;
|
||||
extern const int INCORRECT_DATA;
|
||||
}
|
||||
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
/** distance compression algorithm implementation
|
||||
* http://jmlr.org/papers/volume11/ben-haim10a/ben-haim10a.pdf
|
||||
*/
|
||||
class AggregateFunctionHistogramData
|
||||
{
|
||||
public:
|
||||
using Mean = Float64;
|
||||
using Weight = Float64;
|
||||
|
||||
constexpr static size_t bins_count_limit = 250;
|
||||
|
||||
private:
|
||||
struct WeightedValue
|
||||
{
|
||||
Mean mean;
|
||||
Weight weight;
|
||||
|
||||
WeightedValue operator+(const WeightedValue & other) const
|
||||
{
|
||||
return {mean + other.weight * (other.mean - mean) / (other.weight + weight), other.weight + weight};
|
||||
}
|
||||
};
|
||||
|
||||
// quantity of stored weighted-values
|
||||
UInt32 size;
|
||||
|
||||
// calculated lower and upper bounds of seen points
|
||||
Mean lower_bound;
|
||||
Mean upper_bound;
|
||||
|
||||
// Weighted values representation of histogram.
|
||||
WeightedValue points[0];
|
||||
|
||||
void sort()
|
||||
{
|
||||
::sort(points, points + size,
|
||||
[](const WeightedValue & first, const WeightedValue & second)
|
||||
{
|
||||
return first.mean < second.mean;
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct PriorityQueueStorage
|
||||
{
|
||||
size_t size = 0;
|
||||
T * data_ptr;
|
||||
|
||||
explicit PriorityQueueStorage(T * value)
|
||||
: data_ptr(value)
|
||||
{
|
||||
}
|
||||
|
||||
void push_back(T val) /// NOLINT
|
||||
{
|
||||
data_ptr[size] = std::move(val);
|
||||
++size;
|
||||
}
|
||||
|
||||
void pop_back() { --size; } /// NOLINT
|
||||
T * begin() { return data_ptr; }
|
||||
T * end() const { return data_ptr + size; }
|
||||
bool empty() const { return size == 0; }
|
||||
T & front() { return *data_ptr; }
|
||||
const T & front() const { return *data_ptr; }
|
||||
|
||||
using value_type = T;
|
||||
using reference = T&;
|
||||
using const_reference = const T&;
|
||||
using size_type = size_t;
|
||||
};
|
||||
|
||||
/**
|
||||
* Repeatedly fuse most close values until max_bins bins left
|
||||
*/
|
||||
void compress(UInt32 max_bins)
|
||||
{
|
||||
sort();
|
||||
auto new_size = size;
|
||||
if (size <= max_bins)
|
||||
return;
|
||||
|
||||
// Maintain doubly-linked list of "active" points
|
||||
// and store neighbour pairs in priority queue by distance
|
||||
UInt32 previous[size + 1];
|
||||
UInt32 next[size + 1];
|
||||
bool active[size + 1];
|
||||
std::fill(active, active + size, true);
|
||||
active[size] = false;
|
||||
|
||||
auto delete_node = [&](UInt32 i)
|
||||
{
|
||||
previous[next[i]] = previous[i];
|
||||
next[previous[i]] = next[i];
|
||||
active[i] = false;
|
||||
};
|
||||
|
||||
for (size_t i = 0; i <= size; ++i)
|
||||
{
|
||||
previous[i] = static_cast<UInt32>(i - 1);
|
||||
next[i] = static_cast<UInt32>(i + 1);
|
||||
}
|
||||
|
||||
next[size] = 0;
|
||||
previous[0] = size;
|
||||
|
||||
using QueueItem = std::pair<Mean, UInt32>;
|
||||
|
||||
QueueItem storage[2 * size - max_bins];
|
||||
|
||||
std::priority_queue<
|
||||
QueueItem,
|
||||
PriorityQueueStorage<QueueItem>,
|
||||
std::greater<>>
|
||||
queue{std::greater<>(),
|
||||
PriorityQueueStorage<QueueItem>(storage)};
|
||||
|
||||
auto quality = [&](UInt32 i) { return points[next[i]].mean - points[i].mean; };
|
||||
|
||||
for (size_t i = 0; i + 1 < size; ++i)
|
||||
queue.push({quality(static_cast<UInt32>(i)), i});
|
||||
|
||||
while (new_size > max_bins && !queue.empty())
|
||||
{
|
||||
auto min_item = queue.top();
|
||||
queue.pop();
|
||||
auto left = min_item.second;
|
||||
auto right = next[left];
|
||||
|
||||
if (!active[left] || !active[right] || quality(left) > min_item.first)
|
||||
continue;
|
||||
|
||||
points[left] = points[left] + points[right];
|
||||
|
||||
delete_node(right);
|
||||
if (active[next[left]])
|
||||
queue.push({quality(left), left});
|
||||
if (active[previous[left]])
|
||||
queue.push({quality(previous[left]), previous[left]});
|
||||
|
||||
--new_size;
|
||||
}
|
||||
|
||||
size_t left = 0;
|
||||
for (size_t right = 0; right < size; ++right)
|
||||
{
|
||||
if (active[right])
|
||||
{
|
||||
points[left] = points[right];
|
||||
++left;
|
||||
}
|
||||
}
|
||||
size = new_size;
|
||||
}
|
||||
|
||||
/***
|
||||
* Delete too close points from histogram.
|
||||
* Assumes that points are sorted.
|
||||
*/
|
||||
void unique()
|
||||
{
|
||||
if (size == 0)
|
||||
return;
|
||||
|
||||
size_t left = 0;
|
||||
|
||||
for (auto right = left + 1; right < size; ++right)
|
||||
{
|
||||
// Fuse points if their text representations differ only in last digit
|
||||
auto min_diff = 10 * (points[left].mean + points[right].mean) * std::numeric_limits<Mean>::epsilon();
|
||||
if (points[left].mean + std::fabs(min_diff) >= points[right].mean)
|
||||
{
|
||||
points[left] = points[left] + points[right];
|
||||
}
|
||||
else
|
||||
{
|
||||
++left;
|
||||
points[left] = points[right];
|
||||
}
|
||||
}
|
||||
size = static_cast<UInt32>(left + 1);
|
||||
}
|
||||
|
||||
public:
|
||||
AggregateFunctionHistogramData()
|
||||
: size(0)
|
||||
, lower_bound(std::numeric_limits<Mean>::max())
|
||||
, upper_bound(std::numeric_limits<Mean>::lowest())
|
||||
{
|
||||
static_assert(offsetof(AggregateFunctionHistogramData, points) == sizeof(AggregateFunctionHistogramData), "points should be last member");
|
||||
}
|
||||
|
||||
static size_t structSize(size_t max_bins)
|
||||
{
|
||||
return sizeof(AggregateFunctionHistogramData) + max_bins * 2 * sizeof(WeightedValue);
|
||||
}
|
||||
|
||||
void insertResultInto(ColumnVector<Mean> & to_lower, ColumnVector<Mean> & to_upper, ColumnVector<Weight> & to_weights, UInt32 max_bins)
|
||||
{
|
||||
compress(max_bins);
|
||||
unique();
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
to_lower.insertValue((i == 0) ? lower_bound : (points[i].mean + points[i - 1].mean) / 2);
|
||||
to_upper.insertValue((i + 1 == size) ? upper_bound : (points[i].mean + points[i + 1].mean) / 2);
|
||||
|
||||
// linear density approximation
|
||||
Weight lower_weight = (i == 0) ? points[i].weight : ((points[i - 1].weight) + points[i].weight * 3) / 4;
|
||||
Weight upper_weight = (i + 1 == size) ? points[i].weight : (points[i + 1].weight + points[i].weight * 3) / 4;
|
||||
to_weights.insertValue((lower_weight + upper_weight) / 2);
|
||||
}
|
||||
}
|
||||
|
||||
void add(Mean value, Weight weight, UInt32 max_bins)
|
||||
{
|
||||
// nans break sort and compression
|
||||
// infs don't fit in bins partition method
|
||||
if (!isFinite(value))
|
||||
throw Exception(ErrorCodes::INCORRECT_DATA, "Invalid value (inf or nan) for aggregation by 'histogram' function");
|
||||
|
||||
points[size] = {value, weight};
|
||||
++size;
|
||||
lower_bound = std::min(lower_bound, value);
|
||||
upper_bound = std::max(upper_bound, value);
|
||||
|
||||
if (size >= max_bins * 2)
|
||||
compress(max_bins);
|
||||
}
|
||||
|
||||
void merge(const AggregateFunctionHistogramData & other, UInt32 max_bins)
|
||||
{
|
||||
lower_bound = std::min(lower_bound, other.lower_bound);
|
||||
upper_bound = std::max(upper_bound, other.upper_bound);
|
||||
for (size_t i = 0; i < other.size; ++i)
|
||||
add(other.points[i].mean, other.points[i].weight, max_bins);
|
||||
}
|
||||
|
||||
void write(WriteBuffer & buf) const
|
||||
{
|
||||
writeBinary(lower_bound, buf);
|
||||
writeBinary(upper_bound, buf);
|
||||
|
||||
writeVarUInt(size, buf);
|
||||
buf.write(reinterpret_cast<const char *>(points), size * sizeof(WeightedValue));
|
||||
}
|
||||
|
||||
void read(ReadBuffer & buf, UInt32 max_bins)
|
||||
{
|
||||
readBinary(lower_bound, buf);
|
||||
readBinary(upper_bound, buf);
|
||||
|
||||
readVarUInt(size, buf);
|
||||
if (size > max_bins * 2)
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too many bins");
|
||||
static constexpr size_t max_size = 1_GiB;
|
||||
if (size > max_size)
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
|
||||
"Too large array size in histogram (maximum: {})", max_size);
|
||||
|
||||
buf.readStrict(reinterpret_cast<char *>(points), size * sizeof(WeightedValue));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class AggregateFunctionHistogram final: public IAggregateFunctionDataHelper<AggregateFunctionHistogramData, AggregateFunctionHistogram<T>>
|
||||
{
|
||||
private:
|
||||
using Data = AggregateFunctionHistogramData;
|
||||
|
||||
const UInt32 max_bins;
|
||||
|
||||
public:
|
||||
AggregateFunctionHistogram(UInt32 max_bins_, const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionHistogramData, AggregateFunctionHistogram<T>>(arguments, params, createResultType())
|
||||
, max_bins(max_bins_)
|
||||
{
|
||||
}
|
||||
|
||||
size_t sizeOfData() const override
|
||||
{
|
||||
return Data::structSize(max_bins);
|
||||
}
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
DataTypes types;
|
||||
auto mean = std::make_shared<DataTypeNumber<Data::Mean>>();
|
||||
auto weight = std::make_shared<DataTypeNumber<Data::Weight>>();
|
||||
|
||||
// lower bound
|
||||
types.emplace_back(mean);
|
||||
// upper bound
|
||||
types.emplace_back(mean);
|
||||
// weight
|
||||
types.emplace_back(weight);
|
||||
|
||||
auto tuple = std::make_shared<DataTypeTuple>(types);
|
||||
return std::make_shared<DataTypeArray>(tuple);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
auto val = assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];
|
||||
this->data(place).add(static_cast<Data::Mean>(val), 1, max_bins);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs), max_bins);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).read(buf, max_bins);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
auto & data = this->data(place);
|
||||
|
||||
auto & to_array = assert_cast<ColumnArray &>(to);
|
||||
ColumnArray::Offsets & offsets_to = to_array.getOffsets();
|
||||
auto & to_tuple = assert_cast<ColumnTuple &>(to_array.getData());
|
||||
|
||||
auto & to_lower = assert_cast<ColumnVector<Data::Mean> &>(to_tuple.getColumn(0));
|
||||
auto & to_upper = assert_cast<ColumnVector<Data::Mean> &>(to_tuple.getColumn(1));
|
||||
auto & to_weights = assert_cast<ColumnVector<Data::Weight> &>(to_tuple.getColumn(2));
|
||||
data.insertResultInto(to_lower, to_upper, to_weights, max_bins);
|
||||
|
||||
offsets_to.push_back(to_tuple.size());
|
||||
}
|
||||
|
||||
String getName() const override { return "histogram"; }
|
||||
};
|
||||
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionHistogram(const std::string & name, const DataTypes & arguments, const Array & params, const Settings *)
|
||||
{
|
||||
if (params.size() != 1)
|
||||
|
@ -1,382 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <base/sort.h>
|
||||
|
||||
#include <Common/NaNUtils.h>
|
||||
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
|
||||
#include <IO/WriteBuffer.h>
|
||||
#include <IO/ReadBuffer.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/VarInt.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
#include <math.h>
|
||||
#include <queue>
|
||||
#include <stddef.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
class Arena;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int TOO_LARGE_ARRAY_SIZE;
|
||||
extern const int INCORRECT_DATA;
|
||||
}
|
||||
|
||||
/**
|
||||
* distance compression algorithm implementation
|
||||
* http://jmlr.org/papers/volume11/ben-haim10a/ben-haim10a.pdf
|
||||
*/
|
||||
class AggregateFunctionHistogramData
|
||||
{
|
||||
public:
|
||||
using Mean = Float64;
|
||||
using Weight = Float64;
|
||||
|
||||
constexpr static size_t bins_count_limit = 250;
|
||||
|
||||
private:
|
||||
struct WeightedValue
|
||||
{
|
||||
Mean mean;
|
||||
Weight weight;
|
||||
|
||||
WeightedValue operator+(const WeightedValue & other) const
|
||||
{
|
||||
return {mean + other.weight * (other.mean - mean) / (other.weight + weight), other.weight + weight};
|
||||
}
|
||||
};
|
||||
|
||||
// quantity of stored weighted-values
|
||||
UInt32 size;
|
||||
|
||||
// calculated lower and upper bounds of seen points
|
||||
Mean lower_bound;
|
||||
Mean upper_bound;
|
||||
|
||||
// Weighted values representation of histogram.
|
||||
WeightedValue points[0];
|
||||
|
||||
void sort()
|
||||
{
|
||||
::sort(points, points + size,
|
||||
[](const WeightedValue & first, const WeightedValue & second)
|
||||
{
|
||||
return first.mean < second.mean;
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct PriorityQueueStorage
|
||||
{
|
||||
size_t size = 0;
|
||||
T * data_ptr;
|
||||
|
||||
explicit PriorityQueueStorage(T * value)
|
||||
: data_ptr(value)
|
||||
{
|
||||
}
|
||||
|
||||
void push_back(T val) /// NOLINT
|
||||
{
|
||||
data_ptr[size] = std::move(val);
|
||||
++size;
|
||||
}
|
||||
|
||||
void pop_back() { --size; } /// NOLINT
|
||||
T * begin() { return data_ptr; }
|
||||
T * end() const { return data_ptr + size; }
|
||||
bool empty() const { return size == 0; }
|
||||
T & front() { return *data_ptr; }
|
||||
const T & front() const { return *data_ptr; }
|
||||
|
||||
using value_type = T;
|
||||
using reference = T&;
|
||||
using const_reference = const T&;
|
||||
using size_type = size_t;
|
||||
};
|
||||
|
||||
/**
|
||||
* Repeatedly fuse most close values until max_bins bins left
|
||||
*/
|
||||
void compress(UInt32 max_bins)
|
||||
{
|
||||
sort();
|
||||
auto new_size = size;
|
||||
if (size <= max_bins)
|
||||
return;
|
||||
|
||||
// Maintain doubly-linked list of "active" points
|
||||
// and store neighbour pairs in priority queue by distance
|
||||
UInt32 previous[size + 1];
|
||||
UInt32 next[size + 1];
|
||||
bool active[size + 1];
|
||||
std::fill(active, active + size, true);
|
||||
active[size] = false;
|
||||
|
||||
auto delete_node = [&](UInt32 i)
|
||||
{
|
||||
previous[next[i]] = previous[i];
|
||||
next[previous[i]] = next[i];
|
||||
active[i] = false;
|
||||
};
|
||||
|
||||
for (size_t i = 0; i <= size; ++i)
|
||||
{
|
||||
previous[i] = static_cast<UInt32>(i - 1);
|
||||
next[i] = static_cast<UInt32>(i + 1);
|
||||
}
|
||||
|
||||
next[size] = 0;
|
||||
previous[0] = size;
|
||||
|
||||
using QueueItem = std::pair<Mean, UInt32>;
|
||||
|
||||
QueueItem storage[2 * size - max_bins];
|
||||
|
||||
std::priority_queue<
|
||||
QueueItem,
|
||||
PriorityQueueStorage<QueueItem>,
|
||||
std::greater<QueueItem>>
|
||||
queue{std::greater<QueueItem>(),
|
||||
PriorityQueueStorage<QueueItem>(storage)};
|
||||
|
||||
auto quality = [&](UInt32 i) { return points[next[i]].mean - points[i].mean; };
|
||||
|
||||
for (size_t i = 0; i + 1 < size; ++i)
|
||||
queue.push({quality(static_cast<UInt32>(i)), i});
|
||||
|
||||
while (new_size > max_bins && !queue.empty())
|
||||
{
|
||||
auto min_item = queue.top();
|
||||
queue.pop();
|
||||
auto left = min_item.second;
|
||||
auto right = next[left];
|
||||
|
||||
if (!active[left] || !active[right] || quality(left) > min_item.first)
|
||||
continue;
|
||||
|
||||
points[left] = points[left] + points[right];
|
||||
|
||||
delete_node(right);
|
||||
if (active[next[left]])
|
||||
queue.push({quality(left), left});
|
||||
if (active[previous[left]])
|
||||
queue.push({quality(previous[left]), previous[left]});
|
||||
|
||||
--new_size;
|
||||
}
|
||||
|
||||
size_t left = 0;
|
||||
for (size_t right = 0; right < size; ++right)
|
||||
{
|
||||
if (active[right])
|
||||
{
|
||||
points[left] = points[right];
|
||||
++left;
|
||||
}
|
||||
}
|
||||
size = new_size;
|
||||
}
|
||||
|
||||
/***
|
||||
* Delete too close points from histogram.
|
||||
* Assumes that points are sorted.
|
||||
*/
|
||||
void unique()
|
||||
{
|
||||
if (size == 0)
|
||||
return;
|
||||
|
||||
size_t left = 0;
|
||||
|
||||
for (auto right = left + 1; right < size; ++right)
|
||||
{
|
||||
// Fuse points if their text representations differ only in last digit
|
||||
auto min_diff = 10 * (points[left].mean + points[right].mean) * std::numeric_limits<Mean>::epsilon();
|
||||
if (points[left].mean + std::fabs(min_diff) >= points[right].mean)
|
||||
{
|
||||
points[left] = points[left] + points[right];
|
||||
}
|
||||
else
|
||||
{
|
||||
++left;
|
||||
points[left] = points[right];
|
||||
}
|
||||
}
|
||||
size = static_cast<UInt32>(left + 1);
|
||||
}
|
||||
|
||||
public:
|
||||
AggregateFunctionHistogramData()
|
||||
: size(0)
|
||||
, lower_bound(std::numeric_limits<Mean>::max())
|
||||
, upper_bound(std::numeric_limits<Mean>::lowest())
|
||||
{
|
||||
static_assert(offsetof(AggregateFunctionHistogramData, points) == sizeof(AggregateFunctionHistogramData), "points should be last member");
|
||||
}
|
||||
|
||||
static size_t structSize(size_t max_bins)
|
||||
{
|
||||
return sizeof(AggregateFunctionHistogramData) + max_bins * 2 * sizeof(WeightedValue);
|
||||
}
|
||||
|
||||
void insertResultInto(ColumnVector<Mean> & to_lower, ColumnVector<Mean> & to_upper, ColumnVector<Weight> & to_weights, UInt32 max_bins)
|
||||
{
|
||||
compress(max_bins);
|
||||
unique();
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
to_lower.insertValue((i == 0) ? lower_bound : (points[i].mean + points[i - 1].mean) / 2);
|
||||
to_upper.insertValue((i + 1 == size) ? upper_bound : (points[i].mean + points[i + 1].mean) / 2);
|
||||
|
||||
// linear density approximation
|
||||
Weight lower_weight = (i == 0) ? points[i].weight : ((points[i - 1].weight) + points[i].weight * 3) / 4;
|
||||
Weight upper_weight = (i + 1 == size) ? points[i].weight : (points[i + 1].weight + points[i].weight * 3) / 4;
|
||||
to_weights.insertValue((lower_weight + upper_weight) / 2);
|
||||
}
|
||||
}
|
||||
|
||||
void add(Mean value, Weight weight, UInt32 max_bins)
|
||||
{
|
||||
// nans break sort and compression
|
||||
// infs don't fit in bins partition method
|
||||
if (!isFinite(value))
|
||||
throw Exception(ErrorCodes::INCORRECT_DATA, "Invalid value (inf or nan) for aggregation by 'histogram' function");
|
||||
|
||||
points[size] = {value, weight};
|
||||
++size;
|
||||
lower_bound = std::min(lower_bound, value);
|
||||
upper_bound = std::max(upper_bound, value);
|
||||
|
||||
if (size >= max_bins * 2)
|
||||
compress(max_bins);
|
||||
}
|
||||
|
||||
void merge(const AggregateFunctionHistogramData & other, UInt32 max_bins)
|
||||
{
|
||||
lower_bound = std::min(lower_bound, other.lower_bound);
|
||||
upper_bound = std::max(upper_bound, other.upper_bound);
|
||||
for (size_t i = 0; i < other.size; ++i)
|
||||
add(other.points[i].mean, other.points[i].weight, max_bins);
|
||||
}
|
||||
|
||||
void write(WriteBuffer & buf) const
|
||||
{
|
||||
writeBinary(lower_bound, buf);
|
||||
writeBinary(upper_bound, buf);
|
||||
|
||||
writeVarUInt(size, buf);
|
||||
buf.write(reinterpret_cast<const char *>(points), size * sizeof(WeightedValue));
|
||||
}
|
||||
|
||||
void read(ReadBuffer & buf, UInt32 max_bins)
|
||||
{
|
||||
readBinary(lower_bound, buf);
|
||||
readBinary(upper_bound, buf);
|
||||
|
||||
readVarUInt(size, buf);
|
||||
if (size > max_bins * 2)
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too many bins");
|
||||
static constexpr size_t max_size = 1_GiB;
|
||||
if (size > max_size)
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
|
||||
"Too large array size in histogram (maximum: {})", max_size);
|
||||
|
||||
buf.readStrict(reinterpret_cast<char *>(points), size * sizeof(WeightedValue));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class AggregateFunctionHistogram final: public IAggregateFunctionDataHelper<AggregateFunctionHistogramData, AggregateFunctionHistogram<T>>
|
||||
{
|
||||
private:
|
||||
using Data = AggregateFunctionHistogramData;
|
||||
|
||||
const UInt32 max_bins;
|
||||
|
||||
public:
|
||||
AggregateFunctionHistogram(UInt32 max_bins_, const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionHistogramData, AggregateFunctionHistogram<T>>(arguments, params, createResultType())
|
||||
, max_bins(max_bins_)
|
||||
{
|
||||
}
|
||||
|
||||
size_t sizeOfData() const override
|
||||
{
|
||||
return Data::structSize(max_bins);
|
||||
}
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
DataTypes types;
|
||||
auto mean = std::make_shared<DataTypeNumber<Data::Mean>>();
|
||||
auto weight = std::make_shared<DataTypeNumber<Data::Weight>>();
|
||||
|
||||
// lower bound
|
||||
types.emplace_back(mean);
|
||||
// upper bound
|
||||
types.emplace_back(mean);
|
||||
// weight
|
||||
types.emplace_back(weight);
|
||||
|
||||
auto tuple = std::make_shared<DataTypeTuple>(types);
|
||||
return std::make_shared<DataTypeArray>(tuple);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
auto val = assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];
|
||||
this->data(place).add(static_cast<Data::Mean>(val), 1, max_bins);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs), max_bins);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).read(buf, max_bins);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
auto & data = this->data(place);
|
||||
|
||||
auto & to_array = assert_cast<ColumnArray &>(to);
|
||||
ColumnArray::Offsets & offsets_to = to_array.getOffsets();
|
||||
auto & to_tuple = assert_cast<ColumnTuple &>(to_array.getData());
|
||||
|
||||
auto & to_lower = assert_cast<ColumnVector<Data::Mean> &>(to_tuple.getColumn(0));
|
||||
auto & to_upper = assert_cast<ColumnVector<Data::Mean> &>(to_tuple.getColumn(1));
|
||||
auto & to_weights = assert_cast<ColumnVector<Data::Weight> &>(to_tuple.getColumn(2));
|
||||
data.insertResultInto(to_lower, to_upper, to_weights, max_bins);
|
||||
|
||||
offsets_to.push_back(to_tuple.size());
|
||||
}
|
||||
|
||||
String getName() const override { return "histogram"; }
|
||||
};
|
||||
|
||||
}
|
@ -1,57 +1,272 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionIntervalLengthSum.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <DataTypes/DataTypeDate.h>
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
|
||||
#include <base/range.h>
|
||||
#include <unordered_set>
|
||||
|
||||
#include <AggregateFunctions/Combinators/AggregateFunctionNull.h>
|
||||
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
|
||||
#include <Common/assert_cast.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int TOO_LARGE_ARRAY_SIZE;
|
||||
}
|
||||
|
||||
struct Settings;
|
||||
|
||||
namespace
|
||||
{
|
||||
template <template <typename> class Data>
|
||||
AggregateFunctionPtr
|
||||
createAggregateFunctionIntervalLengthSum(const std::string & name, const DataTypes & arguments, const Array &, const Settings *)
|
||||
|
||||
/** Calculate total length of intervals without intersections. Each interval is the pair of numbers [begin, end];
|
||||
* Returns UInt64 for integral types (UInt/Int*, Date/DateTime) and returns Float64 for Float*.
|
||||
*
|
||||
* Implementation simply stores intervals sorted by beginning and sums lengths at final.
|
||||
*/
|
||||
template <typename T>
|
||||
struct AggregateFunctionIntervalLengthSumData
|
||||
{
|
||||
constexpr static size_t MAX_ARRAY_SIZE = 0xFFFFFF;
|
||||
|
||||
using Segment = std::pair<T, T>;
|
||||
using Segments = PODArrayWithStackMemory<Segment, 64>;
|
||||
|
||||
bool sorted = false;
|
||||
|
||||
Segments segments;
|
||||
|
||||
void add(T begin, T end)
|
||||
{
|
||||
if (arguments.size() != 2)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
|
||||
"Aggregate function {} requires two timestamps argument.", name);
|
||||
/// Reversed intervals are counted by absolute value of their length.
|
||||
if (unlikely(end < begin))
|
||||
std::swap(begin, end);
|
||||
else if (unlikely(begin == end))
|
||||
return;
|
||||
|
||||
auto args = {arguments[0].get(), arguments[1].get()};
|
||||
if (sorted && !segments.empty())
|
||||
sorted = segments.back().first <= begin;
|
||||
segments.emplace_back(begin, end);
|
||||
}
|
||||
|
||||
if (WhichDataType{args.begin()[0]}.idx != WhichDataType{args.begin()[1]}.idx)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Illegal types {} and {} of arguments "
|
||||
"of aggregate function {}, both arguments should have same data type",
|
||||
args.begin()[0]->getName(), args.begin()[1]->getName(), name);
|
||||
void merge(const AggregateFunctionIntervalLengthSumData & other)
|
||||
{
|
||||
if (other.segments.empty())
|
||||
return;
|
||||
|
||||
for (const auto & arg : args)
|
||||
const auto size = segments.size();
|
||||
|
||||
segments.insert(std::begin(other.segments), std::end(other.segments));
|
||||
|
||||
/// either sort whole container or do so partially merging ranges afterwards
|
||||
if (!sorted && !other.sorted)
|
||||
{
|
||||
if (!isNativeNumber(arg) && !isDate(arg) && !isDateTime(arg))
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Illegal type {} of argument of aggregate function {}, must "
|
||||
"be native integral type, Date/DateTime or Float", arg->getName(), name);
|
||||
::sort(std::begin(segments), std::end(segments));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto begin = std::begin(segments);
|
||||
const auto middle = std::next(begin, size);
|
||||
const auto end = std::end(segments);
|
||||
|
||||
if (!sorted)
|
||||
::sort(begin, middle);
|
||||
|
||||
if (!other.sorted)
|
||||
::sort(middle, end);
|
||||
|
||||
std::inplace_merge(begin, middle, end);
|
||||
}
|
||||
|
||||
AggregateFunctionPtr res(createWithBasicNumberOrDateOrDateTime<AggregateFunctionIntervalLengthSum, Data>(*arguments[0], arguments));
|
||||
sorted = true;
|
||||
}
|
||||
|
||||
if (res)
|
||||
return res;
|
||||
void sort()
|
||||
{
|
||||
if (sorted)
|
||||
return;
|
||||
|
||||
::sort(std::begin(segments), std::end(segments));
|
||||
sorted = true;
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
writeBinary(sorted, buf);
|
||||
writeBinary(segments.size(), buf);
|
||||
|
||||
for (const auto & time_gap : segments)
|
||||
{
|
||||
writeBinary(time_gap.first, buf);
|
||||
writeBinary(time_gap.second, buf);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
readBinary(sorted, buf);
|
||||
|
||||
size_t size;
|
||||
readBinary(size, buf);
|
||||
|
||||
if (unlikely(size > MAX_ARRAY_SIZE))
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large array size (maximum: {})", MAX_ARRAY_SIZE);
|
||||
|
||||
segments.clear();
|
||||
segments.reserve(size);
|
||||
|
||||
Segment segment;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
readBinary(segment.first, buf);
|
||||
readBinary(segment.second, buf);
|
||||
segments.emplace_back(segment);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Data>
|
||||
class AggregateFunctionIntervalLengthSum final : public IAggregateFunctionDataHelper<Data, AggregateFunctionIntervalLengthSum<T, Data>>
|
||||
{
|
||||
private:
|
||||
static auto NO_SANITIZE_UNDEFINED length(typename Data::Segment segment)
|
||||
{
|
||||
return segment.second - segment.first;
|
||||
}
|
||||
|
||||
template <typename TResult>
|
||||
TResult getIntervalLengthSum(Data & data) const
|
||||
{
|
||||
if (data.segments.empty())
|
||||
return 0;
|
||||
|
||||
data.sort();
|
||||
|
||||
TResult res = 0;
|
||||
|
||||
typename Data::Segment curr_segment = data.segments[0];
|
||||
|
||||
for (size_t i = 1, size = data.segments.size(); i < size; ++i)
|
||||
{
|
||||
const typename Data::Segment & next_segment = data.segments[i];
|
||||
|
||||
/// Check if current interval intersects with next one then add length, otherwise advance interval end.
|
||||
if (curr_segment.second < next_segment.first)
|
||||
{
|
||||
res += length(curr_segment);
|
||||
curr_segment = next_segment;
|
||||
}
|
||||
else if (next_segment.second > curr_segment.second)
|
||||
{
|
||||
curr_segment.second = next_segment.second;
|
||||
}
|
||||
}
|
||||
res += length(curr_segment);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
public:
|
||||
String getName() const override { return "intervalLengthSum"; }
|
||||
|
||||
explicit AggregateFunctionIntervalLengthSum(const DataTypes & arguments)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionIntervalLengthSum<T, Data>>(arguments, {}, createResultType())
|
||||
{
|
||||
}
|
||||
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
if constexpr (std::is_floating_point_v<T>)
|
||||
return std::make_shared<DataTypeFloat64>();
|
||||
return std::make_shared<DataTypeUInt64>();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
AggregateFunctionPtr getOwnNullAdapter(
|
||||
const AggregateFunctionPtr & nested_function,
|
||||
const DataTypes & arguments,
|
||||
const Array & params,
|
||||
const AggregateFunctionProperties & /*properties*/) const override
|
||||
{
|
||||
return std::make_shared<AggregateFunctionNullVariadic<false, false>>(nested_function, arguments, params);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override
|
||||
{
|
||||
auto begin = assert_cast<const ColumnVector<T> *>(columns[0])->getData()[row_num];
|
||||
auto end = assert_cast<const ColumnVector<T> *>(columns[1])->getData()[row_num];
|
||||
this->data(place).add(begin, end);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
if constexpr (std::is_floating_point_v<T>)
|
||||
assert_cast<ColumnFloat64 &>(to).getData().push_back(getIntervalLengthSum<Float64>(this->data(place)));
|
||||
else
|
||||
assert_cast<ColumnUInt64 &>(to).getData().push_back(getIntervalLengthSum<UInt64>(this->data(place)));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <template <typename> class Data>
|
||||
AggregateFunctionPtr
|
||||
createAggregateFunctionIntervalLengthSum(const std::string & name, const DataTypes & arguments, const Array &, const Settings *)
|
||||
{
|
||||
if (arguments.size() != 2)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
|
||||
"Aggregate function {} requires two timestamps argument.", name);
|
||||
|
||||
auto args = {arguments[0].get(), arguments[1].get()};
|
||||
|
||||
if (WhichDataType{args.begin()[0]}.idx != WhichDataType{args.begin()[1]}.idx)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Illegal type {} of argument of aggregate function {}, must "
|
||||
"be native integral type, Date/DateTime or Float", arguments.front().get()->getName(), name);
|
||||
"Illegal types {} and {} of arguments "
|
||||
"of aggregate function {}, both arguments should have same data type",
|
||||
args.begin()[0]->getName(), args.begin()[1]->getName(), name);
|
||||
|
||||
for (const auto & arg : args)
|
||||
{
|
||||
if (!isNativeNumber(arg) && !isDate(arg) && !isDateTime(arg))
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Illegal type {} of argument of aggregate function {}, must "
|
||||
"be native integral type, Date/DateTime or Float", arg->getName(), name);
|
||||
}
|
||||
|
||||
AggregateFunctionPtr res(createWithBasicNumberOrDateOrDateTime<AggregateFunctionIntervalLengthSum, Data>(*arguments[0], arguments));
|
||||
|
||||
if (res)
|
||||
return res;
|
||||
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Illegal type {} of argument of aggregate function {}, must "
|
||||
"be native integral type, Date/DateTime or Float", arguments.front().get()->getName(), name);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,232 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include <AggregateFunctions/Combinators/AggregateFunctionNull.h>
|
||||
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
|
||||
#include <Common/assert_cast.h>
|
||||
#include <base/arithmeticOverflow.h>
|
||||
#include <base/sort.h>
|
||||
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int TOO_LARGE_ARRAY_SIZE;
|
||||
}
|
||||
|
||||
/** Calculate total length of intervals without intersections. Each interval is the pair of numbers [begin, end];
|
||||
* Returns UInt64 for integral types (UInt/Int*, Date/DateTime) and returns Float64 for Float*.
|
||||
*
|
||||
* Implementation simply stores intervals sorted by beginning and sums lengths at final.
|
||||
*/
|
||||
template <typename T>
|
||||
struct AggregateFunctionIntervalLengthSumData
|
||||
{
|
||||
constexpr static size_t MAX_ARRAY_SIZE = 0xFFFFFF;
|
||||
|
||||
using Segment = std::pair<T, T>;
|
||||
using Segments = PODArrayWithStackMemory<Segment, 64>;
|
||||
|
||||
bool sorted = false;
|
||||
|
||||
Segments segments;
|
||||
|
||||
void add(T begin, T end)
|
||||
{
|
||||
/// Reversed intervals are counted by absolute value of their length.
|
||||
if (unlikely(end < begin))
|
||||
std::swap(begin, end);
|
||||
else if (unlikely(begin == end))
|
||||
return;
|
||||
|
||||
if (sorted && !segments.empty())
|
||||
sorted = segments.back().first <= begin;
|
||||
segments.emplace_back(begin, end);
|
||||
}
|
||||
|
||||
void merge(const AggregateFunctionIntervalLengthSumData & other)
|
||||
{
|
||||
if (other.segments.empty())
|
||||
return;
|
||||
|
||||
const auto size = segments.size();
|
||||
|
||||
segments.insert(std::begin(other.segments), std::end(other.segments));
|
||||
|
||||
/// either sort whole container or do so partially merging ranges afterwards
|
||||
if (!sorted && !other.sorted)
|
||||
{
|
||||
::sort(std::begin(segments), std::end(segments));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto begin = std::begin(segments);
|
||||
const auto middle = std::next(begin, size);
|
||||
const auto end = std::end(segments);
|
||||
|
||||
if (!sorted)
|
||||
::sort(begin, middle);
|
||||
|
||||
if (!other.sorted)
|
||||
::sort(middle, end);
|
||||
|
||||
std::inplace_merge(begin, middle, end);
|
||||
}
|
||||
|
||||
sorted = true;
|
||||
}
|
||||
|
||||
void sort()
|
||||
{
|
||||
if (sorted)
|
||||
return;
|
||||
|
||||
::sort(std::begin(segments), std::end(segments));
|
||||
sorted = true;
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
writeBinary(sorted, buf);
|
||||
writeBinary(segments.size(), buf);
|
||||
|
||||
for (const auto & time_gap : segments)
|
||||
{
|
||||
writeBinary(time_gap.first, buf);
|
||||
writeBinary(time_gap.second, buf);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
readBinary(sorted, buf);
|
||||
|
||||
size_t size;
|
||||
readBinary(size, buf);
|
||||
|
||||
if (unlikely(size > MAX_ARRAY_SIZE))
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large array size (maximum: {})", MAX_ARRAY_SIZE);
|
||||
|
||||
segments.clear();
|
||||
segments.reserve(size);
|
||||
|
||||
Segment segment;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
readBinary(segment.first, buf);
|
||||
readBinary(segment.second, buf);
|
||||
segments.emplace_back(segment);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Data>
|
||||
class AggregateFunctionIntervalLengthSum final : public IAggregateFunctionDataHelper<Data, AggregateFunctionIntervalLengthSum<T, Data>>
|
||||
{
|
||||
private:
|
||||
static auto NO_SANITIZE_UNDEFINED length(typename Data::Segment segment)
|
||||
{
|
||||
return segment.second - segment.first;
|
||||
}
|
||||
|
||||
template <typename TResult>
|
||||
TResult getIntervalLengthSum(Data & data) const
|
||||
{
|
||||
if (data.segments.empty())
|
||||
return 0;
|
||||
|
||||
data.sort();
|
||||
|
||||
TResult res = 0;
|
||||
|
||||
typename Data::Segment curr_segment = data.segments[0];
|
||||
|
||||
for (size_t i = 1, size = data.segments.size(); i < size; ++i)
|
||||
{
|
||||
const typename Data::Segment & next_segment = data.segments[i];
|
||||
|
||||
/// Check if current interval intersects with next one then add length, otherwise advance interval end.
|
||||
if (curr_segment.second < next_segment.first)
|
||||
{
|
||||
res += length(curr_segment);
|
||||
curr_segment = next_segment;
|
||||
}
|
||||
else if (next_segment.second > curr_segment.second)
|
||||
{
|
||||
curr_segment.second = next_segment.second;
|
||||
}
|
||||
}
|
||||
res += length(curr_segment);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
public:
|
||||
String getName() const override { return "intervalLengthSum"; }
|
||||
|
||||
explicit AggregateFunctionIntervalLengthSum(const DataTypes & arguments)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionIntervalLengthSum<T, Data>>(arguments, {}, createResultType())
|
||||
{
|
||||
}
|
||||
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
if constexpr (std::is_floating_point_v<T>)
|
||||
return std::make_shared<DataTypeFloat64>();
|
||||
return std::make_shared<DataTypeUInt64>();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
AggregateFunctionPtr getOwnNullAdapter(
|
||||
const AggregateFunctionPtr & nested_function,
|
||||
const DataTypes & arguments,
|
||||
const Array & params,
|
||||
const AggregateFunctionProperties & /*properties*/) const override
|
||||
{
|
||||
return std::make_shared<AggregateFunctionNullVariadic<false, false>>(nested_function, arguments, params);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override
|
||||
{
|
||||
auto begin = assert_cast<const ColumnVector<T> *>(columns[0])->getData()[row_num];
|
||||
auto end = assert_cast<const ColumnVector<T> *>(columns[1])->getData()[row_num];
|
||||
this->data(place).add(begin, end);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
if constexpr (std::is_floating_point_v<T>)
|
||||
assert_cast<ColumnFloat64 &>(to).getData().push_back(getIntervalLengthSum<Float64>(this->data(place)));
|
||||
else
|
||||
assert_cast<ColumnUInt64 &>(to).getData().push_back(getIntervalLengthSum<UInt64>(this->data(place)));
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -1,19 +1,339 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionKolmogorovSmirnovTest.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/StatCommon.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <Common/PODArray_fwd.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NOT_IMPLEMENTED;
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
struct Settings;
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
struct KolmogorovSmirnov : public StatisticalSample<Float64, Float64>
|
||||
{
|
||||
enum class Alternative
|
||||
{
|
||||
TwoSided,
|
||||
Less,
|
||||
Greater
|
||||
};
|
||||
|
||||
std::pair<Float64, Float64> getResult(Alternative alternative, String method)
|
||||
{
|
||||
::sort(x.begin(), x.end());
|
||||
::sort(y.begin(), y.end());
|
||||
|
||||
Float64 max_s = std::numeric_limits<Float64>::min();
|
||||
Float64 min_s = std::numeric_limits<Float64>::max();
|
||||
Float64 now_s = 0;
|
||||
UInt64 pos_x = 0;
|
||||
UInt64 pos_y = 0;
|
||||
UInt64 pos_tmp;
|
||||
UInt64 n1 = x.size();
|
||||
UInt64 n2 = y.size();
|
||||
|
||||
const Float64 n1_d = 1. / n1;
|
||||
const Float64 n2_d = 1. / n2;
|
||||
const Float64 tol = 1e-7;
|
||||
|
||||
// reference: https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test
|
||||
while (pos_x < x.size() && pos_y < y.size())
|
||||
{
|
||||
if (likely(fabs(x[pos_x] - y[pos_y]) >= tol))
|
||||
{
|
||||
if (x[pos_x] < y[pos_y])
|
||||
{
|
||||
now_s += n1_d;
|
||||
++pos_x;
|
||||
}
|
||||
else
|
||||
{
|
||||
now_s -= n2_d;
|
||||
++pos_y;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
pos_tmp = pos_x + 1;
|
||||
while (pos_tmp < x.size() && unlikely(fabs(x[pos_tmp] - x[pos_x]) <= tol))
|
||||
pos_tmp++;
|
||||
now_s += n1_d * (pos_tmp - pos_x);
|
||||
pos_x = pos_tmp;
|
||||
pos_tmp = pos_y + 1;
|
||||
while (pos_tmp < y.size() && unlikely(fabs(y[pos_tmp] - y[pos_y]) <= tol))
|
||||
pos_tmp++;
|
||||
now_s -= n2_d * (pos_tmp - pos_y);
|
||||
pos_y = pos_tmp;
|
||||
}
|
||||
max_s = std::max(max_s, now_s);
|
||||
min_s = std::min(min_s, now_s);
|
||||
}
|
||||
now_s += n1_d * (x.size() - pos_x) - n2_d * (y.size() - pos_y);
|
||||
min_s = std::min(min_s, now_s);
|
||||
max_s = std::max(max_s, now_s);
|
||||
|
||||
Float64 d = 0;
|
||||
if (alternative == Alternative::TwoSided)
|
||||
d = std::max(std::abs(max_s), std::abs(min_s));
|
||||
else if (alternative == Alternative::Less)
|
||||
d = -min_s;
|
||||
else if (alternative == Alternative::Greater)
|
||||
d = max_s;
|
||||
|
||||
UInt64 g = std::__gcd(n1, n2);
|
||||
UInt64 nx_g = n1 / g;
|
||||
UInt64 ny_g = n2 / g;
|
||||
|
||||
if (method == "auto")
|
||||
method = std::max(n1, n2) <= 10000 ? "exact" : "asymptotic";
|
||||
else if (method == "exact" && nx_g >= std::numeric_limits<Int32>::max() / ny_g)
|
||||
method = "asymptotic";
|
||||
|
||||
Float64 p_value = std::numeric_limits<Float64>::infinity();
|
||||
|
||||
if (method == "exact")
|
||||
{
|
||||
/* reference:
|
||||
* Gunar Schröer and Dietrich Trenkler
|
||||
* Exact and Randomization Distributions of Kolmogorov-Smirnov, Tests for Two or Three Samples
|
||||
*
|
||||
* and
|
||||
*
|
||||
* Thomas Viehmann
|
||||
* Numerically more stable computation of the p-values for the two-sample Kolmogorov-Smirnov test
|
||||
*/
|
||||
if (n2 > n1)
|
||||
std::swap(n1, n2);
|
||||
|
||||
const Float64 f_n1 = static_cast<Float64>(n1);
|
||||
const Float64 f_n2 = static_cast<Float64>(n2);
|
||||
const Float64 k_d = (0.5 + floor(d * f_n2 * f_n1 - tol)) / (f_n2 * f_n1);
|
||||
PaddedPODArray<Float64> c(n1 + 1);
|
||||
|
||||
auto check = alternative == Alternative::TwoSided ?
|
||||
[](const Float64 & q, const Float64 & r, const Float64 & s) { return fabs(r - s) >= q; }
|
||||
: [](const Float64 & q, const Float64 & r, const Float64 & s) { return r - s >= q; };
|
||||
|
||||
c[0] = 0;
|
||||
for (UInt64 j = 1; j <= n1; j++)
|
||||
if (check(k_d, 0., j / f_n1))
|
||||
c[j] = 1.;
|
||||
else
|
||||
c[j] = c[j - 1];
|
||||
|
||||
for (UInt64 i = 1; i <= n2; i++)
|
||||
{
|
||||
if (check(k_d, i / f_n2, 0.))
|
||||
c[0] = 1.;
|
||||
for (UInt64 j = 1; j <= n1; j++)
|
||||
if (check(k_d, i / f_n2, j / f_n1))
|
||||
c[j] = 1.;
|
||||
else
|
||||
{
|
||||
Float64 v = i / static_cast<Float64>(i + j);
|
||||
Float64 w = j / static_cast<Float64>(i + j);
|
||||
c[j] = v * c[j] + w * c[j - 1];
|
||||
}
|
||||
}
|
||||
p_value = c[n1];
|
||||
}
|
||||
else if (method == "asymp" || method == "asymptotic")
|
||||
{
|
||||
Float64 n = std::min(n1, n2);
|
||||
Float64 m = std::max(n1, n2);
|
||||
Float64 p = sqrt((n * m) / (n + m)) * d;
|
||||
|
||||
if (alternative == Alternative::TwoSided)
|
||||
{
|
||||
/* reference:
|
||||
* J.DURBIN
|
||||
* Distribution theory for tests based on the sample distribution function
|
||||
*/
|
||||
Float64 new_val, old_val, s, w, z;
|
||||
UInt64 k_max = static_cast<UInt64>(sqrt(2 - log(tol)));
|
||||
|
||||
if (p < 1)
|
||||
{
|
||||
z = - (M_PI_2 * M_PI_4) / (p * p);
|
||||
w = log(p);
|
||||
s = 0;
|
||||
for (UInt64 k = 1; k < k_max; k += 2)
|
||||
s += exp(k * k * z - w);
|
||||
p = s / 0.398942280401432677939946059934;
|
||||
}
|
||||
else
|
||||
{
|
||||
z = -2 * p * p;
|
||||
s = -1;
|
||||
UInt64 k = 1;
|
||||
old_val = 0;
|
||||
new_val = 1;
|
||||
while (fabs(old_val - new_val) > tol)
|
||||
{
|
||||
old_val = new_val;
|
||||
new_val += 2 * s * exp(z * k * k);
|
||||
s *= -1;
|
||||
k++;
|
||||
}
|
||||
p = new_val;
|
||||
}
|
||||
p_value = 1 - p;
|
||||
}
|
||||
else
|
||||
{
|
||||
/* reference:
|
||||
* J. L. HODGES, Jr
|
||||
* The significance probability of the Smirnov two-sample test
|
||||
*/
|
||||
|
||||
// Use Hodges' suggested approximation Eqn 5.3
|
||||
// Requires m to be the larger of (n1, n2)
|
||||
Float64 expt = -2 * p * p - 2 * p * (m + 2 * n) / sqrt(m * n * (m + n)) / 3.0;
|
||||
p_value = exp(expt);
|
||||
}
|
||||
}
|
||||
return {d, p_value};
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class AggregateFunctionKolmogorovSmirnov final:
|
||||
public IAggregateFunctionDataHelper<KolmogorovSmirnov, AggregateFunctionKolmogorovSmirnov>
|
||||
{
|
||||
private:
|
||||
using Alternative = typename KolmogorovSmirnov::Alternative;
|
||||
Alternative alternative = Alternative::TwoSided;
|
||||
String method = "auto";
|
||||
|
||||
public:
|
||||
explicit AggregateFunctionKolmogorovSmirnov(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<KolmogorovSmirnov, AggregateFunctionKolmogorovSmirnov> ({arguments}, {}, createResultType())
|
||||
{
|
||||
if (params.size() > 2)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} require two parameter or less", getName());
|
||||
|
||||
if (params.empty())
|
||||
return;
|
||||
|
||||
if (params[0].getType() != Field::Types::String)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} require first parameter to be a String", getName());
|
||||
|
||||
const auto & param = params[0].get<String>();
|
||||
if (param == "two-sided")
|
||||
alternative = Alternative::TwoSided;
|
||||
else if (param == "less")
|
||||
alternative = Alternative::Less;
|
||||
else if (param == "greater")
|
||||
alternative = Alternative::Greater;
|
||||
else
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown parameter in aggregate function {}. "
|
||||
"It must be one of: 'two-sided', 'less', 'greater'", getName());
|
||||
|
||||
if (params.size() != 2)
|
||||
return;
|
||||
|
||||
if (params[1].getType() != Field::Types::String)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} require second parameter to be a String", getName());
|
||||
|
||||
method = params[1].get<String>();
|
||||
if (method != "auto" && method != "exact" && method != "asymp" && method != "asymptotic")
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown method in aggregate function {}. "
|
||||
"It must be one of: 'auto', 'exact', 'asymp' (or 'asymptotic')", getName());
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return "kolmogorovSmirnovTest";
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return true; }
|
||||
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
DataTypes types
|
||||
{
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
};
|
||||
|
||||
Strings names
|
||||
{
|
||||
"d_statistic",
|
||||
"p_value"
|
||||
};
|
||||
|
||||
return std::make_shared<DataTypeTuple>(
|
||||
std::move(types),
|
||||
std::move(names)
|
||||
);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
Float64 value = columns[0]->getFloat64(row_num);
|
||||
UInt8 is_second = columns[1]->getUInt(row_num);
|
||||
if (is_second)
|
||||
this->data(place).addY(value, arena);
|
||||
else
|
||||
this->data(place).addX(value, arena);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs), arena);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
this->data(place).read(buf, arena);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
if (!this->data(place).size_x || !this->data(place).size_y)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} require both samples to be non empty", getName());
|
||||
|
||||
auto [d_statistic, p_value] = this->data(place).getResult(alternative, method);
|
||||
|
||||
/// Because p-value is a probability.
|
||||
p_value = std::min(1.0, std::max(0.0, p_value));
|
||||
|
||||
auto & column_tuple = assert_cast<ColumnTuple &>(to);
|
||||
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
|
||||
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
|
||||
|
||||
column_stat.getData().push_back(d_statistic);
|
||||
column_value.getData().push_back(p_value);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionKolmogorovSmirnovTest(
|
||||
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
|
@ -1,331 +0,0 @@
|
||||
#pragma once
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/StatCommon.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <Common/PODArray_fwd.h>
|
||||
#include <base/types.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
struct KolmogorovSmirnov : public StatisticalSample<Float64, Float64>
|
||||
{
|
||||
enum class Alternative
|
||||
{
|
||||
TwoSided,
|
||||
Less,
|
||||
Greater
|
||||
};
|
||||
|
||||
std::pair<Float64, Float64> getResult(Alternative alternative, String method)
|
||||
{
|
||||
::sort(x.begin(), x.end());
|
||||
::sort(y.begin(), y.end());
|
||||
|
||||
Float64 max_s = std::numeric_limits<Float64>::min();
|
||||
Float64 min_s = std::numeric_limits<Float64>::max();
|
||||
Float64 now_s = 0;
|
||||
UInt64 pos_x = 0;
|
||||
UInt64 pos_y = 0;
|
||||
UInt64 pos_tmp;
|
||||
UInt64 n1 = x.size();
|
||||
UInt64 n2 = y.size();
|
||||
|
||||
const Float64 n1_d = 1. / n1;
|
||||
const Float64 n2_d = 1. / n2;
|
||||
const Float64 tol = 1e-7;
|
||||
|
||||
// reference: https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test
|
||||
while (pos_x < x.size() && pos_y < y.size())
|
||||
{
|
||||
if (likely(fabs(x[pos_x] - y[pos_y]) >= tol))
|
||||
{
|
||||
if (x[pos_x] < y[pos_y])
|
||||
{
|
||||
now_s += n1_d;
|
||||
++pos_x;
|
||||
}
|
||||
else
|
||||
{
|
||||
now_s -= n2_d;
|
||||
++pos_y;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
pos_tmp = pos_x + 1;
|
||||
while (pos_tmp < x.size() && unlikely(fabs(x[pos_tmp] - x[pos_x]) <= tol))
|
||||
pos_tmp++;
|
||||
now_s += n1_d * (pos_tmp - pos_x);
|
||||
pos_x = pos_tmp;
|
||||
pos_tmp = pos_y + 1;
|
||||
while (pos_tmp < y.size() && unlikely(fabs(y[pos_tmp] - y[pos_y]) <= tol))
|
||||
pos_tmp++;
|
||||
now_s -= n2_d * (pos_tmp - pos_y);
|
||||
pos_y = pos_tmp;
|
||||
}
|
||||
max_s = std::max(max_s, now_s);
|
||||
min_s = std::min(min_s, now_s);
|
||||
}
|
||||
now_s += n1_d * (x.size() - pos_x) - n2_d * (y.size() - pos_y);
|
||||
min_s = std::min(min_s, now_s);
|
||||
max_s = std::max(max_s, now_s);
|
||||
|
||||
Float64 d = 0;
|
||||
if (alternative == Alternative::TwoSided)
|
||||
d = std::max(std::abs(max_s), std::abs(min_s));
|
||||
else if (alternative == Alternative::Less)
|
||||
d = -min_s;
|
||||
else if (alternative == Alternative::Greater)
|
||||
d = max_s;
|
||||
|
||||
UInt64 g = std::__gcd(n1, n2);
|
||||
UInt64 nx_g = n1 / g;
|
||||
UInt64 ny_g = n2 / g;
|
||||
|
||||
if (method == "auto")
|
||||
method = std::max(n1, n2) <= 10000 ? "exact" : "asymptotic";
|
||||
else if (method == "exact" && nx_g >= std::numeric_limits<Int32>::max() / ny_g)
|
||||
method = "asymptotic";
|
||||
|
||||
Float64 p_value = std::numeric_limits<Float64>::infinity();
|
||||
|
||||
if (method == "exact")
|
||||
{
|
||||
/* reference:
|
||||
* Gunar Schröer and Dietrich Trenkler
|
||||
* Exact and Randomization Distributions of Kolmogorov-Smirnov, Tests for Two or Three Samples
|
||||
*
|
||||
* and
|
||||
*
|
||||
* Thomas Viehmann
|
||||
* Numerically more stable computation of the p-values for the two-sample Kolmogorov-Smirnov test
|
||||
*/
|
||||
if (n2 > n1)
|
||||
std::swap(n1, n2);
|
||||
|
||||
const Float64 f_n1 = static_cast<Float64>(n1);
|
||||
const Float64 f_n2 = static_cast<Float64>(n2);
|
||||
const Float64 k_d = (0.5 + floor(d * f_n2 * f_n1 - tol)) / (f_n2 * f_n1);
|
||||
PaddedPODArray<Float64> c(n1 + 1);
|
||||
|
||||
auto check = alternative == Alternative::TwoSided ?
|
||||
[](const Float64 & q, const Float64 & r, const Float64 & s) { return fabs(r - s) >= q; }
|
||||
: [](const Float64 & q, const Float64 & r, const Float64 & s) { return r - s >= q; };
|
||||
|
||||
c[0] = 0;
|
||||
for (UInt64 j = 1; j <= n1; j++)
|
||||
if (check(k_d, 0., j / f_n1))
|
||||
c[j] = 1.;
|
||||
else
|
||||
c[j] = c[j - 1];
|
||||
|
||||
for (UInt64 i = 1; i <= n2; i++)
|
||||
{
|
||||
if (check(k_d, i / f_n2, 0.))
|
||||
c[0] = 1.;
|
||||
for (UInt64 j = 1; j <= n1; j++)
|
||||
if (check(k_d, i / f_n2, j / f_n1))
|
||||
c[j] = 1.;
|
||||
else
|
||||
{
|
||||
Float64 v = i / static_cast<Float64>(i + j);
|
||||
Float64 w = j / static_cast<Float64>(i + j);
|
||||
c[j] = v * c[j] + w * c[j - 1];
|
||||
}
|
||||
}
|
||||
p_value = c[n1];
|
||||
}
|
||||
else if (method == "asymp" || method == "asymptotic")
|
||||
{
|
||||
Float64 n = std::min(n1, n2);
|
||||
Float64 m = std::max(n1, n2);
|
||||
Float64 p = sqrt((n * m) / (n + m)) * d;
|
||||
|
||||
if (alternative == Alternative::TwoSided)
|
||||
{
|
||||
/* reference:
|
||||
* J.DURBIN
|
||||
* Distribution theory for tests based on the sample distribution function
|
||||
*/
|
||||
Float64 new_val, old_val, s, w, z;
|
||||
UInt64 k_max = static_cast<UInt64>(sqrt(2 - log(tol)));
|
||||
|
||||
if (p < 1)
|
||||
{
|
||||
z = - (M_PI_2 * M_PI_4) / (p * p);
|
||||
w = log(p);
|
||||
s = 0;
|
||||
for (UInt64 k = 1; k < k_max; k += 2)
|
||||
s += exp(k * k * z - w);
|
||||
p = s / 0.398942280401432677939946059934;
|
||||
}
|
||||
else
|
||||
{
|
||||
z = -2 * p * p;
|
||||
s = -1;
|
||||
UInt64 k = 1;
|
||||
old_val = 0;
|
||||
new_val = 1;
|
||||
while (fabs(old_val - new_val) > tol)
|
||||
{
|
||||
old_val = new_val;
|
||||
new_val += 2 * s * exp(z * k * k);
|
||||
s *= -1;
|
||||
k++;
|
||||
}
|
||||
p = new_val;
|
||||
}
|
||||
p_value = 1 - p;
|
||||
}
|
||||
else
|
||||
{
|
||||
/* reference:
|
||||
* J. L. HODGES, Jr
|
||||
* The significance probability of the Smirnov two-sample test
|
||||
*/
|
||||
|
||||
// Use Hodges' suggested approximation Eqn 5.3
|
||||
// Requires m to be the larger of (n1, n2)
|
||||
Float64 expt = -2 * p * p - 2 * p * (m + 2 * n) / sqrt(m * n * (m + n)) / 3.0;
|
||||
p_value = exp(expt);
|
||||
}
|
||||
}
|
||||
return {d, p_value};
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class AggregateFunctionKolmogorovSmirnov final:
|
||||
public IAggregateFunctionDataHelper<KolmogorovSmirnov, AggregateFunctionKolmogorovSmirnov>
|
||||
{
|
||||
private:
|
||||
using Alternative = typename KolmogorovSmirnov::Alternative;
|
||||
Alternative alternative = Alternative::TwoSided;
|
||||
String method = "auto";
|
||||
|
||||
public:
|
||||
explicit AggregateFunctionKolmogorovSmirnov(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<KolmogorovSmirnov, AggregateFunctionKolmogorovSmirnov> ({arguments}, {}, createResultType())
|
||||
{
|
||||
if (params.size() > 2)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} require two parameter or less", getName());
|
||||
|
||||
if (params.empty())
|
||||
return;
|
||||
|
||||
if (params[0].getType() != Field::Types::String)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} require first parameter to be a String", getName());
|
||||
|
||||
const auto & param = params[0].get<String>();
|
||||
if (param == "two-sided")
|
||||
alternative = Alternative::TwoSided;
|
||||
else if (param == "less")
|
||||
alternative = Alternative::Less;
|
||||
else if (param == "greater")
|
||||
alternative = Alternative::Greater;
|
||||
else
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown parameter in aggregate function {}. "
|
||||
"It must be one of: 'two-sided', 'less', 'greater'", getName());
|
||||
|
||||
if (params.size() != 2)
|
||||
return;
|
||||
|
||||
if (params[1].getType() != Field::Types::String)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} require second parameter to be a String", getName());
|
||||
|
||||
method = params[1].get<String>();
|
||||
if (method != "auto" && method != "exact" && method != "asymp" && method != "asymptotic")
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown method in aggregate function {}. "
|
||||
"It must be one of: 'auto', 'exact', 'asymp' (or 'asymptotic')", getName());
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return "kolmogorovSmirnovTest";
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return true; }
|
||||
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
DataTypes types
|
||||
{
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
};
|
||||
|
||||
Strings names
|
||||
{
|
||||
"d_statistic",
|
||||
"p_value"
|
||||
};
|
||||
|
||||
return std::make_shared<DataTypeTuple>(
|
||||
std::move(types),
|
||||
std::move(names)
|
||||
);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
Float64 value = columns[0]->getFloat64(row_num);
|
||||
UInt8 is_second = columns[1]->getUInt(row_num);
|
||||
if (is_second)
|
||||
this->data(place).addY(value, arena);
|
||||
else
|
||||
this->data(place).addX(value, arena);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs), arena);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
this->data(place).read(buf, arena);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
if (!this->data(place).size_x || !this->data(place).size_y)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} require both samples to be non empty", getName());
|
||||
|
||||
auto [d_statistic, p_value] = this->data(place).getResult(alternative, method);
|
||||
|
||||
/// Because p-value is a probability.
|
||||
p_value = std::min(1.0, std::max(0.0, p_value));
|
||||
|
||||
auto & column_tuple = assert_cast<ColumnTuple &>(to);
|
||||
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
|
||||
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
|
||||
|
||||
column_stat.getData().push_back(d_statistic);
|
||||
column_value.getData().push_back(p_value);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
@ -1,12 +1,30 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionLargestTriangleThreeBuckets.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
|
||||
#include <numeric>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/StatCommon.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnsDateTime.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <DataTypes/DataTypesDecimal.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <Common/PODArray.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
#include <boost/math/distributions/normal.hpp>
|
||||
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NOT_IMPLEMENTED;
|
||||
extern const int NOT_IMPLEMENTED;
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
}
|
||||
|
||||
namespace DB
|
||||
@ -16,29 +34,320 @@ struct Settings;
|
||||
namespace
|
||||
{
|
||||
|
||||
AggregateFunctionPtr
|
||||
createAggregateFunctionLargestTriangleThreeBuckets(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
struct LargestTriangleThreeBucketsData : public StatisticalSample<Float64, Float64>
|
||||
{
|
||||
void add(const Float64 xval, const Float64 yval, Arena * arena)
|
||||
{
|
||||
assertBinary(name, argument_types);
|
||||
|
||||
|
||||
if (!(isNumber(argument_types[0]) || isDateOrDate32(argument_types[0]) || isDateTime(argument_types[0])
|
||||
|| isDateTime64(argument_types[0])))
|
||||
throw Exception(
|
||||
ErrorCodes::NOT_IMPLEMENTED,
|
||||
"Aggregate function {} only supports Date, Date32, DateTime, DateTime64 and Number as the first argument",
|
||||
name);
|
||||
|
||||
if (!(isNumber(argument_types[1]) || isDateOrDate32(argument_types[1]) || isDateTime(argument_types[1])
|
||||
|| isDateTime64(argument_types[1])))
|
||||
throw Exception(
|
||||
ErrorCodes::NOT_IMPLEMENTED,
|
||||
"Aggregate function {} only supports Date, Date32, DateTime, DateTime64 and Number as the second argument",
|
||||
name);
|
||||
|
||||
return std::make_shared<AggregateFunctionLargestTriangleThreeBuckets>(argument_types, parameters);
|
||||
this->addX(xval, arena);
|
||||
this->addY(yval, arena);
|
||||
}
|
||||
|
||||
void sort(Arena * arena)
|
||||
{
|
||||
// sort the this->x and this->y in ascending order of this->x using index
|
||||
std::vector<size_t> index(this->x.size());
|
||||
|
||||
std::iota(index.begin(), index.end(), 0);
|
||||
::sort(index.begin(), index.end(), [&](size_t i1, size_t i2) { return this->x[i1] < this->x[i2]; });
|
||||
|
||||
SampleX temp_x{};
|
||||
SampleY temp_y{};
|
||||
|
||||
for (size_t i = 0; i < this->x.size(); ++i)
|
||||
{
|
||||
temp_x.push_back(this->x[index[i]], arena);
|
||||
temp_y.push_back(this->y[index[i]], arena);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < this->x.size(); ++i)
|
||||
{
|
||||
this->x[i] = temp_x[i];
|
||||
this->y[i] = temp_y[i];
|
||||
}
|
||||
}
|
||||
|
||||
PODArray<std::pair<Float64, Float64>> getResult(size_t total_buckets, Arena * arena)
|
||||
{
|
||||
// Sort the data
|
||||
this->sort(arena);
|
||||
|
||||
PODArray<std::pair<Float64, Float64>> result;
|
||||
|
||||
// Handle special cases for small data list
|
||||
if (this->x.size() <= total_buckets)
|
||||
{
|
||||
for (size_t i = 0; i < this->x.size(); ++i)
|
||||
{
|
||||
result.emplace_back(std::make_pair(this->x[i], this->y[i]));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Handle special cases for 0 or 1 or 2 buckets
|
||||
if (total_buckets == 0)
|
||||
return result;
|
||||
if (total_buckets == 1)
|
||||
{
|
||||
result.emplace_back(std::make_pair(this->x.front(), this->y.front()));
|
||||
return result;
|
||||
}
|
||||
if (total_buckets == 2)
|
||||
{
|
||||
result.emplace_back(std::make_pair(this->x.front(), this->y.front()));
|
||||
result.emplace_back(std::make_pair(this->x.back(), this->y.back()));
|
||||
return result;
|
||||
}
|
||||
|
||||
// Find the size of each bucket
|
||||
size_t single_bucket_size = this->x.size() / total_buckets;
|
||||
|
||||
// Include the first data point
|
||||
result.emplace_back(std::make_pair(this->x[0], this->y[0]));
|
||||
|
||||
for (size_t i = 1; i < total_buckets - 1; ++i) // Skip the first and last bucket
|
||||
{
|
||||
size_t start_index = i * single_bucket_size;
|
||||
size_t end_index = (i + 1) * single_bucket_size;
|
||||
|
||||
// Compute the average point in the next bucket
|
||||
Float64 avg_x = 0;
|
||||
Float64 avg_y = 0;
|
||||
for (size_t j = end_index; j < (i + 2) * single_bucket_size; ++j)
|
||||
{
|
||||
avg_x += this->x[j];
|
||||
avg_y += this->y[j];
|
||||
}
|
||||
avg_x /= single_bucket_size;
|
||||
avg_y /= single_bucket_size;
|
||||
|
||||
// Find the point in the current bucket that forms the largest triangle
|
||||
size_t max_index = start_index;
|
||||
Float64 max_area = 0.0;
|
||||
for (size_t j = start_index; j < end_index; ++j)
|
||||
{
|
||||
Float64 area = std::abs(
|
||||
0.5
|
||||
* (result.back().first * this->y[j] + this->x[j] * avg_y + avg_x * result.back().second - result.back().first * avg_y
|
||||
- this->x[j] * result.back().second - avg_x * this->y[j]));
|
||||
if (area > max_area)
|
||||
{
|
||||
max_area = area;
|
||||
max_index = j;
|
||||
}
|
||||
}
|
||||
|
||||
// Include the selected point
|
||||
result.emplace_back(std::make_pair(this->x[max_index], this->y[max_index]));
|
||||
}
|
||||
|
||||
// Include the last data point
|
||||
result.emplace_back(std::make_pair(this->x.back(), this->y.back()));
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
class AggregateFunctionLargestTriangleThreeBuckets final : public IAggregateFunctionDataHelper<LargestTriangleThreeBucketsData, AggregateFunctionLargestTriangleThreeBuckets>
|
||||
{
|
||||
private:
|
||||
UInt64 total_buckets{0};
|
||||
TypeIndex x_type;
|
||||
TypeIndex y_type;
|
||||
|
||||
public:
|
||||
explicit AggregateFunctionLargestTriangleThreeBuckets(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<LargestTriangleThreeBucketsData, AggregateFunctionLargestTriangleThreeBuckets>({arguments}, {}, createResultType(arguments))
|
||||
{
|
||||
if (params.size() != 1)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} require one parameter", getName());
|
||||
|
||||
if (params[0].getType() != Field::Types::UInt64)
|
||||
throw Exception(
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} require first parameter to be a UInt64", getName());
|
||||
|
||||
total_buckets = params[0].get<UInt64>();
|
||||
|
||||
this->x_type = WhichDataType(arguments[0]).idx;
|
||||
this->y_type = WhichDataType(arguments[1]).idx;
|
||||
}
|
||||
|
||||
static constexpr auto name = "largestTriangleThreeBuckets";
|
||||
|
||||
String getName() const override { return name; }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return true; }
|
||||
|
||||
static DataTypePtr createResultType(const DataTypes & arguments)
|
||||
{
|
||||
TypeIndex x_type = arguments[0]->getTypeId();
|
||||
TypeIndex y_type = arguments[1]->getTypeId();
|
||||
|
||||
UInt32 x_scale = 0;
|
||||
UInt32 y_scale = 0;
|
||||
|
||||
if (const auto * datetime64_type = typeid_cast<const DataTypeDateTime64 *>(arguments[0].get()))
|
||||
{
|
||||
x_scale = datetime64_type->getScale();
|
||||
}
|
||||
|
||||
if (const auto * datetime64_type = typeid_cast<const DataTypeDateTime64 *>(arguments[1].get()))
|
||||
{
|
||||
y_scale = datetime64_type->getScale();
|
||||
}
|
||||
|
||||
DataTypes types = {getDataTypeFromTypeIndex(x_type, x_scale), getDataTypeFromTypeIndex(y_type, y_scale)};
|
||||
|
||||
auto tuple = std::make_shared<DataTypeTuple>(std::move(types));
|
||||
|
||||
return std::make_shared<DataTypeArray>(tuple);
|
||||
}
|
||||
|
||||
static DataTypePtr getDataTypeFromTypeIndex(TypeIndex type_index, UInt32 scale)
|
||||
{
|
||||
DataTypePtr data_type;
|
||||
switch (type_index)
|
||||
{
|
||||
case TypeIndex::Date:
|
||||
data_type = std::make_shared<DataTypeDate>();
|
||||
break;
|
||||
case TypeIndex::Date32:
|
||||
data_type = std::make_shared<DataTypeDate32>();
|
||||
break;
|
||||
case TypeIndex::DateTime:
|
||||
data_type = std::make_shared<DataTypeDateTime>();
|
||||
break;
|
||||
case TypeIndex::DateTime64:
|
||||
data_type = std::make_shared<DataTypeDateTime64>(scale);
|
||||
break;
|
||||
default:
|
||||
data_type = std::make_shared<DataTypeNumber<Float64>>();
|
||||
}
|
||||
return data_type;
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
Float64 x = getFloat64DataFromColumn(columns[0], row_num, this->x_type);
|
||||
Float64 y = getFloat64DataFromColumn(columns[1], row_num, this->y_type);
|
||||
this->data(place).add(x, y, arena);
|
||||
}
|
||||
|
||||
Float64 getFloat64DataFromColumn(const IColumn * column, size_t row_num, TypeIndex type_index) const
|
||||
{
|
||||
switch (type_index)
|
||||
{
|
||||
case TypeIndex::Date:
|
||||
return static_cast<const ColumnDate &>(*column).getData()[row_num];
|
||||
case TypeIndex::Date32:
|
||||
return static_cast<const ColumnDate32 &>(*column).getData()[row_num];
|
||||
case TypeIndex::DateTime:
|
||||
return static_cast<const ColumnDateTime &>(*column).getData()[row_num];
|
||||
case TypeIndex::DateTime64:
|
||||
return static_cast<const ColumnDateTime64 &>(*column).getData()[row_num];
|
||||
default:
|
||||
return column->getFloat64(row_num);
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
auto & a = this->data(place);
|
||||
const auto & b = this->data(rhs);
|
||||
|
||||
a.merge(b, arena);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
this->data(place).read(buf, arena);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
|
||||
{
|
||||
auto res = this->data(place).getResult(total_buckets, arena);
|
||||
|
||||
auto & col = assert_cast<ColumnArray &>(to);
|
||||
auto & col_offsets = assert_cast<ColumnArray::ColumnOffsets &>(col.getOffsetsColumn());
|
||||
|
||||
auto column_x_adder_func = getColumnAdderFunc(x_type);
|
||||
auto column_y_adder_func = getColumnAdderFunc(y_type);
|
||||
|
||||
for (const auto & elem : res)
|
||||
{
|
||||
auto & column_tuple = assert_cast<ColumnTuple &>(col.getData());
|
||||
column_x_adder_func(column_tuple.getColumn(0), elem.first);
|
||||
column_y_adder_func(column_tuple.getColumn(1), elem.second);
|
||||
}
|
||||
|
||||
col_offsets.getData().push_back(col.getData().size());
|
||||
}
|
||||
|
||||
std::function<void(IColumn &, Float64)> getColumnAdderFunc(TypeIndex type_index) const
|
||||
{
|
||||
switch (type_index)
|
||||
{
|
||||
case TypeIndex::Date:
|
||||
return [](IColumn & column, Float64 value)
|
||||
{
|
||||
auto & col = assert_cast<ColumnDate &>(column);
|
||||
col.getData().push_back(static_cast<UInt16>(value));
|
||||
};
|
||||
case TypeIndex::Date32:
|
||||
return [](IColumn & column, Float64 value)
|
||||
{
|
||||
auto & col = assert_cast<ColumnDate32 &>(column);
|
||||
col.getData().push_back(static_cast<UInt32>(value));
|
||||
};
|
||||
case TypeIndex::DateTime:
|
||||
return [](IColumn & column, Float64 value)
|
||||
{
|
||||
auto & col = assert_cast<ColumnDateTime &>(column);
|
||||
col.getData().push_back(static_cast<UInt32>(value));
|
||||
};
|
||||
case TypeIndex::DateTime64:
|
||||
return [](IColumn & column, Float64 value)
|
||||
{
|
||||
auto & col = assert_cast<ColumnDateTime64 &>(column);
|
||||
col.getData().push_back(static_cast<UInt64>(value));
|
||||
};
|
||||
default:
|
||||
return [](IColumn & column, Float64 value)
|
||||
{
|
||||
auto & col = assert_cast<ColumnFloat64 &>(column);
|
||||
col.getData().push_back(value);
|
||||
};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
AggregateFunctionPtr
|
||||
createAggregateFunctionLargestTriangleThreeBuckets(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
assertBinary(name, argument_types);
|
||||
|
||||
if (!(isNumber(argument_types[0]) || isDateOrDate32(argument_types[0]) || isDateTime(argument_types[0])
|
||||
|| isDateTime64(argument_types[0])))
|
||||
throw Exception(
|
||||
ErrorCodes::NOT_IMPLEMENTED,
|
||||
"Aggregate function {} only supports Date, Date32, DateTime, DateTime64 and Number as the first argument",
|
||||
name);
|
||||
|
||||
if (!(isNumber(argument_types[1]) || isDateOrDate32(argument_types[1]) || isDateTime(argument_types[1])
|
||||
|| isDateTime64(argument_types[1])))
|
||||
throw Exception(
|
||||
ErrorCodes::NOT_IMPLEMENTED,
|
||||
"Aggregate function {} only supports Date, Date32, DateTime, DateTime64 and Number as the second argument",
|
||||
name);
|
||||
|
||||
return std::make_shared<AggregateFunctionLargestTriangleThreeBuckets>(argument_types, parameters);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -1,327 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/StatCommon.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnsDateTime.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <DataTypes/DataTypesDecimal.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <base/types.h>
|
||||
#include <Common/PODArray_fwd.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
#include <boost/math/distributions/normal.hpp>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
}
|
||||
|
||||
|
||||
struct LargestTriangleThreeBucketsData : public StatisticalSample<Float64, Float64>
|
||||
{
|
||||
void add(const Float64 xval, const Float64 yval, Arena * arena)
|
||||
{
|
||||
this->addX(xval, arena);
|
||||
this->addY(yval, arena);
|
||||
}
|
||||
|
||||
void sort(Arena * arena)
|
||||
{
|
||||
// sort the this->x and this->y in ascending order of this->x using index
|
||||
std::vector<size_t> index(this->x.size());
|
||||
|
||||
std::iota(index.begin(), index.end(), 0);
|
||||
::sort(index.begin(), index.end(), [&](size_t i1, size_t i2) { return this->x[i1] < this->x[i2]; });
|
||||
|
||||
SampleX temp_x{};
|
||||
SampleY temp_y{};
|
||||
|
||||
for (size_t i = 0; i < this->x.size(); ++i)
|
||||
{
|
||||
temp_x.push_back(this->x[index[i]], arena);
|
||||
temp_y.push_back(this->y[index[i]], arena);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < this->x.size(); ++i)
|
||||
{
|
||||
this->x[i] = temp_x[i];
|
||||
this->y[i] = temp_y[i];
|
||||
}
|
||||
}
|
||||
|
||||
PODArray<std::pair<Float64, Float64>> getResult(size_t total_buckets, Arena * arena)
|
||||
{
|
||||
// Sort the data
|
||||
this->sort(arena);
|
||||
|
||||
PODArray<std::pair<Float64, Float64>> result;
|
||||
|
||||
// Handle special cases for small data list
|
||||
if (this->x.size() <= total_buckets)
|
||||
{
|
||||
for (size_t i = 0; i < this->x.size(); ++i)
|
||||
{
|
||||
result.emplace_back(std::make_pair(this->x[i], this->y[i]));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Handle special cases for 0 or 1 or 2 buckets
|
||||
if (total_buckets == 0)
|
||||
return result;
|
||||
if (total_buckets == 1)
|
||||
{
|
||||
result.emplace_back(std::make_pair(this->x.front(), this->y.front()));
|
||||
return result;
|
||||
}
|
||||
if (total_buckets == 2)
|
||||
{
|
||||
result.emplace_back(std::make_pair(this->x.front(), this->y.front()));
|
||||
result.emplace_back(std::make_pair(this->x.back(), this->y.back()));
|
||||
return result;
|
||||
}
|
||||
|
||||
// Find the size of each bucket
|
||||
size_t single_bucket_size = this->x.size() / total_buckets;
|
||||
|
||||
// Include the first data point
|
||||
result.emplace_back(std::make_pair(this->x[0], this->y[0]));
|
||||
|
||||
for (size_t i = 1; i < total_buckets - 1; ++i) // Skip the first and last bucket
|
||||
{
|
||||
size_t start_index = i * single_bucket_size;
|
||||
size_t end_index = (i + 1) * single_bucket_size;
|
||||
|
||||
// Compute the average point in the next bucket
|
||||
Float64 avg_x = 0;
|
||||
Float64 avg_y = 0;
|
||||
for (size_t j = end_index; j < (i + 2) * single_bucket_size; ++j)
|
||||
{
|
||||
avg_x += this->x[j];
|
||||
avg_y += this->y[j];
|
||||
}
|
||||
avg_x /= single_bucket_size;
|
||||
avg_y /= single_bucket_size;
|
||||
|
||||
// Find the point in the current bucket that forms the largest triangle
|
||||
size_t max_index = start_index;
|
||||
Float64 max_area = 0.0;
|
||||
for (size_t j = start_index; j < end_index; ++j)
|
||||
{
|
||||
Float64 area = std::abs(
|
||||
0.5
|
||||
* (result.back().first * this->y[j] + this->x[j] * avg_y + avg_x * result.back().second - result.back().first * avg_y
|
||||
- this->x[j] * result.back().second - avg_x * this->y[j]));
|
||||
if (area > max_area)
|
||||
{
|
||||
max_area = area;
|
||||
max_index = j;
|
||||
}
|
||||
}
|
||||
|
||||
// Include the selected point
|
||||
result.emplace_back(std::make_pair(this->x[max_index], this->y[max_index]));
|
||||
}
|
||||
|
||||
// Include the last data point
|
||||
result.emplace_back(std::make_pair(this->x.back(), this->y.back()));
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
class AggregateFunctionLargestTriangleThreeBuckets final : public IAggregateFunctionDataHelper<LargestTriangleThreeBucketsData, AggregateFunctionLargestTriangleThreeBuckets>
|
||||
{
|
||||
private:
|
||||
UInt64 total_buckets{0};
|
||||
TypeIndex x_type;
|
||||
TypeIndex y_type;
|
||||
|
||||
public:
|
||||
explicit AggregateFunctionLargestTriangleThreeBuckets(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<LargestTriangleThreeBucketsData, AggregateFunctionLargestTriangleThreeBuckets>({arguments}, {}, createResultType(arguments))
|
||||
{
|
||||
if (params.size() != 1)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} require one parameter", getName());
|
||||
|
||||
if (params[0].getType() != Field::Types::UInt64)
|
||||
throw Exception(
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} require first parameter to be a UInt64", getName());
|
||||
|
||||
total_buckets = params[0].get<UInt64>();
|
||||
|
||||
this->x_type = WhichDataType(arguments[0]).idx;
|
||||
this->y_type = WhichDataType(arguments[1]).idx;
|
||||
}
|
||||
|
||||
static constexpr auto name = "largestTriangleThreeBuckets";
|
||||
|
||||
String getName() const override { return name; }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return true; }
|
||||
|
||||
static DataTypePtr createResultType(const DataTypes & arguments)
|
||||
{
|
||||
TypeIndex x_type = arguments[0]->getTypeId();
|
||||
TypeIndex y_type = arguments[1]->getTypeId();
|
||||
|
||||
UInt32 x_scale = 0;
|
||||
UInt32 y_scale = 0;
|
||||
|
||||
if (const auto * datetime64_type = typeid_cast<const DataTypeDateTime64 *>(arguments[0].get()))
|
||||
{
|
||||
x_scale = datetime64_type->getScale();
|
||||
}
|
||||
|
||||
if (const auto * datetime64_type = typeid_cast<const DataTypeDateTime64 *>(arguments[1].get()))
|
||||
{
|
||||
y_scale = datetime64_type->getScale();
|
||||
}
|
||||
|
||||
DataTypes types = {getDataTypeFromTypeIndex(x_type, x_scale), getDataTypeFromTypeIndex(y_type, y_scale)};
|
||||
|
||||
auto tuple = std::make_shared<DataTypeTuple>(std::move(types));
|
||||
|
||||
return std::make_shared<DataTypeArray>(tuple);
|
||||
}
|
||||
|
||||
static DataTypePtr getDataTypeFromTypeIndex(TypeIndex type_index, UInt32 scale)
|
||||
{
|
||||
DataTypePtr data_type;
|
||||
switch (type_index)
|
||||
{
|
||||
case TypeIndex::Date:
|
||||
data_type = std::make_shared<DataTypeDate>();
|
||||
break;
|
||||
case TypeIndex::Date32:
|
||||
data_type = std::make_shared<DataTypeDate32>();
|
||||
break;
|
||||
case TypeIndex::DateTime:
|
||||
data_type = std::make_shared<DataTypeDateTime>();
|
||||
break;
|
||||
case TypeIndex::DateTime64:
|
||||
data_type = std::make_shared<DataTypeDateTime64>(scale);
|
||||
break;
|
||||
default:
|
||||
data_type = std::make_shared<DataTypeNumber<Float64>>();
|
||||
}
|
||||
return data_type;
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
Float64 x = getFloat64DataFromColumn(columns[0], row_num, this->x_type);
|
||||
Float64 y = getFloat64DataFromColumn(columns[1], row_num, this->y_type);
|
||||
this->data(place).add(x, y, arena);
|
||||
}
|
||||
|
||||
Float64 getFloat64DataFromColumn(const IColumn * column, size_t row_num, TypeIndex type_index) const
|
||||
{
|
||||
switch (type_index)
|
||||
{
|
||||
case TypeIndex::Date:
|
||||
return static_cast<const ColumnDate &>(*column).getData()[row_num];
|
||||
case TypeIndex::Date32:
|
||||
return static_cast<const ColumnDate32 &>(*column).getData()[row_num];
|
||||
case TypeIndex::DateTime:
|
||||
return static_cast<const ColumnDateTime &>(*column).getData()[row_num];
|
||||
case TypeIndex::DateTime64:
|
||||
return static_cast<const ColumnDateTime64 &>(*column).getData()[row_num];
|
||||
default:
|
||||
return column->getFloat64(row_num);
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
auto & a = this->data(place);
|
||||
const auto & b = this->data(rhs);
|
||||
|
||||
a.merge(b, arena);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
this->data(place).read(buf, arena);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
|
||||
{
|
||||
auto res = this->data(place).getResult(total_buckets, arena);
|
||||
|
||||
auto & col = assert_cast<ColumnArray &>(to);
|
||||
auto & col_offsets = assert_cast<ColumnArray::ColumnOffsets &>(col.getOffsetsColumn());
|
||||
|
||||
auto column_x_adder_func = getColumnAdderFunc(x_type);
|
||||
auto column_y_adder_func = getColumnAdderFunc(y_type);
|
||||
|
||||
for (size_t i = 0; i < res.size(); ++i)
|
||||
{
|
||||
auto & column_tuple = assert_cast<ColumnTuple &>(col.getData());
|
||||
column_x_adder_func(column_tuple.getColumn(0), res[i].first);
|
||||
column_y_adder_func(column_tuple.getColumn(1), res[i].second);
|
||||
}
|
||||
|
||||
col_offsets.getData().push_back(col.getData().size());
|
||||
}
|
||||
|
||||
std::function<void(IColumn &, Float64)> getColumnAdderFunc(TypeIndex type_index) const
|
||||
{
|
||||
switch (type_index)
|
||||
{
|
||||
case TypeIndex::Date:
|
||||
return [](IColumn & column, Float64 value)
|
||||
{
|
||||
auto & col = assert_cast<ColumnDate &>(column);
|
||||
col.getData().push_back(static_cast<UInt16>(value));
|
||||
};
|
||||
case TypeIndex::Date32:
|
||||
return [](IColumn & column, Float64 value)
|
||||
{
|
||||
auto & col = assert_cast<ColumnDate32 &>(column);
|
||||
col.getData().push_back(static_cast<UInt32>(value));
|
||||
};
|
||||
case TypeIndex::DateTime:
|
||||
return [](IColumn & column, Float64 value)
|
||||
{
|
||||
auto & col = assert_cast<ColumnDateTime &>(column);
|
||||
col.getData().push_back(static_cast<UInt32>(value));
|
||||
};
|
||||
case TypeIndex::DateTime64:
|
||||
return [](IColumn & column, Float64 value)
|
||||
{
|
||||
auto & col = assert_cast<ColumnDateTime64 &>(column);
|
||||
col.getData().push_back(static_cast<UInt64>(value));
|
||||
};
|
||||
default:
|
||||
return [](IColumn & column, Float64 value)
|
||||
{
|
||||
auto & col = assert_cast<ColumnFloat64 &>(column);
|
||||
col.getData().push_back(value);
|
||||
};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -1,21 +1,254 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionMannWhitney.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/StatCommon.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <Common/PODArray.h>
|
||||
#include <DataTypes/DataTypesDecimal.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <limits>
|
||||
|
||||
#include <boost/math/distributions/normal.hpp>
|
||||
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NOT_IMPLEMENTED;
|
||||
extern const int NOT_IMPLEMENTED;
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
struct Settings;
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
struct MannWhitneyData : public StatisticalSample<Float64, Float64>
|
||||
{
|
||||
/*Since null hypothesis is "for randomly selected values X and Y from two populations,
|
||||
*the probability of X being greater than Y is equal to the probability of Y being greater than X".
|
||||
*Or "the distribution F of first sample equals to the distribution G of second sample".
|
||||
*Then alternative for this hypothesis (H1) is "two-sided"(F != G), "less"(F < G), "greater" (F > G). */
|
||||
enum class Alternative
|
||||
{
|
||||
TwoSided,
|
||||
Less,
|
||||
Greater
|
||||
};
|
||||
|
||||
/// The behaviour equals to the similar function from scipy.
|
||||
/// https://github.com/scipy/scipy/blob/ab9e9f17e0b7b2d618c4d4d8402cd4c0c200d6c0/scipy/stats/stats.py#L6978
|
||||
std::pair<Float64, Float64> getResult(Alternative alternative, bool continuity_correction)
|
||||
{
|
||||
ConcatenatedSamples both(this->x, this->y);
|
||||
RanksArray ranks;
|
||||
Float64 tie_correction;
|
||||
|
||||
/// Compute ranks according to both samples.
|
||||
std::tie(ranks, tie_correction) = computeRanksAndTieCorrection(both);
|
||||
|
||||
const Float64 n1 = this->size_x;
|
||||
const Float64 n2 = this->size_y;
|
||||
|
||||
Float64 r1 = 0;
|
||||
for (size_t i = 0; i < n1; ++i)
|
||||
r1 += ranks[i];
|
||||
|
||||
const Float64 u1 = n1 * n2 + (n1 * (n1 + 1.)) / 2. - r1;
|
||||
const Float64 u2 = n1 * n2 - u1;
|
||||
|
||||
/// The distribution of U-statistic under null hypothesis H0 is symmetric with respect to meanrank.
|
||||
const Float64 meanrank = n1 * n2 /2. + 0.5 * continuity_correction;
|
||||
const Float64 sd = std::sqrt(tie_correction * n1 * n2 * (n1 + n2 + 1) / 12.0);
|
||||
|
||||
Float64 u = 0;
|
||||
if (alternative == Alternative::TwoSided)
|
||||
/// There is no difference which u_i to take as u, because z will be differ only in sign and we take std::abs() from it.
|
||||
u = std::max(u1, u2);
|
||||
else if (alternative == Alternative::Less)
|
||||
u = u1;
|
||||
else if (alternative == Alternative::Greater)
|
||||
u = u2;
|
||||
|
||||
Float64 z = (u - meanrank) / sd;
|
||||
|
||||
if (unlikely(!std::isfinite(z)))
|
||||
return {std::numeric_limits<Float64>::quiet_NaN(), std::numeric_limits<Float64>::quiet_NaN()};
|
||||
|
||||
if (alternative == Alternative::TwoSided)
|
||||
z = std::abs(z);
|
||||
|
||||
auto standard_normal_distribution = boost::math::normal_distribution<Float64>();
|
||||
auto cdf = boost::math::cdf(standard_normal_distribution, z);
|
||||
|
||||
Float64 p_value = 0;
|
||||
if (alternative == Alternative::TwoSided)
|
||||
p_value = 2 - 2 * cdf;
|
||||
else
|
||||
p_value = 1 - cdf;
|
||||
|
||||
return {u2, p_value};
|
||||
}
|
||||
|
||||
private:
|
||||
using Sample = typename StatisticalSample<Float64, Float64>::SampleX;
|
||||
|
||||
/// We need to compute ranks according to all samples. Use this class to avoid extra copy and memory allocation.
|
||||
class ConcatenatedSamples
|
||||
{
|
||||
public:
|
||||
ConcatenatedSamples(const Sample & first_, const Sample & second_)
|
||||
: first(first_), second(second_) {}
|
||||
|
||||
const Float64 & operator[](size_t ind) const
|
||||
{
|
||||
if (ind < first.size())
|
||||
return first[ind];
|
||||
return second[ind % first.size()];
|
||||
}
|
||||
|
||||
size_t size() const
|
||||
{
|
||||
return first.size() + second.size();
|
||||
}
|
||||
|
||||
private:
|
||||
const Sample & first;
|
||||
const Sample & second;
|
||||
};
|
||||
};
|
||||
|
||||
class AggregateFunctionMannWhitney final:
|
||||
public IAggregateFunctionDataHelper<MannWhitneyData, AggregateFunctionMannWhitney>
|
||||
{
|
||||
private:
|
||||
using Alternative = typename MannWhitneyData::Alternative;
|
||||
Alternative alternative;
|
||||
bool continuity_correction{true};
|
||||
|
||||
public:
|
||||
explicit AggregateFunctionMannWhitney(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<MannWhitneyData, AggregateFunctionMannWhitney> ({arguments}, {}, createResultType())
|
||||
{
|
||||
if (params.size() > 2)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} require two parameter or less", getName());
|
||||
|
||||
if (params.empty())
|
||||
{
|
||||
alternative = Alternative::TwoSided;
|
||||
return;
|
||||
}
|
||||
|
||||
if (params[0].getType() != Field::Types::String)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} require first parameter to be a String", getName());
|
||||
|
||||
const auto & param = params[0].get<String>();
|
||||
if (param == "two-sided")
|
||||
alternative = Alternative::TwoSided;
|
||||
else if (param == "less")
|
||||
alternative = Alternative::Less;
|
||||
else if (param == "greater")
|
||||
alternative = Alternative::Greater;
|
||||
else
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown parameter in aggregate function {}. "
|
||||
"It must be one of: 'two-sided', 'less', 'greater'", getName());
|
||||
|
||||
if (params.size() != 2)
|
||||
return;
|
||||
|
||||
if (params[1].getType() != Field::Types::UInt64)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} require second parameter to be a UInt64", getName());
|
||||
|
||||
continuity_correction = static_cast<bool>(params[1].get<UInt64>());
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return "mannWhitneyUTest";
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return true; }
|
||||
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
DataTypes types
|
||||
{
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
};
|
||||
|
||||
Strings names
|
||||
{
|
||||
"u_statistic",
|
||||
"p_value"
|
||||
};
|
||||
|
||||
return std::make_shared<DataTypeTuple>(
|
||||
std::move(types),
|
||||
std::move(names)
|
||||
);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
Float64 value = columns[0]->getFloat64(row_num);
|
||||
UInt8 is_second = columns[1]->getUInt(row_num);
|
||||
|
||||
if (is_second)
|
||||
this->data(place).addY(value, arena);
|
||||
else
|
||||
this->data(place).addX(value, arena);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
auto & a = this->data(place);
|
||||
const auto & b = this->data(rhs);
|
||||
|
||||
a.merge(b, arena);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
this->data(place).read(buf, arena);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
if (!this->data(place).size_x || !this->data(place).size_y)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} require both samples to be non empty", getName());
|
||||
|
||||
auto [u_statistic, p_value] = this->data(place).getResult(alternative, continuity_correction);
|
||||
|
||||
/// Because p-value is a probability.
|
||||
p_value = std::min(1.0, std::max(0.0, p_value));
|
||||
|
||||
auto & column_tuple = assert_cast<ColumnTuple &>(to);
|
||||
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
|
||||
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
|
||||
|
||||
column_stat.getData().push_back(u_statistic);
|
||||
column_value.getData().push_back(p_value);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionMannWhitneyUTest(
|
||||
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
|
@ -1,249 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/StatCommon.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <Common/PODArray_fwd.h>
|
||||
#include <base/types.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypesDecimal.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <limits>
|
||||
|
||||
#include <boost/math/distributions/normal.hpp>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
|
||||
struct MannWhitneyData : public StatisticalSample<Float64, Float64>
|
||||
{
|
||||
/*Since null hypothesis is "for randomly selected values X and Y from two populations,
|
||||
*the probability of X being greater than Y is equal to the probability of Y being greater than X".
|
||||
*Or "the distribution F of first sample equals to the distribution G of second sample".
|
||||
*Then alternative for this hypothesis (H1) is "two-sided"(F != G), "less"(F < G), "greater" (F > G). */
|
||||
enum class Alternative
|
||||
{
|
||||
TwoSided,
|
||||
Less,
|
||||
Greater
|
||||
};
|
||||
|
||||
/// The behaviour equals to the similar function from scipy.
|
||||
/// https://github.com/scipy/scipy/blob/ab9e9f17e0b7b2d618c4d4d8402cd4c0c200d6c0/scipy/stats/stats.py#L6978
|
||||
std::pair<Float64, Float64> getResult(Alternative alternative, bool continuity_correction)
|
||||
{
|
||||
ConcatenatedSamples both(this->x, this->y);
|
||||
RanksArray ranks;
|
||||
Float64 tie_correction;
|
||||
|
||||
/// Compute ranks according to both samples.
|
||||
std::tie(ranks, tie_correction) = computeRanksAndTieCorrection(both);
|
||||
|
||||
const Float64 n1 = this->size_x;
|
||||
const Float64 n2 = this->size_y;
|
||||
|
||||
Float64 r1 = 0;
|
||||
for (size_t i = 0; i < n1; ++i)
|
||||
r1 += ranks[i];
|
||||
|
||||
const Float64 u1 = n1 * n2 + (n1 * (n1 + 1.)) / 2. - r1;
|
||||
const Float64 u2 = n1 * n2 - u1;
|
||||
|
||||
/// The distribution of U-statistic under null hypothesis H0 is symmetric with respect to meanrank.
|
||||
const Float64 meanrank = n1 * n2 /2. + 0.5 * continuity_correction;
|
||||
const Float64 sd = std::sqrt(tie_correction * n1 * n2 * (n1 + n2 + 1) / 12.0);
|
||||
|
||||
Float64 u = 0;
|
||||
if (alternative == Alternative::TwoSided)
|
||||
/// There is no difference which u_i to take as u, because z will be differ only in sign and we take std::abs() from it.
|
||||
u = std::max(u1, u2);
|
||||
else if (alternative == Alternative::Less)
|
||||
u = u1;
|
||||
else if (alternative == Alternative::Greater)
|
||||
u = u2;
|
||||
|
||||
Float64 z = (u - meanrank) / sd;
|
||||
|
||||
if (unlikely(!std::isfinite(z)))
|
||||
return {std::numeric_limits<Float64>::quiet_NaN(), std::numeric_limits<Float64>::quiet_NaN()};
|
||||
|
||||
if (alternative == Alternative::TwoSided)
|
||||
z = std::abs(z);
|
||||
|
||||
auto standard_normal_distribution = boost::math::normal_distribution<Float64>();
|
||||
auto cdf = boost::math::cdf(standard_normal_distribution, z);
|
||||
|
||||
Float64 p_value = 0;
|
||||
if (alternative == Alternative::TwoSided)
|
||||
p_value = 2 - 2 * cdf;
|
||||
else
|
||||
p_value = 1 - cdf;
|
||||
|
||||
return {u2, p_value};
|
||||
}
|
||||
|
||||
private:
|
||||
using Sample = typename StatisticalSample<Float64, Float64>::SampleX;
|
||||
|
||||
/// We need to compute ranks according to all samples. Use this class to avoid extra copy and memory allocation.
|
||||
class ConcatenatedSamples
|
||||
{
|
||||
public:
|
||||
ConcatenatedSamples(const Sample & first_, const Sample & second_)
|
||||
: first(first_), second(second_) {}
|
||||
|
||||
const Float64 & operator[](size_t ind) const
|
||||
{
|
||||
if (ind < first.size())
|
||||
return first[ind];
|
||||
return second[ind % first.size()];
|
||||
}
|
||||
|
||||
size_t size() const
|
||||
{
|
||||
return first.size() + second.size();
|
||||
}
|
||||
|
||||
private:
|
||||
const Sample & first;
|
||||
const Sample & second;
|
||||
};
|
||||
};
|
||||
|
||||
class AggregateFunctionMannWhitney final:
|
||||
public IAggregateFunctionDataHelper<MannWhitneyData, AggregateFunctionMannWhitney>
|
||||
{
|
||||
private:
|
||||
using Alternative = typename MannWhitneyData::Alternative;
|
||||
Alternative alternative;
|
||||
bool continuity_correction{true};
|
||||
|
||||
public:
|
||||
explicit AggregateFunctionMannWhitney(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<MannWhitneyData, AggregateFunctionMannWhitney> ({arguments}, {}, createResultType())
|
||||
{
|
||||
if (params.size() > 2)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} require two parameter or less", getName());
|
||||
|
||||
if (params.empty())
|
||||
{
|
||||
alternative = Alternative::TwoSided;
|
||||
return;
|
||||
}
|
||||
|
||||
if (params[0].getType() != Field::Types::String)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} require first parameter to be a String", getName());
|
||||
|
||||
const auto & param = params[0].get<String>();
|
||||
if (param == "two-sided")
|
||||
alternative = Alternative::TwoSided;
|
||||
else if (param == "less")
|
||||
alternative = Alternative::Less;
|
||||
else if (param == "greater")
|
||||
alternative = Alternative::Greater;
|
||||
else
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown parameter in aggregate function {}. "
|
||||
"It must be one of: 'two-sided', 'less', 'greater'", getName());
|
||||
|
||||
if (params.size() != 2)
|
||||
return;
|
||||
|
||||
if (params[1].getType() != Field::Types::UInt64)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} require second parameter to be a UInt64", getName());
|
||||
|
||||
continuity_correction = static_cast<bool>(params[1].get<UInt64>());
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return "mannWhitneyUTest";
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return true; }
|
||||
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
DataTypes types
|
||||
{
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
};
|
||||
|
||||
Strings names
|
||||
{
|
||||
"u_statistic",
|
||||
"p_value"
|
||||
};
|
||||
|
||||
return std::make_shared<DataTypeTuple>(
|
||||
std::move(types),
|
||||
std::move(names)
|
||||
);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
Float64 value = columns[0]->getFloat64(row_num);
|
||||
UInt8 is_second = columns[1]->getUInt(row_num);
|
||||
|
||||
if (is_second)
|
||||
this->data(place).addY(value, arena);
|
||||
else
|
||||
this->data(place).addX(value, arena);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
auto & a = this->data(place);
|
||||
const auto & b = this->data(rhs);
|
||||
|
||||
a.merge(b, arena);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
this->data(place).read(buf, arena);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
if (!this->data(place).size_x || !this->data(place).size_y)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} require both samples to be non empty", getName());
|
||||
|
||||
auto [u_statistic, p_value] = this->data(place).getResult(alternative, continuity_correction);
|
||||
|
||||
/// Because p-value is a probability.
|
||||
p_value = std::min(1.0, std::max(0.0, p_value));
|
||||
|
||||
auto & column_tuple = assert_cast<ColumnTuple &>(to);
|
||||
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
|
||||
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
|
||||
|
||||
column_stat.getData().push_back(u_statistic);
|
||||
column_value.getData().push_back(p_value);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
@ -1,8 +1,21 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionMaxIntersections.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
#include <Common/ArenaAllocator.h>
|
||||
#include <Common/NaNUtils.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
#define AGGREGATE_FUNCTION_MAX_INTERSECTIONS_MAX_ARRAY_SIZE 0xFFFFFF
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -11,24 +24,186 @@ struct Settings;
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int TOO_LARGE_ARRAY_SIZE;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
AggregateFunctionPtr createAggregateFunctionMaxIntersections(
|
||||
AggregateFunctionIntersectionsKind kind,
|
||||
const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
||||
|
||||
/** maxIntersections: returns maximum count of the intersected intervals defined by start_column and end_column values,
|
||||
* maxIntersectionsPosition: returns leftmost position of maximum intersection of intervals.
|
||||
*/
|
||||
|
||||
/// Similar to GroupArrayNumericData.
|
||||
template <typename T>
|
||||
struct MaxIntersectionsData
|
||||
{
|
||||
/// Left or right end of the interval and signed weight; with positive sign for begin of interval and negative sign for end of interval.
|
||||
using Value = std::pair<T, Int64>;
|
||||
|
||||
// Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena
|
||||
using Allocator = MixedAlignedArenaAllocator<alignof(Value), 4096>;
|
||||
using Array = PODArray<Value, 32, Allocator>;
|
||||
|
||||
Array value;
|
||||
};
|
||||
|
||||
enum class AggregateFunctionIntersectionsKind
|
||||
{
|
||||
Count,
|
||||
Position
|
||||
};
|
||||
|
||||
template <typename PointType>
|
||||
class AggregateFunctionIntersectionsMax final
|
||||
: public IAggregateFunctionDataHelper<MaxIntersectionsData<PointType>, AggregateFunctionIntersectionsMax<PointType>>
|
||||
{
|
||||
private:
|
||||
AggregateFunctionIntersectionsKind kind;
|
||||
|
||||
public:
|
||||
AggregateFunctionIntersectionsMax(AggregateFunctionIntersectionsKind kind_, const DataTypes & arguments)
|
||||
: IAggregateFunctionDataHelper<MaxIntersectionsData<PointType>, AggregateFunctionIntersectionsMax<PointType>>(arguments, {}, createResultType(kind_))
|
||||
, kind(kind_)
|
||||
{
|
||||
assertBinary(name, argument_types);
|
||||
assertNoParameters(name, parameters);
|
||||
if (!isNativeNumber(arguments[0]))
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "{}: first argument must be represented by integer", getName());
|
||||
|
||||
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionIntersectionsMax>(*argument_types[0], kind, argument_types));
|
||||
if (!res)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal types {} and {} of argument for aggregate function {}",
|
||||
argument_types[0]->getName(), argument_types[1]->getName(), name);
|
||||
if (!isNativeNumber(arguments[1]))
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "{}: second argument must be represented by integer", getName());
|
||||
|
||||
return res;
|
||||
if (!arguments[0]->equals(*arguments[1]))
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "{}: arguments must have the same type", getName());
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return kind == AggregateFunctionIntersectionsKind::Count
|
||||
? "maxIntersections"
|
||||
: "maxIntersectionsPosition";
|
||||
}
|
||||
|
||||
static DataTypePtr createResultType(AggregateFunctionIntersectionsKind kind_)
|
||||
{
|
||||
if (kind_ == AggregateFunctionIntersectionsKind::Count)
|
||||
return std::make_shared<DataTypeUInt64>();
|
||||
else
|
||||
return std::make_shared<DataTypeNumber<PointType>>();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
PointType left = assert_cast<const ColumnVector<PointType> &>(*columns[0]).getData()[row_num];
|
||||
PointType right = assert_cast<const ColumnVector<PointType> &>(*columns[1]).getData()[row_num];
|
||||
|
||||
if (!isNaN(left))
|
||||
this->data(place).value.push_back(std::make_pair(left, Int64(1)), arena);
|
||||
|
||||
if (!isNaN(right))
|
||||
this->data(place).value.push_back(std::make_pair(right, Int64(-1)), arena);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
auto & cur_elems = this->data(place);
|
||||
auto & rhs_elems = this->data(rhs);
|
||||
|
||||
cur_elems.value.insert(rhs_elems.value.begin(), rhs_elems.value.end(), arena);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
const auto & value = this->data(place).value;
|
||||
size_t size = value.size();
|
||||
writeVarUInt(size, buf);
|
||||
|
||||
/// In this version, pairs were serialized with padding.
|
||||
/// We must ensure that padding bytes are zero-filled.
|
||||
|
||||
static_assert(offsetof(typename MaxIntersectionsData<PointType>::Value, first) == 0);
|
||||
static_assert(offsetof(typename MaxIntersectionsData<PointType>::Value, second) > 0);
|
||||
|
||||
char zero_padding[offsetof(typename MaxIntersectionsData<PointType>::Value, second) - sizeof(value[0].first)]{};
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
writePODBinary(value[i].first, buf);
|
||||
writePODBinary(zero_padding, buf);
|
||||
if constexpr (std::endian::native == std::endian::little)
|
||||
writePODBinary(value[i].second, buf);
|
||||
else
|
||||
writePODBinary(std::byteswap(value[i].second), buf);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
size_t size = 0;
|
||||
readVarUInt(size, buf);
|
||||
|
||||
if (unlikely(size > AGGREGATE_FUNCTION_MAX_INTERSECTIONS_MAX_ARRAY_SIZE))
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
|
||||
"Too large array size (maximum: {})", AGGREGATE_FUNCTION_MAX_INTERSECTIONS_MAX_ARRAY_SIZE);
|
||||
|
||||
auto & value = this->data(place).value;
|
||||
|
||||
value.resize(size, arena);
|
||||
buf.readStrict(reinterpret_cast<char *>(value.data()), size * sizeof(value[0]));
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
Int64 current_intersections = 0;
|
||||
Int64 max_intersections = 0;
|
||||
PointType position_of_max_intersections = 0;
|
||||
|
||||
/// const_cast because we will sort the array
|
||||
auto & array = this->data(place).value;
|
||||
|
||||
/// Sort by position; for equal position, sort by weight to get deterministic result.
|
||||
::sort(array.begin(), array.end());
|
||||
|
||||
for (const auto & point_weight : array)
|
||||
{
|
||||
current_intersections += point_weight.second;
|
||||
if (current_intersections > max_intersections)
|
||||
{
|
||||
max_intersections = current_intersections;
|
||||
position_of_max_intersections = point_weight.first;
|
||||
}
|
||||
}
|
||||
|
||||
if (kind == AggregateFunctionIntersectionsKind::Count)
|
||||
{
|
||||
auto & result_column = assert_cast<ColumnUInt64 &>(to).getData();
|
||||
result_column.push_back(max_intersections);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto & result_column = assert_cast<ColumnVector<PointType> &>(to).getData();
|
||||
result_column.push_back(position_of_max_intersections);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionMaxIntersections(
|
||||
AggregateFunctionIntersectionsKind kind,
|
||||
const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
||||
{
|
||||
assertBinary(name, argument_types);
|
||||
assertNoParameters(name, parameters);
|
||||
|
||||
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionIntersectionsMax>(*argument_types[0], kind, argument_types));
|
||||
if (!res)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal types {} and {} of argument for aggregate function {}",
|
||||
argument_types[0]->getName(), argument_types[1]->getName(), name);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory & factory)
|
||||
|
@ -1,189 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <base/sort.h>
|
||||
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
#include <Common/ArenaAllocator.h>
|
||||
#include <Common/NaNUtils.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
#define AGGREGATE_FUNCTION_MAX_INTERSECTIONS_MAX_ARRAY_SIZE 0xFFFFFF
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int TOO_LARGE_ARRAY_SIZE;
|
||||
}
|
||||
|
||||
|
||||
/** maxIntersections: returns maximum count of the intersected intervals defined by start_column and end_column values,
|
||||
* maxIntersectionsPosition: returns leftmost position of maximum intersection of intervals.
|
||||
*/
|
||||
|
||||
/// Similar to GroupArrayNumericData.
|
||||
template <typename T>
|
||||
struct MaxIntersectionsData
|
||||
{
|
||||
/// Left or right end of the interval and signed weight; with positive sign for begin of interval and negative sign for end of interval.
|
||||
using Value = std::pair<T, Int64>;
|
||||
|
||||
// Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena
|
||||
using Allocator = MixedAlignedArenaAllocator<alignof(Value), 4096>;
|
||||
using Array = PODArray<Value, 32, Allocator>;
|
||||
|
||||
Array value;
|
||||
};
|
||||
|
||||
enum class AggregateFunctionIntersectionsKind
|
||||
{
|
||||
Count,
|
||||
Position
|
||||
};
|
||||
|
||||
template <typename PointType>
|
||||
class AggregateFunctionIntersectionsMax final
|
||||
: public IAggregateFunctionDataHelper<MaxIntersectionsData<PointType>, AggregateFunctionIntersectionsMax<PointType>>
|
||||
{
|
||||
private:
|
||||
AggregateFunctionIntersectionsKind kind;
|
||||
|
||||
public:
|
||||
AggregateFunctionIntersectionsMax(AggregateFunctionIntersectionsKind kind_, const DataTypes & arguments)
|
||||
: IAggregateFunctionDataHelper<MaxIntersectionsData<PointType>, AggregateFunctionIntersectionsMax<PointType>>(arguments, {}, createResultType(kind_))
|
||||
, kind(kind_)
|
||||
{
|
||||
if (!isNativeNumber(arguments[0]))
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "{}: first argument must be represented by integer", getName());
|
||||
|
||||
if (!isNativeNumber(arguments[1]))
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "{}: second argument must be represented by integer", getName());
|
||||
|
||||
if (!arguments[0]->equals(*arguments[1]))
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "{}: arguments must have the same type", getName());
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return kind == AggregateFunctionIntersectionsKind::Count
|
||||
? "maxIntersections"
|
||||
: "maxIntersectionsPosition";
|
||||
}
|
||||
|
||||
static DataTypePtr createResultType(AggregateFunctionIntersectionsKind kind_)
|
||||
{
|
||||
if (kind_ == AggregateFunctionIntersectionsKind::Count)
|
||||
return std::make_shared<DataTypeUInt64>();
|
||||
else
|
||||
return std::make_shared<DataTypeNumber<PointType>>();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
PointType left = assert_cast<const ColumnVector<PointType> &>(*columns[0]).getData()[row_num];
|
||||
PointType right = assert_cast<const ColumnVector<PointType> &>(*columns[1]).getData()[row_num];
|
||||
|
||||
if (!isNaN(left))
|
||||
this->data(place).value.push_back(std::make_pair(left, Int64(1)), arena);
|
||||
|
||||
if (!isNaN(right))
|
||||
this->data(place).value.push_back(std::make_pair(right, Int64(-1)), arena);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
auto & cur_elems = this->data(place);
|
||||
auto & rhs_elems = this->data(rhs);
|
||||
|
||||
cur_elems.value.insert(rhs_elems.value.begin(), rhs_elems.value.end(), arena);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
const auto & value = this->data(place).value;
|
||||
size_t size = value.size();
|
||||
writeVarUInt(size, buf);
|
||||
|
||||
/// In this version, pairs were serialized with padding.
|
||||
/// We must ensure that padding bytes are zero-filled.
|
||||
|
||||
static_assert(offsetof(typename MaxIntersectionsData<PointType>::Value, first) == 0);
|
||||
static_assert(offsetof(typename MaxIntersectionsData<PointType>::Value, second) > 0);
|
||||
|
||||
char zero_padding[offsetof(typename MaxIntersectionsData<PointType>::Value, second) - sizeof(value[0].first)]{};
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
writePODBinary(value[i].first, buf);
|
||||
writePODBinary(zero_padding, buf);
|
||||
if constexpr (std::endian::native == std::endian::little)
|
||||
writePODBinary(value[i].second, buf);
|
||||
else
|
||||
writePODBinary(std::byteswap(value[i].second), buf);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
size_t size = 0;
|
||||
readVarUInt(size, buf);
|
||||
|
||||
if (unlikely(size > AGGREGATE_FUNCTION_MAX_INTERSECTIONS_MAX_ARRAY_SIZE))
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
|
||||
"Too large array size (maximum: {})", AGGREGATE_FUNCTION_MAX_INTERSECTIONS_MAX_ARRAY_SIZE);
|
||||
|
||||
auto & value = this->data(place).value;
|
||||
|
||||
value.resize(size, arena);
|
||||
buf.readStrict(reinterpret_cast<char *>(value.data()), size * sizeof(value[0]));
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
Int64 current_intersections = 0;
|
||||
Int64 max_intersections = 0;
|
||||
PointType position_of_max_intersections = 0;
|
||||
|
||||
/// const_cast because we will sort the array
|
||||
auto & array = this->data(place).value;
|
||||
|
||||
/// Sort by position; for equal position, sort by weight to get deterministic result.
|
||||
::sort(array.begin(), array.end());
|
||||
|
||||
for (const auto & point_weight : array)
|
||||
{
|
||||
current_intersections += point_weight.second;
|
||||
if (current_intersections > max_intersections)
|
||||
{
|
||||
max_intersections = current_intersections;
|
||||
position_of_max_intersections = point_weight.first;
|
||||
}
|
||||
}
|
||||
|
||||
if (kind == AggregateFunctionIntersectionsKind::Count)
|
||||
{
|
||||
auto & result_column = assert_cast<ColumnUInt64 &>(to).getData();
|
||||
result_column.push_back(max_intersections);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto & result_column = assert_cast<ColumnVector<PointType> &>(to).getData();
|
||||
result_column.push_back(position_of_max_intersections);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -1,8 +1,16 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionMeanZTest.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <AggregateFunctions/Moments.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/StatCommon.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <cmath>
|
||||
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
@ -18,6 +26,121 @@ struct Settings;
|
||||
namespace
|
||||
{
|
||||
|
||||
/// Returns tuple of (z-statistic, p-value, confidence-interval-low, confidence-interval-high)
|
||||
template <typename Data>
|
||||
class AggregateFunctionMeanZTest :
|
||||
public IAggregateFunctionDataHelper<Data, AggregateFunctionMeanZTest<Data>>
|
||||
{
|
||||
private:
|
||||
Float64 pop_var_x;
|
||||
Float64 pop_var_y;
|
||||
Float64 confidence_level;
|
||||
|
||||
public:
|
||||
AggregateFunctionMeanZTest(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionMeanZTest<Data>>({arguments}, params, createResultType())
|
||||
{
|
||||
pop_var_x = params.at(0).safeGet<Float64>();
|
||||
pop_var_y = params.at(1).safeGet<Float64>();
|
||||
confidence_level = params.at(2).safeGet<Float64>();
|
||||
|
||||
if (!std::isfinite(pop_var_x) || !std::isfinite(pop_var_y) || !std::isfinite(confidence_level))
|
||||
{
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} requires finite parameter values.", Data::name);
|
||||
}
|
||||
|
||||
if (pop_var_x < 0.0 || pop_var_y < 0.0)
|
||||
{
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS,
|
||||
"Population variance parameters must be larger than or equal to zero "
|
||||
"in aggregate function {}.", Data::name);
|
||||
}
|
||||
|
||||
if (confidence_level <= 0.0 || confidence_level >= 1.0)
|
||||
{
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Confidence level parameter must be between 0 and 1 in aggregate function {}.", Data::name);
|
||||
}
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return Data::name;
|
||||
}
|
||||
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
DataTypes types
|
||||
{
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
};
|
||||
|
||||
Strings names
|
||||
{
|
||||
"z_statistic",
|
||||
"p_value",
|
||||
"confidence_interval_low",
|
||||
"confidence_interval_high"
|
||||
};
|
||||
|
||||
return std::make_shared<DataTypeTuple>(
|
||||
std::move(types),
|
||||
std::move(names)
|
||||
);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
Float64 value = columns[0]->getFloat64(row_num);
|
||||
UInt8 is_second = columns[1]->getUInt(row_num);
|
||||
|
||||
if (is_second)
|
||||
this->data(place).addY(value);
|
||||
else
|
||||
this->data(place).addX(value);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).read(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
auto [z_stat, p_value] = this->data(place).getResult(pop_var_x, pop_var_y);
|
||||
auto [ci_low, ci_high] = this->data(place).getConfidenceIntervals(pop_var_x, pop_var_y, confidence_level);
|
||||
|
||||
/// Because p-value is a probability.
|
||||
p_value = std::min(1.0, std::max(0.0, p_value));
|
||||
|
||||
auto & column_tuple = assert_cast<ColumnTuple &>(to);
|
||||
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
|
||||
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
|
||||
auto & column_ci_low = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(2));
|
||||
auto & column_ci_high = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(3));
|
||||
|
||||
column_stat.getData().push_back(z_stat);
|
||||
column_value.getData().push_back(p_value);
|
||||
column_ci_low.getData().push_back(ci_low);
|
||||
column_ci_high.getData().push_back(ci_high);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct MeanZTestData : public ZTestMoments<Float64>
|
||||
{
|
||||
static constexpr auto name = "meanZTest";
|
||||
|
@ -1,141 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/StatCommon.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <Core/Types.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <cmath>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
class ReadBuffer;
|
||||
class WriteBuffer;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
|
||||
/// Returns tuple of (z-statistic, p-value, confidence-interval-low, confidence-interval-high)
|
||||
template <typename Data>
|
||||
class AggregateFunctionMeanZTest :
|
||||
public IAggregateFunctionDataHelper<Data, AggregateFunctionMeanZTest<Data>>
|
||||
{
|
||||
private:
|
||||
Float64 pop_var_x;
|
||||
Float64 pop_var_y;
|
||||
Float64 confidence_level;
|
||||
|
||||
public:
|
||||
AggregateFunctionMeanZTest(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionMeanZTest<Data>>({arguments}, params, createResultType())
|
||||
{
|
||||
pop_var_x = params.at(0).safeGet<Float64>();
|
||||
pop_var_y = params.at(1).safeGet<Float64>();
|
||||
confidence_level = params.at(2).safeGet<Float64>();
|
||||
|
||||
if (!std::isfinite(pop_var_x) || !std::isfinite(pop_var_y) || !std::isfinite(confidence_level))
|
||||
{
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} requires finite parameter values.", Data::name);
|
||||
}
|
||||
|
||||
if (pop_var_x < 0.0 || pop_var_y < 0.0)
|
||||
{
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS,
|
||||
"Population variance parameters must be larger than or equal to zero "
|
||||
"in aggregate function {}.", Data::name);
|
||||
}
|
||||
|
||||
if (confidence_level <= 0.0 || confidence_level >= 1.0)
|
||||
{
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Confidence level parameter must be between 0 and 1 in aggregate function {}.", Data::name);
|
||||
}
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return Data::name;
|
||||
}
|
||||
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
DataTypes types
|
||||
{
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
};
|
||||
|
||||
Strings names
|
||||
{
|
||||
"z_statistic",
|
||||
"p_value",
|
||||
"confidence_interval_low",
|
||||
"confidence_interval_high"
|
||||
};
|
||||
|
||||
return std::make_shared<DataTypeTuple>(
|
||||
std::move(types),
|
||||
std::move(names)
|
||||
);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
Float64 value = columns[0]->getFloat64(row_num);
|
||||
UInt8 is_second = columns[1]->getUInt(row_num);
|
||||
|
||||
if (is_second)
|
||||
this->data(place).addY(value);
|
||||
else
|
||||
this->data(place).addX(value);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).read(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
auto [z_stat, p_value] = this->data(place).getResult(pop_var_x, pop_var_y);
|
||||
auto [ci_low, ci_high] = this->data(place).getConfidenceIntervals(pop_var_x, pop_var_y, confidence_level);
|
||||
|
||||
/// Because p-value is a probability.
|
||||
p_value = std::min(1.0, std::max(0.0, p_value));
|
||||
|
||||
auto & column_tuple = assert_cast<ColumnTuple &>(to);
|
||||
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
|
||||
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
|
||||
auto & column_ci_low = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(2));
|
||||
auto & column_ci_high = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(3));
|
||||
|
||||
column_stat.getData().push_back(z_stat);
|
||||
column_value.getData().push_back(p_value);
|
||||
column_ci_low.getData().push_back(ci_low);
|
||||
column_ci_high.getData().push_back(ci_high);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -1,5 +1,5 @@
|
||||
#include <AggregateFunctions/AggregateFunctionQuantile.h>
|
||||
#include <AggregateFunctions/QuantileReservoirSampler.h>
|
||||
#include <AggregateFunctions/ReservoirSampler.h>
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <DataTypes/DataTypeDate.h>
|
||||
@ -9,18 +9,108 @@
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename Value, bool float_return> using FuncQuantile = AggregateFunctionQuantile<Value, QuantileReservoirSampler<Value>, NameQuantile, false, std::conditional_t<float_return, Float64, void>, false>;
|
||||
template <typename Value, bool float_return> using FuncQuantiles = AggregateFunctionQuantile<Value, QuantileReservoirSampler<Value>, NameQuantiles, false, std::conditional_t<float_return, Float64, void>, true>;
|
||||
/** Quantile calculation with "reservoir sample" algorithm.
|
||||
* It collects pseudorandom subset of limited size from a stream of values,
|
||||
* and approximate quantile from it.
|
||||
* The result is non-deterministic. Also look at QuantileReservoirSamplerDeterministic.
|
||||
*
|
||||
* This algorithm is quite inefficient in terms of precision for memory usage,
|
||||
* but very efficient in CPU (though less efficient than QuantileTiming and than QuantileExact for small sets).
|
||||
*/
|
||||
template <typename Value>
|
||||
struct QuantileReservoirSampler
|
||||
{
|
||||
using Data = ReservoirSampler<Value, ReservoirSamplerOnEmpty::RETURN_NAN_OR_ZERO>;
|
||||
Data data;
|
||||
|
||||
void add(const Value & x)
|
||||
{
|
||||
data.insert(x);
|
||||
}
|
||||
|
||||
template <typename Weight>
|
||||
void add(const Value &, const Weight &)
|
||||
{
|
||||
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Method add with weight is not implemented for ReservoirSampler");
|
||||
}
|
||||
|
||||
void merge(const QuantileReservoirSampler & rhs)
|
||||
{
|
||||
data.merge(rhs.data);
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
data.write(buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
data.read(buf);
|
||||
}
|
||||
|
||||
/// Get the value of the `level` quantile. The level must be between 0 and 1.
|
||||
Value get(Float64 level)
|
||||
{
|
||||
if (data.empty())
|
||||
return {};
|
||||
|
||||
if constexpr (is_decimal<Value>)
|
||||
return Value(static_cast<typename Value::NativeType>(data.quantileInterpolated(level)));
|
||||
else
|
||||
return static_cast<Value>(data.quantileInterpolated(level));
|
||||
}
|
||||
|
||||
/// Get the `size` values of `levels` quantiles. Write `size` results starting with `result` address.
|
||||
/// indices - an array of index levels such that the corresponding elements will go in ascending order.
|
||||
void getMany(const Float64 * levels, const size_t * indices, size_t size, Value * result)
|
||||
{
|
||||
bool is_empty = data.empty();
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
if (is_empty)
|
||||
{
|
||||
result[i] = Value{};
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr (is_decimal<Value>)
|
||||
result[indices[i]] = Value(static_cast<typename Value::NativeType>(data.quantileInterpolated(levels[indices[i]])));
|
||||
else
|
||||
result[indices[i]] = Value(data.quantileInterpolated(levels[indices[i]]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The same, but in the case of an empty state, NaN is returned.
|
||||
Float64 getFloat(Float64 level)
|
||||
{
|
||||
return data.quantileInterpolated(level);
|
||||
}
|
||||
|
||||
void getManyFloat(const Float64 * levels, const size_t * indices, size_t size, Float64 * result)
|
||||
{
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
result[indices[i]] = data.quantileInterpolated(levels[indices[i]]);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename Value, bool float_return> using FuncQuantile = AggregateFunctionQuantile<Value, QuantileReservoirSampler<Value>, NameQuantile, false, std::conditional_t<float_return, Float64, void>, false, false>;
|
||||
template <typename Value, bool float_return> using FuncQuantiles = AggregateFunctionQuantile<Value, QuantileReservoirSampler<Value>, NameQuantiles, false, std::conditional_t<float_return, Float64, void>, true, false>;
|
||||
|
||||
template <template <typename, bool> class Function>
|
||||
AggregateFunctionPtr createAggregateFunctionQuantile(
|
||||
|
@ -54,15 +54,16 @@ template <
|
||||
typename FloatReturnType,
|
||||
/// If true, the function will accept multiple parameters with quantile levels
|
||||
/// and return an Array filled with many values of that quantiles.
|
||||
bool returns_many>
|
||||
bool returns_many,
|
||||
/// If the first parameter (before level) is accuracy.
|
||||
bool has_accuracy_parameter>
|
||||
class AggregateFunctionQuantile final
|
||||
: public IAggregateFunctionDataHelper<Data, AggregateFunctionQuantile<Value, Data, Name, has_second_arg, FloatReturnType, returns_many>>
|
||||
: public IAggregateFunctionDataHelper<Data, AggregateFunctionQuantile<Value, Data, Name, has_second_arg, FloatReturnType, returns_many, has_accuracy_parameter>>
|
||||
{
|
||||
private:
|
||||
using ColVecType = ColumnVectorOrDecimal<Value>;
|
||||
|
||||
static constexpr bool returns_float = !(std::is_same_v<FloatReturnType, void>);
|
||||
static constexpr bool is_quantile_gk = std::is_same_v<Data, QuantileGK<Value>>;
|
||||
static_assert(!is_decimal<Value> || !returns_float);
|
||||
|
||||
QuantileLevels<Float64> levels;
|
||||
@ -77,16 +78,16 @@ private:
|
||||
|
||||
public:
|
||||
AggregateFunctionQuantile(const DataTypes & argument_types_, const Array & params)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionQuantile<Value, Data, Name, has_second_arg, FloatReturnType, returns_many>>(
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionQuantile<Value, Data, Name, has_second_arg, FloatReturnType, returns_many, has_accuracy_parameter>>(
|
||||
argument_types_, params, createResultType(argument_types_))
|
||||
, levels(is_quantile_gk && !params.empty() ? Array(params.begin() + 1, params.end()) : params, returns_many)
|
||||
, levels(has_accuracy_parameter && !params.empty() ? Array(params.begin() + 1, params.end()) : params, returns_many)
|
||||
, level(levels.levels[0])
|
||||
, argument_type(this->argument_types[0])
|
||||
{
|
||||
if (!returns_many && levels.size() > 1)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires one level parameter or less", getName());
|
||||
|
||||
if constexpr (is_quantile_gk)
|
||||
if constexpr (has_accuracy_parameter)
|
||||
{
|
||||
if (params.empty())
|
||||
throw Exception(
|
||||
@ -115,7 +116,7 @@ public:
|
||||
|
||||
void create(AggregateDataPtr __restrict place) const override /// NOLINT
|
||||
{
|
||||
if constexpr (is_quantile_gk)
|
||||
if constexpr (has_accuracy_parameter)
|
||||
new (place) Data(accuracy);
|
||||
else
|
||||
new (place) Data;
|
||||
|
@ -1,71 +0,0 @@
|
||||
#include <AggregateFunctions/AggregateFunctionQuantile.h>
|
||||
#include <AggregateFunctions/QuantileApprox.h>
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <DataTypes/DataTypeDate.h>
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
#include <Core/Field.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename Value, bool _> using FuncQuantileGK = AggregateFunctionQuantile<Value, QuantileGK<Value>, NameQuantileGK, false, void, false>;
|
||||
template <typename Value, bool _> using FuncQuantilesGK = AggregateFunctionQuantile<Value, QuantileGK<Value>, NameQuantilesGK, false, void, true>;
|
||||
|
||||
template <template <typename, bool> class Function>
|
||||
AggregateFunctionPtr createAggregateFunctionQuantile(
|
||||
const std::string & name, const DataTypes & argument_types, const Array & params, const Settings *)
|
||||
{
|
||||
/// Second argument type check doesn't depend on the type of the first one.
|
||||
Function<void, true>::assertSecondArg(argument_types);
|
||||
|
||||
const DataTypePtr & argument_type = argument_types[0];
|
||||
WhichDataType which(argument_type);
|
||||
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) \
|
||||
return std::make_shared<Function<TYPE, true>>(argument_types, params);
|
||||
FOR_BASIC_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
|
||||
if (which.idx == TypeIndex::Date) return std::make_shared<Function<DataTypeDate::FieldType, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::DateTime) return std::make_shared<Function<DataTypeDateTime::FieldType, false>>(argument_types, params);
|
||||
|
||||
if (which.idx == TypeIndex::Decimal32) return std::make_shared<Function<Decimal32, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::Decimal64) return std::make_shared<Function<Decimal64, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::Decimal128) return std::make_shared<Function<Decimal128, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::Decimal256) return std::make_shared<Function<Decimal256, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::DateTime64) return std::make_shared<Function<DateTime64, false>>(argument_types, params);
|
||||
|
||||
if (which.idx == TypeIndex::Int128) return std::make_shared<Function<Int128, true>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::UInt128) return std::make_shared<Function<UInt128, true>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::Int256) return std::make_shared<Function<Int256, true>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::UInt256) return std::make_shared<Function<UInt256, true>>(argument_types, params);
|
||||
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}",
|
||||
argument_type->getName(), name);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionsQuantileApprox(AggregateFunctionFactory & factory)
|
||||
{
|
||||
/// For aggregate functions returning array we cannot return NULL on empty set.
|
||||
AggregateFunctionProperties properties = { .returns_default_when_only_null = true };
|
||||
|
||||
factory.registerFunction(NameQuantileGK::name, createAggregateFunctionQuantile<FuncQuantileGK>);
|
||||
factory.registerFunction(NameQuantilesGK::name, {createAggregateFunctionQuantile<FuncQuantilesGK>, properties});
|
||||
|
||||
/// 'median' is an alias for 'quantile'
|
||||
factory.registerAlias("medianGK", NameQuantileGK::name);
|
||||
}
|
||||
|
||||
}
|
@ -19,8 +19,8 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename Value, bool float_return> using FuncQuantileBFloat16 = AggregateFunctionQuantile<Value, QuantileBFloat16Histogram<Value>, NameQuantileBFloat16, false, std::conditional_t<float_return, Float64, void>, false>;
|
||||
template <typename Value, bool float_return> using FuncQuantilesBFloat16 = AggregateFunctionQuantile<Value, QuantileBFloat16Histogram<Value>, NameQuantilesBFloat16, false, std::conditional_t<float_return, Float64, void>, true>;
|
||||
template <typename Value, bool float_return> using FuncQuantileBFloat16 = AggregateFunctionQuantile<Value, QuantileBFloat16Histogram<Value>, NameQuantileBFloat16, false, std::conditional_t<float_return, Float64, void>, false, false>;
|
||||
template <typename Value, bool float_return> using FuncQuantilesBFloat16 = AggregateFunctionQuantile<Value, QuantileBFloat16Histogram<Value>, NameQuantilesBFloat16, false, std::conditional_t<float_return, Float64, void>, true, false>;
|
||||
|
||||
template <template <typename, bool> class Function>
|
||||
AggregateFunctionPtr createAggregateFunctionQuantile(
|
||||
|
@ -19,8 +19,8 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename Value, bool float_return> using FuncQuantileBFloat16Weighted = AggregateFunctionQuantile<Value, QuantileBFloat16Histogram<Value>, NameQuantileBFloat16Weighted, true, std::conditional_t<float_return, Float64, void>, false>;
|
||||
template <typename Value, bool float_return> using FuncQuantilesBFloat16Weighted = AggregateFunctionQuantile<Value, QuantileBFloat16Histogram<Value>, NameQuantilesBFloat16Weighted, true, std::conditional_t<float_return, Float64, void>, true>;
|
||||
template <typename Value, bool float_return> using FuncQuantileBFloat16Weighted = AggregateFunctionQuantile<Value, QuantileBFloat16Histogram<Value>, NameQuantileBFloat16Weighted, true, std::conditional_t<float_return, Float64, void>, false, false>;
|
||||
template <typename Value, bool float_return> using FuncQuantilesBFloat16Weighted = AggregateFunctionQuantile<Value, QuantileBFloat16Histogram<Value>, NameQuantilesBFloat16Weighted, true, std::conditional_t<float_return, Float64, void>, true, false>;
|
||||
|
||||
template <template <typename, bool> class Function>
|
||||
AggregateFunctionPtr createAggregateFunctionQuantile(
|
||||
|
@ -1,5 +1,5 @@
|
||||
#include <AggregateFunctions/AggregateFunctionQuantile.h>
|
||||
#include <AggregateFunctions/QuantileReservoirSamplerDeterministic.h>
|
||||
#include <AggregateFunctions/ReservoirSamplerDeterministic.h>
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <DataTypes/DataTypeDate.h>
|
||||
@ -9,18 +9,108 @@
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename Value, bool float_return> using FuncQuantileDeterministic = AggregateFunctionQuantile<Value, QuantileReservoirSamplerDeterministic<Value>, NameQuantileDeterministic, true, std::conditional_t<float_return, Float64, void>, false>;
|
||||
template <typename Value, bool float_return> using FuncQuantilesDeterministic = AggregateFunctionQuantile<Value, QuantileReservoirSamplerDeterministic<Value>, NameQuantilesDeterministic, true, std::conditional_t<float_return, Float64, void>, true>;
|
||||
/** Quantile calculation with "reservoir sample" algorithm.
|
||||
* It collects pseudorandom subset of limited size from a stream of values,
|
||||
* and approximate quantile from it.
|
||||
* The function accept second argument, named "determinator"
|
||||
* and a hash function from it is calculated and used as a source for randomness
|
||||
* to apply random sampling.
|
||||
* The function is deterministic, but care should be taken with choose of "determinator" argument.
|
||||
*/
|
||||
template <typename Value>
|
||||
struct QuantileReservoirSamplerDeterministic
|
||||
{
|
||||
using Data = ReservoirSamplerDeterministic<Value, ReservoirSamplerDeterministicOnEmpty::RETURN_NAN_OR_ZERO>;
|
||||
Data data;
|
||||
|
||||
void add(const Value &)
|
||||
{
|
||||
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Method add without determinator is not implemented for ReservoirSamplerDeterministic");
|
||||
}
|
||||
|
||||
template <typename Determinator>
|
||||
void add(const Value & x, const Determinator & determinator)
|
||||
{
|
||||
data.insert(x, determinator);
|
||||
}
|
||||
|
||||
void merge(const QuantileReservoirSamplerDeterministic & rhs)
|
||||
{
|
||||
data.merge(rhs.data);
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
data.write(buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
data.read(buf);
|
||||
}
|
||||
|
||||
/// Get the value of the `level` quantile. The level must be between 0 and 1.
|
||||
Value get(Float64 level)
|
||||
{
|
||||
if (data.empty())
|
||||
return {};
|
||||
|
||||
if constexpr (is_decimal<Value>)
|
||||
return static_cast<typename Value::NativeType>(data.quantileInterpolated(level));
|
||||
else
|
||||
return static_cast<Value>(data.quantileInterpolated(level));
|
||||
}
|
||||
|
||||
/// Get the `size` values of `levels` quantiles. Write `size` results starting with `result` address.
|
||||
/// indices - an array of index levels such that the corresponding elements will go in ascending order.
|
||||
void getMany(const Float64 * levels, const size_t * indices, size_t size, Value * result)
|
||||
{
|
||||
bool is_empty = data.empty();
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
if (is_empty)
|
||||
{
|
||||
result[i] = Value{};
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr (is_decimal<Value>)
|
||||
result[indices[i]] = static_cast<typename Value::NativeType>(data.quantileInterpolated(levels[indices[i]]));
|
||||
else
|
||||
result[indices[i]] = static_cast<Value>(data.quantileInterpolated(levels[indices[i]]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The same, but in the case of an empty state, NaN is returned.
|
||||
Float64 getFloat(Float64 level)
|
||||
{
|
||||
return data.quantileInterpolated(level);
|
||||
}
|
||||
|
||||
void getManyFloat(const Float64 * levels, const size_t * indices, size_t size, Float64 * result)
|
||||
{
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
result[indices[i]] = data.quantileInterpolated(levels[indices[i]]);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename Value, bool float_return> using FuncQuantileDeterministic = AggregateFunctionQuantile<Value, QuantileReservoirSamplerDeterministic<Value>, NameQuantileDeterministic, true, std::conditional_t<float_return, Float64, void>, false, false>;
|
||||
template <typename Value, bool float_return> using FuncQuantilesDeterministic = AggregateFunctionQuantile<Value, QuantileReservoirSamplerDeterministic<Value>, NameQuantilesDeterministic, true, std::conditional_t<float_return, Float64, void>, true, false>;
|
||||
|
||||
template <template <typename, bool> class Function>
|
||||
AggregateFunctionPtr createAggregateFunctionQuantile(
|
||||
|
@ -19,8 +19,8 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename Value, bool _> using FuncQuantileExact = AggregateFunctionQuantile<Value, QuantileExact<Value>, NameQuantileExact, false, void, false>;
|
||||
template <typename Value, bool _> using FuncQuantilesExact = AggregateFunctionQuantile<Value, QuantileExact<Value>, NameQuantilesExact, false, void, true>;
|
||||
template <typename Value, bool _> using FuncQuantileExact = AggregateFunctionQuantile<Value, QuantileExact<Value>, NameQuantileExact, false, void, false, false>;
|
||||
template <typename Value, bool _> using FuncQuantilesExact = AggregateFunctionQuantile<Value, QuantileExact<Value>, NameQuantilesExact, false, void, true, false>;
|
||||
|
||||
|
||||
template <template <typename, bool> class Function>
|
||||
|
@ -19,8 +19,8 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename Value, bool _> using FuncQuantileExactExclusive = AggregateFunctionQuantile<Value, QuantileExactExclusive<Value>, NameQuantileExactExclusive, false, Float64, false>;
|
||||
template <typename Value, bool _> using FuncQuantilesExactExclusive = AggregateFunctionQuantile<Value, QuantileExactExclusive<Value>, NameQuantilesExactExclusive, false, Float64, true>;
|
||||
template <typename Value, bool _> using FuncQuantileExactExclusive = AggregateFunctionQuantile<Value, QuantileExactExclusive<Value>, NameQuantileExactExclusive, false, Float64, false, false>;
|
||||
template <typename Value, bool _> using FuncQuantilesExactExclusive = AggregateFunctionQuantile<Value, QuantileExactExclusive<Value>, NameQuantilesExactExclusive, false, Float64, true, false>;
|
||||
|
||||
template <template <typename, bool> class Function>
|
||||
AggregateFunctionPtr createAggregateFunctionQuantile(
|
||||
|
@ -19,8 +19,8 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename Value, bool _> using FuncQuantileExactHigh = AggregateFunctionQuantile<Value, QuantileExactHigh<Value>, NameQuantileExactHigh, false, void, false>;
|
||||
template <typename Value, bool _> using FuncQuantilesExactHigh = AggregateFunctionQuantile<Value, QuantileExactHigh<Value>, NameQuantilesExactHigh, false, void, true>;
|
||||
template <typename Value, bool _> using FuncQuantileExactHigh = AggregateFunctionQuantile<Value, QuantileExactHigh<Value>, NameQuantileExactHigh, false, void, false, false>;
|
||||
template <typename Value, bool _> using FuncQuantilesExactHigh = AggregateFunctionQuantile<Value, QuantileExactHigh<Value>, NameQuantilesExactHigh, false, void, true, false>;
|
||||
|
||||
|
||||
template <template <typename, bool> class Function>
|
||||
|
@ -19,8 +19,8 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename Value, bool _> using FuncQuantileExactInclusive = AggregateFunctionQuantile<Value, QuantileExactInclusive<Value>, NameQuantileExactInclusive, false, Float64, false>;
|
||||
template <typename Value, bool _> using FuncQuantilesExactInclusive = AggregateFunctionQuantile<Value, QuantileExactInclusive<Value>, NameQuantilesExactInclusive, false, Float64, true>;
|
||||
template <typename Value, bool _> using FuncQuantileExactInclusive = AggregateFunctionQuantile<Value, QuantileExactInclusive<Value>, NameQuantileExactInclusive, false, Float64, false, false>;
|
||||
template <typename Value, bool _> using FuncQuantilesExactInclusive = AggregateFunctionQuantile<Value, QuantileExactInclusive<Value>, NameQuantilesExactInclusive, false, Float64, true, false>;
|
||||
|
||||
template <template <typename, bool> class Function>
|
||||
AggregateFunctionPtr createAggregateFunctionQuantile(
|
||||
|
@ -19,8 +19,8 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename Value, bool _> using FuncQuantileExactLow = AggregateFunctionQuantile<Value, QuantileExactLow<Value>, NameQuantileExactLow, false, void, false>;
|
||||
template <typename Value, bool _> using FuncQuantilesExactLow = AggregateFunctionQuantile<Value, QuantileExactLow<Value>, NameQuantilesExactLow, false, void, true>;
|
||||
template <typename Value, bool _> using FuncQuantileExactLow = AggregateFunctionQuantile<Value, QuantileExactLow<Value>, NameQuantileExactLow, false, void, false, false>;
|
||||
template <typename Value, bool _> using FuncQuantilesExactLow = AggregateFunctionQuantile<Value, QuantileExactLow<Value>, NameQuantilesExactLow, false, void, true, false>;
|
||||
|
||||
template <template <typename, bool> class Function>
|
||||
AggregateFunctionPtr createAggregateFunctionQuantile(
|
||||
|
@ -1,26 +1,216 @@
|
||||
#include <AggregateFunctions/AggregateFunctionQuantile.h>
|
||||
#include <AggregateFunctions/QuantileExactWeighted.h>
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <DataTypes/DataTypeDate.h>
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
#include <Core/Field.h>
|
||||
|
||||
#include <Common/HashTable/HashMap.h>
|
||||
#include <Common/NaNUtils.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename Value, bool _> using FuncQuantileExactWeighted = AggregateFunctionQuantile<Value, QuantileExactWeighted<Value>, NameQuantileExactWeighted, true, void, false>;
|
||||
template <typename Value, bool _> using FuncQuantilesExactWeighted = AggregateFunctionQuantile<Value, QuantileExactWeighted<Value>, NameQuantilesExactWeighted, true, void, true>;
|
||||
/** Calculates quantile by counting number of occurrences for each value in a hash map.
|
||||
*
|
||||
* It uses O(distinct(N)) memory. Can be naturally applied for values with weight.
|
||||
* In case of many identical values, it can be more efficient than QuantileExact even when weight is not used.
|
||||
*/
|
||||
template <typename Value>
|
||||
struct QuantileExactWeighted
|
||||
{
|
||||
struct Int128Hash
|
||||
{
|
||||
size_t operator()(Int128 x) const
|
||||
{
|
||||
return CityHash_v1_0_2::Hash128to64({x >> 64, x & 0xffffffffffffffffll});
|
||||
}
|
||||
};
|
||||
|
||||
using Weight = UInt64;
|
||||
using UnderlyingType = NativeType<Value>;
|
||||
using Hasher = HashCRC32<UnderlyingType>;
|
||||
|
||||
/// When creating, the hash table must be small.
|
||||
using Map = HashMapWithStackMemory<UnderlyingType, Weight, Hasher, 4>;
|
||||
|
||||
Map map;
|
||||
|
||||
void add(const Value & x)
|
||||
{
|
||||
/// We must skip NaNs as they are not compatible with comparison sorting.
|
||||
if (!isNaN(x))
|
||||
++map[x];
|
||||
}
|
||||
|
||||
void add(const Value & x, Weight weight)
|
||||
{
|
||||
if (!isNaN(x))
|
||||
map[x] += weight;
|
||||
}
|
||||
|
||||
void merge(const QuantileExactWeighted & rhs)
|
||||
{
|
||||
for (const auto & pair : rhs.map)
|
||||
map[pair.getKey()] += pair.getMapped();
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
map.write(buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
typename Map::Reader reader(buf);
|
||||
while (reader.next())
|
||||
{
|
||||
const auto & pair = reader.get();
|
||||
map[pair.first] = pair.second;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the value of the `level` quantile. The level must be between 0 and 1.
|
||||
Value get(Float64 level) const
|
||||
{
|
||||
size_t size = map.size();
|
||||
|
||||
if (0 == size)
|
||||
return std::numeric_limits<Value>::quiet_NaN();
|
||||
|
||||
/// Copy the data to a temporary array to get the element you need in order.
|
||||
using Pair = typename Map::value_type;
|
||||
std::unique_ptr<Pair[]> array_holder(new Pair[size]);
|
||||
Pair * array = array_holder.get();
|
||||
|
||||
/// Note: 64-bit integer weight can overflow.
|
||||
/// We do some implementation specific behaviour (return approximate or garbage results).
|
||||
/// Float64 is used as accumulator here to get approximate results.
|
||||
/// But weight can be already overflowed in computations in 'add' and 'merge' methods.
|
||||
/// It will be reasonable to change the type of weight to Float64 in the map,
|
||||
/// but we don't do that for compatibility of serialized data.
|
||||
|
||||
size_t i = 0;
|
||||
Float64 sum_weight = 0;
|
||||
for (const auto & pair : map)
|
||||
{
|
||||
sum_weight += pair.getMapped();
|
||||
array[i] = pair.getValue();
|
||||
++i;
|
||||
}
|
||||
|
||||
::sort(array, array + size, [](const Pair & a, const Pair & b) { return a.first < b.first; });
|
||||
|
||||
Float64 threshold = std::ceil(sum_weight * level);
|
||||
Float64 accumulated = 0;
|
||||
|
||||
const Pair * it = array;
|
||||
const Pair * end = array + size;
|
||||
while (it < end)
|
||||
{
|
||||
accumulated += it->second;
|
||||
|
||||
if (accumulated >= threshold)
|
||||
break;
|
||||
|
||||
++it;
|
||||
}
|
||||
|
||||
if (it == end)
|
||||
--it;
|
||||
|
||||
return it->first;
|
||||
}
|
||||
|
||||
/// Get the `size` values of `levels` quantiles. Write `size` results starting with `result` address.
|
||||
/// indices - an array of index levels such that the corresponding elements will go in ascending order.
|
||||
void getMany(const Float64 * levels, const size_t * indices, size_t num_levels, Value * result) const
|
||||
{
|
||||
size_t size = map.size();
|
||||
|
||||
if (0 == size)
|
||||
{
|
||||
for (size_t i = 0; i < num_levels; ++i)
|
||||
result[i] = Value();
|
||||
return;
|
||||
}
|
||||
|
||||
/// Copy the data to a temporary array to get the element you need in order.
|
||||
using Pair = typename Map::value_type;
|
||||
std::unique_ptr<Pair[]> array_holder(new Pair[size]);
|
||||
Pair * array = array_holder.get();
|
||||
|
||||
size_t i = 0;
|
||||
Float64 sum_weight = 0;
|
||||
for (const auto & pair : map)
|
||||
{
|
||||
sum_weight += pair.getMapped();
|
||||
array[i] = pair.getValue();
|
||||
++i;
|
||||
}
|
||||
|
||||
::sort(array, array + size, [](const Pair & a, const Pair & b) { return a.first < b.first; });
|
||||
|
||||
Float64 accumulated = 0;
|
||||
|
||||
const Pair * it = array;
|
||||
const Pair * end = array + size;
|
||||
|
||||
size_t level_index = 0;
|
||||
Float64 threshold = std::ceil(sum_weight * levels[indices[level_index]]);
|
||||
|
||||
while (it < end)
|
||||
{
|
||||
accumulated += it->second;
|
||||
|
||||
while (accumulated >= threshold)
|
||||
{
|
||||
result[indices[level_index]] = it->first;
|
||||
++level_index;
|
||||
|
||||
if (level_index == num_levels)
|
||||
return;
|
||||
|
||||
threshold = std::ceil(sum_weight * levels[indices[level_index]]);
|
||||
}
|
||||
|
||||
++it;
|
||||
}
|
||||
|
||||
while (level_index < num_levels)
|
||||
{
|
||||
result[indices[level_index]] = array[size - 1].first;
|
||||
++level_index;
|
||||
}
|
||||
}
|
||||
|
||||
/// The same, but in the case of an empty state, NaN is returned.
|
||||
Float64 getFloat(Float64) const
|
||||
{
|
||||
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Method getFloat is not implemented for QuantileExact");
|
||||
}
|
||||
|
||||
void getManyFloat(const Float64 *, const size_t *, size_t, Float64 *) const
|
||||
{
|
||||
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Method getManyFloat is not implemented for QuantileExact");
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename Value, bool _> using FuncQuantileExactWeighted = AggregateFunctionQuantile<Value, QuantileExactWeighted<Value>, NameQuantileExactWeighted, true, void, false, false>;
|
||||
template <typename Value, bool _> using FuncQuantilesExactWeighted = AggregateFunctionQuantile<Value, QuantileExactWeighted<Value>, NameQuantilesExactWeighted, true, void, true, false>;
|
||||
|
||||
template <template <typename, bool> class Function>
|
||||
AggregateFunctionPtr createAggregateFunctionQuantile(
|
||||
|
@ -1,22 +1,29 @@
|
||||
#pragma once
|
||||
|
||||
#include <AggregateFunctions/AggregateFunctionQuantile.h>
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <DataTypes/DataTypeDate.h>
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
#include <Core/Field.h>
|
||||
#include <cmath>
|
||||
#include <base/sort.h>
|
||||
#include <Common/RadixSort.h>
|
||||
#include <IO/WriteBuffer.h>
|
||||
#include <IO/ReadBuffer.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int LOGICAL_ERROR;
|
||||
extern const int NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
class ApproxSampler
|
||||
{
|
||||
@ -474,4 +481,56 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename Value, bool _> using FuncQuantileGK = AggregateFunctionQuantile<Value, QuantileGK<Value>, NameQuantileGK, false, void, false, true>;
|
||||
template <typename Value, bool _> using FuncQuantilesGK = AggregateFunctionQuantile<Value, QuantileGK<Value>, NameQuantilesGK, false, void, true, true>;
|
||||
|
||||
template <template <typename, bool> class Function>
|
||||
AggregateFunctionPtr createAggregateFunctionQuantile(
|
||||
const std::string & name, const DataTypes & argument_types, const Array & params, const Settings *)
|
||||
{
|
||||
/// Second argument type check doesn't depend on the type of the first one.
|
||||
Function<void, true>::assertSecondArg(argument_types);
|
||||
|
||||
const DataTypePtr & argument_type = argument_types[0];
|
||||
WhichDataType which(argument_type);
|
||||
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) \
|
||||
return std::make_shared<Function<TYPE, true>>(argument_types, params);
|
||||
FOR_BASIC_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
|
||||
if (which.idx == TypeIndex::Date) return std::make_shared<Function<DataTypeDate::FieldType, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::DateTime) return std::make_shared<Function<DataTypeDateTime::FieldType, false>>(argument_types, params);
|
||||
|
||||
if (which.idx == TypeIndex::Decimal32) return std::make_shared<Function<Decimal32, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::Decimal64) return std::make_shared<Function<Decimal64, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::Decimal128) return std::make_shared<Function<Decimal128, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::Decimal256) return std::make_shared<Function<Decimal256, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::DateTime64) return std::make_shared<Function<DateTime64, false>>(argument_types, params);
|
||||
|
||||
if (which.idx == TypeIndex::Int128) return std::make_shared<Function<Int128, true>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::UInt128) return std::make_shared<Function<UInt128, true>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::Int256) return std::make_shared<Function<Int256, true>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::UInt256) return std::make_shared<Function<UInt256, true>>(argument_types, params);
|
||||
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}",
|
||||
argument_type->getName(), name);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionsQuantileApprox(AggregateFunctionFactory & factory)
|
||||
{
|
||||
/// For aggregate functions returning array we cannot return NULL on empty set.
|
||||
AggregateFunctionProperties properties = { .returns_default_when_only_null = true };
|
||||
|
||||
factory.registerFunction(NameQuantileGK::name, createAggregateFunctionQuantile<FuncQuantileGK>);
|
||||
factory.registerFunction(NameQuantilesGK::name, {createAggregateFunctionQuantile<FuncQuantilesGK>, properties});
|
||||
|
||||
/// 'median' is an alias for 'quantile'
|
||||
factory.registerAlias("medianGK", NameQuantileGK::name);
|
||||
}
|
||||
|
||||
}
|
@ -1,58 +1,353 @@
|
||||
#include <AggregateFunctions/AggregateFunctionQuantile.h>
|
||||
#include <AggregateFunctions/QuantileInterpolatedWeighted.h>
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <DataTypes/DataTypeDate.h>
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
#include <Core/Field.h>
|
||||
#include <Common/HashTable/HashMap.h>
|
||||
#include <Common/NaNUtils.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename Value, bool _> using FuncQuantileInterpolatedWeighted = AggregateFunctionQuantile<Value, QuantileInterpolatedWeighted<Value>, NameQuantileInterpolatedWeighted, true, void, false>;
|
||||
template <typename Value, bool _> using FuncQuantilesInterpolatedWeighted = AggregateFunctionQuantile<Value, QuantileInterpolatedWeighted<Value>, NameQuantilesInterpolatedWeighted, true, void, true>;
|
||||
|
||||
template <template <typename, bool> class Function>
|
||||
AggregateFunctionPtr createAggregateFunctionQuantile(
|
||||
const std::string & name, const DataTypes & argument_types, const Array & params, const Settings *)
|
||||
/** Approximates Quantile by:
|
||||
* - sorting input values and weights
|
||||
* - building a cumulative distribution based on weights
|
||||
* - performing linear interpolation between the weights and values
|
||||
*
|
||||
*/
|
||||
template <typename Value>
|
||||
struct QuantileInterpolatedWeighted
|
||||
{
|
||||
struct Int128Hash
|
||||
{
|
||||
/// Second argument type check doesn't depend on the type of the first one.
|
||||
Function<void, true>::assertSecondArg(argument_types);
|
||||
size_t operator()(Int128 x) const
|
||||
{
|
||||
return CityHash_v1_0_2::Hash128to64({x >> 64, x & 0xffffffffffffffffll});
|
||||
}
|
||||
};
|
||||
|
||||
const DataTypePtr & argument_type = argument_types[0];
|
||||
WhichDataType which(argument_type);
|
||||
using Weight = UInt64;
|
||||
using UnderlyingType = NativeType<Value>;
|
||||
using Hasher = HashCRC32<UnderlyingType>;
|
||||
|
||||
/// When creating, the hash table must be small.
|
||||
using Map = HashMapWithStackMemory<UnderlyingType, Weight, Hasher, 4>;
|
||||
|
||||
Map map;
|
||||
|
||||
void add(const Value & x)
|
||||
{
|
||||
/// We must skip NaNs as they are not compatible with comparison sorting.
|
||||
if (!isNaN(x))
|
||||
++map[x];
|
||||
}
|
||||
|
||||
void add(const Value & x, Weight weight)
|
||||
{
|
||||
if (!isNaN(x))
|
||||
map[x] += weight;
|
||||
}
|
||||
|
||||
void merge(const QuantileInterpolatedWeighted & rhs)
|
||||
{
|
||||
for (const auto & pair : rhs.map)
|
||||
map[pair.getKey()] += pair.getMapped();
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
map.write(buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
typename Map::Reader reader(buf);
|
||||
while (reader.next())
|
||||
{
|
||||
const auto & pair = reader.get();
|
||||
map[pair.first] = pair.second;
|
||||
}
|
||||
}
|
||||
|
||||
Value get(Float64 level) const
|
||||
{
|
||||
return getImpl<Value>(level);
|
||||
}
|
||||
|
||||
void getMany(const Float64 * levels, const size_t * indices, size_t size, Value * result) const
|
||||
{
|
||||
getManyImpl<Value>(levels, indices, size, result);
|
||||
}
|
||||
|
||||
/// The same, but in the case of an empty state, NaN is returned.
|
||||
Float64 getFloat(Float64) const
|
||||
{
|
||||
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Method getFloat is not implemented for QuantileInterpolatedWeighted");
|
||||
}
|
||||
|
||||
void getManyFloat(const Float64 *, const size_t *, size_t, Float64 *) const
|
||||
{
|
||||
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Method getManyFloat is not implemented for QuantileInterpolatedWeighted");
|
||||
}
|
||||
|
||||
private:
|
||||
using Pair = typename std::pair<UnderlyingType, Float64>;
|
||||
|
||||
/// Get the value of the `level` quantile. The level must be between 0 and 1.
|
||||
template <typename T>
|
||||
T getImpl(Float64 level) const
|
||||
{
|
||||
size_t size = map.size();
|
||||
|
||||
if (0 == size)
|
||||
return std::numeric_limits<Value>::quiet_NaN();
|
||||
|
||||
/// Maintain a vector of pair of values and weights for easier sorting and for building
|
||||
/// a cumulative distribution using the provided weights.
|
||||
std::vector<Pair> value_weight_pairs;
|
||||
value_weight_pairs.reserve(size);
|
||||
|
||||
/// Note: weight provided must be a 64-bit integer
|
||||
/// Float64 is used as accumulator here to get approximate results.
|
||||
/// But weight used in the internal array is stored as Float64 as we
|
||||
/// do some quantile estimation operation which involves division and
|
||||
/// require Float64 level of precision.
|
||||
|
||||
Float64 sum_weight = 0;
|
||||
for (const auto & pair : map)
|
||||
{
|
||||
sum_weight += pair.getMapped();
|
||||
auto value = pair.getKey();
|
||||
auto weight = pair.getMapped();
|
||||
value_weight_pairs.push_back({value, weight});
|
||||
}
|
||||
|
||||
::sort(value_weight_pairs.begin(), value_weight_pairs.end(), [](const Pair & a, const Pair & b) { return a.first < b.first; });
|
||||
|
||||
Float64 accumulated = 0;
|
||||
|
||||
/// vector for populating and storing the cumulative sum using the provided weights.
|
||||
/// example: [0,1,2,3,4,5] -> [0,1,3,6,10,15]
|
||||
std::vector<Float64> weights_cum_sum;
|
||||
weights_cum_sum.reserve(size);
|
||||
|
||||
for (size_t idx = 0; idx < size; ++idx)
|
||||
{
|
||||
accumulated += value_weight_pairs[idx].second;
|
||||
weights_cum_sum.push_back(accumulated);
|
||||
}
|
||||
|
||||
/// The following estimation of quantile is general and the idea is:
|
||||
/// https://en.wikipedia.org/wiki/Percentile#The_weighted_percentile_method
|
||||
|
||||
/// calculates a simple cumulative distribution based on weights
|
||||
if (sum_weight != 0)
|
||||
{
|
||||
for (size_t idx = 0; idx < size; ++idx)
|
||||
value_weight_pairs[idx].second = (weights_cum_sum[idx] - 0.5 * value_weight_pairs[idx].second) / sum_weight;
|
||||
}
|
||||
|
||||
/// perform linear interpolation
|
||||
size_t idx = 0;
|
||||
if (size >= 2)
|
||||
{
|
||||
if (level >= value_weight_pairs[size - 2].second)
|
||||
{
|
||||
idx = size - 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
size_t start = 0, end = size - 1;
|
||||
while (start <= end)
|
||||
{
|
||||
size_t mid = start + (end - start) / 2;
|
||||
if (mid > size)
|
||||
break;
|
||||
if (level > value_weight_pairs[mid + 1].second)
|
||||
start = mid + 1;
|
||||
else
|
||||
{
|
||||
idx = mid;
|
||||
end = mid - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t l = idx;
|
||||
size_t u = idx + 1 < size ? idx + 1 : idx;
|
||||
|
||||
Float64 xl = value_weight_pairs[l].second, xr = value_weight_pairs[u].second;
|
||||
UnderlyingType yl = value_weight_pairs[l].first, yr = value_weight_pairs[u].first;
|
||||
|
||||
if (level < xl)
|
||||
yr = yl;
|
||||
if (level > xr)
|
||||
yl = yr;
|
||||
|
||||
return static_cast<T>(interpolate(level, xl, xr, yl, yr));
|
||||
}
|
||||
|
||||
/// Get the `size` values of `levels` quantiles. Write `size` results starting with `result` address.
|
||||
/// indices - an array of index levels such that the corresponding elements will go in ascending order.
|
||||
template <typename T>
|
||||
void getManyImpl(const Float64 * levels, const size_t * indices, size_t num_levels, Value * result) const
|
||||
{
|
||||
size_t size = map.size();
|
||||
|
||||
if (0 == size)
|
||||
{
|
||||
for (size_t i = 0; i < num_levels; ++i)
|
||||
result[i] = Value();
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<Pair> value_weight_pairs;
|
||||
value_weight_pairs.reserve(size);
|
||||
|
||||
Float64 sum_weight = 0;
|
||||
for (const auto & pair : map)
|
||||
{
|
||||
sum_weight += pair.getMapped();
|
||||
auto value = pair.getKey();
|
||||
auto weight = pair.getMapped();
|
||||
value_weight_pairs.push_back({value, weight});
|
||||
}
|
||||
|
||||
::sort(value_weight_pairs.begin(), value_weight_pairs.end(), [](const Pair & a, const Pair & b) { return a.first < b.first; });
|
||||
|
||||
Float64 accumulated = 0;
|
||||
|
||||
/// vector for populating and storing the cumulative sum using the provided weights.
|
||||
/// example: [0,1,2,3,4,5] -> [0,1,3,6,10,15]
|
||||
std::vector<Float64> weights_cum_sum;
|
||||
weights_cum_sum.reserve(size);
|
||||
|
||||
for (size_t idx = 0; idx < size; ++idx)
|
||||
{
|
||||
accumulated += value_weight_pairs[idx].second;
|
||||
weights_cum_sum.emplace_back(accumulated);
|
||||
}
|
||||
|
||||
|
||||
/// The following estimation of quantile is general and the idea is:
|
||||
/// https://en.wikipedia.org/wiki/Percentile#The_weighted_percentile_method
|
||||
|
||||
/// calculates a simple cumulative distribution based on weights
|
||||
if (sum_weight != 0)
|
||||
{
|
||||
for (size_t idx = 0; idx < size; ++idx)
|
||||
value_weight_pairs[idx].second = (weights_cum_sum[idx] - 0.5 * value_weight_pairs[idx].second) / sum_weight;
|
||||
}
|
||||
|
||||
for (size_t level_index = 0; level_index < num_levels; ++level_index)
|
||||
{
|
||||
/// perform linear interpolation for every level
|
||||
auto level = levels[indices[level_index]];
|
||||
|
||||
size_t idx = 0;
|
||||
if (size >= 2)
|
||||
{
|
||||
if (level >= value_weight_pairs[size - 2].second)
|
||||
{
|
||||
idx = size - 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
size_t start = 0, end = size - 1;
|
||||
while (start <= end)
|
||||
{
|
||||
size_t mid = start + (end - start) / 2;
|
||||
if (mid > size)
|
||||
break;
|
||||
if (level > value_weight_pairs[mid + 1].second)
|
||||
start = mid + 1;
|
||||
else
|
||||
{
|
||||
idx = mid;
|
||||
end = mid - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t l = idx;
|
||||
size_t u = idx + 1 < size ? idx + 1 : idx;
|
||||
|
||||
Float64 xl = value_weight_pairs[l].second, xr = value_weight_pairs[u].second;
|
||||
UnderlyingType yl = value_weight_pairs[l].first, yr = value_weight_pairs[u].first;
|
||||
|
||||
if (level < xl)
|
||||
yr = yl;
|
||||
if (level > xr)
|
||||
yl = yr;
|
||||
|
||||
result[indices[level_index]] = static_cast<T>(interpolate(level, xl, xr, yl, yr));
|
||||
}
|
||||
}
|
||||
|
||||
/// This ignores overflows or NaN's that might arise during add, sub and mul operations and doesn't aim to provide exact
|
||||
/// results since `the quantileInterpolatedWeighted` function itself relies mainly on approximation.
|
||||
UnderlyingType NO_SANITIZE_UNDEFINED interpolate(Float64 level, Float64 xl, Float64 xr, UnderlyingType yl, UnderlyingType yr) const
|
||||
{
|
||||
UnderlyingType dy = yr - yl;
|
||||
Float64 dx = xr - xl;
|
||||
dx = dx == 0 ? 1 : dx; /// to handle NaN behavior that might arise during integer division below.
|
||||
|
||||
/// yl + (dy / dx) * (level - xl)
|
||||
return static_cast<UnderlyingType>(yl + (dy / dx) * (level - xl));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename Value, bool _> using FuncQuantileInterpolatedWeighted = AggregateFunctionQuantile<Value, QuantileInterpolatedWeighted<Value>, NameQuantileInterpolatedWeighted, true, void, false, false>;
|
||||
template <typename Value, bool _> using FuncQuantilesInterpolatedWeighted = AggregateFunctionQuantile<Value, QuantileInterpolatedWeighted<Value>, NameQuantilesInterpolatedWeighted, true, void, true, false>;
|
||||
|
||||
template <template <typename, bool> class Function>
|
||||
AggregateFunctionPtr createAggregateFunctionQuantile(
|
||||
const std::string & name, const DataTypes & argument_types, const Array & params, const Settings *)
|
||||
{
|
||||
/// Second argument type check doesn't depend on the type of the first one.
|
||||
Function<void, true>::assertSecondArg(argument_types);
|
||||
|
||||
const DataTypePtr & argument_type = argument_types[0];
|
||||
WhichDataType which(argument_type);
|
||||
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) return std::make_shared<Function<TYPE, true>>(argument_types, params);
|
||||
FOR_BASIC_NUMERIC_TYPES(DISPATCH)
|
||||
if (which.idx == TypeIndex::TYPE) return std::make_shared<Function<TYPE, true>>(argument_types, params);
|
||||
FOR_BASIC_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
if (which.idx == TypeIndex::Date) return std::make_shared<Function<DataTypeDate::FieldType, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::DateTime) return std::make_shared<Function<DataTypeDateTime::FieldType, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::Date) return std::make_shared<Function<DataTypeDate::FieldType, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::DateTime) return std::make_shared<Function<DataTypeDateTime::FieldType, false>>(argument_types, params);
|
||||
|
||||
if (which.idx == TypeIndex::Decimal32) return std::make_shared<Function<Decimal32, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::Decimal64) return std::make_shared<Function<Decimal64, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::Decimal128) return std::make_shared<Function<Decimal128, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::Decimal256) return std::make_shared<Function<Decimal256, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::DateTime64) return std::make_shared<Function<DateTime64, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::Decimal32) return std::make_shared<Function<Decimal32, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::Decimal64) return std::make_shared<Function<Decimal64, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::Decimal128) return std::make_shared<Function<Decimal128, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::Decimal256) return std::make_shared<Function<Decimal256, false>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::DateTime64) return std::make_shared<Function<DateTime64, false>>(argument_types, params);
|
||||
|
||||
if (which.idx == TypeIndex::Int128) return std::make_shared<Function<Int128, true>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::UInt128) return std::make_shared<Function<UInt128, true>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::Int256) return std::make_shared<Function<Int256, true>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::UInt256) return std::make_shared<Function<UInt256, true>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::Int128) return std::make_shared<Function<Int128, true>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::UInt128) return std::make_shared<Function<UInt128, true>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::Int256) return std::make_shared<Function<Int256, true>>(argument_types, params);
|
||||
if (which.idx == TypeIndex::UInt256) return std::make_shared<Function<UInt256, true>>(argument_types, params);
|
||||
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}",
|
||||
argument_type->getName(), name);
|
||||
}
|
||||
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}",
|
||||
argument_type->getName(), name);
|
||||
}
|
||||
}
|
||||
|
||||
void registerAggregateFunctionsQuantileInterpolatedWeighted(AggregateFunctionFactory & factory)
|
||||
|
@ -19,8 +19,8 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename Value, bool float_return> using FuncQuantileTDigest = AggregateFunctionQuantile<Value, QuantileTDigest<Value>, NameQuantileTDigest, false, std::conditional_t<float_return, Float32, void>, false>;
|
||||
template <typename Value, bool float_return> using FuncQuantilesTDigest = AggregateFunctionQuantile<Value, QuantileTDigest<Value>, NameQuantilesTDigest, false, std::conditional_t<float_return, Float32, void>, true>;
|
||||
template <typename Value, bool float_return> using FuncQuantileTDigest = AggregateFunctionQuantile<Value, QuantileTDigest<Value>, NameQuantileTDigest, false, std::conditional_t<float_return, Float32, void>, false, false>;
|
||||
template <typename Value, bool float_return> using FuncQuantilesTDigest = AggregateFunctionQuantile<Value, QuantileTDigest<Value>, NameQuantilesTDigest, false, std::conditional_t<float_return, Float32, void>, true, false>;
|
||||
|
||||
template <template <typename, bool> class Function>
|
||||
AggregateFunctionPtr createAggregateFunctionQuantile(
|
||||
|
@ -19,8 +19,8 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename Value, bool float_return> using FuncQuantileTDigestWeighted = AggregateFunctionQuantile<Value, QuantileTDigest<Value>, NameQuantileTDigestWeighted, true, std::conditional_t<float_return, Float32, void>, false>;
|
||||
template <typename Value, bool float_return> using FuncQuantilesTDigestWeighted = AggregateFunctionQuantile<Value, QuantileTDigest<Value>, NameQuantilesTDigestWeighted, true, std::conditional_t<float_return, Float32, void>, true>;
|
||||
template <typename Value, bool float_return> using FuncQuantileTDigestWeighted = AggregateFunctionQuantile<Value, QuantileTDigest<Value>, NameQuantileTDigestWeighted, true, std::conditional_t<float_return, Float32, void>, false, false>;
|
||||
template <typename Value, bool float_return> using FuncQuantilesTDigestWeighted = AggregateFunctionQuantile<Value, QuantileTDigest<Value>, NameQuantilesTDigestWeighted, true, std::conditional_t<float_return, Float32, void>, true, false>;
|
||||
|
||||
template <template <typename, bool> class Function>
|
||||
AggregateFunctionPtr createAggregateFunctionQuantile(
|
||||
|
@ -19,8 +19,8 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename Value, bool _> using FuncQuantileTiming = AggregateFunctionQuantile<Value, QuantileTiming<Value>, NameQuantileTiming, false, Float32, false>;
|
||||
template <typename Value, bool _> using FuncQuantilesTiming = AggregateFunctionQuantile<Value, QuantileTiming<Value>, NameQuantilesTiming, false, Float32, true>;
|
||||
template <typename Value, bool _> using FuncQuantileTiming = AggregateFunctionQuantile<Value, QuantileTiming<Value>, NameQuantileTiming, false, Float32, false, false>;
|
||||
template <typename Value, bool _> using FuncQuantilesTiming = AggregateFunctionQuantile<Value, QuantileTiming<Value>, NameQuantilesTiming, false, Float32, true, false>;
|
||||
|
||||
template <template <typename, bool> class Function>
|
||||
AggregateFunctionPtr createAggregateFunctionQuantile(
|
||||
|
@ -19,8 +19,8 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename Value, bool _> using FuncQuantileTimingWeighted = AggregateFunctionQuantile<Value, QuantileTiming<Value>, NameQuantileTimingWeighted, true, Float32, false>;
|
||||
template <typename Value, bool _> using FuncQuantilesTimingWeighted = AggregateFunctionQuantile<Value, QuantileTiming<Value>, NameQuantilesTimingWeighted, true, Float32, true>;
|
||||
template <typename Value, bool _> using FuncQuantileTimingWeighted = AggregateFunctionQuantile<Value, QuantileTiming<Value>, NameQuantileTimingWeighted, true, Float32, false, false>;
|
||||
template <typename Value, bool _> using FuncQuantilesTimingWeighted = AggregateFunctionQuantile<Value, QuantileTiming<Value>, NameQuantilesTimingWeighted, true, Float32, true, false>;
|
||||
|
||||
template <template <typename, bool> class Function>
|
||||
AggregateFunctionPtr createAggregateFunctionQuantile(
|
||||
|
@ -1,7 +1,13 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionRankCorrelation.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/StatCommon.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <Common/PODArray_fwd.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
|
||||
|
||||
namespace ErrorCodes
|
||||
@ -16,6 +22,83 @@ struct Settings;
|
||||
namespace
|
||||
{
|
||||
|
||||
struct RankCorrelationData : public StatisticalSample<Float64, Float64>
|
||||
{
|
||||
Float64 getResult()
|
||||
{
|
||||
RanksArray ranks_x;
|
||||
std::tie(ranks_x, std::ignore) = computeRanksAndTieCorrection(this->x);
|
||||
|
||||
RanksArray ranks_y;
|
||||
std::tie(ranks_y, std::ignore) = computeRanksAndTieCorrection(this->y);
|
||||
|
||||
/// Sizes can be non-equal due to skipped NaNs.
|
||||
const Float64 size = static_cast<Float64>(std::min(this->size_x, this->size_y));
|
||||
|
||||
/// Count d^2 sum
|
||||
Float64 answer = 0;
|
||||
for (size_t j = 0; j < size; ++j)
|
||||
answer += (ranks_x[j] - ranks_y[j]) * (ranks_x[j] - ranks_y[j]);
|
||||
|
||||
answer *= 6;
|
||||
answer /= size * (size * size - 1);
|
||||
answer = 1 - answer;
|
||||
return answer;
|
||||
}
|
||||
};
|
||||
|
||||
class AggregateFunctionRankCorrelation :
|
||||
public IAggregateFunctionDataHelper<RankCorrelationData, AggregateFunctionRankCorrelation>
|
||||
{
|
||||
public:
|
||||
explicit AggregateFunctionRankCorrelation(const DataTypes & arguments)
|
||||
: IAggregateFunctionDataHelper<RankCorrelationData, AggregateFunctionRankCorrelation> ({arguments}, {}, std::make_shared<DataTypeNumber<Float64>>())
|
||||
{}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return "rankCorr";
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return true; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
Float64 new_x = columns[0]->getFloat64(row_num);
|
||||
Float64 new_y = columns[1]->getFloat64(row_num);
|
||||
this->data(place).addX(new_x, arena);
|
||||
this->data(place).addY(new_y, arena);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
auto & a = this->data(place);
|
||||
const auto & b = this->data(rhs);
|
||||
|
||||
a.merge(b, arena);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
this->data(place).read(buf, arena);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
auto answer = this->data(place).getResult();
|
||||
|
||||
auto & column = static_cast<ColumnVector<Float64> &>(to);
|
||||
column.getData().push_back(answer);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionRankCorrelation(
|
||||
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
|
@ -1,98 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/StatCommon.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <Common/PODArray_fwd.h>
|
||||
#include <base/types.h>
|
||||
#include <DataTypes/DataTypesDecimal.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
|
||||
struct RankCorrelationData : public StatisticalSample<Float64, Float64>
|
||||
{
|
||||
Float64 getResult()
|
||||
{
|
||||
RanksArray ranks_x;
|
||||
std::tie(ranks_x, std::ignore) = computeRanksAndTieCorrection(this->x);
|
||||
|
||||
RanksArray ranks_y;
|
||||
std::tie(ranks_y, std::ignore) = computeRanksAndTieCorrection(this->y);
|
||||
|
||||
/// Sizes can be non-equal due to skipped NaNs.
|
||||
const Float64 size = static_cast<Float64>(std::min(this->size_x, this->size_y));
|
||||
|
||||
/// Count d^2 sum
|
||||
Float64 answer = 0;
|
||||
for (size_t j = 0; j < size; ++j)
|
||||
answer += (ranks_x[j] - ranks_y[j]) * (ranks_x[j] - ranks_y[j]);
|
||||
|
||||
answer *= 6;
|
||||
answer /= size * (size * size - 1);
|
||||
answer = 1 - answer;
|
||||
return answer;
|
||||
}
|
||||
};
|
||||
|
||||
class AggregateFunctionRankCorrelation :
|
||||
public IAggregateFunctionDataHelper<RankCorrelationData, AggregateFunctionRankCorrelation>
|
||||
{
|
||||
public:
|
||||
explicit AggregateFunctionRankCorrelation(const DataTypes & arguments)
|
||||
:IAggregateFunctionDataHelper<RankCorrelationData, AggregateFunctionRankCorrelation> ({arguments}, {}, std::make_shared<DataTypeNumber<Float64>>())
|
||||
{}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return "rankCorr";
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return true; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
Float64 new_x = columns[0]->getFloat64(row_num);
|
||||
Float64 new_y = columns[1]->getFloat64(row_num);
|
||||
this->data(place).addX(new_x, arena);
|
||||
this->data(place).addY(new_y, arena);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
auto & a = this->data(place);
|
||||
const auto & b = this->data(rhs);
|
||||
|
||||
a.merge(b, arena);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
this->data(place).read(buf, arena);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
auto answer = this->data(place).getResult();
|
||||
|
||||
auto & column = static_cast<ColumnVector<Float64> &>(to);
|
||||
column.getData().push_back(answer);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
@ -1,21 +1,150 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionRetention.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
|
||||
#include <unordered_set>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <bitset>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
struct AggregateFunctionRetentionData
|
||||
{
|
||||
static constexpr auto max_events = 32;
|
||||
|
||||
using Events = std::bitset<max_events>;
|
||||
|
||||
Events events;
|
||||
|
||||
void add(UInt8 event)
|
||||
{
|
||||
events.set(event);
|
||||
}
|
||||
|
||||
void merge(const AggregateFunctionRetentionData & other)
|
||||
{
|
||||
events |= other.events;
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
UInt32 event_value = static_cast<UInt32>(events.to_ulong());
|
||||
writeBinary(event_value, buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
UInt32 event_value;
|
||||
readBinary(event_value, buf);
|
||||
events = event_value;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* The max size of events is 32, that's enough for retention analytics
|
||||
*
|
||||
* Usage:
|
||||
* - retention(cond1, cond2, cond3, ....)
|
||||
* - returns [cond1_flag, cond1_flag && cond2_flag, cond1_flag && cond3_flag, ...]
|
||||
*/
|
||||
class AggregateFunctionRetention final
|
||||
: public IAggregateFunctionDataHelper<AggregateFunctionRetentionData, AggregateFunctionRetention>
|
||||
{
|
||||
private:
|
||||
UInt8 events_size;
|
||||
|
||||
public:
|
||||
String getName() const override
|
||||
{
|
||||
return "retention";
|
||||
}
|
||||
|
||||
explicit AggregateFunctionRetention(const DataTypes & arguments)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionRetentionData, AggregateFunctionRetention>(arguments, {}, std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt8>()))
|
||||
{
|
||||
for (const auto i : collections::range(0, arguments.size()))
|
||||
{
|
||||
const auto * cond_arg = arguments[i].get();
|
||||
if (!isUInt8(cond_arg))
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Illegal type {} of argument {} of aggregate function {}, must be UInt8",
|
||||
cond_arg->getName(), i, getName());
|
||||
}
|
||||
|
||||
events_size = static_cast<UInt8>(arguments.size());
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override
|
||||
{
|
||||
for (const auto i : collections::range(0, events_size))
|
||||
{
|
||||
auto event = assert_cast<const ColumnVector<UInt8> *>(columns[i])->getData()[row_num];
|
||||
if (event)
|
||||
{
|
||||
this->data(place).add(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
auto & data_to = assert_cast<ColumnUInt8 &>(assert_cast<ColumnArray &>(to).getData()).getData();
|
||||
auto & offsets_to = assert_cast<ColumnArray &>(to).getOffsets();
|
||||
|
||||
ColumnArray::Offset current_offset = data_to.size();
|
||||
data_to.resize(current_offset + events_size);
|
||||
|
||||
const bool first_flag = this->data(place).events.test(0);
|
||||
data_to[current_offset] = first_flag;
|
||||
++current_offset;
|
||||
|
||||
for (size_t i = 1; i < events_size; ++i)
|
||||
{
|
||||
data_to[current_offset] = (first_flag && this->data(place).events.test(i));
|
||||
++current_offset;
|
||||
}
|
||||
|
||||
offsets_to.push_back(current_offset);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionRetention(const std::string & name, const DataTypes & arguments, const Array & params, const Settings *)
|
||||
{
|
||||
assertNoParameters(name, params);
|
||||
|
@ -1,143 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <unordered_set>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <base/range.h>
|
||||
#include <bitset>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
}
|
||||
|
||||
struct AggregateFunctionRetentionData
|
||||
{
|
||||
static constexpr auto max_events = 32;
|
||||
|
||||
using Events = std::bitset<max_events>;
|
||||
|
||||
Events events;
|
||||
|
||||
void add(UInt8 event)
|
||||
{
|
||||
events.set(event);
|
||||
}
|
||||
|
||||
void merge(const AggregateFunctionRetentionData & other)
|
||||
{
|
||||
events |= other.events;
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
UInt32 event_value = static_cast<UInt32>(events.to_ulong());
|
||||
writeBinary(event_value, buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
UInt32 event_value;
|
||||
readBinary(event_value, buf);
|
||||
events = event_value;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* The max size of events is 32, that's enough for retention analytics
|
||||
*
|
||||
* Usage:
|
||||
* - retention(cond1, cond2, cond3, ....)
|
||||
* - returns [cond1_flag, cond1_flag && cond2_flag, cond1_flag && cond3_flag, ...]
|
||||
*/
|
||||
class AggregateFunctionRetention final
|
||||
: public IAggregateFunctionDataHelper<AggregateFunctionRetentionData, AggregateFunctionRetention>
|
||||
{
|
||||
private:
|
||||
UInt8 events_size;
|
||||
|
||||
public:
|
||||
String getName() const override
|
||||
{
|
||||
return "retention";
|
||||
}
|
||||
|
||||
explicit AggregateFunctionRetention(const DataTypes & arguments)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionRetentionData, AggregateFunctionRetention>(arguments, {}, std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt8>()))
|
||||
{
|
||||
for (const auto i : collections::range(0, arguments.size()))
|
||||
{
|
||||
const auto * cond_arg = arguments[i].get();
|
||||
if (!isUInt8(cond_arg))
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Illegal type {} of argument {} of aggregate function {}, must be UInt8",
|
||||
cond_arg->getName(), i, getName());
|
||||
}
|
||||
|
||||
events_size = static_cast<UInt8>(arguments.size());
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override
|
||||
{
|
||||
for (const auto i : collections::range(0, events_size))
|
||||
{
|
||||
auto event = assert_cast<const ColumnVector<UInt8> *>(columns[i])->getData()[row_num];
|
||||
if (event)
|
||||
{
|
||||
this->data(place).add(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
auto & data_to = assert_cast<ColumnUInt8 &>(assert_cast<ColumnArray &>(to).getData()).getData();
|
||||
auto & offsets_to = assert_cast<ColumnArray &>(to).getOffsets();
|
||||
|
||||
ColumnArray::Offset current_offset = data_to.size();
|
||||
data_to.resize(current_offset + events_size);
|
||||
|
||||
const bool first_flag = this->data(place).events.test(0);
|
||||
data_to[current_offset] = first_flag;
|
||||
++current_offset;
|
||||
|
||||
for (size_t i = 1; i < events_size; ++i)
|
||||
{
|
||||
data_to[current_offset] = (first_flag && this->data(place).events.test(i));
|
||||
++current_offset;
|
||||
}
|
||||
|
||||
offsets_to.push_back(current_offset);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -1,15 +1,22 @@
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionSequenceMatch.h>
|
||||
|
||||
#include <DataTypes/DataTypeDate.h>
|
||||
#include <DataTypes/DataTypeDate32.h>
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
|
||||
#include <base/range.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <bitset>
|
||||
#include <stack>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
@ -18,11 +25,689 @@ namespace ErrorCodes
|
||||
extern const int TOO_MANY_ARGUMENTS_FOR_FUNCTION;
|
||||
extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int TOO_SLOW;
|
||||
extern const int SYNTAX_ERROR;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
extern const int LOGICAL_ERROR;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
/// helper type for comparing `std::pair`s using solely the .first member
|
||||
template <template <typename> class Comparator>
|
||||
struct ComparePairFirst final
|
||||
{
|
||||
template <typename T1, typename T2>
|
||||
bool operator()(const std::pair<T1, T2> & lhs, const std::pair<T1, T2> & rhs) const
|
||||
{
|
||||
return Comparator<T1>{}(lhs.first, rhs.first);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr size_t max_events = 32;
|
||||
|
||||
template <typename T>
|
||||
struct AggregateFunctionSequenceMatchData final
|
||||
{
|
||||
using Timestamp = T;
|
||||
using Events = std::bitset<max_events>;
|
||||
using TimestampEvents = std::pair<Timestamp, Events>;
|
||||
using Comparator = ComparePairFirst<std::less>;
|
||||
|
||||
bool sorted = true;
|
||||
PODArrayWithStackMemory<TimestampEvents, 64> events_list;
|
||||
/// sequenceMatch conditions met at least once in events_list
|
||||
Events conditions_met;
|
||||
|
||||
void add(const Timestamp timestamp, const Events & events)
|
||||
{
|
||||
/// store information exclusively for rows with at least one event
|
||||
if (events.any())
|
||||
{
|
||||
events_list.emplace_back(timestamp, events);
|
||||
sorted = false;
|
||||
conditions_met |= events;
|
||||
}
|
||||
}
|
||||
|
||||
void merge(const AggregateFunctionSequenceMatchData & other)
|
||||
{
|
||||
if (other.events_list.empty())
|
||||
return;
|
||||
|
||||
events_list.insert(std::begin(other.events_list), std::end(other.events_list));
|
||||
sorted = false;
|
||||
conditions_met |= other.conditions_met;
|
||||
}
|
||||
|
||||
void sort()
|
||||
{
|
||||
if (sorted)
|
||||
return;
|
||||
|
||||
::sort(std::begin(events_list), std::end(events_list), Comparator{});
|
||||
sorted = true;
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
writeBinary(sorted, buf);
|
||||
writeBinary(events_list.size(), buf);
|
||||
|
||||
for (const auto & events : events_list)
|
||||
{
|
||||
writeBinary(events.first, buf);
|
||||
writeBinary(events.second.to_ulong(), buf);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
readBinary(sorted, buf);
|
||||
|
||||
size_t size;
|
||||
readBinary(size, buf);
|
||||
|
||||
/// If we lose these flags, functionality is broken
|
||||
/// If we serialize/deserialize these flags, we have compatibility issues
|
||||
/// If we set these flags to 1, we have a minor performance penalty, which seems acceptable
|
||||
conditions_met.set();
|
||||
|
||||
events_list.clear();
|
||||
events_list.reserve(size);
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
Timestamp timestamp;
|
||||
readBinary(timestamp, buf);
|
||||
|
||||
UInt64 events;
|
||||
readBinary(events, buf);
|
||||
|
||||
events_list.emplace_back(timestamp, Events{events});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Max number of iterations to match the pattern against a sequence, exception thrown when exceeded
|
||||
constexpr auto sequence_match_max_iterations = 1000000;
|
||||
|
||||
|
||||
template <typename T, typename Data, typename Derived>
|
||||
class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper<Data, Derived>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionSequenceBase(const DataTypes & arguments, const Array & params, const String & pattern_, const DataTypePtr & result_type_)
|
||||
: IAggregateFunctionDataHelper<Data, Derived>(arguments, params, result_type_)
|
||||
, pattern(pattern_)
|
||||
{
|
||||
arg_count = arguments.size();
|
||||
parsePattern();
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override
|
||||
{
|
||||
const auto timestamp = assert_cast<const ColumnVector<T> *>(columns[0])->getData()[row_num];
|
||||
|
||||
typename Data::Events events;
|
||||
for (const auto i : collections::range(1, arg_count))
|
||||
{
|
||||
const auto event = assert_cast<const ColumnUInt8 *>(columns[i])->getData()[row_num];
|
||||
events.set(i - 1, event);
|
||||
}
|
||||
|
||||
this->data(place).add(timestamp, events);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
bool haveSameStateRepresentationImpl(const IAggregateFunction & rhs) const override
|
||||
{
|
||||
return this->getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs);
|
||||
}
|
||||
|
||||
private:
|
||||
enum class PatternActionType
|
||||
{
|
||||
SpecificEvent,
|
||||
AnyEvent,
|
||||
KleeneStar,
|
||||
TimeLessOrEqual,
|
||||
TimeLess,
|
||||
TimeGreaterOrEqual,
|
||||
TimeGreater,
|
||||
TimeEqual
|
||||
};
|
||||
|
||||
struct PatternAction final
|
||||
{
|
||||
PatternActionType type;
|
||||
std::uint64_t extra;
|
||||
|
||||
PatternAction() = default;
|
||||
explicit PatternAction(const PatternActionType type_, const std::uint64_t extra_ = 0) : type{type_}, extra{extra_} {}
|
||||
};
|
||||
|
||||
using PatternActions = PODArrayWithStackMemory<PatternAction, 64>;
|
||||
|
||||
Derived & derived() { return static_cast<Derived &>(*this); }
|
||||
|
||||
void parsePattern()
|
||||
{
|
||||
actions.clear();
|
||||
actions.emplace_back(PatternActionType::KleeneStar);
|
||||
|
||||
dfa_states.clear();
|
||||
dfa_states.emplace_back(true);
|
||||
|
||||
pattern_has_time = false;
|
||||
|
||||
const char * pos = pattern.data();
|
||||
const char * begin = pos;
|
||||
const char * end = pos + pattern.size();
|
||||
|
||||
auto throw_exception = [&](const std::string & msg)
|
||||
{
|
||||
throw Exception(ErrorCodes::SYNTAX_ERROR, "{} '{}' at position {}", msg, std::string(pos, end), toString(pos - begin));
|
||||
};
|
||||
|
||||
auto match = [&pos, end](const char * str) mutable
|
||||
{
|
||||
size_t length = strlen(str);
|
||||
if (pos + length <= end && 0 == memcmp(pos, str, length))
|
||||
{
|
||||
pos += length;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
while (pos < end)
|
||||
{
|
||||
if (match("(?"))
|
||||
{
|
||||
if (match("t"))
|
||||
{
|
||||
PatternActionType type;
|
||||
|
||||
if (match("<="))
|
||||
type = PatternActionType::TimeLessOrEqual;
|
||||
else if (match("<"))
|
||||
type = PatternActionType::TimeLess;
|
||||
else if (match(">="))
|
||||
type = PatternActionType::TimeGreaterOrEqual;
|
||||
else if (match(">"))
|
||||
type = PatternActionType::TimeGreater;
|
||||
else if (match("=="))
|
||||
type = PatternActionType::TimeEqual;
|
||||
else
|
||||
throw_exception("Unknown time condition");
|
||||
|
||||
UInt64 duration = 0;
|
||||
const auto * prev_pos = pos;
|
||||
pos = tryReadIntText(duration, pos, end);
|
||||
if (pos == prev_pos)
|
||||
throw_exception("Could not parse number");
|
||||
|
||||
if (actions.back().type != PatternActionType::SpecificEvent &&
|
||||
actions.back().type != PatternActionType::AnyEvent &&
|
||||
actions.back().type != PatternActionType::KleeneStar)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Temporal condition should be preceded by an event condition");
|
||||
|
||||
pattern_has_time = true;
|
||||
actions.emplace_back(type, duration);
|
||||
}
|
||||
else
|
||||
{
|
||||
UInt64 event_number = 0;
|
||||
const auto * prev_pos = pos;
|
||||
pos = tryReadIntText(event_number, pos, end);
|
||||
if (pos == prev_pos)
|
||||
throw_exception("Could not parse number");
|
||||
|
||||
if (event_number > arg_count - 1)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Event number {} is out of range", event_number);
|
||||
|
||||
actions.emplace_back(PatternActionType::SpecificEvent, event_number - 1);
|
||||
dfa_states.back().transition = DFATransition::SpecificEvent;
|
||||
dfa_states.back().event = static_cast<uint32_t>(event_number - 1);
|
||||
dfa_states.emplace_back();
|
||||
conditions_in_pattern.set(event_number - 1);
|
||||
}
|
||||
|
||||
if (!match(")"))
|
||||
throw_exception("Expected closing parenthesis, found");
|
||||
|
||||
}
|
||||
else if (match(".*"))
|
||||
{
|
||||
actions.emplace_back(PatternActionType::KleeneStar);
|
||||
dfa_states.back().has_kleene = true;
|
||||
}
|
||||
else if (match("."))
|
||||
{
|
||||
actions.emplace_back(PatternActionType::AnyEvent);
|
||||
dfa_states.back().transition = DFATransition::AnyEvent;
|
||||
dfa_states.emplace_back();
|
||||
}
|
||||
else
|
||||
throw_exception("Could not parse pattern, unexpected starting symbol");
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
/// Uses a DFA based approach in order to better handle patterns without
|
||||
/// time assertions.
|
||||
///
|
||||
/// NOTE: This implementation relies on the assumption that the pattern is *small*.
|
||||
///
|
||||
/// This algorithm performs in O(mn) (with m the number of DFA states and N the number
|
||||
/// of events) with a memory consumption and memory allocations in O(m). It means that
|
||||
/// if n >>> m (which is expected to be the case), this algorithm can be considered linear.
|
||||
template <typename EventEntry>
|
||||
bool dfaMatch(EventEntry & events_it, const EventEntry events_end) const
|
||||
{
|
||||
using ActiveStates = std::vector<bool>;
|
||||
|
||||
/// Those two vectors keep track of which states should be considered for the current
|
||||
/// event as well as the states which should be considered for the next event.
|
||||
ActiveStates active_states(dfa_states.size(), false);
|
||||
ActiveStates next_active_states(dfa_states.size(), false);
|
||||
active_states[0] = true;
|
||||
|
||||
/// Keeps track of dead-ends in order not to iterate over all the events to realize that
|
||||
/// the match failed.
|
||||
size_t n_active = 1;
|
||||
|
||||
for (/* empty */; events_it != events_end && n_active > 0 && !active_states.back(); ++events_it)
|
||||
{
|
||||
n_active = 0;
|
||||
next_active_states.assign(dfa_states.size(), false);
|
||||
|
||||
for (size_t state = 0; state < dfa_states.size(); ++state)
|
||||
{
|
||||
if (!active_states[state])
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
switch (dfa_states[state].transition)
|
||||
{
|
||||
case DFATransition::None:
|
||||
break;
|
||||
case DFATransition::AnyEvent:
|
||||
next_active_states[state + 1] = true;
|
||||
++n_active;
|
||||
break;
|
||||
case DFATransition::SpecificEvent:
|
||||
if (events_it->second.test(dfa_states[state].event))
|
||||
{
|
||||
next_active_states[state + 1] = true;
|
||||
++n_active;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (dfa_states[state].has_kleene)
|
||||
{
|
||||
next_active_states[state] = true;
|
||||
++n_active;
|
||||
}
|
||||
}
|
||||
swap(active_states, next_active_states);
|
||||
}
|
||||
|
||||
return active_states.back();
|
||||
}
|
||||
|
||||
template <typename EventEntry>
|
||||
bool backtrackingMatch(EventEntry & events_it, const EventEntry events_end) const
|
||||
{
|
||||
const auto action_begin = std::begin(actions);
|
||||
const auto action_end = std::end(actions);
|
||||
auto action_it = action_begin;
|
||||
|
||||
const auto events_begin = events_it;
|
||||
auto base_it = events_it;
|
||||
|
||||
/// an iterator to action plus an iterator to row in events list plus timestamp at the start of sequence
|
||||
using backtrack_info = std::tuple<decltype(action_it), EventEntry, EventEntry>;
|
||||
std::stack<backtrack_info> back_stack;
|
||||
|
||||
/// backtrack if possible
|
||||
const auto do_backtrack = [&]
|
||||
{
|
||||
while (!back_stack.empty())
|
||||
{
|
||||
auto & top = back_stack.top();
|
||||
|
||||
action_it = std::get<0>(top);
|
||||
events_it = std::next(std::get<1>(top));
|
||||
base_it = std::get<2>(top);
|
||||
|
||||
back_stack.pop();
|
||||
|
||||
if (events_it != events_end)
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
size_t i = 0;
|
||||
while (action_it != action_end && events_it != events_end)
|
||||
{
|
||||
if (action_it->type == PatternActionType::SpecificEvent)
|
||||
{
|
||||
if (events_it->second.test(action_it->extra))
|
||||
{
|
||||
/// move to the next action and events
|
||||
base_it = events_it;
|
||||
++action_it, ++events_it;
|
||||
}
|
||||
else if (!do_backtrack())
|
||||
/// backtracking failed, bail out
|
||||
break;
|
||||
}
|
||||
else if (action_it->type == PatternActionType::AnyEvent)
|
||||
{
|
||||
base_it = events_it;
|
||||
++action_it, ++events_it;
|
||||
}
|
||||
else if (action_it->type == PatternActionType::KleeneStar)
|
||||
{
|
||||
back_stack.emplace(action_it, events_it, base_it);
|
||||
base_it = events_it;
|
||||
++action_it;
|
||||
}
|
||||
else if (action_it->type == PatternActionType::TimeLessOrEqual)
|
||||
{
|
||||
if (events_it->first <= base_it->first + action_it->extra)
|
||||
{
|
||||
/// condition satisfied, move onto next action
|
||||
back_stack.emplace(action_it, events_it, base_it);
|
||||
base_it = events_it;
|
||||
++action_it;
|
||||
}
|
||||
else if (!do_backtrack())
|
||||
break;
|
||||
}
|
||||
else if (action_it->type == PatternActionType::TimeLess)
|
||||
{
|
||||
if (events_it->first < base_it->first + action_it->extra)
|
||||
{
|
||||
back_stack.emplace(action_it, events_it, base_it);
|
||||
base_it = events_it;
|
||||
++action_it;
|
||||
}
|
||||
else if (!do_backtrack())
|
||||
break;
|
||||
}
|
||||
else if (action_it->type == PatternActionType::TimeGreaterOrEqual)
|
||||
{
|
||||
if (events_it->first >= base_it->first + action_it->extra)
|
||||
{
|
||||
back_stack.emplace(action_it, events_it, base_it);
|
||||
base_it = events_it;
|
||||
++action_it;
|
||||
}
|
||||
else if (++events_it == events_end && !do_backtrack())
|
||||
break;
|
||||
}
|
||||
else if (action_it->type == PatternActionType::TimeGreater)
|
||||
{
|
||||
if (events_it->first > base_it->first + action_it->extra)
|
||||
{
|
||||
back_stack.emplace(action_it, events_it, base_it);
|
||||
base_it = events_it;
|
||||
++action_it;
|
||||
}
|
||||
else if (++events_it == events_end && !do_backtrack())
|
||||
break;
|
||||
}
|
||||
else if (action_it->type == PatternActionType::TimeEqual)
|
||||
{
|
||||
if (events_it->first == base_it->first + action_it->extra)
|
||||
{
|
||||
back_stack.emplace(action_it, events_it, base_it);
|
||||
base_it = events_it;
|
||||
++action_it;
|
||||
}
|
||||
else if (++events_it == events_end && !do_backtrack())
|
||||
break;
|
||||
}
|
||||
else
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unknown PatternActionType");
|
||||
|
||||
if (++i > sequence_match_max_iterations)
|
||||
throw Exception(ErrorCodes::TOO_SLOW, "Pattern application proves too difficult, exceeding max iterations ({})",
|
||||
sequence_match_max_iterations);
|
||||
}
|
||||
|
||||
/// if there are some actions remaining
|
||||
if (action_it != action_end)
|
||||
{
|
||||
/// match multiple empty strings at end
|
||||
while (action_it->type == PatternActionType::KleeneStar ||
|
||||
action_it->type == PatternActionType::TimeLessOrEqual ||
|
||||
action_it->type == PatternActionType::TimeLess ||
|
||||
(action_it->type == PatternActionType::TimeGreaterOrEqual && action_it->extra == 0))
|
||||
++action_it;
|
||||
}
|
||||
|
||||
if (events_it == events_begin)
|
||||
++events_it;
|
||||
|
||||
return action_it == action_end;
|
||||
}
|
||||
|
||||
/// Splits the pattern into deterministic parts separated by non-deterministic fragments
|
||||
/// (time constraints and Kleene stars), and tries to match the deterministic parts in their specified order,
|
||||
/// ignoring the non-deterministic fragments.
|
||||
/// This function can quickly check that a full match is not possible if some deterministic fragment is missing.
|
||||
template <typename EventEntry>
|
||||
bool couldMatchDeterministicParts(const EventEntry events_begin, const EventEntry events_end, bool limit_iterations = true) const
|
||||
{
|
||||
size_t events_processed = 0;
|
||||
auto events_it = events_begin;
|
||||
|
||||
const auto actions_end = std::end(actions);
|
||||
auto actions_it = std::begin(actions);
|
||||
auto det_part_begin = actions_it;
|
||||
|
||||
auto match_deterministic_part = [&events_it, events_end, &events_processed, det_part_begin, actions_it, limit_iterations]()
|
||||
{
|
||||
auto events_it_init = events_it;
|
||||
auto det_part_it = det_part_begin;
|
||||
|
||||
while (det_part_it != actions_it && events_it != events_end)
|
||||
{
|
||||
/// matching any event
|
||||
if (det_part_it->type == PatternActionType::AnyEvent)
|
||||
++events_it, ++det_part_it;
|
||||
|
||||
/// matching specific event
|
||||
else
|
||||
{
|
||||
if (events_it->second.test(det_part_it->extra))
|
||||
++events_it, ++det_part_it;
|
||||
|
||||
/// abandon current matching, try to match the deterministic fragment further in the list
|
||||
else
|
||||
{
|
||||
events_it = ++events_it_init;
|
||||
det_part_it = det_part_begin;
|
||||
}
|
||||
}
|
||||
|
||||
if (limit_iterations && ++events_processed > sequence_match_max_iterations)
|
||||
throw Exception(ErrorCodes::TOO_SLOW, "Pattern application proves too difficult, exceeding max iterations ({})",
|
||||
sequence_match_max_iterations);
|
||||
}
|
||||
|
||||
return det_part_it == actions_it;
|
||||
};
|
||||
|
||||
for (; actions_it != actions_end; ++actions_it)
|
||||
if (actions_it->type != PatternActionType::SpecificEvent && actions_it->type != PatternActionType::AnyEvent)
|
||||
{
|
||||
if (!match_deterministic_part())
|
||||
return false;
|
||||
det_part_begin = std::next(actions_it);
|
||||
}
|
||||
|
||||
return match_deterministic_part();
|
||||
}
|
||||
|
||||
private:
|
||||
enum class DFATransition : char
|
||||
{
|
||||
/// .-------.
|
||||
/// | |
|
||||
/// `-------'
|
||||
None,
|
||||
/// .-------. (?[0-9])
|
||||
/// | | ----------
|
||||
/// `-------'
|
||||
SpecificEvent,
|
||||
/// .-------. .
|
||||
/// | | ----------
|
||||
/// `-------'
|
||||
AnyEvent,
|
||||
};
|
||||
|
||||
struct DFAState
|
||||
{
|
||||
explicit DFAState(bool has_kleene_ = false)
|
||||
: has_kleene{has_kleene_}, event{0}, transition{DFATransition::None}
|
||||
{}
|
||||
|
||||
/// .-------.
|
||||
/// | | - - -
|
||||
/// `-------'
|
||||
/// |_^
|
||||
bool has_kleene;
|
||||
/// In the case of a state transitions with a `SpecificEvent`,
|
||||
/// `event` contains the value of the event.
|
||||
uint32_t event;
|
||||
/// The kind of transition out of this state.
|
||||
DFATransition transition;
|
||||
};
|
||||
|
||||
using DFAStates = std::vector<DFAState>;
|
||||
|
||||
protected:
|
||||
/// `True` if the parsed pattern contains time assertions (?t...), `false` otherwise.
|
||||
bool pattern_has_time;
|
||||
/// sequenceMatch conditions met at least once in the pattern
|
||||
std::bitset<max_events> conditions_in_pattern;
|
||||
|
||||
private:
|
||||
std::string pattern;
|
||||
size_t arg_count;
|
||||
PatternActions actions;
|
||||
|
||||
DFAStates dfa_states;
|
||||
};
|
||||
|
||||
template <typename T, typename Data>
|
||||
class AggregateFunctionSequenceMatch final : public AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionSequenceMatch(const DataTypes & arguments, const Array & params, const String & pattern_)
|
||||
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>(arguments, params, pattern_, std::make_shared<DataTypeUInt8>()) {}
|
||||
|
||||
using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>::AggregateFunctionSequenceBase;
|
||||
|
||||
String getName() const override { return "sequenceMatch"; }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
auto & output = assert_cast<ColumnUInt8 &>(to).getData();
|
||||
if ((this->conditions_in_pattern & this->data(place).conditions_met) != this->conditions_in_pattern)
|
||||
{
|
||||
output.push_back(false);
|
||||
return;
|
||||
}
|
||||
this->data(place).sort();
|
||||
|
||||
const auto & data_ref = this->data(place);
|
||||
|
||||
const auto events_begin = std::begin(data_ref.events_list);
|
||||
const auto events_end = std::end(data_ref.events_list);
|
||||
auto events_it = events_begin;
|
||||
|
||||
bool match = (this->pattern_has_time ?
|
||||
(this->couldMatchDeterministicParts(events_begin, events_end) && this->backtrackingMatch(events_it, events_end)) :
|
||||
this->dfaMatch(events_it, events_end));
|
||||
output.push_back(match);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Data>
|
||||
class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionSequenceCount(const DataTypes & arguments, const Array & params, const String & pattern_)
|
||||
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>(arguments, params, pattern_, std::make_shared<DataTypeUInt64>()) {}
|
||||
|
||||
using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>::AggregateFunctionSequenceBase;
|
||||
|
||||
String getName() const override { return "sequenceCount"; }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
auto & output = assert_cast<ColumnUInt64 &>(to).getData();
|
||||
if ((this->conditions_in_pattern & this->data(place).conditions_met) != this->conditions_in_pattern)
|
||||
{
|
||||
output.push_back(0);
|
||||
return;
|
||||
}
|
||||
this->data(place).sort();
|
||||
output.push_back(count(place));
|
||||
}
|
||||
|
||||
private:
|
||||
UInt64 count(ConstAggregateDataPtr __restrict place) const
|
||||
{
|
||||
const auto & data_ref = this->data(place);
|
||||
|
||||
const auto events_begin = std::begin(data_ref.events_list);
|
||||
const auto events_end = std::end(data_ref.events_list);
|
||||
auto events_it = events_begin;
|
||||
|
||||
size_t count = 0;
|
||||
// check if there is a chance of matching the sequence at least once
|
||||
if (this->couldMatchDeterministicParts(events_begin, events_end))
|
||||
{
|
||||
while (events_it != events_end && this->backtrackingMatch(events_it, events_end))
|
||||
++count;
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <template <typename, typename> typename AggregateFunction, template <typename> typename Data>
|
||||
AggregateFunctionPtr createAggregateFunctionSequenceBase(
|
||||
const std::string & name, const DataTypes & argument_types, const Array & params, const Settings *)
|
||||
|
@ -1,702 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <base/range.h>
|
||||
#include <base/sort.h>
|
||||
#include <Common/PODArray.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <bitset>
|
||||
#include <stack>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int TOO_SLOW;
|
||||
extern const int SYNTAX_ERROR;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
extern const int LOGICAL_ERROR;
|
||||
}
|
||||
|
||||
/// helper type for comparing `std::pair`s using solely the .first member
|
||||
template <template <typename> class Comparator>
|
||||
struct ComparePairFirst final
|
||||
{
|
||||
template <typename T1, typename T2>
|
||||
bool operator()(const std::pair<T1, T2> & lhs, const std::pair<T1, T2> & rhs) const
|
||||
{
|
||||
return Comparator<T1>{}(lhs.first, rhs.first);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr size_t max_events = 32;
|
||||
|
||||
template <typename T>
|
||||
struct AggregateFunctionSequenceMatchData final
|
||||
{
|
||||
using Timestamp = T;
|
||||
using Events = std::bitset<max_events>;
|
||||
using TimestampEvents = std::pair<Timestamp, Events>;
|
||||
using Comparator = ComparePairFirst<std::less>;
|
||||
|
||||
bool sorted = true;
|
||||
PODArrayWithStackMemory<TimestampEvents, 64> events_list;
|
||||
/// sequenceMatch conditions met at least once in events_list
|
||||
Events conditions_met;
|
||||
|
||||
void add(const Timestamp timestamp, const Events & events)
|
||||
{
|
||||
/// store information exclusively for rows with at least one event
|
||||
if (events.any())
|
||||
{
|
||||
events_list.emplace_back(timestamp, events);
|
||||
sorted = false;
|
||||
conditions_met |= events;
|
||||
}
|
||||
}
|
||||
|
||||
void merge(const AggregateFunctionSequenceMatchData & other)
|
||||
{
|
||||
if (other.events_list.empty())
|
||||
return;
|
||||
|
||||
events_list.insert(std::begin(other.events_list), std::end(other.events_list));
|
||||
sorted = false;
|
||||
conditions_met |= other.conditions_met;
|
||||
}
|
||||
|
||||
void sort()
|
||||
{
|
||||
if (sorted)
|
||||
return;
|
||||
|
||||
::sort(std::begin(events_list), std::end(events_list), Comparator{});
|
||||
sorted = true;
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
writeBinary(sorted, buf);
|
||||
writeBinary(events_list.size(), buf);
|
||||
|
||||
for (const auto & events : events_list)
|
||||
{
|
||||
writeBinary(events.first, buf);
|
||||
writeBinary(events.second.to_ulong(), buf);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
readBinary(sorted, buf);
|
||||
|
||||
size_t size;
|
||||
readBinary(size, buf);
|
||||
|
||||
/// If we lose these flags, functionality is broken
|
||||
/// If we serialize/deserialize these flags, we have compatibility issues
|
||||
/// If we set these flags to 1, we have a minor performance penalty, which seems acceptable
|
||||
conditions_met.set();
|
||||
|
||||
events_list.clear();
|
||||
events_list.reserve(size);
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
Timestamp timestamp;
|
||||
readBinary(timestamp, buf);
|
||||
|
||||
UInt64 events;
|
||||
readBinary(events, buf);
|
||||
|
||||
events_list.emplace_back(timestamp, Events{events});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Max number of iterations to match the pattern against a sequence, exception thrown when exceeded
|
||||
constexpr auto sequence_match_max_iterations = 1000000;
|
||||
|
||||
|
||||
template <typename T, typename Data, typename Derived>
|
||||
class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper<Data, Derived>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionSequenceBase(const DataTypes & arguments, const Array & params, const String & pattern_, const DataTypePtr & result_type_)
|
||||
: IAggregateFunctionDataHelper<Data, Derived>(arguments, params, result_type_)
|
||||
, pattern(pattern_)
|
||||
{
|
||||
arg_count = arguments.size();
|
||||
parsePattern();
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override
|
||||
{
|
||||
const auto timestamp = assert_cast<const ColumnVector<T> *>(columns[0])->getData()[row_num];
|
||||
|
||||
typename Data::Events events;
|
||||
for (const auto i : collections::range(1, arg_count))
|
||||
{
|
||||
const auto event = assert_cast<const ColumnUInt8 *>(columns[i])->getData()[row_num];
|
||||
events.set(i - 1, event);
|
||||
}
|
||||
|
||||
this->data(place).add(timestamp, events);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
bool haveSameStateRepresentationImpl(const IAggregateFunction & rhs) const override
|
||||
{
|
||||
return this->getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs);
|
||||
}
|
||||
|
||||
private:
|
||||
enum class PatternActionType
|
||||
{
|
||||
SpecificEvent,
|
||||
AnyEvent,
|
||||
KleeneStar,
|
||||
TimeLessOrEqual,
|
||||
TimeLess,
|
||||
TimeGreaterOrEqual,
|
||||
TimeGreater,
|
||||
TimeEqual
|
||||
};
|
||||
|
||||
struct PatternAction final
|
||||
{
|
||||
PatternActionType type;
|
||||
std::uint64_t extra;
|
||||
|
||||
PatternAction() = default;
|
||||
explicit PatternAction(const PatternActionType type_, const std::uint64_t extra_ = 0) : type{type_}, extra{extra_} {}
|
||||
};
|
||||
|
||||
using PatternActions = PODArrayWithStackMemory<PatternAction, 64>;
|
||||
|
||||
Derived & derived() { return static_cast<Derived &>(*this); }
|
||||
|
||||
void parsePattern()
|
||||
{
|
||||
actions.clear();
|
||||
actions.emplace_back(PatternActionType::KleeneStar);
|
||||
|
||||
dfa_states.clear();
|
||||
dfa_states.emplace_back(true);
|
||||
|
||||
pattern_has_time = false;
|
||||
|
||||
const char * pos = pattern.data();
|
||||
const char * begin = pos;
|
||||
const char * end = pos + pattern.size();
|
||||
|
||||
auto throw_exception = [&](const std::string & msg)
|
||||
{
|
||||
throw Exception(ErrorCodes::SYNTAX_ERROR, "{} '{}' at position {}", msg, std::string(pos, end), toString(pos - begin));
|
||||
};
|
||||
|
||||
auto match = [&pos, end](const char * str) mutable
|
||||
{
|
||||
size_t length = strlen(str);
|
||||
if (pos + length <= end && 0 == memcmp(pos, str, length))
|
||||
{
|
||||
pos += length;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
while (pos < end)
|
||||
{
|
||||
if (match("(?"))
|
||||
{
|
||||
if (match("t"))
|
||||
{
|
||||
PatternActionType type;
|
||||
|
||||
if (match("<="))
|
||||
type = PatternActionType::TimeLessOrEqual;
|
||||
else if (match("<"))
|
||||
type = PatternActionType::TimeLess;
|
||||
else if (match(">="))
|
||||
type = PatternActionType::TimeGreaterOrEqual;
|
||||
else if (match(">"))
|
||||
type = PatternActionType::TimeGreater;
|
||||
else if (match("=="))
|
||||
type = PatternActionType::TimeEqual;
|
||||
else
|
||||
throw_exception("Unknown time condition");
|
||||
|
||||
UInt64 duration = 0;
|
||||
const auto * prev_pos = pos;
|
||||
pos = tryReadIntText(duration, pos, end);
|
||||
if (pos == prev_pos)
|
||||
throw_exception("Could not parse number");
|
||||
|
||||
if (actions.back().type != PatternActionType::SpecificEvent &&
|
||||
actions.back().type != PatternActionType::AnyEvent &&
|
||||
actions.back().type != PatternActionType::KleeneStar)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Temporal condition should be preceded by an event condition");
|
||||
|
||||
pattern_has_time = true;
|
||||
actions.emplace_back(type, duration);
|
||||
}
|
||||
else
|
||||
{
|
||||
UInt64 event_number = 0;
|
||||
const auto * prev_pos = pos;
|
||||
pos = tryReadIntText(event_number, pos, end);
|
||||
if (pos == prev_pos)
|
||||
throw_exception("Could not parse number");
|
||||
|
||||
if (event_number > arg_count - 1)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Event number {} is out of range", event_number);
|
||||
|
||||
actions.emplace_back(PatternActionType::SpecificEvent, event_number - 1);
|
||||
dfa_states.back().transition = DFATransition::SpecificEvent;
|
||||
dfa_states.back().event = static_cast<uint32_t>(event_number - 1);
|
||||
dfa_states.emplace_back();
|
||||
conditions_in_pattern.set(event_number - 1);
|
||||
}
|
||||
|
||||
if (!match(")"))
|
||||
throw_exception("Expected closing parenthesis, found");
|
||||
|
||||
}
|
||||
else if (match(".*"))
|
||||
{
|
||||
actions.emplace_back(PatternActionType::KleeneStar);
|
||||
dfa_states.back().has_kleene = true;
|
||||
}
|
||||
else if (match("."))
|
||||
{
|
||||
actions.emplace_back(PatternActionType::AnyEvent);
|
||||
dfa_states.back().transition = DFATransition::AnyEvent;
|
||||
dfa_states.emplace_back();
|
||||
}
|
||||
else
|
||||
throw_exception("Could not parse pattern, unexpected starting symbol");
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
/// Uses a DFA based approach in order to better handle patterns without
|
||||
/// time assertions.
|
||||
///
|
||||
/// NOTE: This implementation relies on the assumption that the pattern is *small*.
|
||||
///
|
||||
/// This algorithm performs in O(mn) (with m the number of DFA states and N the number
|
||||
/// of events) with a memory consumption and memory allocations in O(m). It means that
|
||||
/// if n >>> m (which is expected to be the case), this algorithm can be considered linear.
|
||||
template <typename EventEntry>
|
||||
bool dfaMatch(EventEntry & events_it, const EventEntry events_end) const
|
||||
{
|
||||
using ActiveStates = std::vector<bool>;
|
||||
|
||||
/// Those two vectors keep track of which states should be considered for the current
|
||||
/// event as well as the states which should be considered for the next event.
|
||||
ActiveStates active_states(dfa_states.size(), false);
|
||||
ActiveStates next_active_states(dfa_states.size(), false);
|
||||
active_states[0] = true;
|
||||
|
||||
/// Keeps track of dead-ends in order not to iterate over all the events to realize that
|
||||
/// the match failed.
|
||||
size_t n_active = 1;
|
||||
|
||||
for (/* empty */; events_it != events_end && n_active > 0 && !active_states.back(); ++events_it)
|
||||
{
|
||||
n_active = 0;
|
||||
next_active_states.assign(dfa_states.size(), false);
|
||||
|
||||
for (size_t state = 0; state < dfa_states.size(); ++state)
|
||||
{
|
||||
if (!active_states[state])
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
switch (dfa_states[state].transition)
|
||||
{
|
||||
case DFATransition::None:
|
||||
break;
|
||||
case DFATransition::AnyEvent:
|
||||
next_active_states[state + 1] = true;
|
||||
++n_active;
|
||||
break;
|
||||
case DFATransition::SpecificEvent:
|
||||
if (events_it->second.test(dfa_states[state].event))
|
||||
{
|
||||
next_active_states[state + 1] = true;
|
||||
++n_active;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (dfa_states[state].has_kleene)
|
||||
{
|
||||
next_active_states[state] = true;
|
||||
++n_active;
|
||||
}
|
||||
}
|
||||
swap(active_states, next_active_states);
|
||||
}
|
||||
|
||||
return active_states.back();
|
||||
}
|
||||
|
||||
template <typename EventEntry>
|
||||
bool backtrackingMatch(EventEntry & events_it, const EventEntry events_end) const
|
||||
{
|
||||
const auto action_begin = std::begin(actions);
|
||||
const auto action_end = std::end(actions);
|
||||
auto action_it = action_begin;
|
||||
|
||||
const auto events_begin = events_it;
|
||||
auto base_it = events_it;
|
||||
|
||||
/// an iterator to action plus an iterator to row in events list plus timestamp at the start of sequence
|
||||
using backtrack_info = std::tuple<decltype(action_it), EventEntry, EventEntry>;
|
||||
std::stack<backtrack_info> back_stack;
|
||||
|
||||
/// backtrack if possible
|
||||
const auto do_backtrack = [&]
|
||||
{
|
||||
while (!back_stack.empty())
|
||||
{
|
||||
auto & top = back_stack.top();
|
||||
|
||||
action_it = std::get<0>(top);
|
||||
events_it = std::next(std::get<1>(top));
|
||||
base_it = std::get<2>(top);
|
||||
|
||||
back_stack.pop();
|
||||
|
||||
if (events_it != events_end)
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
size_t i = 0;
|
||||
while (action_it != action_end && events_it != events_end)
|
||||
{
|
||||
if (action_it->type == PatternActionType::SpecificEvent)
|
||||
{
|
||||
if (events_it->second.test(action_it->extra))
|
||||
{
|
||||
/// move to the next action and events
|
||||
base_it = events_it;
|
||||
++action_it, ++events_it;
|
||||
}
|
||||
else if (!do_backtrack())
|
||||
/// backtracking failed, bail out
|
||||
break;
|
||||
}
|
||||
else if (action_it->type == PatternActionType::AnyEvent)
|
||||
{
|
||||
base_it = events_it;
|
||||
++action_it, ++events_it;
|
||||
}
|
||||
else if (action_it->type == PatternActionType::KleeneStar)
|
||||
{
|
||||
back_stack.emplace(action_it, events_it, base_it);
|
||||
base_it = events_it;
|
||||
++action_it;
|
||||
}
|
||||
else if (action_it->type == PatternActionType::TimeLessOrEqual)
|
||||
{
|
||||
if (events_it->first <= base_it->first + action_it->extra)
|
||||
{
|
||||
/// condition satisfied, move onto next action
|
||||
back_stack.emplace(action_it, events_it, base_it);
|
||||
base_it = events_it;
|
||||
++action_it;
|
||||
}
|
||||
else if (!do_backtrack())
|
||||
break;
|
||||
}
|
||||
else if (action_it->type == PatternActionType::TimeLess)
|
||||
{
|
||||
if (events_it->first < base_it->first + action_it->extra)
|
||||
{
|
||||
back_stack.emplace(action_it, events_it, base_it);
|
||||
base_it = events_it;
|
||||
++action_it;
|
||||
}
|
||||
else if (!do_backtrack())
|
||||
break;
|
||||
}
|
||||
else if (action_it->type == PatternActionType::TimeGreaterOrEqual)
|
||||
{
|
||||
if (events_it->first >= base_it->first + action_it->extra)
|
||||
{
|
||||
back_stack.emplace(action_it, events_it, base_it);
|
||||
base_it = events_it;
|
||||
++action_it;
|
||||
}
|
||||
else if (++events_it == events_end && !do_backtrack())
|
||||
break;
|
||||
}
|
||||
else if (action_it->type == PatternActionType::TimeGreater)
|
||||
{
|
||||
if (events_it->first > base_it->first + action_it->extra)
|
||||
{
|
||||
back_stack.emplace(action_it, events_it, base_it);
|
||||
base_it = events_it;
|
||||
++action_it;
|
||||
}
|
||||
else if (++events_it == events_end && !do_backtrack())
|
||||
break;
|
||||
}
|
||||
else if (action_it->type == PatternActionType::TimeEqual)
|
||||
{
|
||||
if (events_it->first == base_it->first + action_it->extra)
|
||||
{
|
||||
back_stack.emplace(action_it, events_it, base_it);
|
||||
base_it = events_it;
|
||||
++action_it;
|
||||
}
|
||||
else if (++events_it == events_end && !do_backtrack())
|
||||
break;
|
||||
}
|
||||
else
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unknown PatternActionType");
|
||||
|
||||
if (++i > sequence_match_max_iterations)
|
||||
throw Exception(ErrorCodes::TOO_SLOW, "Pattern application proves too difficult, exceeding max iterations ({})",
|
||||
sequence_match_max_iterations);
|
||||
}
|
||||
|
||||
/// if there are some actions remaining
|
||||
if (action_it != action_end)
|
||||
{
|
||||
/// match multiple empty strings at end
|
||||
while (action_it->type == PatternActionType::KleeneStar ||
|
||||
action_it->type == PatternActionType::TimeLessOrEqual ||
|
||||
action_it->type == PatternActionType::TimeLess ||
|
||||
(action_it->type == PatternActionType::TimeGreaterOrEqual && action_it->extra == 0))
|
||||
++action_it;
|
||||
}
|
||||
|
||||
if (events_it == events_begin)
|
||||
++events_it;
|
||||
|
||||
return action_it == action_end;
|
||||
}
|
||||
|
||||
/// Splits the pattern into deterministic parts separated by non-deterministic fragments
|
||||
/// (time constraints and Kleene stars), and tries to match the deterministic parts in their specified order,
|
||||
/// ignoring the non-deterministic fragments.
|
||||
/// This function can quickly check that a full match is not possible if some deterministic fragment is missing.
|
||||
template <typename EventEntry>
|
||||
bool couldMatchDeterministicParts(const EventEntry events_begin, const EventEntry events_end, bool limit_iterations = true) const
|
||||
{
|
||||
size_t events_processed = 0;
|
||||
auto events_it = events_begin;
|
||||
|
||||
const auto actions_end = std::end(actions);
|
||||
auto actions_it = std::begin(actions);
|
||||
auto det_part_begin = actions_it;
|
||||
|
||||
auto match_deterministic_part = [&events_it, events_end, &events_processed, det_part_begin, actions_it, limit_iterations]()
|
||||
{
|
||||
auto events_it_init = events_it;
|
||||
auto det_part_it = det_part_begin;
|
||||
|
||||
while (det_part_it != actions_it && events_it != events_end)
|
||||
{
|
||||
/// matching any event
|
||||
if (det_part_it->type == PatternActionType::AnyEvent)
|
||||
++events_it, ++det_part_it;
|
||||
|
||||
/// matching specific event
|
||||
else
|
||||
{
|
||||
if (events_it->second.test(det_part_it->extra))
|
||||
++events_it, ++det_part_it;
|
||||
|
||||
/// abandon current matching, try to match the deterministic fragment further in the list
|
||||
else
|
||||
{
|
||||
events_it = ++events_it_init;
|
||||
det_part_it = det_part_begin;
|
||||
}
|
||||
}
|
||||
|
||||
if (limit_iterations && ++events_processed > sequence_match_max_iterations)
|
||||
throw Exception(ErrorCodes::TOO_SLOW, "Pattern application proves too difficult, exceeding max iterations ({})",
|
||||
sequence_match_max_iterations);
|
||||
}
|
||||
|
||||
return det_part_it == actions_it;
|
||||
};
|
||||
|
||||
for (; actions_it != actions_end; ++actions_it)
|
||||
if (actions_it->type != PatternActionType::SpecificEvent && actions_it->type != PatternActionType::AnyEvent)
|
||||
{
|
||||
if (!match_deterministic_part())
|
||||
return false;
|
||||
det_part_begin = std::next(actions_it);
|
||||
}
|
||||
|
||||
return match_deterministic_part();
|
||||
}
|
||||
|
||||
private:
|
||||
enum class DFATransition : char
|
||||
{
|
||||
/// .-------.
|
||||
/// | |
|
||||
/// `-------'
|
||||
None,
|
||||
/// .-------. (?[0-9])
|
||||
/// | | ----------
|
||||
/// `-------'
|
||||
SpecificEvent,
|
||||
/// .-------. .
|
||||
/// | | ----------
|
||||
/// `-------'
|
||||
AnyEvent,
|
||||
};
|
||||
|
||||
struct DFAState
|
||||
{
|
||||
explicit DFAState(bool has_kleene_ = false)
|
||||
: has_kleene{has_kleene_}, event{0}, transition{DFATransition::None}
|
||||
{}
|
||||
|
||||
/// .-------.
|
||||
/// | | - - -
|
||||
/// `-------'
|
||||
/// |_^
|
||||
bool has_kleene;
|
||||
/// In the case of a state transitions with a `SpecificEvent`,
|
||||
/// `event` contains the value of the event.
|
||||
uint32_t event;
|
||||
/// The kind of transition out of this state.
|
||||
DFATransition transition;
|
||||
};
|
||||
|
||||
using DFAStates = std::vector<DFAState>;
|
||||
|
||||
protected:
|
||||
/// `True` if the parsed pattern contains time assertions (?t...), `false` otherwise.
|
||||
bool pattern_has_time;
|
||||
/// sequenceMatch conditions met at least once in the pattern
|
||||
std::bitset<max_events> conditions_in_pattern;
|
||||
|
||||
private:
|
||||
std::string pattern;
|
||||
size_t arg_count;
|
||||
PatternActions actions;
|
||||
|
||||
DFAStates dfa_states;
|
||||
};
|
||||
|
||||
template <typename T, typename Data>
|
||||
class AggregateFunctionSequenceMatch final : public AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionSequenceMatch(const DataTypes & arguments, const Array & params, const String & pattern_)
|
||||
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>(arguments, params, pattern_, std::make_shared<DataTypeUInt8>()) {}
|
||||
|
||||
using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>::AggregateFunctionSequenceBase;
|
||||
|
||||
String getName() const override { return "sequenceMatch"; }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
auto & output = assert_cast<ColumnUInt8 &>(to).getData();
|
||||
if ((this->conditions_in_pattern & this->data(place).conditions_met) != this->conditions_in_pattern)
|
||||
{
|
||||
output.push_back(false);
|
||||
return;
|
||||
}
|
||||
this->data(place).sort();
|
||||
|
||||
const auto & data_ref = this->data(place);
|
||||
|
||||
const auto events_begin = std::begin(data_ref.events_list);
|
||||
const auto events_end = std::end(data_ref.events_list);
|
||||
auto events_it = events_begin;
|
||||
|
||||
bool match = (this->pattern_has_time ?
|
||||
(this->couldMatchDeterministicParts(events_begin, events_end) && this->backtrackingMatch(events_it, events_end)) :
|
||||
this->dfaMatch(events_it, events_end));
|
||||
output.push_back(match);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Data>
|
||||
class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionSequenceCount(const DataTypes & arguments, const Array & params, const String & pattern_)
|
||||
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>(arguments, params, pattern_, std::make_shared<DataTypeUInt64>()) {}
|
||||
|
||||
using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>::AggregateFunctionSequenceBase;
|
||||
|
||||
String getName() const override { return "sequenceCount"; }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
auto & output = assert_cast<ColumnUInt64 &>(to).getData();
|
||||
if ((this->conditions_in_pattern & this->data(place).conditions_met) != this->conditions_in_pattern)
|
||||
{
|
||||
output.push_back(0);
|
||||
return;
|
||||
}
|
||||
this->data(place).sort();
|
||||
output.push_back(count(place));
|
||||
}
|
||||
|
||||
private:
|
||||
UInt64 count(ConstAggregateDataPtr __restrict place) const
|
||||
{
|
||||
const auto & data_ref = this->data(place);
|
||||
|
||||
const auto events_begin = std::begin(data_ref.events_list);
|
||||
const auto events_end = std::end(data_ref.events_list);
|
||||
auto events_it = events_begin;
|
||||
|
||||
size_t count = 0;
|
||||
// check if there is a chance of matching the sequence at least once
|
||||
if (this->couldMatchDeterministicParts(events_begin, events_end))
|
||||
{
|
||||
while (events_it != events_end && this->backtrackingMatch(events_it, events_end))
|
||||
++count;
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -1,14 +1,25 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionSequenceNextNode.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <Core/Settings.h>
|
||||
#include <DataTypes/DataTypeDate.h>
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Common/CurrentThread.h>
|
||||
#include <base/range.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <IO/WriteBufferFromString.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnNullable.h>
|
||||
|
||||
#include <Common/ArenaAllocator.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
#include <bitset>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -24,11 +35,409 @@ namespace ErrorCodes
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
extern const int UNKNOWN_AGGREGATE_FUNCTION;
|
||||
extern const int TOO_LARGE_ARRAY_SIZE;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
enum class SequenceDirection
|
||||
{
|
||||
Forward,
|
||||
Backward,
|
||||
};
|
||||
|
||||
enum SequenceBase
|
||||
{
|
||||
Head,
|
||||
Tail,
|
||||
FirstMatch,
|
||||
LastMatch,
|
||||
};
|
||||
|
||||
/// This is for security
|
||||
const UInt64 max_node_size_deserialize = 0xFFFFFF;
|
||||
|
||||
/// NodeBase used to implement a linked list for storage of SequenceNextNodeImpl
|
||||
template <typename Node, size_t MaxEventsSize>
|
||||
struct NodeBase
|
||||
{
|
||||
UInt64 size; /// size of payload
|
||||
|
||||
DataTypeDateTime::FieldType event_time;
|
||||
std::bitset<MaxEventsSize> events_bitset;
|
||||
bool can_be_base;
|
||||
|
||||
char * data() { return reinterpret_cast<char *>(this) + sizeof(Node); }
|
||||
|
||||
const char * data() const { return reinterpret_cast<const char *>(this) + sizeof(Node); }
|
||||
|
||||
Node * clone(Arena * arena) const
|
||||
{
|
||||
return reinterpret_cast<Node *>(
|
||||
const_cast<char *>(arena->alignedInsert(reinterpret_cast<const char *>(this), sizeof(Node) + size, alignof(Node))));
|
||||
}
|
||||
|
||||
void write(WriteBuffer & buf) const
|
||||
{
|
||||
writeVarUInt(size, buf);
|
||||
buf.write(data(), size);
|
||||
|
||||
writeBinary(event_time, buf);
|
||||
UInt64 ulong_bitset = events_bitset.to_ulong();
|
||||
writeBinary(ulong_bitset, buf);
|
||||
writeBinary(can_be_base, buf);
|
||||
}
|
||||
|
||||
static Node * read(ReadBuffer & buf, Arena * arena)
|
||||
{
|
||||
UInt64 size;
|
||||
readVarUInt(size, buf);
|
||||
if (unlikely(size > max_node_size_deserialize))
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large node state size");
|
||||
|
||||
Node * node = reinterpret_cast<Node *>(arena->alignedAlloc(sizeof(Node) + size, alignof(Node)));
|
||||
node->size = size;
|
||||
buf.readStrict(node->data(), size);
|
||||
|
||||
readBinary(node->event_time, buf);
|
||||
UInt64 ulong_bitset;
|
||||
readBinary(ulong_bitset, buf);
|
||||
node->events_bitset = ulong_bitset;
|
||||
readBinary(node->can_be_base, buf);
|
||||
|
||||
return node;
|
||||
}
|
||||
};
|
||||
|
||||
/// It stores String, timestamp, bitset of matched events.
|
||||
template <size_t MaxEventsSize>
|
||||
struct NodeString : public NodeBase<NodeString<MaxEventsSize>, MaxEventsSize>
|
||||
{
|
||||
using Node = NodeString<MaxEventsSize>;
|
||||
|
||||
static Node * allocate(const IColumn & column, size_t row_num, Arena * arena)
|
||||
{
|
||||
StringRef string = assert_cast<const ColumnString &>(column).getDataAt(row_num);
|
||||
|
||||
Node * node = reinterpret_cast<Node *>(arena->alignedAlloc(sizeof(Node) + string.size, alignof(Node)));
|
||||
node->size = string.size;
|
||||
memcpy(node->data(), string.data, string.size);
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
void insertInto(IColumn & column)
|
||||
{
|
||||
assert_cast<ColumnString &>(column).insertData(this->data(), this->size);
|
||||
}
|
||||
|
||||
bool compare(const Node * rhs) const
|
||||
{
|
||||
auto cmp = strncmp(this->data(), rhs->data(), std::min(this->size, rhs->size));
|
||||
return (cmp == 0) ? this->size < rhs->size : cmp < 0;
|
||||
}
|
||||
};
|
||||
|
||||
/// TODO : Support other types than string
|
||||
template <typename Node>
|
||||
struct SequenceNextNodeGeneralData
|
||||
{
|
||||
using Allocator = MixedAlignedArenaAllocator<alignof(Node *), 4096>;
|
||||
using Array = PODArray<Node *, 32, Allocator>;
|
||||
|
||||
Array value;
|
||||
bool sorted = false;
|
||||
|
||||
struct Comparator final
|
||||
{
|
||||
bool operator()(const Node * lhs, const Node * rhs) const
|
||||
{
|
||||
return lhs->event_time == rhs->event_time ? lhs->compare(rhs) : lhs->event_time < rhs->event_time;
|
||||
}
|
||||
};
|
||||
|
||||
void sort()
|
||||
{
|
||||
if (!sorted)
|
||||
{
|
||||
std::stable_sort(std::begin(value), std::end(value), Comparator{});
|
||||
sorted = true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Implementation of sequenceFirstNode
|
||||
template <typename T, typename Node>
|
||||
class SequenceNextNodeImpl final
|
||||
: public IAggregateFunctionDataHelper<SequenceNextNodeGeneralData<Node>, SequenceNextNodeImpl<T, Node>>
|
||||
{
|
||||
using Self = SequenceNextNodeImpl<T, Node>;
|
||||
|
||||
using Data = SequenceNextNodeGeneralData<Node>;
|
||||
static Data & data(AggregateDataPtr __restrict place) { return *reinterpret_cast<Data *>(place); }
|
||||
static const Data & data(ConstAggregateDataPtr __restrict place) { return *reinterpret_cast<const Data *>(place); }
|
||||
|
||||
static constexpr size_t base_cond_column_idx = 2;
|
||||
static constexpr size_t event_column_idx = 1;
|
||||
|
||||
SequenceBase seq_base_kind;
|
||||
SequenceDirection seq_direction;
|
||||
const size_t min_required_args;
|
||||
|
||||
DataTypePtr & data_type;
|
||||
UInt8 events_size;
|
||||
UInt64 max_elems;
|
||||
public:
|
||||
SequenceNextNodeImpl(
|
||||
const DataTypePtr & data_type_,
|
||||
const DataTypes & arguments,
|
||||
const Array & parameters_,
|
||||
SequenceBase seq_base_kind_,
|
||||
SequenceDirection seq_direction_,
|
||||
size_t min_required_args_,
|
||||
UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
|
||||
: IAggregateFunctionDataHelper<SequenceNextNodeGeneralData<Node>, Self>(arguments, parameters_, data_type_)
|
||||
, seq_base_kind(seq_base_kind_)
|
||||
, seq_direction(seq_direction_)
|
||||
, min_required_args(min_required_args_)
|
||||
, data_type(this->argument_types[0])
|
||||
, events_size(arguments.size() - min_required_args)
|
||||
, max_elems(max_elems_)
|
||||
{
|
||||
}
|
||||
|
||||
String getName() const override { return "sequenceNextNode"; }
|
||||
|
||||
bool haveSameStateRepresentationImpl(const IAggregateFunction & rhs) const override
|
||||
{
|
||||
return this->getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs);
|
||||
}
|
||||
|
||||
void insert(Data & a, const Node * v, Arena * arena) const
|
||||
{
|
||||
++a.total_values;
|
||||
a.value.push_back(v->clone(arena), arena);
|
||||
}
|
||||
|
||||
void create(AggregateDataPtr __restrict place) const override /// NOLINT
|
||||
{
|
||||
new (place) Data;
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
Node * node = Node::allocate(*columns[event_column_idx], row_num, arena);
|
||||
|
||||
const auto timestamp = assert_cast<const ColumnVector<T> *>(columns[0])->getData()[row_num];
|
||||
|
||||
/// The events_bitset variable stores matched events in the form of bitset.
|
||||
/// Each Nth-bit indicates that the Nth-event are matched.
|
||||
/// For example, event1 and event3 is matched then the values of events_bitset is 0x00000005.
|
||||
/// 0x00000000
|
||||
/// + 1 (bit of event1)
|
||||
/// + 4 (bit of event3)
|
||||
node->events_bitset.reset();
|
||||
for (UInt8 i = 0; i < events_size; ++i)
|
||||
if (assert_cast<const ColumnVector<UInt8> *>(columns[min_required_args + i])->getData()[row_num])
|
||||
node->events_bitset.set(i);
|
||||
node->event_time = static_cast<DataTypeDateTime::FieldType>(timestamp);
|
||||
|
||||
node->can_be_base = assert_cast<const ColumnVector<UInt8> *>(columns[base_cond_column_idx])->getData()[row_num];
|
||||
|
||||
data(place).value.push_back(node, arena);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
if (data(rhs).value.empty())
|
||||
return;
|
||||
|
||||
if (data(place).value.size() >= max_elems)
|
||||
return;
|
||||
|
||||
auto & a = data(place).value;
|
||||
auto & b = data(rhs).value;
|
||||
const auto a_size = a.size();
|
||||
|
||||
const UInt64 new_elems = std::min(data(rhs).value.size(), static_cast<size_t>(max_elems) - data(place).value.size());
|
||||
for (UInt64 i = 0; i < new_elems; ++i)
|
||||
a.push_back(b[i]->clone(arena), arena);
|
||||
|
||||
/// Either sort whole container or do so partially merging ranges afterwards
|
||||
using Comparator = typename SequenceNextNodeGeneralData<Node>::Comparator;
|
||||
|
||||
if (!data(place).sorted && !data(rhs).sorted)
|
||||
std::stable_sort(std::begin(a), std::end(a), Comparator{});
|
||||
else
|
||||
{
|
||||
const auto begin = std::begin(a);
|
||||
const auto middle = std::next(begin, a_size);
|
||||
const auto end = std::end(a);
|
||||
|
||||
if (!data(place).sorted)
|
||||
std::stable_sort(begin, middle, Comparator{});
|
||||
|
||||
if (!data(rhs).sorted)
|
||||
std::stable_sort(middle, end, Comparator{});
|
||||
|
||||
std::inplace_merge(begin, middle, end, Comparator{});
|
||||
}
|
||||
|
||||
data(place).sorted = true;
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
/// Temporarily do a const_cast to sort the values. It helps to reduce the computational burden on the initiator node.
|
||||
this->data(const_cast<AggregateDataPtr>(place)).sort();
|
||||
|
||||
writeBinary(data(place).sorted, buf);
|
||||
|
||||
auto & value = data(place).value;
|
||||
|
||||
size_t size = std::min(static_cast<size_t>(events_size + 1), value.size());
|
||||
switch (seq_base_kind)
|
||||
{
|
||||
case SequenceBase::Head:
|
||||
writeVarUInt(size, buf);
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
value[i]->write(buf);
|
||||
break;
|
||||
|
||||
case SequenceBase::Tail:
|
||||
writeVarUInt(size, buf);
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
value[value.size() - size + i]->write(buf);
|
||||
break;
|
||||
|
||||
case SequenceBase::FirstMatch:
|
||||
case SequenceBase::LastMatch:
|
||||
writeVarUInt(value.size(), buf);
|
||||
for (auto & node : value)
|
||||
node->write(buf);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
readBinary(data(place).sorted, buf);
|
||||
|
||||
UInt64 size;
|
||||
readVarUInt(size, buf);
|
||||
|
||||
if (unlikely(size == 0))
|
||||
return;
|
||||
|
||||
if (unlikely(size > max_node_size_deserialize))
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
|
||||
"Too large array size (maximum: {})", max_node_size_deserialize);
|
||||
|
||||
auto & value = data(place).value;
|
||||
|
||||
value.resize(size, arena);
|
||||
for (UInt64 i = 0; i < size; ++i)
|
||||
value[i] = Node::read(buf, arena);
|
||||
}
|
||||
|
||||
inline std::optional<size_t> getBaseIndex(Data & data) const
|
||||
{
|
||||
if (data.value.size() == 0)
|
||||
return {};
|
||||
|
||||
switch (seq_base_kind)
|
||||
{
|
||||
case SequenceBase::Head:
|
||||
if (data.value[0]->can_be_base)
|
||||
return 0;
|
||||
break;
|
||||
|
||||
case SequenceBase::Tail:
|
||||
if (data.value[data.value.size() - 1]->can_be_base)
|
||||
return data.value.size() - 1;
|
||||
break;
|
||||
|
||||
case SequenceBase::FirstMatch:
|
||||
for (size_t i = 0; i < data.value.size(); ++i)
|
||||
{
|
||||
if (data.value[i]->events_bitset.test(0) && data.value[i]->can_be_base)
|
||||
return i;
|
||||
}
|
||||
break;
|
||||
|
||||
case SequenceBase::LastMatch:
|
||||
for (size_t i = 0; i < data.value.size(); ++i)
|
||||
{
|
||||
auto reversed_i = data.value.size() - i - 1;
|
||||
if (data.value[reversed_i]->events_bitset.test(0) && data.value[reversed_i]->can_be_base)
|
||||
return reversed_i;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
/// This method returns an index of next node that matched the events.
|
||||
/// matched events in the chain of events are represented as a bitmask.
|
||||
/// The first matched event is 0x00000001, the second one is 0x00000002, the third one is 0x00000004, and so on.
|
||||
UInt32 getNextNodeIndex(Data & data) const
|
||||
{
|
||||
const UInt32 unmatched_idx = static_cast<UInt32>(data.value.size());
|
||||
|
||||
if (data.value.size() <= events_size)
|
||||
return unmatched_idx;
|
||||
|
||||
data.sort();
|
||||
|
||||
std::optional<size_t> base_opt = getBaseIndex(data);
|
||||
if (!base_opt.has_value())
|
||||
return unmatched_idx;
|
||||
UInt32 base = static_cast<UInt32>(base_opt.value());
|
||||
|
||||
if (events_size == 0)
|
||||
return data.value.size() > 0 ? base : unmatched_idx;
|
||||
|
||||
UInt32 i = 0;
|
||||
switch (seq_direction)
|
||||
{
|
||||
case SequenceDirection::Forward:
|
||||
for (i = 0; i < events_size && base + i < data.value.size(); ++i)
|
||||
if (!data.value[base + i]->events_bitset.test(i))
|
||||
break;
|
||||
return (i == events_size) ? base + i : unmatched_idx;
|
||||
|
||||
case SequenceDirection::Backward:
|
||||
for (i = 0; i < events_size && i < base; ++i)
|
||||
if (!data.value[base - i]->events_bitset.test(i))
|
||||
break;
|
||||
return (i == events_size) ? base - i : unmatched_idx;
|
||||
}
|
||||
UNREACHABLE();
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
auto & value = data(place).value;
|
||||
|
||||
UInt32 event_idx = getNextNodeIndex(this->data(place));
|
||||
if (event_idx < value.size())
|
||||
{
|
||||
ColumnNullable & to_concrete = assert_cast<ColumnNullable &>(to);
|
||||
value[event_idx]->insertInto(to_concrete.getNestedColumn());
|
||||
to_concrete.getNullMapData().push_back(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
to.insertDefault();
|
||||
}
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return true; }
|
||||
};
|
||||
|
||||
|
||||
template <typename T>
|
||||
inline AggregateFunctionPtr createAggregateFunctionSequenceNodeImpl(
|
||||
const DataTypePtr data_type, const DataTypes & argument_types, const Array & parameters, SequenceDirection direction, SequenceBase base)
|
||||
|
@ -1,432 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <IO/WriteBufferFromString.h>
|
||||
#include <IO/Operators.h>
|
||||
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnNullable.h>
|
||||
|
||||
#include <Common/ArenaAllocator.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
#include <type_traits>
|
||||
#include <bitset>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int TOO_LARGE_ARRAY_SIZE;
|
||||
}
|
||||
|
||||
enum class SequenceDirection
|
||||
{
|
||||
Forward,
|
||||
Backward,
|
||||
};
|
||||
|
||||
enum SequenceBase
|
||||
{
|
||||
Head,
|
||||
Tail,
|
||||
FirstMatch,
|
||||
LastMatch,
|
||||
};
|
||||
|
||||
/// This is for security
|
||||
static const UInt64 max_node_size_deserialize = 0xFFFFFF;
|
||||
|
||||
/// NodeBase used to implement a linked list for storage of SequenceNextNodeImpl
|
||||
template <typename Node, size_t MaxEventsSize>
|
||||
struct NodeBase
|
||||
{
|
||||
UInt64 size; /// size of payload
|
||||
|
||||
DataTypeDateTime::FieldType event_time;
|
||||
std::bitset<MaxEventsSize> events_bitset;
|
||||
bool can_be_base;
|
||||
|
||||
char * data() { return reinterpret_cast<char *>(this) + sizeof(Node); }
|
||||
|
||||
const char * data() const { return reinterpret_cast<const char *>(this) + sizeof(Node); }
|
||||
|
||||
Node * clone(Arena * arena) const
|
||||
{
|
||||
return reinterpret_cast<Node *>(
|
||||
const_cast<char *>(arena->alignedInsert(reinterpret_cast<const char *>(this), sizeof(Node) + size, alignof(Node))));
|
||||
}
|
||||
|
||||
void write(WriteBuffer & buf) const
|
||||
{
|
||||
writeVarUInt(size, buf);
|
||||
buf.write(data(), size);
|
||||
|
||||
writeBinary(event_time, buf);
|
||||
UInt64 ulong_bitset = events_bitset.to_ulong();
|
||||
writeBinary(ulong_bitset, buf);
|
||||
writeBinary(can_be_base, buf);
|
||||
}
|
||||
|
||||
static Node * read(ReadBuffer & buf, Arena * arena)
|
||||
{
|
||||
UInt64 size;
|
||||
readVarUInt(size, buf);
|
||||
if (unlikely(size > max_node_size_deserialize))
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large node state size");
|
||||
|
||||
Node * node = reinterpret_cast<Node *>(arena->alignedAlloc(sizeof(Node) + size, alignof(Node)));
|
||||
node->size = size;
|
||||
buf.readStrict(node->data(), size);
|
||||
|
||||
readBinary(node->event_time, buf);
|
||||
UInt64 ulong_bitset;
|
||||
readBinary(ulong_bitset, buf);
|
||||
node->events_bitset = ulong_bitset;
|
||||
readBinary(node->can_be_base, buf);
|
||||
|
||||
return node;
|
||||
}
|
||||
};
|
||||
|
||||
/// It stores String, timestamp, bitset of matched events.
|
||||
template <size_t MaxEventsSize>
|
||||
struct NodeString : public NodeBase<NodeString<MaxEventsSize>, MaxEventsSize>
|
||||
{
|
||||
using Node = NodeString<MaxEventsSize>;
|
||||
|
||||
static Node * allocate(const IColumn & column, size_t row_num, Arena * arena)
|
||||
{
|
||||
StringRef string = assert_cast<const ColumnString &>(column).getDataAt(row_num);
|
||||
|
||||
Node * node = reinterpret_cast<Node *>(arena->alignedAlloc(sizeof(Node) + string.size, alignof(Node)));
|
||||
node->size = string.size;
|
||||
memcpy(node->data(), string.data, string.size);
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
void insertInto(IColumn & column)
|
||||
{
|
||||
assert_cast<ColumnString &>(column).insertData(this->data(), this->size);
|
||||
}
|
||||
|
||||
bool compare(const Node * rhs) const
|
||||
{
|
||||
auto cmp = strncmp(this->data(), rhs->data(), std::min(this->size, rhs->size));
|
||||
return (cmp == 0) ? this->size < rhs->size : cmp < 0;
|
||||
}
|
||||
};
|
||||
|
||||
/// TODO : Support other types than string
|
||||
template <typename Node>
|
||||
struct SequenceNextNodeGeneralData
|
||||
{
|
||||
using Allocator = MixedAlignedArenaAllocator<alignof(Node *), 4096>;
|
||||
using Array = PODArray<Node *, 32, Allocator>;
|
||||
|
||||
Array value;
|
||||
bool sorted = false;
|
||||
|
||||
struct Comparator final
|
||||
{
|
||||
bool operator()(const Node * lhs, const Node * rhs) const
|
||||
{
|
||||
return lhs->event_time == rhs->event_time ? lhs->compare(rhs) : lhs->event_time < rhs->event_time;
|
||||
}
|
||||
};
|
||||
|
||||
void sort()
|
||||
{
|
||||
if (!sorted)
|
||||
{
|
||||
std::stable_sort(std::begin(value), std::end(value), Comparator{});
|
||||
sorted = true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Implementation of sequenceFirstNode
|
||||
template <typename T, typename Node>
|
||||
class SequenceNextNodeImpl final
|
||||
: public IAggregateFunctionDataHelper<SequenceNextNodeGeneralData<Node>, SequenceNextNodeImpl<T, Node>>
|
||||
{
|
||||
using Self = SequenceNextNodeImpl<T, Node>;
|
||||
|
||||
using Data = SequenceNextNodeGeneralData<Node>;
|
||||
static Data & data(AggregateDataPtr __restrict place) { return *reinterpret_cast<Data *>(place); }
|
||||
static const Data & data(ConstAggregateDataPtr __restrict place) { return *reinterpret_cast<const Data *>(place); }
|
||||
|
||||
static constexpr size_t base_cond_column_idx = 2;
|
||||
static constexpr size_t event_column_idx = 1;
|
||||
|
||||
SequenceBase seq_base_kind;
|
||||
SequenceDirection seq_direction;
|
||||
const size_t min_required_args;
|
||||
|
||||
DataTypePtr & data_type;
|
||||
UInt8 events_size;
|
||||
UInt64 max_elems;
|
||||
public:
|
||||
SequenceNextNodeImpl(
|
||||
const DataTypePtr & data_type_,
|
||||
const DataTypes & arguments,
|
||||
const Array & parameters_,
|
||||
SequenceBase seq_base_kind_,
|
||||
SequenceDirection seq_direction_,
|
||||
size_t min_required_args_,
|
||||
UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
|
||||
: IAggregateFunctionDataHelper<SequenceNextNodeGeneralData<Node>, Self>(arguments, parameters_, data_type_)
|
||||
, seq_base_kind(seq_base_kind_)
|
||||
, seq_direction(seq_direction_)
|
||||
, min_required_args(min_required_args_)
|
||||
, data_type(this->argument_types[0])
|
||||
, events_size(arguments.size() - min_required_args)
|
||||
, max_elems(max_elems_)
|
||||
{
|
||||
}
|
||||
|
||||
String getName() const override { return "sequenceNextNode"; }
|
||||
|
||||
bool haveSameStateRepresentationImpl(const IAggregateFunction & rhs) const override
|
||||
{
|
||||
return this->getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs);
|
||||
}
|
||||
|
||||
void insert(Data & a, const Node * v, Arena * arena) const
|
||||
{
|
||||
++a.total_values;
|
||||
a.value.push_back(v->clone(arena), arena);
|
||||
}
|
||||
|
||||
void create(AggregateDataPtr __restrict place) const override /// NOLINT
|
||||
{
|
||||
new (place) Data;
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
Node * node = Node::allocate(*columns[event_column_idx], row_num, arena);
|
||||
|
||||
const auto timestamp = assert_cast<const ColumnVector<T> *>(columns[0])->getData()[row_num];
|
||||
|
||||
/// The events_bitset variable stores matched events in the form of bitset.
|
||||
/// Each Nth-bit indicates that the Nth-event are matched.
|
||||
/// For example, event1 and event3 is matched then the values of events_bitset is 0x00000005.
|
||||
/// 0x00000000
|
||||
/// + 1 (bit of event1)
|
||||
/// + 4 (bit of event3)
|
||||
node->events_bitset.reset();
|
||||
for (UInt8 i = 0; i < events_size; ++i)
|
||||
if (assert_cast<const ColumnVector<UInt8> *>(columns[min_required_args + i])->getData()[row_num])
|
||||
node->events_bitset.set(i);
|
||||
node->event_time = static_cast<DataTypeDateTime::FieldType>(timestamp);
|
||||
|
||||
node->can_be_base = assert_cast<const ColumnVector<UInt8> *>(columns[base_cond_column_idx])->getData()[row_num];
|
||||
|
||||
data(place).value.push_back(node, arena);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
if (data(rhs).value.empty())
|
||||
return;
|
||||
|
||||
if (data(place).value.size() >= max_elems)
|
||||
return;
|
||||
|
||||
auto & a = data(place).value;
|
||||
auto & b = data(rhs).value;
|
||||
const auto a_size = a.size();
|
||||
|
||||
const UInt64 new_elems = std::min(data(rhs).value.size(), static_cast<size_t>(max_elems) - data(place).value.size());
|
||||
for (UInt64 i = 0; i < new_elems; ++i)
|
||||
a.push_back(b[i]->clone(arena), arena);
|
||||
|
||||
/// Either sort whole container or do so partially merging ranges afterwards
|
||||
using Comparator = typename SequenceNextNodeGeneralData<Node>::Comparator;
|
||||
|
||||
if (!data(place).sorted && !data(rhs).sorted)
|
||||
std::stable_sort(std::begin(a), std::end(a), Comparator{});
|
||||
else
|
||||
{
|
||||
const auto begin = std::begin(a);
|
||||
const auto middle = std::next(begin, a_size);
|
||||
const auto end = std::end(a);
|
||||
|
||||
if (!data(place).sorted)
|
||||
std::stable_sort(begin, middle, Comparator{});
|
||||
|
||||
if (!data(rhs).sorted)
|
||||
std::stable_sort(middle, end, Comparator{});
|
||||
|
||||
std::inplace_merge(begin, middle, end, Comparator{});
|
||||
}
|
||||
|
||||
data(place).sorted = true;
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
/// Temporarily do a const_cast to sort the values. It helps to reduce the computational burden on the initiator node.
|
||||
this->data(const_cast<AggregateDataPtr>(place)).sort();
|
||||
|
||||
writeBinary(data(place).sorted, buf);
|
||||
|
||||
auto & value = data(place).value;
|
||||
|
||||
size_t size = std::min(static_cast<size_t>(events_size + 1), value.size());
|
||||
switch (seq_base_kind)
|
||||
{
|
||||
case SequenceBase::Head:
|
||||
writeVarUInt(size, buf);
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
value[i]->write(buf);
|
||||
break;
|
||||
|
||||
case SequenceBase::Tail:
|
||||
writeVarUInt(size, buf);
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
value[value.size() - size + i]->write(buf);
|
||||
break;
|
||||
|
||||
case SequenceBase::FirstMatch:
|
||||
case SequenceBase::LastMatch:
|
||||
writeVarUInt(value.size(), buf);
|
||||
for (auto & node : value)
|
||||
node->write(buf);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
readBinary(data(place).sorted, buf);
|
||||
|
||||
UInt64 size;
|
||||
readVarUInt(size, buf);
|
||||
|
||||
if (unlikely(size == 0))
|
||||
return;
|
||||
|
||||
if (unlikely(size > max_node_size_deserialize))
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
|
||||
"Too large array size (maximum: {})", max_node_size_deserialize);
|
||||
|
||||
auto & value = data(place).value;
|
||||
|
||||
value.resize(size, arena);
|
||||
for (UInt64 i = 0; i < size; ++i)
|
||||
value[i] = Node::read(buf, arena);
|
||||
}
|
||||
|
||||
inline std::optional<size_t> getBaseIndex(Data & data) const
|
||||
{
|
||||
if (data.value.size() == 0)
|
||||
return {};
|
||||
|
||||
switch (seq_base_kind)
|
||||
{
|
||||
case SequenceBase::Head:
|
||||
if (data.value[0]->can_be_base)
|
||||
return 0;
|
||||
break;
|
||||
|
||||
case SequenceBase::Tail:
|
||||
if (data.value[data.value.size() - 1]->can_be_base)
|
||||
return data.value.size() - 1;
|
||||
break;
|
||||
|
||||
case SequenceBase::FirstMatch:
|
||||
for (size_t i = 0; i < data.value.size(); ++i)
|
||||
{
|
||||
if (data.value[i]->events_bitset.test(0) && data.value[i]->can_be_base)
|
||||
return i;
|
||||
}
|
||||
break;
|
||||
|
||||
case SequenceBase::LastMatch:
|
||||
for (size_t i = 0; i < data.value.size(); ++i)
|
||||
{
|
||||
auto reversed_i = data.value.size() - i - 1;
|
||||
if (data.value[reversed_i]->events_bitset.test(0) && data.value[reversed_i]->can_be_base)
|
||||
return reversed_i;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
/// This method returns an index of next node that matched the events.
|
||||
/// matched events in the chain of events are represented as a bitmask.
|
||||
/// The first matched event is 0x00000001, the second one is 0x00000002, the third one is 0x00000004, and so on.
|
||||
UInt32 getNextNodeIndex(Data & data) const
|
||||
{
|
||||
const UInt32 unmatched_idx = static_cast<UInt32>(data.value.size());
|
||||
|
||||
if (data.value.size() <= events_size)
|
||||
return unmatched_idx;
|
||||
|
||||
data.sort();
|
||||
|
||||
std::optional<size_t> base_opt = getBaseIndex(data);
|
||||
if (!base_opt.has_value())
|
||||
return unmatched_idx;
|
||||
UInt32 base = static_cast<UInt32>(base_opt.value());
|
||||
|
||||
if (events_size == 0)
|
||||
return data.value.size() > 0 ? base : unmatched_idx;
|
||||
|
||||
UInt32 i = 0;
|
||||
switch (seq_direction)
|
||||
{
|
||||
case SequenceDirection::Forward:
|
||||
for (i = 0; i < events_size && base + i < data.value.size(); ++i)
|
||||
if (!data.value[base + i]->events_bitset.test(i))
|
||||
break;
|
||||
return (i == events_size) ? base + i : unmatched_idx;
|
||||
|
||||
case SequenceDirection::Backward:
|
||||
for (i = 0; i < events_size && i < base; ++i)
|
||||
if (!data.value[base - i]->events_bitset.test(i))
|
||||
break;
|
||||
return (i == events_size) ? base - i : unmatched_idx;
|
||||
}
|
||||
UNREACHABLE();
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
auto & value = data(place).value;
|
||||
|
||||
UInt32 event_idx = getNextNodeIndex(this->data(place));
|
||||
if (event_idx < value.size())
|
||||
{
|
||||
ColumnNullable & to_concrete = assert_cast<ColumnNullable &>(to);
|
||||
value[event_idx]->insertInto(to_concrete.getNestedColumn());
|
||||
to_concrete.getNullMapData().push_back(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
to.insertDefault();
|
||||
}
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return true; }
|
||||
};
|
||||
|
||||
}
|
@ -1,10 +1,21 @@
|
||||
#include <AggregateFunctions/AggregateFunctionSimpleLinearRegression.h>
|
||||
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <limits>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
@ -15,6 +26,161 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
struct AggregateFunctionSimpleLinearRegressionData final
|
||||
{
|
||||
size_t count = 0;
|
||||
Float64 sum_x = 0;
|
||||
Float64 sum_y = 0;
|
||||
Float64 sum_xx = 0;
|
||||
Float64 sum_xy = 0;
|
||||
|
||||
void add(Float64 x, Float64 y)
|
||||
{
|
||||
count += 1;
|
||||
sum_x += x;
|
||||
sum_y += y;
|
||||
sum_xx += x * x;
|
||||
sum_xy += x * y;
|
||||
}
|
||||
|
||||
void merge(const AggregateFunctionSimpleLinearRegressionData & other)
|
||||
{
|
||||
count += other.count;
|
||||
sum_x += other.sum_x;
|
||||
sum_y += other.sum_y;
|
||||
sum_xx += other.sum_xx;
|
||||
sum_xy += other.sum_xy;
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
writeBinary(count, buf);
|
||||
writeBinary(sum_x, buf);
|
||||
writeBinary(sum_y, buf);
|
||||
writeBinary(sum_xx, buf);
|
||||
writeBinary(sum_xy, buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
readBinary(count, buf);
|
||||
readBinary(sum_x, buf);
|
||||
readBinary(sum_y, buf);
|
||||
readBinary(sum_xx, buf);
|
||||
readBinary(sum_xy, buf);
|
||||
}
|
||||
|
||||
Float64 getK() const
|
||||
{
|
||||
Float64 divisor = sum_xx * count - sum_x * sum_x;
|
||||
|
||||
if (divisor == 0)
|
||||
return std::numeric_limits<Float64>::quiet_NaN();
|
||||
|
||||
return (sum_xy * count - sum_x * sum_y) / divisor;
|
||||
}
|
||||
|
||||
Float64 getB(Float64 k) const
|
||||
{
|
||||
if (count == 0)
|
||||
return std::numeric_limits<Float64>::quiet_NaN();
|
||||
|
||||
return (sum_y - k * sum_x) / count;
|
||||
}
|
||||
};
|
||||
|
||||
/// Calculates simple linear regression parameters.
|
||||
/// Result is a tuple (k, b) for y = k * x + b equation, solved by least squares approximation.
|
||||
class AggregateFunctionSimpleLinearRegression final : public IAggregateFunctionDataHelper<
|
||||
AggregateFunctionSimpleLinearRegressionData,
|
||||
AggregateFunctionSimpleLinearRegression>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionSimpleLinearRegression(
|
||||
const DataTypes & arguments,
|
||||
const Array & params
|
||||
):
|
||||
IAggregateFunctionDataHelper<
|
||||
AggregateFunctionSimpleLinearRegressionData,
|
||||
AggregateFunctionSimpleLinearRegression
|
||||
> {arguments, params, createResultType()}
|
||||
{
|
||||
// notice: arguments has been checked before
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return "simpleLinearRegression";
|
||||
}
|
||||
|
||||
void add(
|
||||
AggregateDataPtr __restrict place,
|
||||
const IColumn ** columns,
|
||||
size_t row_num,
|
||||
Arena *
|
||||
) const override
|
||||
{
|
||||
Float64 x = columns[0]->getFloat64(row_num);
|
||||
Float64 y = columns[1]->getFloat64(row_num);
|
||||
|
||||
this->data(place).add(x, y);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
DataTypes types
|
||||
{
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
};
|
||||
|
||||
Strings names
|
||||
{
|
||||
"k",
|
||||
"b",
|
||||
};
|
||||
|
||||
return std::make_shared<DataTypeTuple>(
|
||||
std::move(types),
|
||||
std::move(names)
|
||||
);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void insertResultInto(
|
||||
AggregateDataPtr __restrict place,
|
||||
IColumn & to,
|
||||
Arena *) const override
|
||||
{
|
||||
Float64 k = this->data(place).getK();
|
||||
Float64 b = this->data(place).getB(k);
|
||||
|
||||
auto & col_tuple = assert_cast<ColumnTuple &>(to);
|
||||
auto & col_k = assert_cast<ColumnVector<Float64> &>(col_tuple.getColumn(0));
|
||||
auto & col_b = assert_cast<ColumnVector<Float64> &>(col_tuple.getColumn(1));
|
||||
|
||||
col_k.getData().push_back(k);
|
||||
col_b.getData().push_back(b);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionSimpleLinearRegression(
|
||||
const String & name,
|
||||
const DataTypes & arguments,
|
||||
@ -24,51 +190,12 @@ AggregateFunctionPtr createAggregateFunctionSimpleLinearRegression(
|
||||
assertNoParameters(name, params);
|
||||
assertBinary(name, arguments);
|
||||
|
||||
const IDataType * x_arg = arguments.front().get();
|
||||
WhichDataType which_x = x_arg;
|
||||
if (!isNumber(arguments[0]) || !isNumber(arguments[1]))
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Illegal types ({}, {}) of arguments of aggregate function {}, must "
|
||||
"be Native Ints, Native UInts or Floats", arguments[0]->getName(), arguments[1]->getName(), name);
|
||||
|
||||
const IDataType * y_arg = arguments.back().get();
|
||||
WhichDataType which_y = y_arg;
|
||||
|
||||
|
||||
#define FOR_LEASTSQR_TYPES_2(M, T) \
|
||||
M(T, UInt8) \
|
||||
M(T, UInt16) \
|
||||
M(T, UInt32) \
|
||||
M(T, UInt64) \
|
||||
M(T, Int8) \
|
||||
M(T, Int16) \
|
||||
M(T, Int32) \
|
||||
M(T, Int64) \
|
||||
M(T, Float32) \
|
||||
M(T, Float64)
|
||||
#define FOR_LEASTSQR_TYPES(M) \
|
||||
FOR_LEASTSQR_TYPES_2(M, UInt8) \
|
||||
FOR_LEASTSQR_TYPES_2(M, UInt16) \
|
||||
FOR_LEASTSQR_TYPES_2(M, UInt32) \
|
||||
FOR_LEASTSQR_TYPES_2(M, UInt64) \
|
||||
FOR_LEASTSQR_TYPES_2(M, Int8) \
|
||||
FOR_LEASTSQR_TYPES_2(M, Int16) \
|
||||
FOR_LEASTSQR_TYPES_2(M, Int32) \
|
||||
FOR_LEASTSQR_TYPES_2(M, Int64) \
|
||||
FOR_LEASTSQR_TYPES_2(M, Float32) \
|
||||
FOR_LEASTSQR_TYPES_2(M, Float64)
|
||||
#define DISPATCH(T1, T2) \
|
||||
if (which_x.idx == TypeIndex::T1 && which_y.idx == TypeIndex::T2) \
|
||||
return std::make_shared<AggregateFunctionSimpleLinearRegression<T1, T2>>(/* NOLINT */ \
|
||||
arguments, \
|
||||
params \
|
||||
);
|
||||
|
||||
FOR_LEASTSQR_TYPES(DISPATCH)
|
||||
|
||||
#undef FOR_LEASTSQR_TYPES_2
|
||||
#undef FOR_LEASTSQR_TYPES
|
||||
#undef DISPATCH
|
||||
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Illegal types ({}, {}) of arguments of aggregate function {}, must "
|
||||
"be Native Ints, Native UInts or Floats", x_arg->getName(), y_arg->getName(), name);
|
||||
return std::make_shared<AggregateFunctionSimpleLinearRegression>(arguments, params);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,182 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <limits>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct AggregateFunctionSimpleLinearRegressionData final
|
||||
{
|
||||
size_t count = 0;
|
||||
T sum_x = 0;
|
||||
T sum_y = 0;
|
||||
T sum_xx = 0;
|
||||
T sum_xy = 0;
|
||||
|
||||
void add(T x, T y)
|
||||
{
|
||||
count += 1;
|
||||
sum_x += x;
|
||||
sum_y += y;
|
||||
sum_xx += x * x;
|
||||
sum_xy += x * y;
|
||||
}
|
||||
|
||||
void merge(const AggregateFunctionSimpleLinearRegressionData & other)
|
||||
{
|
||||
count += other.count;
|
||||
sum_x += other.sum_x;
|
||||
sum_y += other.sum_y;
|
||||
sum_xx += other.sum_xx;
|
||||
sum_xy += other.sum_xy;
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
writeBinary(count, buf);
|
||||
writeBinary(sum_x, buf);
|
||||
writeBinary(sum_y, buf);
|
||||
writeBinary(sum_xx, buf);
|
||||
writeBinary(sum_xy, buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
readBinary(count, buf);
|
||||
readBinary(sum_x, buf);
|
||||
readBinary(sum_y, buf);
|
||||
readBinary(sum_xx, buf);
|
||||
readBinary(sum_xy, buf);
|
||||
}
|
||||
|
||||
T getK() const
|
||||
{
|
||||
T divisor = sum_xx * count - sum_x * sum_x;
|
||||
|
||||
if (divisor == 0)
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
|
||||
return (sum_xy * count - sum_x * sum_y) / divisor;
|
||||
}
|
||||
|
||||
T getB(T k) const
|
||||
{
|
||||
if (count == 0)
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
|
||||
return (sum_y - k * sum_x) / count;
|
||||
}
|
||||
};
|
||||
|
||||
/// Calculates simple linear regression parameters.
|
||||
/// Result is a tuple (k, b) for y = k * x + b equation, solved by least squares approximation.
|
||||
template <typename X, typename Y, typename Ret = Float64>
|
||||
class AggregateFunctionSimpleLinearRegression final : public IAggregateFunctionDataHelper<
|
||||
AggregateFunctionSimpleLinearRegressionData<Ret>,
|
||||
AggregateFunctionSimpleLinearRegression<X, Y, Ret>
|
||||
>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionSimpleLinearRegression(
|
||||
const DataTypes & arguments,
|
||||
const Array & params
|
||||
):
|
||||
IAggregateFunctionDataHelper<
|
||||
AggregateFunctionSimpleLinearRegressionData<Ret>,
|
||||
AggregateFunctionSimpleLinearRegression<X, Y, Ret>
|
||||
> {arguments, params, createResultType()}
|
||||
{
|
||||
// notice: arguments has been checked before
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return "simpleLinearRegression";
|
||||
}
|
||||
|
||||
void add(
|
||||
AggregateDataPtr __restrict place,
|
||||
const IColumn ** columns,
|
||||
size_t row_num,
|
||||
Arena *
|
||||
) const override
|
||||
{
|
||||
auto col_x = assert_cast<const ColumnVector<X> *>(columns[0]);
|
||||
auto col_y = assert_cast<const ColumnVector<Y> *>(columns[1]);
|
||||
|
||||
X x = col_x->getData()[row_num];
|
||||
Y y = col_y->getData()[row_num];
|
||||
|
||||
this->data(place).add(x, y);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
DataTypes types
|
||||
{
|
||||
std::make_shared<DataTypeNumber<Ret>>(),
|
||||
std::make_shared<DataTypeNumber<Ret>>(),
|
||||
};
|
||||
|
||||
Strings names
|
||||
{
|
||||
"k",
|
||||
"b",
|
||||
};
|
||||
|
||||
return std::make_shared<DataTypeTuple>(
|
||||
std::move(types),
|
||||
std::move(names)
|
||||
);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void insertResultInto(
|
||||
AggregateDataPtr __restrict place,
|
||||
IColumn & to,
|
||||
Arena *) const override
|
||||
{
|
||||
Ret k = this->data(place).getK();
|
||||
Ret b = this->data(place).getB(k);
|
||||
|
||||
auto & col_tuple = assert_cast<ColumnTuple &>(to);
|
||||
auto & col_k = assert_cast<ColumnVector<Ret> &>(col_tuple.getColumn(0));
|
||||
auto & col_b = assert_cast<ColumnVector<Ret> &>(col_tuple.getColumn(1));
|
||||
|
||||
col_k.getData().push_back(k);
|
||||
col_b.getData().push_back(b);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -1,8 +1,18 @@
|
||||
#include <AggregateFunctions/AggregateFunctionSparkbar.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
|
||||
#include <array>
|
||||
#include <string_view>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Common/PODArray.h>
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <Common/HashTable/HashMap.h>
|
||||
#include <Columns/IColumn.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -13,11 +23,309 @@ namespace ErrorCodes
|
||||
{
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
template<typename X, typename Y>
|
||||
struct AggregateFunctionSparkbarData
|
||||
{
|
||||
/// TODO: calculate histogram instead of storing all points
|
||||
using Points = HashMap<X, Y>;
|
||||
Points points;
|
||||
|
||||
X min_x = std::numeric_limits<X>::max();
|
||||
X max_x = std::numeric_limits<X>::lowest();
|
||||
|
||||
Y min_y = std::numeric_limits<Y>::max();
|
||||
Y max_y = std::numeric_limits<Y>::lowest();
|
||||
|
||||
Y insert(const X & x, const Y & y)
|
||||
{
|
||||
if (isNaN(y) || y <= 0)
|
||||
return 0;
|
||||
|
||||
auto [it, inserted] = points.insert({x, y});
|
||||
if (!inserted)
|
||||
{
|
||||
if constexpr (std::is_floating_point_v<Y>)
|
||||
{
|
||||
it->getMapped() += y;
|
||||
return it->getMapped();
|
||||
}
|
||||
else
|
||||
{
|
||||
Y res;
|
||||
bool has_overfllow = common::addOverflow(it->getMapped(), y, res);
|
||||
it->getMapped() = has_overfllow ? std::numeric_limits<Y>::max() : res;
|
||||
}
|
||||
}
|
||||
return it->getMapped();
|
||||
}
|
||||
|
||||
void add(X x, Y y)
|
||||
{
|
||||
auto new_y = insert(x, y);
|
||||
|
||||
min_x = std::min(x, min_x);
|
||||
max_x = std::max(x, max_x);
|
||||
|
||||
min_y = std::min(y, min_y);
|
||||
max_y = std::max(new_y, max_y);
|
||||
}
|
||||
|
||||
void merge(const AggregateFunctionSparkbarData & other)
|
||||
{
|
||||
if (other.points.empty())
|
||||
return;
|
||||
|
||||
for (auto & point : other.points)
|
||||
{
|
||||
auto new_y = insert(point.getKey(), point.getMapped());
|
||||
max_y = std::max(new_y, max_y);
|
||||
}
|
||||
|
||||
min_x = std::min(other.min_x, min_x);
|
||||
max_x = std::max(other.max_x, max_x);
|
||||
|
||||
min_y = std::min(other.min_y, min_y);
|
||||
max_y = std::max(other.max_y, max_y);
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
writeBinary(min_x, buf);
|
||||
writeBinary(max_x, buf);
|
||||
writeBinary(min_y, buf);
|
||||
writeBinary(max_y, buf);
|
||||
writeVarUInt(points.size(), buf);
|
||||
|
||||
for (const auto & elem : points)
|
||||
{
|
||||
writeBinary(elem.getKey(), buf);
|
||||
writeBinary(elem.getMapped(), buf);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
readBinary(min_x, buf);
|
||||
readBinary(max_x, buf);
|
||||
readBinary(min_y, buf);
|
||||
readBinary(max_y, buf);
|
||||
size_t size;
|
||||
readVarUInt(size, buf);
|
||||
|
||||
X x;
|
||||
Y y;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
readBinary(x, buf);
|
||||
readBinary(y, buf);
|
||||
insert(x, y);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<typename X, typename Y>
|
||||
class AggregateFunctionSparkbar final
|
||||
: public IAggregateFunctionDataHelper<AggregateFunctionSparkbarData<X, Y>, AggregateFunctionSparkbar<X, Y>>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr size_t BAR_LEVELS = 8;
|
||||
const size_t width = 0;
|
||||
|
||||
/// Range for x specified in parameters.
|
||||
const bool is_specified_range_x = false;
|
||||
const X begin_x = std::numeric_limits<X>::min();
|
||||
const X end_x = std::numeric_limits<X>::max();
|
||||
|
||||
size_t updateFrame(ColumnString::Chars & frame, Y value) const
|
||||
{
|
||||
static constexpr std::array<std::string_view, BAR_LEVELS + 1> bars{" ", "▁", "▂", "▃", "▄", "▅", "▆", "▇", "█"};
|
||||
const auto & bar = (isNaN(value) || value < 1 || static_cast<Y>(BAR_LEVELS) < value) ? bars[0] : bars[static_cast<UInt8>(value)];
|
||||
frame.insert(bar.begin(), bar.end());
|
||||
return bar.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* The minimum value of y is rendered as the lowest height "▁",
|
||||
* the maximum value of y is rendered as the highest height "█", and the middle value will be rendered proportionally.
|
||||
* If a bucket has no y value, it will be rendered as " ".
|
||||
*/
|
||||
void render(ColumnString & to_column, const AggregateFunctionSparkbarData<X, Y> & data) const
|
||||
{
|
||||
auto & values = to_column.getChars();
|
||||
auto & offsets = to_column.getOffsets();
|
||||
|
||||
if (data.points.empty())
|
||||
{
|
||||
values.push_back('\0');
|
||||
offsets.push_back(offsets.empty() ? 1 : offsets.back() + 1);
|
||||
return;
|
||||
}
|
||||
|
||||
auto from_x = is_specified_range_x ? begin_x : data.min_x;
|
||||
auto to_x = is_specified_range_x ? end_x : data.max_x;
|
||||
|
||||
if (from_x >= to_x)
|
||||
{
|
||||
size_t sz = updateFrame(values, 8);
|
||||
values.push_back('\0');
|
||||
offsets.push_back(offsets.empty() ? sz + 1 : offsets.back() + sz + 1);
|
||||
return;
|
||||
}
|
||||
|
||||
PaddedPODArray<Y> histogram(width, 0);
|
||||
PaddedPODArray<UInt64> count_histogram(width, 0); /// The number of points in each bucket
|
||||
|
||||
for (const auto & point : data.points)
|
||||
{
|
||||
if (point.getKey() < from_x || to_x < point.getKey())
|
||||
continue;
|
||||
|
||||
X delta = to_x - from_x;
|
||||
if (delta < std::numeric_limits<X>::max())
|
||||
delta = delta + 1;
|
||||
|
||||
X value = point.getKey() - from_x;
|
||||
Float64 w = histogram.size();
|
||||
size_t index = std::min<size_t>(static_cast<size_t>(w / delta * value), histogram.size() - 1);
|
||||
|
||||
Y res;
|
||||
bool has_overfllow = false;
|
||||
if constexpr (std::is_floating_point_v<Y>)
|
||||
res = histogram[index] + point.getMapped();
|
||||
else
|
||||
has_overfllow = common::addOverflow(histogram[index], point.getMapped(), res);
|
||||
|
||||
if (unlikely(has_overfllow))
|
||||
{
|
||||
/// In case of overflow, just saturate
|
||||
/// Do not count new values, because we do not know how many of them were added
|
||||
histogram[index] = std::numeric_limits<Y>::max();
|
||||
}
|
||||
else
|
||||
{
|
||||
histogram[index] = res;
|
||||
count_histogram[index] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < histogram.size(); ++i)
|
||||
{
|
||||
if (count_histogram[i] > 0)
|
||||
histogram[i] /= count_histogram[i];
|
||||
}
|
||||
|
||||
Y y_max = 0;
|
||||
for (auto & y : histogram)
|
||||
{
|
||||
if (isNaN(y) || y <= 0)
|
||||
continue;
|
||||
y_max = std::max(y_max, y);
|
||||
}
|
||||
|
||||
if (y_max == 0)
|
||||
{
|
||||
values.push_back('\0');
|
||||
offsets.push_back(offsets.empty() ? 1 : offsets.back() + 1);
|
||||
return;
|
||||
}
|
||||
|
||||
/// Scale the histogram to the range [0, BAR_LEVELS]
|
||||
for (auto & y : histogram)
|
||||
{
|
||||
if (isNaN(y) || y <= 0)
|
||||
{
|
||||
y = 0;
|
||||
continue;
|
||||
}
|
||||
|
||||
constexpr auto levels_num = static_cast<Y>(BAR_LEVELS - 1);
|
||||
if constexpr (std::is_floating_point_v<Y>)
|
||||
{
|
||||
y = y / (y_max / levels_num) + 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
Y scaled;
|
||||
bool has_overfllow = common::mulOverflow<Y>(y, levels_num, scaled);
|
||||
|
||||
if (has_overfllow)
|
||||
y = y / (y_max / levels_num) + 1;
|
||||
else
|
||||
y = scaled / y_max + 1;
|
||||
}
|
||||
}
|
||||
|
||||
size_t sz = 0;
|
||||
for (const auto & y : histogram)
|
||||
sz += updateFrame(values, y);
|
||||
|
||||
values.push_back('\0');
|
||||
offsets.push_back(offsets.empty() ? sz + 1 : offsets.back() + sz + 1);
|
||||
}
|
||||
|
||||
public:
|
||||
AggregateFunctionSparkbar(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionSparkbarData<X, Y>, AggregateFunctionSparkbar>(arguments, params, std::make_shared<DataTypeString>())
|
||||
, width(params.empty() ? 0 : params.at(0).safeGet<UInt64>())
|
||||
, is_specified_range_x(params.size() >= 3)
|
||||
, begin_x(is_specified_range_x ? static_cast<X>(params.at(1).safeGet<X>()) : std::numeric_limits<X>::min())
|
||||
, end_x(is_specified_range_x ? static_cast<X>(params.at(2).safeGet<X>()) : std::numeric_limits<X>::max())
|
||||
{
|
||||
if (width < 2 || 1024 < width)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter width must be in range [2, 1024]");
|
||||
|
||||
if (begin_x >= end_x)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter `min_x` must be less than `max_x`");
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return "sparkbar";
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * /*arena*/) const override
|
||||
{
|
||||
X x = assert_cast<const ColumnVector<X> *>(columns[0])->getData()[row_num];
|
||||
if (begin_x <= x && x <= end_x)
|
||||
{
|
||||
Y y = assert_cast<const ColumnVector<Y> *>(columns[1])->getData()[row_num];
|
||||
this->data(place).add(x, y);
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr __restrict rhs, Arena * /*arena*/) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * /*arena*/) const override
|
||||
{
|
||||
auto & to_column = assert_cast<ColumnString &>(to);
|
||||
const auto & data = this->data(place);
|
||||
render(to_column, data);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <template <typename, typename> class AggregateFunctionTemplate, typename Data, typename ... TArgs>
|
||||
IAggregateFunction * createWithUIntegerOrTimeType(const std::string & name, const IDataType & argument_type, TArgs && ... args)
|
||||
{
|
||||
|
@ -1,323 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <base/arithmeticOverflow.h>
|
||||
|
||||
#include <array>
|
||||
#include <string_view>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <base/range.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Common/PODArray.h>
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <Common/HashTable/HashMap.h>
|
||||
#include <Columns/IColumn.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
template<typename X, typename Y>
|
||||
struct AggregateFunctionSparkbarData
|
||||
{
|
||||
/// TODO: calculate histogram instead of storing all points
|
||||
using Points = HashMap<X, Y>;
|
||||
Points points;
|
||||
|
||||
X min_x = std::numeric_limits<X>::max();
|
||||
X max_x = std::numeric_limits<X>::lowest();
|
||||
|
||||
Y min_y = std::numeric_limits<Y>::max();
|
||||
Y max_y = std::numeric_limits<Y>::lowest();
|
||||
|
||||
Y insert(const X & x, const Y & y)
|
||||
{
|
||||
if (isNaN(y) || y <= 0)
|
||||
return 0;
|
||||
|
||||
auto [it, inserted] = points.insert({x, y});
|
||||
if (!inserted)
|
||||
{
|
||||
if constexpr (std::is_floating_point_v<Y>)
|
||||
{
|
||||
it->getMapped() += y;
|
||||
return it->getMapped();
|
||||
}
|
||||
else
|
||||
{
|
||||
Y res;
|
||||
bool has_overfllow = common::addOverflow(it->getMapped(), y, res);
|
||||
it->getMapped() = has_overfllow ? std::numeric_limits<Y>::max() : res;
|
||||
}
|
||||
}
|
||||
return it->getMapped();
|
||||
}
|
||||
|
||||
void add(X x, Y y)
|
||||
{
|
||||
auto new_y = insert(x, y);
|
||||
|
||||
min_x = std::min(x, min_x);
|
||||
max_x = std::max(x, max_x);
|
||||
|
||||
min_y = std::min(y, min_y);
|
||||
max_y = std::max(new_y, max_y);
|
||||
}
|
||||
|
||||
void merge(const AggregateFunctionSparkbarData & other)
|
||||
{
|
||||
if (other.points.empty())
|
||||
return;
|
||||
|
||||
for (auto & point : other.points)
|
||||
{
|
||||
auto new_y = insert(point.getKey(), point.getMapped());
|
||||
max_y = std::max(new_y, max_y);
|
||||
}
|
||||
|
||||
min_x = std::min(other.min_x, min_x);
|
||||
max_x = std::max(other.max_x, max_x);
|
||||
|
||||
min_y = std::min(other.min_y, min_y);
|
||||
max_y = std::max(other.max_y, max_y);
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
writeBinary(min_x, buf);
|
||||
writeBinary(max_x, buf);
|
||||
writeBinary(min_y, buf);
|
||||
writeBinary(max_y, buf);
|
||||
writeVarUInt(points.size(), buf);
|
||||
|
||||
for (const auto & elem : points)
|
||||
{
|
||||
writeBinary(elem.getKey(), buf);
|
||||
writeBinary(elem.getMapped(), buf);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
readBinary(min_x, buf);
|
||||
readBinary(max_x, buf);
|
||||
readBinary(min_y, buf);
|
||||
readBinary(max_y, buf);
|
||||
size_t size;
|
||||
readVarUInt(size, buf);
|
||||
|
||||
X x;
|
||||
Y y;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
readBinary(x, buf);
|
||||
readBinary(y, buf);
|
||||
insert(x, y);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<typename X, typename Y>
|
||||
class AggregateFunctionSparkbar final
|
||||
: public IAggregateFunctionDataHelper<AggregateFunctionSparkbarData<X, Y>, AggregateFunctionSparkbar<X, Y>>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr size_t BAR_LEVELS = 8;
|
||||
const size_t width = 0;
|
||||
|
||||
/// Range for x specified in parameters.
|
||||
const bool is_specified_range_x = false;
|
||||
const X begin_x = std::numeric_limits<X>::min();
|
||||
const X end_x = std::numeric_limits<X>::max();
|
||||
|
||||
size_t updateFrame(ColumnString::Chars & frame, Y value) const
|
||||
{
|
||||
static constexpr std::array<std::string_view, BAR_LEVELS + 1> bars{" ", "▁", "▂", "▃", "▄", "▅", "▆", "▇", "█"};
|
||||
const auto & bar = (isNaN(value) || value < 1 || static_cast<Y>(BAR_LEVELS) < value) ? bars[0] : bars[static_cast<UInt8>(value)];
|
||||
frame.insert(bar.begin(), bar.end());
|
||||
return bar.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* The minimum value of y is rendered as the lowest height "▁",
|
||||
* the maximum value of y is rendered as the highest height "█", and the middle value will be rendered proportionally.
|
||||
* If a bucket has no y value, it will be rendered as " ".
|
||||
*/
|
||||
void render(ColumnString & to_column, const AggregateFunctionSparkbarData<X, Y> & data) const
|
||||
{
|
||||
auto & values = to_column.getChars();
|
||||
auto & offsets = to_column.getOffsets();
|
||||
|
||||
if (data.points.empty())
|
||||
{
|
||||
values.push_back('\0');
|
||||
offsets.push_back(offsets.empty() ? 1 : offsets.back() + 1);
|
||||
return;
|
||||
}
|
||||
|
||||
auto from_x = is_specified_range_x ? begin_x : data.min_x;
|
||||
auto to_x = is_specified_range_x ? end_x : data.max_x;
|
||||
|
||||
if (from_x >= to_x)
|
||||
{
|
||||
size_t sz = updateFrame(values, 8);
|
||||
values.push_back('\0');
|
||||
offsets.push_back(offsets.empty() ? sz + 1 : offsets.back() + sz + 1);
|
||||
return;
|
||||
}
|
||||
|
||||
PaddedPODArray<Y> histogram(width, 0);
|
||||
PaddedPODArray<UInt64> count_histogram(width, 0); /// The number of points in each bucket
|
||||
|
||||
for (const auto & point : data.points)
|
||||
{
|
||||
if (point.getKey() < from_x || to_x < point.getKey())
|
||||
continue;
|
||||
|
||||
X delta = to_x - from_x;
|
||||
if (delta < std::numeric_limits<X>::max())
|
||||
delta = delta + 1;
|
||||
|
||||
X value = point.getKey() - from_x;
|
||||
Float64 w = histogram.size();
|
||||
size_t index = std::min<size_t>(static_cast<size_t>(w / delta * value), histogram.size() - 1);
|
||||
|
||||
Y res;
|
||||
bool has_overfllow = false;
|
||||
if constexpr (std::is_floating_point_v<Y>)
|
||||
res = histogram[index] + point.getMapped();
|
||||
else
|
||||
has_overfllow = common::addOverflow(histogram[index], point.getMapped(), res);
|
||||
|
||||
if (unlikely(has_overfllow))
|
||||
{
|
||||
/// In case of overflow, just saturate
|
||||
/// Do not count new values, because we do not know how many of them were added
|
||||
histogram[index] = std::numeric_limits<Y>::max();
|
||||
}
|
||||
else
|
||||
{
|
||||
histogram[index] = res;
|
||||
count_histogram[index] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < histogram.size(); ++i)
|
||||
{
|
||||
if (count_histogram[i] > 0)
|
||||
histogram[i] /= count_histogram[i];
|
||||
}
|
||||
|
||||
Y y_max = 0;
|
||||
for (auto & y : histogram)
|
||||
{
|
||||
if (isNaN(y) || y <= 0)
|
||||
continue;
|
||||
y_max = std::max(y_max, y);
|
||||
}
|
||||
|
||||
if (y_max == 0)
|
||||
{
|
||||
values.push_back('\0');
|
||||
offsets.push_back(offsets.empty() ? 1 : offsets.back() + 1);
|
||||
return;
|
||||
}
|
||||
|
||||
/// Scale the histogram to the range [0, BAR_LEVELS]
|
||||
for (auto & y : histogram)
|
||||
{
|
||||
if (isNaN(y) || y <= 0)
|
||||
{
|
||||
y = 0;
|
||||
continue;
|
||||
}
|
||||
|
||||
constexpr auto levels_num = static_cast<Y>(BAR_LEVELS - 1);
|
||||
if constexpr (std::is_floating_point_v<Y>)
|
||||
{
|
||||
y = y / (y_max / levels_num) + 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
Y scaled;
|
||||
bool has_overfllow = common::mulOverflow<Y>(y, levels_num, scaled);
|
||||
|
||||
if (has_overfllow)
|
||||
y = y / (y_max / levels_num) + 1;
|
||||
else
|
||||
y = scaled / y_max + 1;
|
||||
}
|
||||
}
|
||||
|
||||
size_t sz = 0;
|
||||
for (const auto & y : histogram)
|
||||
sz += updateFrame(values, y);
|
||||
|
||||
values.push_back('\0');
|
||||
offsets.push_back(offsets.empty() ? sz + 1 : offsets.back() + sz + 1);
|
||||
}
|
||||
|
||||
public:
|
||||
AggregateFunctionSparkbar(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionSparkbarData<X, Y>, AggregateFunctionSparkbar>(arguments, params, std::make_shared<DataTypeString>())
|
||||
, width(params.empty() ? 0 : params.at(0).safeGet<UInt64>())
|
||||
, is_specified_range_x(params.size() >= 3)
|
||||
, begin_x(is_specified_range_x ? static_cast<X>(params.at(1).safeGet<X>()) : std::numeric_limits<X>::min())
|
||||
, end_x(is_specified_range_x ? static_cast<X>(params.at(2).safeGet<X>()) : std::numeric_limits<X>::max())
|
||||
{
|
||||
if (width < 2 || 1024 < width)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter width must be in range [2, 1024]");
|
||||
|
||||
if (begin_x >= end_x)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter `min_x` must be less than `max_x`");
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return "sparkbar";
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * /*arena*/) const override
|
||||
{
|
||||
X x = assert_cast<const ColumnVector<X> *>(columns[0])->getData()[row_num];
|
||||
if (begin_x <= x && x <= end_x)
|
||||
{
|
||||
Y y = assert_cast<const ColumnVector<Y> *>(columns[1])->getData()[row_num];
|
||||
this->data(place).add(x, y);
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr __restrict rhs, Arena * /*arena*/) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * /*arena*/) const override
|
||||
{
|
||||
auto & to_column = assert_cast<ColumnString &>(to);
|
||||
const auto & data = this->data(place);
|
||||
render(to_column, data);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -1,7 +1,15 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <AggregateFunctions/AggregateFunctionStatistics.h>
|
||||
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -16,6 +24,430 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
/// This function returns true if both values are large and comparable.
|
||||
/// It is used to calculate the mean value by merging two sources.
|
||||
/// It means that if the sizes of both sources are large and comparable, then we must apply a special
|
||||
/// formula guaranteeing more stability.
|
||||
bool areComparable(UInt64 a, UInt64 b)
|
||||
{
|
||||
const Float64 sensitivity = 0.001;
|
||||
const UInt64 threshold = 10000;
|
||||
|
||||
if ((a == 0) || (b == 0))
|
||||
return false;
|
||||
|
||||
auto res = std::minmax(a, b);
|
||||
return (((1 - static_cast<Float64>(res.first) / res.second) < sensitivity) && (res.first > threshold));
|
||||
}
|
||||
|
||||
|
||||
/** Statistical aggregate functions
|
||||
* varSamp - sample variance
|
||||
* stddevSamp - mean sample quadratic deviation
|
||||
* varPop - variance
|
||||
* stddevPop - standard deviation
|
||||
* covarSamp - selective covariance
|
||||
* covarPop - covariance
|
||||
* corr - correlation
|
||||
*/
|
||||
|
||||
/** Parallel and incremental algorithm for calculating variance.
|
||||
* Source: "Updating formulae and a pairwise algorithm for computing sample variances"
|
||||
* (Chan et al., Stanford University, 12.1979)
|
||||
*/
|
||||
struct AggregateFunctionVarianceData
|
||||
{
|
||||
void update(const IColumn & column, size_t row_num)
|
||||
{
|
||||
Float64 val = column.getFloat64(row_num);
|
||||
Float64 delta = val - mean;
|
||||
|
||||
++count;
|
||||
mean += delta / count;
|
||||
m2 += delta * (val - mean);
|
||||
}
|
||||
|
||||
void mergeWith(const AggregateFunctionVarianceData & source)
|
||||
{
|
||||
UInt64 total_count = count + source.count;
|
||||
if (total_count == 0)
|
||||
return;
|
||||
|
||||
Float64 factor = static_cast<Float64>(count * source.count) / total_count;
|
||||
Float64 delta = mean - source.mean;
|
||||
|
||||
if (areComparable(count, source.count))
|
||||
mean = (source.count * source.mean + count * mean) / total_count;
|
||||
else
|
||||
mean = source.mean + delta * (static_cast<Float64>(count) / total_count);
|
||||
|
||||
m2 += source.m2 + delta * delta * factor;
|
||||
count = total_count;
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
writeVarUInt(count, buf);
|
||||
writeBinary(mean, buf);
|
||||
writeBinary(m2, buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
readVarUInt(count, buf);
|
||||
readBinary(mean, buf);
|
||||
readBinary(m2, buf);
|
||||
}
|
||||
|
||||
UInt64 count = 0;
|
||||
Float64 mean = 0.0;
|
||||
Float64 m2 = 0.0;
|
||||
};
|
||||
|
||||
enum class VarKind
|
||||
{
|
||||
varSampStable,
|
||||
stddevSampStable,
|
||||
varPopStable,
|
||||
stddevPopStable,
|
||||
};
|
||||
|
||||
/** The main code for the implementation of varSamp, stddevSamp, varPop, stddevPop.
|
||||
*/
|
||||
class AggregateFunctionVariance final
|
||||
: public IAggregateFunctionDataHelper<AggregateFunctionVarianceData, AggregateFunctionVariance>
|
||||
{
|
||||
private:
|
||||
VarKind kind;
|
||||
|
||||
static Float64 getVarSamp(Float64 m2, UInt64 count)
|
||||
{
|
||||
if (count < 2)
|
||||
return std::numeric_limits<Float64>::infinity();
|
||||
else
|
||||
return m2 / (count - 1);
|
||||
}
|
||||
|
||||
static Float64 getStddevSamp(Float64 m2, UInt64 count)
|
||||
{
|
||||
return sqrt(getVarSamp(m2, count));
|
||||
}
|
||||
|
||||
static Float64 getVarPop(Float64 m2, UInt64 count)
|
||||
{
|
||||
if (count == 0)
|
||||
return std::numeric_limits<Float64>::infinity();
|
||||
else if (count == 1)
|
||||
return 0.0;
|
||||
else
|
||||
return m2 / count;
|
||||
}
|
||||
|
||||
static Float64 getStddevPop(Float64 m2, UInt64 count)
|
||||
{
|
||||
return sqrt(getVarPop(m2, count));
|
||||
}
|
||||
|
||||
Float64 getResult(ConstAggregateDataPtr __restrict place) const
|
||||
{
|
||||
const auto & data = this->data(place);
|
||||
switch (kind)
|
||||
{
|
||||
case VarKind::varSampStable: return getVarSamp(data.m2, data.count);
|
||||
case VarKind::stddevSampStable: return getStddevSamp(data.m2, data.count);
|
||||
case VarKind::varPopStable: return getVarPop(data.m2, data.count);
|
||||
case VarKind::stddevPopStable: return getStddevPop(data.m2, data.count);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
explicit AggregateFunctionVariance(VarKind kind_, const DataTypePtr & arg)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionVarianceData, AggregateFunctionVariance>({arg}, {}, std::make_shared<DataTypeFloat64>()),
|
||||
kind(kind_)
|
||||
{
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
case VarKind::varSampStable: return "varSampStable";
|
||||
case VarKind::stddevSampStable: return "stddevSampStable";
|
||||
case VarKind::varPopStable: return "varPopStable";
|
||||
case VarKind::stddevPopStable: return "stddevPopStable";
|
||||
}
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
this->data(place).update(*columns[0], row_num);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).mergeWith(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
assert_cast<ColumnFloat64 &>(to).getData().push_back(getResult(place));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/** If `compute_marginal_moments` flag is set this class provides the successor
|
||||
* CovarianceData support of marginal moments for calculating the correlation.
|
||||
*/
|
||||
template <bool compute_marginal_moments>
|
||||
struct BaseCovarianceData
|
||||
{
|
||||
void incrementMarginalMoments(Float64, Float64) {}
|
||||
void mergeWith(const BaseCovarianceData &) {}
|
||||
void serialize(WriteBuffer &) const {}
|
||||
void deserialize(const ReadBuffer &) {}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BaseCovarianceData<true>
|
||||
{
|
||||
void incrementMarginalMoments(Float64 left_incr, Float64 right_incr)
|
||||
{
|
||||
left_m2 += left_incr;
|
||||
right_m2 += right_incr;
|
||||
}
|
||||
|
||||
void mergeWith(const BaseCovarianceData & source)
|
||||
{
|
||||
left_m2 += source.left_m2;
|
||||
right_m2 += source.right_m2;
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
writeBinary(left_m2, buf);
|
||||
writeBinary(right_m2, buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
readBinary(left_m2, buf);
|
||||
readBinary(right_m2, buf);
|
||||
}
|
||||
|
||||
Float64 left_m2 = 0.0;
|
||||
Float64 right_m2 = 0.0;
|
||||
};
|
||||
|
||||
/** Parallel and incremental algorithm for calculating covariance.
|
||||
* Source: "Numerically Stable, Single-Pass, Parallel Statistics Algorithms"
|
||||
* (J. Bennett et al., Sandia National Laboratories,
|
||||
* 2009 IEEE International Conference on Cluster Computing)
|
||||
*/
|
||||
template <bool compute_marginal_moments>
|
||||
struct CovarianceData : public BaseCovarianceData<compute_marginal_moments>
|
||||
{
|
||||
using Base = BaseCovarianceData<compute_marginal_moments>;
|
||||
|
||||
void update(const IColumn & column_left, const IColumn & column_right, size_t row_num)
|
||||
{
|
||||
Float64 left_val = column_left.getFloat64(row_num);
|
||||
Float64 left_delta = left_val - left_mean;
|
||||
|
||||
Float64 right_val = column_right.getFloat64(row_num);
|
||||
Float64 right_delta = right_val - right_mean;
|
||||
|
||||
Float64 old_right_mean = right_mean;
|
||||
|
||||
++count;
|
||||
|
||||
left_mean += left_delta / count;
|
||||
right_mean += right_delta / count;
|
||||
co_moment += (left_val - left_mean) * (right_val - old_right_mean);
|
||||
|
||||
/// Update the marginal moments, if any.
|
||||
if (compute_marginal_moments)
|
||||
{
|
||||
Float64 left_incr = left_delta * (left_val - left_mean);
|
||||
Float64 right_incr = right_delta * (right_val - right_mean);
|
||||
Base::incrementMarginalMoments(left_incr, right_incr);
|
||||
}
|
||||
}
|
||||
|
||||
void mergeWith(const CovarianceData & source)
|
||||
{
|
||||
UInt64 total_count = count + source.count;
|
||||
if (total_count == 0)
|
||||
return;
|
||||
|
||||
Float64 factor = static_cast<Float64>(count * source.count) / total_count;
|
||||
Float64 left_delta = left_mean - source.left_mean;
|
||||
Float64 right_delta = right_mean - source.right_mean;
|
||||
|
||||
if (areComparable(count, source.count))
|
||||
{
|
||||
left_mean = (source.count * source.left_mean + count * left_mean) / total_count;
|
||||
right_mean = (source.count * source.right_mean + count * right_mean) / total_count;
|
||||
}
|
||||
else
|
||||
{
|
||||
left_mean = source.left_mean + left_delta * (static_cast<Float64>(count) / total_count);
|
||||
right_mean = source.right_mean + right_delta * (static_cast<Float64>(count) / total_count);
|
||||
}
|
||||
|
||||
co_moment += source.co_moment + left_delta * right_delta * factor;
|
||||
count = total_count;
|
||||
|
||||
/// Update the marginal moments, if any.
|
||||
if (compute_marginal_moments)
|
||||
{
|
||||
Float64 left_incr = left_delta * left_delta * factor;
|
||||
Float64 right_incr = right_delta * right_delta * factor;
|
||||
Base::mergeWith(source);
|
||||
Base::incrementMarginalMoments(left_incr, right_incr);
|
||||
}
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
writeVarUInt(count, buf);
|
||||
writeBinary(left_mean, buf);
|
||||
writeBinary(right_mean, buf);
|
||||
writeBinary(co_moment, buf);
|
||||
Base::serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
readVarUInt(count, buf);
|
||||
readBinary(left_mean, buf);
|
||||
readBinary(right_mean, buf);
|
||||
readBinary(co_moment, buf);
|
||||
Base::deserialize(buf);
|
||||
}
|
||||
|
||||
UInt64 count = 0;
|
||||
Float64 left_mean = 0.0;
|
||||
Float64 right_mean = 0.0;
|
||||
Float64 co_moment = 0.0;
|
||||
};
|
||||
|
||||
enum class CovarKind
|
||||
{
|
||||
covarSampStable,
|
||||
covarPopStable,
|
||||
corrStable,
|
||||
};
|
||||
|
||||
template <bool compute_marginal_moments>
|
||||
class AggregateFunctionCovariance final
|
||||
: public IAggregateFunctionDataHelper<
|
||||
CovarianceData<compute_marginal_moments>,
|
||||
AggregateFunctionCovariance<compute_marginal_moments>>
|
||||
{
|
||||
private:
|
||||
CovarKind kind;
|
||||
|
||||
static Float64 getCovarSamp(Float64 co_moment, UInt64 count)
|
||||
{
|
||||
if (count < 2)
|
||||
return std::numeric_limits<Float64>::infinity();
|
||||
else
|
||||
return co_moment / (count - 1);
|
||||
}
|
||||
|
||||
static Float64 getCovarPop(Float64 co_moment, UInt64 count)
|
||||
{
|
||||
if (count == 0)
|
||||
return std::numeric_limits<Float64>::infinity();
|
||||
else if (count == 1)
|
||||
return 0.0;
|
||||
else
|
||||
return co_moment / count;
|
||||
}
|
||||
|
||||
static Float64 getCorr(Float64 co_moment, Float64 left_m2, Float64 right_m2, UInt64 count)
|
||||
{
|
||||
if (count < 2)
|
||||
return std::numeric_limits<Float64>::infinity();
|
||||
else
|
||||
return co_moment / sqrt(left_m2 * right_m2);
|
||||
}
|
||||
|
||||
Float64 getResult(ConstAggregateDataPtr __restrict place) const
|
||||
{
|
||||
const auto & data = this->data(place);
|
||||
switch (kind)
|
||||
{
|
||||
case CovarKind::covarSampStable: return getCovarSamp(data.co_moment, data.count);
|
||||
case CovarKind::covarPopStable: return getCovarPop(data.co_moment, data.count);
|
||||
|
||||
case CovarKind::corrStable:
|
||||
if constexpr (compute_marginal_moments)
|
||||
return getCorr(data.co_moment, data.left_m2, data.right_m2, data.count);
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
explicit AggregateFunctionCovariance(CovarKind kind_, const DataTypes & args) : IAggregateFunctionDataHelper<
|
||||
CovarianceData<compute_marginal_moments>,
|
||||
AggregateFunctionCovariance<compute_marginal_moments>>(args, {}, std::make_shared<DataTypeFloat64>()),
|
||||
kind(kind_)
|
||||
{
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
case CovarKind::covarSampStable: return "covarSampStable";
|
||||
case CovarKind::covarPopStable: return "covarPopStable";
|
||||
case CovarKind::corrStable: return "corrStable";
|
||||
}
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
this->data(place).update(*columns[0], *columns[1], row_num);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).mergeWith(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
assert_cast<ColumnFloat64 &>(to).getData().push_back(getResult(place));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <template <typename> typename FunctionTemplate>
|
||||
AggregateFunctionPtr createAggregateFunctionStatisticsUnary(
|
||||
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
@ -51,13 +483,54 @@ AggregateFunctionPtr createAggregateFunctionStatisticsBinary(
|
||||
|
||||
void registerAggregateFunctionsStatisticsStable(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("varSampStable", createAggregateFunctionStatisticsUnary<AggregateFunctionVarSampStable>);
|
||||
factory.registerFunction("varPopStable", createAggregateFunctionStatisticsUnary<AggregateFunctionVarPopStable>);
|
||||
factory.registerFunction("stddevSampStable", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevSampStable>);
|
||||
factory.registerFunction("stddevPopStable", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevPopStable>);
|
||||
factory.registerFunction("covarSampStable", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarSampStable>);
|
||||
factory.registerFunction("covarPopStable", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarPopStable>);
|
||||
factory.registerFunction("corrStable", createAggregateFunctionStatisticsBinary<AggregateFunctionCorrStable>);
|
||||
factory.registerFunction("varSampStable", [](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
assertUnary(name, argument_types);
|
||||
return std::make_shared<AggregateFunctionVariance>(VarKind::varSampStable, argument_types[0]);
|
||||
});
|
||||
|
||||
factory.registerFunction("varPopStable", [](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
assertUnary(name, argument_types);
|
||||
return std::make_shared<AggregateFunctionVariance>(VarKind::varPopStable, argument_types[0]);
|
||||
});
|
||||
|
||||
factory.registerFunction("stddevSampStable", [](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
assertUnary(name, argument_types);
|
||||
return std::make_shared<AggregateFunctionVariance>(VarKind::stddevSampStable, argument_types[0]);
|
||||
});
|
||||
|
||||
factory.registerFunction("stddevPopStable", [](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
assertUnary(name, argument_types);
|
||||
return std::make_shared<AggregateFunctionVariance>(VarKind::stddevPopStable, argument_types[0]);
|
||||
});
|
||||
|
||||
factory.registerFunction("covarSampStable", [](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
assertBinary(name, argument_types);
|
||||
return std::make_shared<AggregateFunctionCovariance<false>>(CovarKind::covarSampStable, argument_types);
|
||||
});
|
||||
|
||||
factory.registerFunction("covarPopStable", [](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
assertBinary(name, argument_types);
|
||||
return std::make_shared<AggregateFunctionCovariance<false>>(CovarKind::covarPopStable, argument_types);
|
||||
});
|
||||
|
||||
factory.registerFunction("corrStable", [](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
assertBinary(name, argument_types);
|
||||
return std::make_shared<AggregateFunctionCovariance<true>>(CovarKind::corrStable, argument_types);
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,468 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
/// This function returns true if both values are large and comparable.
|
||||
/// It is used to calculate the mean value by merging two sources.
|
||||
/// It means that if the sizes of both sources are large and comparable, then we must apply a special
|
||||
/// formula guaranteeing more stability.
|
||||
bool areComparable(UInt64 a, UInt64 b)
|
||||
{
|
||||
const Float64 sensitivity = 0.001;
|
||||
const UInt64 threshold = 10000;
|
||||
|
||||
if ((a == 0) || (b == 0))
|
||||
return false;
|
||||
|
||||
auto res = std::minmax(a, b);
|
||||
return (((1 - static_cast<Float64>(res.first) / res.second) < sensitivity) && (res.first > threshold));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/** Statistical aggregate functions
|
||||
* varSamp - sample variance
|
||||
* stddevSamp - mean sample quadratic deviation
|
||||
* varPop - variance
|
||||
* stddevPop - standard deviation
|
||||
* covarSamp - selective covariance
|
||||
* covarPop - covariance
|
||||
* corr - correlation
|
||||
*/
|
||||
|
||||
/** Parallel and incremental algorithm for calculating variance.
|
||||
* Source: "Updating formulae and a pairwise algorithm for computing sample variances"
|
||||
* (Chan et al., Stanford University, 12.1979)
|
||||
*/
|
||||
template <typename T, typename Op>
|
||||
class AggregateFunctionVarianceData
|
||||
{
|
||||
public:
|
||||
void update(const IColumn & column, size_t row_num)
|
||||
{
|
||||
T received = assert_cast<const ColumnVector<T> &>(column).getData()[row_num];
|
||||
Float64 val = static_cast<Float64>(received);
|
||||
Float64 delta = val - mean;
|
||||
|
||||
++count;
|
||||
mean += delta / count;
|
||||
m2 += delta * (val - mean);
|
||||
}
|
||||
|
||||
void mergeWith(const AggregateFunctionVarianceData & source)
|
||||
{
|
||||
UInt64 total_count = count + source.count;
|
||||
if (total_count == 0)
|
||||
return;
|
||||
|
||||
Float64 factor = static_cast<Float64>(count * source.count) / total_count;
|
||||
Float64 delta = mean - source.mean;
|
||||
|
||||
if (detail::areComparable(count, source.count))
|
||||
mean = (source.count * source.mean + count * mean) / total_count;
|
||||
else
|
||||
mean = source.mean + delta * (static_cast<Float64>(count) / total_count);
|
||||
|
||||
m2 += source.m2 + delta * delta * factor;
|
||||
count = total_count;
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
writeVarUInt(count, buf);
|
||||
writeBinary(mean, buf);
|
||||
writeBinary(m2, buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
readVarUInt(count, buf);
|
||||
readBinary(mean, buf);
|
||||
readBinary(m2, buf);
|
||||
}
|
||||
|
||||
void publish(IColumn & to) const
|
||||
{
|
||||
assert_cast<ColumnFloat64 &>(to).getData().push_back(Op::apply(m2, count));
|
||||
}
|
||||
|
||||
private:
|
||||
UInt64 count = 0;
|
||||
Float64 mean = 0.0;
|
||||
Float64 m2 = 0.0;
|
||||
};
|
||||
|
||||
/** The main code for the implementation of varSamp, stddevSamp, varPop, stddevPop.
|
||||
*/
|
||||
template <typename T, typename Op>
|
||||
class AggregateFunctionVariance final
|
||||
: public IAggregateFunctionDataHelper<AggregateFunctionVarianceData<T, Op>, AggregateFunctionVariance<T, Op>>
|
||||
{
|
||||
public:
|
||||
explicit AggregateFunctionVariance(const DataTypePtr & arg)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionVarianceData<T, Op>, AggregateFunctionVariance<T, Op>>({arg}, {}, std::make_shared<DataTypeFloat64>())
|
||||
{}
|
||||
|
||||
String getName() const override { return Op::name; }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
this->data(place).update(*columns[0], row_num);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).mergeWith(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
this->data(place).publish(to);
|
||||
}
|
||||
};
|
||||
|
||||
/** Implementing the varSamp function.
|
||||
*/
|
||||
struct AggregateFunctionVarSampImpl
|
||||
{
|
||||
static constexpr auto name = "varSampStable";
|
||||
|
||||
static inline Float64 apply(Float64 m2, UInt64 count)
|
||||
{
|
||||
if (count < 2)
|
||||
return std::numeric_limits<Float64>::infinity();
|
||||
else
|
||||
return m2 / (count - 1);
|
||||
}
|
||||
};
|
||||
|
||||
/** Implementing the stddevSamp function.
|
||||
*/
|
||||
struct AggregateFunctionStdDevSampImpl
|
||||
{
|
||||
static constexpr auto name = "stddevSampStable";
|
||||
|
||||
static inline Float64 apply(Float64 m2, UInt64 count)
|
||||
{
|
||||
return sqrt(AggregateFunctionVarSampImpl::apply(m2, count));
|
||||
}
|
||||
};
|
||||
|
||||
/** Implementing the varPop function.
|
||||
*/
|
||||
struct AggregateFunctionVarPopImpl
|
||||
{
|
||||
static constexpr auto name = "varPopStable";
|
||||
|
||||
static inline Float64 apply(Float64 m2, UInt64 count)
|
||||
{
|
||||
if (count == 0)
|
||||
return std::numeric_limits<Float64>::infinity();
|
||||
else if (count == 1)
|
||||
return 0.0;
|
||||
else
|
||||
return m2 / count;
|
||||
}
|
||||
};
|
||||
|
||||
/** Implementing the stddevPop function.
|
||||
*/
|
||||
struct AggregateFunctionStdDevPopImpl
|
||||
{
|
||||
static constexpr auto name = "stddevPopStable";
|
||||
|
||||
static inline Float64 apply(Float64 m2, UInt64 count)
|
||||
{
|
||||
return sqrt(AggregateFunctionVarPopImpl::apply(m2, count));
|
||||
}
|
||||
};
|
||||
|
||||
/** If `compute_marginal_moments` flag is set this class provides the successor
|
||||
* CovarianceData support of marginal moments for calculating the correlation.
|
||||
*/
|
||||
template <bool compute_marginal_moments>
|
||||
class BaseCovarianceData
|
||||
{
|
||||
protected:
|
||||
void incrementMarginalMoments(Float64, Float64) {}
|
||||
void mergeWith(const BaseCovarianceData &) {}
|
||||
void serialize(WriteBuffer &) const {}
|
||||
void deserialize(const ReadBuffer &) {}
|
||||
};
|
||||
|
||||
template <>
|
||||
class BaseCovarianceData<true>
|
||||
{
|
||||
protected:
|
||||
void incrementMarginalMoments(Float64 left_incr, Float64 right_incr)
|
||||
{
|
||||
left_m2 += left_incr;
|
||||
right_m2 += right_incr;
|
||||
}
|
||||
|
||||
void mergeWith(const BaseCovarianceData & source)
|
||||
{
|
||||
left_m2 += source.left_m2;
|
||||
right_m2 += source.right_m2;
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
writeBinary(left_m2, buf);
|
||||
writeBinary(right_m2, buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
readBinary(left_m2, buf);
|
||||
readBinary(right_m2, buf);
|
||||
}
|
||||
|
||||
Float64 left_m2 = 0.0;
|
||||
Float64 right_m2 = 0.0;
|
||||
};
|
||||
|
||||
/** Parallel and incremental algorithm for calculating covariance.
|
||||
* Source: "Numerically Stable, Single-Pass, Parallel Statistics Algorithms"
|
||||
* (J. Bennett et al., Sandia National Laboratories,
|
||||
* 2009 IEEE International Conference on Cluster Computing)
|
||||
*/
|
||||
template <typename T, typename U, typename Op, bool compute_marginal_moments>
|
||||
class CovarianceData : public BaseCovarianceData<compute_marginal_moments>
|
||||
{
|
||||
private:
|
||||
using Base = BaseCovarianceData<compute_marginal_moments>;
|
||||
|
||||
public:
|
||||
void update(const IColumn & column_left, const IColumn & column_right, size_t row_num)
|
||||
{
|
||||
T left_received = assert_cast<const ColumnVector<T> &>(column_left).getData()[row_num];
|
||||
Float64 left_val = static_cast<Float64>(left_received);
|
||||
Float64 left_delta = left_val - left_mean;
|
||||
|
||||
U right_received = assert_cast<const ColumnVector<U> &>(column_right).getData()[row_num];
|
||||
Float64 right_val = static_cast<Float64>(right_received);
|
||||
Float64 right_delta = right_val - right_mean;
|
||||
|
||||
Float64 old_right_mean = right_mean;
|
||||
|
||||
++count;
|
||||
|
||||
left_mean += left_delta / count;
|
||||
right_mean += right_delta / count;
|
||||
co_moment += (left_val - left_mean) * (right_val - old_right_mean);
|
||||
|
||||
/// Update the marginal moments, if any.
|
||||
if (compute_marginal_moments)
|
||||
{
|
||||
Float64 left_incr = left_delta * (left_val - left_mean);
|
||||
Float64 right_incr = right_delta * (right_val - right_mean);
|
||||
Base::incrementMarginalMoments(left_incr, right_incr);
|
||||
}
|
||||
}
|
||||
|
||||
void mergeWith(const CovarianceData & source)
|
||||
{
|
||||
UInt64 total_count = count + source.count;
|
||||
if (total_count == 0)
|
||||
return;
|
||||
|
||||
Float64 factor = static_cast<Float64>(count * source.count) / total_count;
|
||||
Float64 left_delta = left_mean - source.left_mean;
|
||||
Float64 right_delta = right_mean - source.right_mean;
|
||||
|
||||
if (detail::areComparable(count, source.count))
|
||||
{
|
||||
left_mean = (source.count * source.left_mean + count * left_mean) / total_count;
|
||||
right_mean = (source.count * source.right_mean + count * right_mean) / total_count;
|
||||
}
|
||||
else
|
||||
{
|
||||
left_mean = source.left_mean + left_delta * (static_cast<Float64>(count) / total_count);
|
||||
right_mean = source.right_mean + right_delta * (static_cast<Float64>(count) / total_count);
|
||||
}
|
||||
|
||||
co_moment += source.co_moment + left_delta * right_delta * factor;
|
||||
count = total_count;
|
||||
|
||||
/// Update the marginal moments, if any.
|
||||
if (compute_marginal_moments)
|
||||
{
|
||||
Float64 left_incr = left_delta * left_delta * factor;
|
||||
Float64 right_incr = right_delta * right_delta * factor;
|
||||
Base::mergeWith(source);
|
||||
Base::incrementMarginalMoments(left_incr, right_incr);
|
||||
}
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
writeVarUInt(count, buf);
|
||||
writeBinary(left_mean, buf);
|
||||
writeBinary(right_mean, buf);
|
||||
writeBinary(co_moment, buf);
|
||||
Base::serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
readVarUInt(count, buf);
|
||||
readBinary(left_mean, buf);
|
||||
readBinary(right_mean, buf);
|
||||
readBinary(co_moment, buf);
|
||||
Base::deserialize(buf);
|
||||
}
|
||||
|
||||
void publish(IColumn & to) const
|
||||
{
|
||||
if constexpr (compute_marginal_moments)
|
||||
assert_cast<ColumnFloat64 &>(to).getData().push_back(Op::apply(co_moment, Base::left_m2, Base::right_m2, count));
|
||||
else
|
||||
assert_cast<ColumnFloat64 &>(to).getData().push_back(Op::apply(co_moment, count));
|
||||
}
|
||||
|
||||
private:
|
||||
UInt64 count = 0;
|
||||
Float64 left_mean = 0.0;
|
||||
Float64 right_mean = 0.0;
|
||||
Float64 co_moment = 0.0;
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op, bool compute_marginal_moments = false>
|
||||
class AggregateFunctionCovariance final
|
||||
: public IAggregateFunctionDataHelper<
|
||||
CovarianceData<T, U, Op, compute_marginal_moments>,
|
||||
AggregateFunctionCovariance<T, U, Op, compute_marginal_moments>>
|
||||
{
|
||||
public:
|
||||
explicit AggregateFunctionCovariance(const DataTypes & args) : IAggregateFunctionDataHelper<
|
||||
CovarianceData<T, U, Op, compute_marginal_moments>,
|
||||
AggregateFunctionCovariance<T, U, Op, compute_marginal_moments>>(args, {}, std::make_shared<DataTypeFloat64>())
|
||||
{}
|
||||
|
||||
String getName() const override { return Op::name; }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
this->data(place).update(*columns[0], *columns[1], row_num);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).mergeWith(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
this->data(place).publish(to);
|
||||
}
|
||||
};
|
||||
|
||||
/** Implementing the covarSamp function.
|
||||
*/
|
||||
struct AggregateFunctionCovarSampImpl
|
||||
{
|
||||
static constexpr auto name = "covarSampStable";
|
||||
|
||||
static inline Float64 apply(Float64 co_moment, UInt64 count)
|
||||
{
|
||||
if (count < 2)
|
||||
return std::numeric_limits<Float64>::infinity();
|
||||
else
|
||||
return co_moment / (count - 1);
|
||||
}
|
||||
};
|
||||
|
||||
/** Implementing the covarPop function.
|
||||
*/
|
||||
struct AggregateFunctionCovarPopImpl
|
||||
{
|
||||
static constexpr auto name = "covarPopStable";
|
||||
|
||||
static inline Float64 apply(Float64 co_moment, UInt64 count)
|
||||
{
|
||||
if (count == 0)
|
||||
return std::numeric_limits<Float64>::infinity();
|
||||
else if (count == 1)
|
||||
return 0.0;
|
||||
else
|
||||
return co_moment / count;
|
||||
}
|
||||
};
|
||||
|
||||
/** `corr` function implementation.
|
||||
*/
|
||||
struct AggregateFunctionCorrImpl
|
||||
{
|
||||
static constexpr auto name = "corrStable";
|
||||
|
||||
static inline Float64 apply(Float64 co_moment, Float64 left_m2, Float64 right_m2, UInt64 count)
|
||||
{
|
||||
if (count < 2)
|
||||
return std::numeric_limits<Float64>::infinity();
|
||||
else
|
||||
return co_moment / sqrt(left_m2 * right_m2);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using AggregateFunctionVarSampStable = AggregateFunctionVariance<T, AggregateFunctionVarSampImpl>;
|
||||
|
||||
template <typename T>
|
||||
using AggregateFunctionStddevSampStable = AggregateFunctionVariance<T, AggregateFunctionStdDevSampImpl>;
|
||||
|
||||
template <typename T>
|
||||
using AggregateFunctionVarPopStable = AggregateFunctionVariance<T, AggregateFunctionVarPopImpl>;
|
||||
|
||||
template <typename T>
|
||||
using AggregateFunctionStddevPopStable = AggregateFunctionVariance<T, AggregateFunctionStdDevPopImpl>;
|
||||
|
||||
template <typename T, typename U>
|
||||
using AggregateFunctionCovarSampStable = AggregateFunctionCovariance<T, U, AggregateFunctionCovarSampImpl>;
|
||||
|
||||
template <typename T, typename U>
|
||||
using AggregateFunctionCovarPopStable = AggregateFunctionCovariance<T, U, AggregateFunctionCovarPopImpl>;
|
||||
|
||||
template <typename T, typename U>
|
||||
using AggregateFunctionCorrStable = AggregateFunctionCovariance<T, U, AggregateFunctionCorrImpl, true>;
|
||||
|
||||
}
|
@ -1,7 +1,8 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionSumCount.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <AggregateFunctions/AggregateFunctionAvg.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -16,6 +17,59 @@ namespace ErrorCodes
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
class AggregateFunctionSumCount final : public AggregateFunctionAvg<T>
|
||||
{
|
||||
public:
|
||||
using Base = AggregateFunctionAvg<T>;
|
||||
|
||||
explicit AggregateFunctionSumCount(const DataTypes & argument_types_, UInt32 num_scale_ = 0)
|
||||
: Base(argument_types_, createResultType(num_scale_), num_scale_)
|
||||
{}
|
||||
|
||||
static DataTypePtr createResultType(UInt32 num_scale_)
|
||||
{
|
||||
auto second_elem = std::make_shared<DataTypeUInt64>();
|
||||
return std::make_shared<DataTypeTuple>(DataTypes{getReturnTypeFirstElement(num_scale_), std::move(second_elem)});
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const final
|
||||
{
|
||||
assert_cast<ColumnVectorOrDecimal<AvgFieldType<T>> &>((assert_cast<ColumnTuple &>(to)).getColumn(0)).getData().push_back(
|
||||
this->data(place).numerator);
|
||||
|
||||
assert_cast<ColumnUInt64 &>((assert_cast<ColumnTuple &>(to)).getColumn(1)).getData().push_back(
|
||||
this->data(place).denominator);
|
||||
}
|
||||
|
||||
String getName() const final { return "sumCount"; }
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
bool isCompilable() const override
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
private:
|
||||
static auto getReturnTypeFirstElement(UInt32 num_scale_)
|
||||
{
|
||||
using FieldType = AvgFieldType<T>;
|
||||
|
||||
if constexpr (!is_decimal<T>)
|
||||
return std::make_shared<DataTypeNumber<FieldType>>();
|
||||
else
|
||||
{
|
||||
using DataType = DataTypeDecimal<FieldType>;
|
||||
return std::make_shared<DataType>(DataType::maxPrecision(), num_scale_);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
bool allowType(const DataTypePtr& type) noexcept
|
||||
{
|
||||
const WhichDataType t(type);
|
||||
|
@ -1,61 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <AggregateFunctions/AggregateFunctionAvg.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
template <typename T>
|
||||
class AggregateFunctionSumCount final : public AggregateFunctionAvg<T>
|
||||
{
|
||||
public:
|
||||
using Base = AggregateFunctionAvg<T>;
|
||||
|
||||
explicit AggregateFunctionSumCount(const DataTypes & argument_types_, UInt32 num_scale_ = 0)
|
||||
: Base(argument_types_, createResultType(num_scale_), num_scale_)
|
||||
{}
|
||||
|
||||
static DataTypePtr createResultType(UInt32 num_scale_)
|
||||
{
|
||||
auto second_elem = std::make_shared<DataTypeUInt64>();
|
||||
return std::make_shared<DataTypeTuple>(DataTypes{getReturnTypeFirstElement(num_scale_), std::move(second_elem)});
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const final
|
||||
{
|
||||
assert_cast<ColumnVectorOrDecimal<AvgFieldType<T>> &>((assert_cast<ColumnTuple &>(to)).getColumn(0)).getData().push_back(
|
||||
this->data(place).numerator);
|
||||
|
||||
assert_cast<ColumnUInt64 &>((assert_cast<ColumnTuple &>(to)).getColumn(1)).getData().push_back(
|
||||
this->data(place).denominator);
|
||||
}
|
||||
|
||||
String getName() const final { return "sumCount"; }
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
bool isCompilable() const override
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
private:
|
||||
static auto getReturnTypeFirstElement(UInt32 num_scale_)
|
||||
{
|
||||
using FieldType = AvgFieldType<T>;
|
||||
|
||||
if constexpr (!is_decimal<T>)
|
||||
return std::make_shared<DataTypeNumber<FieldType>>();
|
||||
else
|
||||
{
|
||||
using DataType = DataTypeDecimal<FieldType>;
|
||||
return std::make_shared<DataType>(DataType::maxPrecision(), num_scale_);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -1,24 +1,676 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionSumMap.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
#include <IO/ReadHelpers.h>
|
||||
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
|
||||
#include <Common/FieldVisitorSum.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <map>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int LOGICAL_ERROR;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
struct AggregateFunctionMapData
|
||||
{
|
||||
// Map needs to be ordered to maintain function properties
|
||||
std::map<Field, Array> merged_maps;
|
||||
};
|
||||
|
||||
/** Aggregate function, that takes at least two arguments: keys and values, and as a result, builds a tuple of at least 2 arrays -
|
||||
* ordered keys and variable number of argument values aggregated by corresponding keys.
|
||||
*
|
||||
* sumMap function is the most useful when using SummingMergeTree to sum Nested columns, which name ends in "Map".
|
||||
*
|
||||
* Example: sumMap(k, v...) of:
|
||||
* k v
|
||||
* [1,2,3] [10,10,10]
|
||||
* [3,4,5] [10,10,10]
|
||||
* [4,5,6] [10,10,10]
|
||||
* [6,7,8] [10,10,10]
|
||||
* [7,5,3] [5,15,25]
|
||||
* [8,9,10] [20,20,20]
|
||||
* will return:
|
||||
* ([1,2,3,4,5,6,7,8,9,10],[10,10,45,20,35,20,15,30,20,20])
|
||||
*
|
||||
* minMap and maxMap share the same idea, but calculate min and max correspondingly.
|
||||
*
|
||||
* NOTE: The implementation of these functions are "amateur grade" - not efficient and low quality.
|
||||
*/
|
||||
|
||||
template <typename Derived, typename Visitor, bool overflow, bool tuple_argument, bool compact>
|
||||
class AggregateFunctionMapBase : public IAggregateFunctionDataHelper<
|
||||
AggregateFunctionMapData, Derived>
|
||||
{
|
||||
private:
|
||||
static constexpr auto STATE_VERSION_1_MIN_REVISION = 54452;
|
||||
|
||||
DataTypePtr keys_type;
|
||||
SerializationPtr keys_serialization;
|
||||
DataTypes values_types;
|
||||
Serializations values_serializations;
|
||||
Serializations promoted_values_serializations;
|
||||
|
||||
public:
|
||||
using Base = IAggregateFunctionDataHelper<AggregateFunctionMapData, Derived>;
|
||||
|
||||
AggregateFunctionMapBase(const DataTypePtr & keys_type_,
|
||||
const DataTypes & values_types_, const DataTypes & argument_types_)
|
||||
: Base(argument_types_, {} /* parameters */, createResultType(keys_type_, values_types_))
|
||||
, keys_type(keys_type_)
|
||||
, keys_serialization(keys_type->getDefaultSerialization())
|
||||
, values_types(values_types_)
|
||||
{
|
||||
values_serializations.reserve(values_types.size());
|
||||
promoted_values_serializations.reserve(values_types.size());
|
||||
for (const auto & type : values_types)
|
||||
{
|
||||
values_serializations.emplace_back(type->getDefaultSerialization());
|
||||
if (type->canBePromoted())
|
||||
{
|
||||
if (type->isNullable())
|
||||
promoted_values_serializations.emplace_back(
|
||||
makeNullable(removeNullable(type)->promoteNumericType())->getDefaultSerialization());
|
||||
else
|
||||
promoted_values_serializations.emplace_back(type->promoteNumericType()->getDefaultSerialization());
|
||||
}
|
||||
else
|
||||
{
|
||||
promoted_values_serializations.emplace_back(type->getDefaultSerialization());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool isVersioned() const override { return true; }
|
||||
|
||||
size_t getDefaultVersion() const override { return 1; }
|
||||
|
||||
size_t getVersionFromRevision(size_t revision) const override
|
||||
{
|
||||
if (revision >= STATE_VERSION_1_MIN_REVISION)
|
||||
return 1;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
static DataTypePtr createResultType(
|
||||
const DataTypePtr & keys_type_,
|
||||
const DataTypes & values_types_)
|
||||
{
|
||||
DataTypes types;
|
||||
types.emplace_back(std::make_shared<DataTypeArray>(keys_type_));
|
||||
|
||||
for (const auto & value_type : values_types_)
|
||||
{
|
||||
if constexpr (std::is_same_v<Visitor, FieldVisitorSum>)
|
||||
{
|
||||
if (!value_type->isSummable())
|
||||
throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Values for -Map cannot be summed, passed type {}",
|
||||
value_type->getName()};
|
||||
}
|
||||
|
||||
DataTypePtr result_type;
|
||||
|
||||
if constexpr (overflow)
|
||||
{
|
||||
if (value_type->onlyNull())
|
||||
throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Cannot calculate -Map of type {}",
|
||||
value_type->getName()};
|
||||
|
||||
// Overflow, meaning that the returned type is the same as
|
||||
// the input type. Nulls are skipped.
|
||||
result_type = removeNullable(value_type);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto value_type_without_nullable = removeNullable(value_type);
|
||||
|
||||
// No overflow, meaning we promote the types if necessary.
|
||||
if (!value_type_without_nullable->canBePromoted())
|
||||
throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Values for -Map are expected to be Numeric, Float or Decimal, passed type {}",
|
||||
value_type->getName()};
|
||||
|
||||
WhichDataType value_type_to_check(value_type_without_nullable);
|
||||
|
||||
/// Do not promote decimal because of implementation issues of this function design
|
||||
/// Currently we cannot get result column type in case of decimal we cannot get decimal scale
|
||||
/// in method void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
/// If we decide to make this function more efficient we should promote decimal type during summ
|
||||
if (value_type_to_check.isDecimal())
|
||||
result_type = value_type_without_nullable;
|
||||
else
|
||||
result_type = value_type_without_nullable->promoteNumericType();
|
||||
}
|
||||
|
||||
types.emplace_back(std::make_shared<DataTypeArray>(result_type));
|
||||
}
|
||||
|
||||
return std::make_shared<DataTypeTuple>(types);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
static auto getArgumentColumns(const IColumn ** columns)
|
||||
{
|
||||
if constexpr (tuple_argument)
|
||||
{
|
||||
return assert_cast<const ColumnTuple *>(columns[0])->getColumns();
|
||||
}
|
||||
else
|
||||
{
|
||||
return columns;
|
||||
}
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns_, const size_t row_num, Arena *) const override
|
||||
{
|
||||
const auto & columns = getArgumentColumns(columns_);
|
||||
|
||||
// Column 0 contains array of keys of known type
|
||||
const ColumnArray & array_column0 = assert_cast<const ColumnArray &>(*columns[0]);
|
||||
const IColumn::Offsets & offsets0 = array_column0.getOffsets();
|
||||
const IColumn & key_column = array_column0.getData();
|
||||
const size_t keys_vec_offset = offsets0[row_num - 1];
|
||||
const size_t keys_vec_size = (offsets0[row_num] - keys_vec_offset);
|
||||
|
||||
// Columns 1..n contain arrays of numeric values to sum
|
||||
auto & merged_maps = this->data(place).merged_maps;
|
||||
for (size_t col = 0, size = values_types.size(); col < size; ++col)
|
||||
{
|
||||
const auto & array_column = assert_cast<const ColumnArray &>(*columns[col + 1]);
|
||||
const IColumn & value_column = array_column.getData();
|
||||
const IColumn::Offsets & offsets = array_column.getOffsets();
|
||||
const size_t values_vec_offset = offsets[row_num - 1];
|
||||
const size_t values_vec_size = (offsets[row_num] - values_vec_offset);
|
||||
|
||||
// Expect key and value arrays to be of same length
|
||||
if (keys_vec_size != values_vec_size)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Sizes of keys and values arrays do not match");
|
||||
|
||||
// Insert column values for all keys
|
||||
for (size_t i = 0; i < keys_vec_size; ++i)
|
||||
{
|
||||
Field value = value_column[values_vec_offset + i];
|
||||
Field key = key_column[keys_vec_offset + i];
|
||||
|
||||
if (!keepKey(key))
|
||||
continue;
|
||||
|
||||
auto [it, inserted] = merged_maps.emplace(key, Array());
|
||||
|
||||
if (inserted)
|
||||
{
|
||||
it->second.resize(size);
|
||||
it->second[col] = value;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!value.isNull())
|
||||
{
|
||||
if (it->second[col].isNull())
|
||||
it->second[col] = value;
|
||||
else
|
||||
applyVisitor(Visitor(value), it->second[col]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
auto & merged_maps = this->data(place).merged_maps;
|
||||
const auto & rhs_maps = this->data(rhs).merged_maps;
|
||||
|
||||
for (const auto & elem : rhs_maps)
|
||||
{
|
||||
const auto & it = merged_maps.find(elem.first);
|
||||
if (it != merged_maps.end())
|
||||
{
|
||||
for (size_t col = 0; col < values_types.size(); ++col)
|
||||
if (!elem.second[col].isNull())
|
||||
applyVisitor(Visitor(elem.second[col]), it->second[col]);
|
||||
}
|
||||
else
|
||||
merged_maps[elem.first] = elem.second;
|
||||
}
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> version) const override
|
||||
{
|
||||
if (!version)
|
||||
version = getDefaultVersion();
|
||||
|
||||
const auto & merged_maps = this->data(place).merged_maps;
|
||||
size_t size = merged_maps.size();
|
||||
writeVarUInt(size, buf);
|
||||
|
||||
std::function<void(size_t, const Array &)> serialize;
|
||||
switch (*version)
|
||||
{
|
||||
case 0:
|
||||
{
|
||||
serialize = [&](size_t col_idx, const Array & values)
|
||||
{
|
||||
values_serializations[col_idx]->serializeBinary(values[col_idx], buf, {});
|
||||
};
|
||||
break;
|
||||
}
|
||||
case 1:
|
||||
{
|
||||
serialize = [&](size_t col_idx, const Array & values)
|
||||
{
|
||||
Field value = values[col_idx];
|
||||
|
||||
/// Compatibility with previous versions.
|
||||
if (value.getType() == Field::Types::Decimal32)
|
||||
{
|
||||
auto source = value.get<DecimalField<Decimal32>>();
|
||||
value = DecimalField<Decimal128>(source.getValue(), source.getScale());
|
||||
}
|
||||
else if (value.getType() == Field::Types::Decimal64)
|
||||
{
|
||||
auto source = value.get<DecimalField<Decimal64>>();
|
||||
value = DecimalField<Decimal128>(source.getValue(), source.getScale());
|
||||
}
|
||||
|
||||
promoted_values_serializations[col_idx]->serializeBinary(value, buf, {});
|
||||
};
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unknown version {}, of -Map aggregate function serialization state", *version);
|
||||
}
|
||||
|
||||
for (const auto & elem : merged_maps)
|
||||
{
|
||||
keys_serialization->serializeBinary(elem.first, buf, {});
|
||||
for (size_t col = 0; col < values_types.size(); ++col)
|
||||
serialize(col, elem.second);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> version, Arena *) const override
|
||||
{
|
||||
if (!version)
|
||||
version = getDefaultVersion();
|
||||
|
||||
auto & merged_maps = this->data(place).merged_maps;
|
||||
size_t size = 0;
|
||||
readVarUInt(size, buf);
|
||||
|
||||
std::function<void(size_t, Array &)> deserialize;
|
||||
switch (*version)
|
||||
{
|
||||
case 0:
|
||||
{
|
||||
deserialize = [&](size_t col_idx, Array & values)
|
||||
{
|
||||
values_serializations[col_idx]->deserializeBinary(values[col_idx], buf, {});
|
||||
};
|
||||
break;
|
||||
}
|
||||
case 1:
|
||||
{
|
||||
deserialize = [&](size_t col_idx, Array & values)
|
||||
{
|
||||
Field & value = values[col_idx];
|
||||
promoted_values_serializations[col_idx]->deserializeBinary(value, buf, {});
|
||||
|
||||
/// Compatibility with previous versions.
|
||||
if (value.getType() == Field::Types::Decimal128)
|
||||
{
|
||||
auto source = value.get<DecimalField<Decimal128>>();
|
||||
WhichDataType value_type(values_types[col_idx]);
|
||||
if (value_type.isDecimal32())
|
||||
{
|
||||
value = DecimalField<Decimal32>(source.getValue(), source.getScale());
|
||||
}
|
||||
else if (value_type.isDecimal64())
|
||||
{
|
||||
value = DecimalField<Decimal64>(source.getValue(), source.getScale());
|
||||
}
|
||||
}
|
||||
};
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unexpected version {} of -Map aggregate function serialization state", *version);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
Field key;
|
||||
keys_serialization->deserializeBinary(key, buf, {});
|
||||
|
||||
Array values;
|
||||
values.resize(values_types.size());
|
||||
|
||||
for (size_t col = 0; col < values_types.size(); ++col)
|
||||
deserialize(col, values);
|
||||
|
||||
merged_maps[key] = values;
|
||||
}
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
size_t num_columns = values_types.size();
|
||||
|
||||
// Final step does compaction of keys that have zero values, this mutates the state
|
||||
auto & merged_maps = this->data(place).merged_maps;
|
||||
|
||||
// Remove keys which are zeros or empty. This should be enabled only for sumMap.
|
||||
if constexpr (compact)
|
||||
{
|
||||
for (auto it = merged_maps.cbegin(); it != merged_maps.cend();)
|
||||
{
|
||||
// Key is not compacted if it has at least one non-zero value
|
||||
bool erase = true;
|
||||
for (size_t col = 0; col < num_columns; ++col)
|
||||
{
|
||||
if (!it->second[col].isNull() && it->second[col] != values_types[col]->getDefault())
|
||||
{
|
||||
erase = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (erase)
|
||||
it = merged_maps.erase(it);
|
||||
else
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
size_t size = merged_maps.size();
|
||||
|
||||
auto & to_tuple = assert_cast<ColumnTuple &>(to);
|
||||
auto & to_keys_arr = assert_cast<ColumnArray &>(to_tuple.getColumn(0));
|
||||
auto & to_keys_col = to_keys_arr.getData();
|
||||
|
||||
// Advance column offsets
|
||||
auto & to_keys_offsets = to_keys_arr.getOffsets();
|
||||
to_keys_offsets.push_back(to_keys_offsets.back() + size);
|
||||
to_keys_col.reserve(size);
|
||||
|
||||
for (size_t col = 0; col < num_columns; ++col)
|
||||
{
|
||||
auto & to_values_arr = assert_cast<ColumnArray &>(to_tuple.getColumn(col + 1));
|
||||
auto & to_values_offsets = to_values_arr.getOffsets();
|
||||
to_values_offsets.push_back(to_values_offsets.back() + size);
|
||||
to_values_arr.getData().reserve(size);
|
||||
}
|
||||
|
||||
// Write arrays of keys and values
|
||||
for (const auto & elem : merged_maps)
|
||||
{
|
||||
// Write array of keys into column
|
||||
to_keys_col.insert(elem.first);
|
||||
|
||||
// Write 0..n arrays of values
|
||||
for (size_t col = 0; col < num_columns; ++col)
|
||||
{
|
||||
auto & to_values_col = assert_cast<ColumnArray &>(to_tuple.getColumn(col + 1)).getData();
|
||||
if (elem.second[col].isNull())
|
||||
to_values_col.insertDefault();
|
||||
else
|
||||
to_values_col.insert(elem.second[col]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool keepKey(const Field & key) const { return static_cast<const Derived &>(*this).keepKey(key); }
|
||||
String getName() const override { return Derived::getNameImpl(); }
|
||||
};
|
||||
|
||||
template <bool overflow, bool tuple_argument>
|
||||
class AggregateFunctionSumMap final :
|
||||
public AggregateFunctionMapBase<AggregateFunctionSumMap<overflow, tuple_argument>, FieldVisitorSum, overflow, tuple_argument, true>
|
||||
{
|
||||
private:
|
||||
using Self = AggregateFunctionSumMap<overflow, tuple_argument>;
|
||||
using Base = AggregateFunctionMapBase<Self, FieldVisitorSum, overflow, tuple_argument, true>;
|
||||
|
||||
public:
|
||||
AggregateFunctionSumMap(const DataTypePtr & keys_type_,
|
||||
DataTypes & values_types_, const DataTypes & argument_types_,
|
||||
const Array & params_)
|
||||
: Base{keys_type_, values_types_, argument_types_}
|
||||
{
|
||||
// The constructor accepts parameters to have a uniform interface with
|
||||
// sumMapFiltered, but this function doesn't have any parameters.
|
||||
assertNoParameters(getNameImpl(), params_);
|
||||
}
|
||||
|
||||
static String getNameImpl()
|
||||
{
|
||||
if constexpr (overflow)
|
||||
{
|
||||
return "sumMapWithOverflow";
|
||||
}
|
||||
else
|
||||
{
|
||||
return "sumMap";
|
||||
}
|
||||
}
|
||||
|
||||
bool keepKey(const Field &) const { return true; }
|
||||
};
|
||||
|
||||
|
||||
template <bool overflow, bool tuple_argument>
|
||||
class AggregateFunctionSumMapFiltered final :
|
||||
public AggregateFunctionMapBase<
|
||||
AggregateFunctionSumMapFiltered<overflow, tuple_argument>,
|
||||
FieldVisitorSum,
|
||||
overflow,
|
||||
tuple_argument,
|
||||
true>
|
||||
{
|
||||
private:
|
||||
using Self = AggregateFunctionSumMapFiltered<overflow, tuple_argument>;
|
||||
using Base = AggregateFunctionMapBase<Self, FieldVisitorSum, overflow, tuple_argument, true>;
|
||||
|
||||
using ContainerT = std::set<Field>;
|
||||
ContainerT keys_to_keep;
|
||||
|
||||
public:
|
||||
AggregateFunctionSumMapFiltered(const DataTypePtr & keys_type_,
|
||||
const DataTypes & values_types_, const DataTypes & argument_types_,
|
||||
const Array & params_)
|
||||
: Base{keys_type_, values_types_, argument_types_}
|
||||
{
|
||||
if (params_.size() != 1)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
|
||||
"Aggregate function '{}' requires exactly one parameter "
|
||||
"of Array type", getNameImpl());
|
||||
|
||||
Array keys_to_keep_values;
|
||||
if (!params_.front().tryGet<Array>(keys_to_keep_values))
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Aggregate function {} requires an Array as a parameter",
|
||||
getNameImpl());
|
||||
|
||||
this->parameters = params_;
|
||||
|
||||
for (const Field & f : keys_to_keep_values)
|
||||
keys_to_keep.emplace(f);
|
||||
}
|
||||
|
||||
static String getNameImpl()
|
||||
{
|
||||
if constexpr (overflow)
|
||||
{
|
||||
return "sumMapFilteredWithOverflow";
|
||||
}
|
||||
else
|
||||
{
|
||||
return "sumMapFiltered";
|
||||
}
|
||||
}
|
||||
|
||||
bool keepKey(const Field & key) const { return keys_to_keep.contains(key); }
|
||||
};
|
||||
|
||||
|
||||
/** Implements `Max` operation.
|
||||
* Returns true if changed
|
||||
*/
|
||||
class FieldVisitorMax : public StaticVisitor<bool>
|
||||
{
|
||||
private:
|
||||
const Field & rhs;
|
||||
|
||||
template <typename FieldType>
|
||||
bool compareImpl(FieldType & x) const
|
||||
{
|
||||
auto val = rhs.get<FieldType>();
|
||||
if (val > x)
|
||||
{
|
||||
x = val;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
public:
|
||||
explicit FieldVisitorMax(const Field & rhs_) : rhs(rhs_) {}
|
||||
|
||||
bool operator() (Null &) const
|
||||
{
|
||||
/// Do not update current value, skip nulls
|
||||
return false;
|
||||
}
|
||||
|
||||
bool operator() (AggregateFunctionStateData &) const { throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot compare AggregateFunctionStates"); }
|
||||
|
||||
bool operator() (Array & x) const { return compareImpl<Array>(x); }
|
||||
bool operator() (Tuple & x) const { return compareImpl<Tuple>(x); }
|
||||
template <typename T>
|
||||
bool operator() (DecimalField<T> & x) const { return compareImpl<DecimalField<T>>(x); }
|
||||
template <typename T>
|
||||
bool operator() (T & x) const { return compareImpl<T>(x); }
|
||||
};
|
||||
|
||||
/** Implements `Min` operation.
|
||||
* Returns true if changed
|
||||
*/
|
||||
class FieldVisitorMin : public StaticVisitor<bool>
|
||||
{
|
||||
private:
|
||||
const Field & rhs;
|
||||
|
||||
template <typename FieldType>
|
||||
bool compareImpl(FieldType & x) const
|
||||
{
|
||||
auto val = rhs.get<FieldType>();
|
||||
if (val < x)
|
||||
{
|
||||
x = val;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
public:
|
||||
explicit FieldVisitorMin(const Field & rhs_) : rhs(rhs_) {}
|
||||
|
||||
|
||||
bool operator() (Null &) const
|
||||
{
|
||||
/// Do not update current value, skip nulls
|
||||
return false;
|
||||
}
|
||||
|
||||
bool operator() (AggregateFunctionStateData &) const { throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot sum AggregateFunctionStates"); }
|
||||
|
||||
bool operator() (Array & x) const { return compareImpl<Array>(x); }
|
||||
bool operator() (Tuple & x) const { return compareImpl<Tuple>(x); }
|
||||
template <typename T>
|
||||
bool operator() (DecimalField<T> & x) const { return compareImpl<DecimalField<T>>(x); }
|
||||
template <typename T>
|
||||
bool operator() (T & x) const { return compareImpl<T>(x); }
|
||||
};
|
||||
|
||||
|
||||
template <bool tuple_argument>
|
||||
class AggregateFunctionMinMap final :
|
||||
public AggregateFunctionMapBase<AggregateFunctionMinMap<tuple_argument>, FieldVisitorMin, true, tuple_argument, false>
|
||||
{
|
||||
private:
|
||||
using Self = AggregateFunctionMinMap<tuple_argument>;
|
||||
using Base = AggregateFunctionMapBase<Self, FieldVisitorMin, true, tuple_argument, false>;
|
||||
|
||||
public:
|
||||
AggregateFunctionMinMap(const DataTypePtr & keys_type_,
|
||||
DataTypes & values_types_, const DataTypes & argument_types_,
|
||||
const Array & params_)
|
||||
: Base{keys_type_, values_types_, argument_types_}
|
||||
{
|
||||
// The constructor accepts parameters to have a uniform interface with
|
||||
// sumMapFiltered, but this function doesn't have any parameters.
|
||||
assertNoParameters(getNameImpl(), params_);
|
||||
}
|
||||
|
||||
static String getNameImpl() { return "minMap"; }
|
||||
|
||||
bool keepKey(const Field &) const { return true; }
|
||||
};
|
||||
|
||||
template <bool tuple_argument>
|
||||
class AggregateFunctionMaxMap final :
|
||||
public AggregateFunctionMapBase<AggregateFunctionMaxMap<tuple_argument>, FieldVisitorMax, true, tuple_argument, false>
|
||||
{
|
||||
private:
|
||||
using Self = AggregateFunctionMaxMap<tuple_argument>;
|
||||
using Base = AggregateFunctionMapBase<Self, FieldVisitorMax, true, tuple_argument, false>;
|
||||
|
||||
public:
|
||||
AggregateFunctionMaxMap(const DataTypePtr & keys_type_,
|
||||
DataTypes & values_types_, const DataTypes & argument_types_,
|
||||
const Array & params_)
|
||||
: Base{keys_type_, values_types_, argument_types_}
|
||||
{
|
||||
// The constructor accepts parameters to have a uniform interface with
|
||||
// sumMapFiltered, but this function doesn't have any parameters.
|
||||
assertNoParameters(getNameImpl(), params_);
|
||||
}
|
||||
|
||||
static String getNameImpl() { return "maxMap"; }
|
||||
|
||||
bool keepKey(const Field &) const { return true; }
|
||||
};
|
||||
|
||||
|
||||
auto parseArguments(const std::string & name, const DataTypes & arguments)
|
||||
{
|
||||
DataTypes args;
|
||||
@ -69,77 +721,6 @@ auto parseArguments(const std::string & name, const DataTypes & arguments)
|
||||
return std::tuple<DataTypePtr, DataTypes, bool>{std::move(keys_type), std::move(values_types), tuple_argument};
|
||||
}
|
||||
|
||||
// This function instantiates a particular overload of the sumMap family of
|
||||
// functions.
|
||||
// The template parameter MappedFunction<bool template_argument> is an aggregate
|
||||
// function template that allows to choose the aggregate function variant that
|
||||
// accepts either normal arguments or tuple argument.
|
||||
template<template <bool tuple_argument> typename MappedFunction>
|
||||
AggregateFunctionPtr createAggregateFunctionMap(const std::string & name, const DataTypes & arguments, const Array & params, const Settings *)
|
||||
{
|
||||
auto [keys_type, values_types, tuple_argument] = parseArguments(name, arguments);
|
||||
|
||||
AggregateFunctionPtr res;
|
||||
if (tuple_argument)
|
||||
{
|
||||
res.reset(createWithNumericBasedType<MappedFunction<true>::template F>(*keys_type, keys_type, values_types, arguments, params));
|
||||
if (!res)
|
||||
res.reset(createWithDecimalType<MappedFunction<true>::template F>(*keys_type, keys_type, values_types, arguments, params));
|
||||
if (!res)
|
||||
res.reset(createWithStringType<MappedFunction<true>::template F>(*keys_type, keys_type, values_types, arguments, params));
|
||||
}
|
||||
else
|
||||
{
|
||||
res.reset(createWithNumericBasedType<MappedFunction<false>::template F>(*keys_type, keys_type, values_types, arguments, params));
|
||||
if (!res)
|
||||
res.reset(createWithDecimalType<MappedFunction<false>::template F>(*keys_type, keys_type, values_types, arguments, params));
|
||||
if (!res)
|
||||
res.reset(createWithStringType<MappedFunction<false>::template F>(*keys_type, keys_type, values_types, arguments, params));
|
||||
}
|
||||
if (!res)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type of argument for aggregate function {}", name);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// This template chooses the sumMap variant with given filtering and overflow
|
||||
// handling.
|
||||
template <bool filtered, bool overflow>
|
||||
struct SumMapVariants
|
||||
{
|
||||
// SumMapVariants chooses the `overflow` and `filtered` parameters of the
|
||||
// aggregate functions. The `tuple_argument` and the value type `T` are left
|
||||
// as free parameters.
|
||||
// DispatchOnTupleArgument chooses `tuple_argument`, and the value type `T`
|
||||
// is left free.
|
||||
template <bool tuple_argument>
|
||||
struct DispatchOnTupleArgument
|
||||
{
|
||||
template <typename T>
|
||||
using F = std::conditional_t<filtered,
|
||||
AggregateFunctionSumMapFiltered<T, overflow, tuple_argument>,
|
||||
AggregateFunctionSumMap<T, overflow, tuple_argument>>;
|
||||
};
|
||||
};
|
||||
|
||||
// This template gives an aggregate function template that is narrowed
|
||||
// to accept either tuple argumen or normal arguments.
|
||||
template <bool tuple_argument>
|
||||
struct MinMapDispatchOnTupleArgument
|
||||
{
|
||||
template <typename T>
|
||||
using F = AggregateFunctionMinMap<T, tuple_argument>;
|
||||
};
|
||||
|
||||
// This template gives an aggregate function template that is narrowed
|
||||
// to accept either tuple argumen or normal arguments.
|
||||
template <bool tuple_argument>
|
||||
struct MaxMapDispatchOnTupleArgument
|
||||
{
|
||||
template <typename T>
|
||||
using F = AggregateFunctionMaxMap<T, tuple_argument>;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory)
|
||||
@ -147,26 +728,61 @@ void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory)
|
||||
// these functions used to be called *Map, with now these names occupied by
|
||||
// Map combinator, which redirects calls here if was called with
|
||||
// array or tuple arguments.
|
||||
factory.registerFunction("sumMappedArrays", createAggregateFunctionMap<
|
||||
SumMapVariants<false, false>::DispatchOnTupleArgument>);
|
||||
factory.registerFunction("sumMappedArrays", [](const std::string & name, const DataTypes & arguments, const Array & params, const Settings *) -> AggregateFunctionPtr
|
||||
{
|
||||
auto [keys_type, values_types, tuple_argument] = parseArguments(name, arguments);
|
||||
if (tuple_argument)
|
||||
return std::make_shared<AggregateFunctionSumMap<false, true>>(keys_type, values_types, arguments, params);
|
||||
else
|
||||
return std::make_shared<AggregateFunctionSumMap<false, false>>(keys_type, values_types, arguments, params);
|
||||
});
|
||||
|
||||
factory.registerFunction("minMappedArrays",
|
||||
createAggregateFunctionMap<MinMapDispatchOnTupleArgument>);
|
||||
factory.registerFunction("minMappedArrays", [](const std::string & name, const DataTypes & arguments, const Array & params, const Settings *) -> AggregateFunctionPtr
|
||||
{
|
||||
auto [keys_type, values_types, tuple_argument] = parseArguments(name, arguments);
|
||||
if (tuple_argument)
|
||||
return std::make_shared<AggregateFunctionMinMap<true>>(keys_type, values_types, arguments, params);
|
||||
else
|
||||
return std::make_shared<AggregateFunctionMinMap<false>>(keys_type, values_types, arguments, params);
|
||||
});
|
||||
|
||||
factory.registerFunction("maxMappedArrays",
|
||||
createAggregateFunctionMap<MaxMapDispatchOnTupleArgument>);
|
||||
factory.registerFunction("maxMappedArrays", [](const std::string & name, const DataTypes & arguments, const Array & params, const Settings *) -> AggregateFunctionPtr
|
||||
{
|
||||
auto [keys_type, values_types, tuple_argument] = parseArguments(name, arguments);
|
||||
if (tuple_argument)
|
||||
return std::make_shared<AggregateFunctionMaxMap<true>>(keys_type, values_types, arguments, params);
|
||||
else
|
||||
return std::make_shared<AggregateFunctionMaxMap<false>>(keys_type, values_types, arguments, params);
|
||||
});
|
||||
|
||||
// these functions could be renamed to *MappedArrays too, but it would
|
||||
// break backward compatibility
|
||||
factory.registerFunction("sumMapWithOverflow", createAggregateFunctionMap<
|
||||
SumMapVariants<false, true>::DispatchOnTupleArgument>);
|
||||
factory.registerFunction("sumMapWithOverflow", [](const std::string & name, const DataTypes & arguments, const Array & params, const Settings *) -> AggregateFunctionPtr
|
||||
{
|
||||
auto [keys_type, values_types, tuple_argument] = parseArguments(name, arguments);
|
||||
if (tuple_argument)
|
||||
return std::make_shared<AggregateFunctionSumMap<true, true>>(keys_type, values_types, arguments, params);
|
||||
else
|
||||
return std::make_shared<AggregateFunctionSumMap<true, false>>(keys_type, values_types, arguments, params);
|
||||
});
|
||||
|
||||
factory.registerFunction("sumMapFiltered", createAggregateFunctionMap<
|
||||
SumMapVariants<true, false>::DispatchOnTupleArgument>);
|
||||
factory.registerFunction("sumMapFiltered", [](const std::string & name, const DataTypes & arguments, const Array & params, const Settings *) -> AggregateFunctionPtr
|
||||
{
|
||||
auto [keys_type, values_types, tuple_argument] = parseArguments(name, arguments);
|
||||
if (tuple_argument)
|
||||
return std::make_shared<AggregateFunctionSumMapFiltered<false, true>>(keys_type, values_types, arguments, params);
|
||||
else
|
||||
return std::make_shared<AggregateFunctionSumMapFiltered<false, false>>(keys_type, values_types, arguments, params);
|
||||
});
|
||||
|
||||
factory.registerFunction("sumMapFilteredWithOverflow",
|
||||
createAggregateFunctionMap<
|
||||
SumMapVariants<true, true>::DispatchOnTupleArgument>);
|
||||
factory.registerFunction("sumMapFilteredWithOverflow", [](const std::string & name, const DataTypes & arguments, const Array & params, const Settings *) -> AggregateFunctionPtr
|
||||
{
|
||||
auto [keys_type, values_types, tuple_argument] = parseArguments(name, arguments);
|
||||
if (tuple_argument)
|
||||
return std::make_shared<AggregateFunctionSumMapFiltered<true, true>>(keys_type, values_types, arguments, params);
|
||||
else
|
||||
return std::make_shared<AggregateFunctionSumMapFiltered<true, false>>(keys_type, values_types, arguments, params);
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,656 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnDecimal.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
|
||||
#include <Common/FieldVisitorSum.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <map>
|
||||
#include <Common/ClickHouseRevision.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int BAD_ARGUMENTS;
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int LOGICAL_ERROR;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct AggregateFunctionMapData
|
||||
{
|
||||
// Map needs to be ordered to maintain function properties
|
||||
std::map<T, Array> merged_maps;
|
||||
};
|
||||
|
||||
/** Aggregate function, that takes at least two arguments: keys and values, and as a result, builds a tuple of at least 2 arrays -
|
||||
* ordered keys and variable number of argument values aggregated by corresponding keys.
|
||||
*
|
||||
* sumMap function is the most useful when using SummingMergeTree to sum Nested columns, which name ends in "Map".
|
||||
*
|
||||
* Example: sumMap(k, v...) of:
|
||||
* k v
|
||||
* [1,2,3] [10,10,10]
|
||||
* [3,4,5] [10,10,10]
|
||||
* [4,5,6] [10,10,10]
|
||||
* [6,7,8] [10,10,10]
|
||||
* [7,5,3] [5,15,25]
|
||||
* [8,9,10] [20,20,20]
|
||||
* will return:
|
||||
* ([1,2,3,4,5,6,7,8,9,10],[10,10,45,20,35,20,15,30,20,20])
|
||||
*
|
||||
* minMap and maxMap share the same idea, but calculate min and max correspondingly.
|
||||
*
|
||||
* NOTE: The implementation of these functions are "amateur grade" - not efficient and low quality.
|
||||
*/
|
||||
|
||||
template <typename T, typename Derived, typename Visitor, bool overflow, bool tuple_argument, bool compact>
|
||||
class AggregateFunctionMapBase : public IAggregateFunctionDataHelper<
|
||||
AggregateFunctionMapData<NearestFieldType<T>>, Derived>
|
||||
{
|
||||
private:
|
||||
static constexpr auto STATE_VERSION_1_MIN_REVISION = 54452;
|
||||
|
||||
DataTypePtr keys_type;
|
||||
SerializationPtr keys_serialization;
|
||||
DataTypes values_types;
|
||||
Serializations values_serializations;
|
||||
Serializations promoted_values_serializations;
|
||||
|
||||
public:
|
||||
using Base = IAggregateFunctionDataHelper<
|
||||
AggregateFunctionMapData<NearestFieldType<T>>, Derived>;
|
||||
|
||||
AggregateFunctionMapBase(const DataTypePtr & keys_type_,
|
||||
const DataTypes & values_types_, const DataTypes & argument_types_)
|
||||
: Base(argument_types_, {} /* parameters */, createResultType(keys_type_, values_types_, getName()))
|
||||
, keys_type(keys_type_)
|
||||
, keys_serialization(keys_type->getDefaultSerialization())
|
||||
, values_types(values_types_)
|
||||
{
|
||||
values_serializations.reserve(values_types.size());
|
||||
promoted_values_serializations.reserve(values_types.size());
|
||||
for (const auto & type : values_types)
|
||||
{
|
||||
values_serializations.emplace_back(type->getDefaultSerialization());
|
||||
if (type->canBePromoted())
|
||||
{
|
||||
if (type->isNullable())
|
||||
promoted_values_serializations.emplace_back(
|
||||
makeNullable(removeNullable(type)->promoteNumericType())->getDefaultSerialization());
|
||||
else
|
||||
promoted_values_serializations.emplace_back(type->promoteNumericType()->getDefaultSerialization());
|
||||
}
|
||||
else
|
||||
{
|
||||
promoted_values_serializations.emplace_back(type->getDefaultSerialization());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool isVersioned() const override { return true; }
|
||||
|
||||
size_t getDefaultVersion() const override { return 1; }
|
||||
|
||||
size_t getVersionFromRevision(size_t revision) const override
|
||||
{
|
||||
if (revision >= STATE_VERSION_1_MIN_REVISION)
|
||||
return 1;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
static DataTypePtr createResultType(
|
||||
const DataTypePtr & keys_type_,
|
||||
const DataTypes & values_types_,
|
||||
const String & name_)
|
||||
{
|
||||
DataTypes types;
|
||||
types.emplace_back(std::make_shared<DataTypeArray>(keys_type_));
|
||||
|
||||
for (const auto & value_type : values_types_)
|
||||
{
|
||||
if constexpr (std::is_same_v<Visitor, FieldVisitorSum>)
|
||||
{
|
||||
if (!value_type->isSummable())
|
||||
throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Values for {} cannot be summed, passed type {}",
|
||||
name_, value_type->getName()};
|
||||
}
|
||||
|
||||
DataTypePtr result_type;
|
||||
|
||||
if constexpr (overflow)
|
||||
{
|
||||
if (value_type->onlyNull())
|
||||
throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Cannot calculate {} of type {}",
|
||||
name_, value_type->getName()};
|
||||
|
||||
// Overflow, meaning that the returned type is the same as
|
||||
// the input type. Nulls are skipped.
|
||||
result_type = removeNullable(value_type);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto value_type_without_nullable = removeNullable(value_type);
|
||||
|
||||
// No overflow, meaning we promote the types if necessary.
|
||||
if (!value_type_without_nullable->canBePromoted())
|
||||
throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Values for {} are expected to be Numeric, Float or Decimal, passed type {}",
|
||||
name_, value_type->getName()};
|
||||
|
||||
WhichDataType value_type_to_check(value_type_without_nullable);
|
||||
|
||||
/// Do not promote decimal because of implementation issues of this function design
|
||||
/// Currently we cannot get result column type in case of decimal we cannot get decimal scale
|
||||
/// in method void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
/// If we decide to make this function more efficient we should promote decimal type during summ
|
||||
if (value_type_to_check.isDecimal())
|
||||
result_type = value_type_without_nullable;
|
||||
else
|
||||
result_type = value_type_without_nullable->promoteNumericType();
|
||||
}
|
||||
|
||||
types.emplace_back(std::make_shared<DataTypeArray>(result_type));
|
||||
}
|
||||
|
||||
return std::make_shared<DataTypeTuple>(types);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
static const auto & getArgumentColumns(const IColumn**& columns)
|
||||
{
|
||||
if constexpr (tuple_argument)
|
||||
{
|
||||
return assert_cast<const ColumnTuple *>(columns[0])->getColumns();
|
||||
}
|
||||
else
|
||||
{
|
||||
return columns;
|
||||
}
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns_, const size_t row_num, Arena *) const override
|
||||
{
|
||||
const auto & columns = getArgumentColumns(columns_);
|
||||
|
||||
// Column 0 contains array of keys of known type
|
||||
const ColumnArray & array_column0 = assert_cast<const ColumnArray &>(*columns[0]);
|
||||
const IColumn::Offsets & offsets0 = array_column0.getOffsets();
|
||||
const IColumn & key_column = array_column0.getData();
|
||||
const size_t keys_vec_offset = offsets0[row_num - 1];
|
||||
const size_t keys_vec_size = (offsets0[row_num] - keys_vec_offset);
|
||||
|
||||
// Columns 1..n contain arrays of numeric values to sum
|
||||
auto & merged_maps = this->data(place).merged_maps;
|
||||
for (size_t col = 0, size = values_types.size(); col < size; ++col)
|
||||
{
|
||||
const auto & array_column = assert_cast<const ColumnArray &>(*columns[col + 1]);
|
||||
const IColumn & value_column = array_column.getData();
|
||||
const IColumn::Offsets & offsets = array_column.getOffsets();
|
||||
const size_t values_vec_offset = offsets[row_num - 1];
|
||||
const size_t values_vec_size = (offsets[row_num] - values_vec_offset);
|
||||
|
||||
// Expect key and value arrays to be of same length
|
||||
if (keys_vec_size != values_vec_size)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Sizes of keys and values arrays do not match");
|
||||
|
||||
// Insert column values for all keys
|
||||
for (size_t i = 0; i < keys_vec_size; ++i)
|
||||
{
|
||||
auto value = value_column[values_vec_offset + i];
|
||||
T key = static_cast<T>(key_column[keys_vec_offset + i].get<T>());
|
||||
|
||||
if (!keepKey(key))
|
||||
continue;
|
||||
|
||||
decltype(merged_maps.begin()) it;
|
||||
if constexpr (is_decimal<T>)
|
||||
{
|
||||
// FIXME why is storing NearestFieldType not enough, and we
|
||||
// have to check for decimals again here?
|
||||
UInt32 scale = static_cast<const ColumnDecimal<T> &>(key_column).getScale();
|
||||
it = merged_maps.find(DecimalField<T>(key, scale));
|
||||
}
|
||||
else
|
||||
it = merged_maps.find(key);
|
||||
|
||||
if (it != merged_maps.end())
|
||||
{
|
||||
if (!value.isNull())
|
||||
{
|
||||
if (it->second[col].isNull())
|
||||
it->second[col] = value;
|
||||
else
|
||||
applyVisitor(Visitor(value), it->second[col]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Create a value array for this key
|
||||
Array new_values;
|
||||
new_values.resize(size);
|
||||
new_values[col] = value;
|
||||
|
||||
if constexpr (is_decimal<T>)
|
||||
{
|
||||
UInt32 scale = static_cast<const ColumnDecimal<T> &>(key_column).getScale();
|
||||
merged_maps.emplace(DecimalField<T>(key, scale), std::move(new_values));
|
||||
}
|
||||
else
|
||||
{
|
||||
merged_maps.emplace(key, std::move(new_values));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
auto & merged_maps = this->data(place).merged_maps;
|
||||
const auto & rhs_maps = this->data(rhs).merged_maps;
|
||||
|
||||
for (const auto & elem : rhs_maps)
|
||||
{
|
||||
const auto & it = merged_maps.find(elem.first);
|
||||
if (it != merged_maps.end())
|
||||
{
|
||||
for (size_t col = 0; col < values_types.size(); ++col)
|
||||
if (!elem.second[col].isNull())
|
||||
applyVisitor(Visitor(elem.second[col]), it->second[col]);
|
||||
}
|
||||
else
|
||||
merged_maps[elem.first] = elem.second;
|
||||
}
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> version) const override
|
||||
{
|
||||
if (!version)
|
||||
version = getDefaultVersion();
|
||||
|
||||
const auto & merged_maps = this->data(place).merged_maps;
|
||||
size_t size = merged_maps.size();
|
||||
writeVarUInt(size, buf);
|
||||
|
||||
std::function<void(size_t, const Array &)> serialize;
|
||||
switch (*version)
|
||||
{
|
||||
case 0:
|
||||
{
|
||||
serialize = [&](size_t col_idx, const Array & values){ values_serializations[col_idx]->serializeBinary(values[col_idx], buf, {}); };
|
||||
break;
|
||||
}
|
||||
case 1:
|
||||
{
|
||||
serialize = [&](size_t col_idx, const Array & values){ promoted_values_serializations[col_idx]->serializeBinary(values[col_idx], buf, {}); };
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & elem : merged_maps)
|
||||
{
|
||||
keys_serialization->serializeBinary(elem.first, buf, {});
|
||||
for (size_t col = 0; col < values_types.size(); ++col)
|
||||
serialize(col, elem.second);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> version, Arena *) const override
|
||||
{
|
||||
if (!version)
|
||||
version = getDefaultVersion();
|
||||
|
||||
auto & merged_maps = this->data(place).merged_maps;
|
||||
size_t size = 0;
|
||||
readVarUInt(size, buf);
|
||||
|
||||
std::function<void(size_t, Array &)> deserialize;
|
||||
switch (*version)
|
||||
{
|
||||
case 0:
|
||||
{
|
||||
deserialize = [&](size_t col_idx, Array & values){ values_serializations[col_idx]->deserializeBinary(values[col_idx], buf, {}); };
|
||||
break;
|
||||
}
|
||||
case 1:
|
||||
{
|
||||
deserialize = [&](size_t col_idx, Array & values){ promoted_values_serializations[col_idx]->deserializeBinary(values[col_idx], buf, {}); };
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
Field key;
|
||||
keys_serialization->deserializeBinary(key, buf, {});
|
||||
|
||||
Array values;
|
||||
values.resize(values_types.size());
|
||||
|
||||
for (size_t col = 0; col < values_types.size(); ++col)
|
||||
deserialize(col, values);
|
||||
|
||||
if constexpr (is_decimal<T>)
|
||||
merged_maps[key.get<DecimalField<T>>()] = values;
|
||||
else
|
||||
merged_maps[key.get<T>()] = values;
|
||||
}
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
size_t num_columns = values_types.size();
|
||||
|
||||
// Final step does compaction of keys that have zero values, this mutates the state
|
||||
auto & merged_maps = this->data(place).merged_maps;
|
||||
|
||||
// Remove keys which are zeros or empty. This should be enabled only for sumMap.
|
||||
if constexpr (compact)
|
||||
{
|
||||
for (auto it = merged_maps.cbegin(); it != merged_maps.cend();)
|
||||
{
|
||||
// Key is not compacted if it has at least one non-zero value
|
||||
bool erase = true;
|
||||
for (size_t col = 0; col < num_columns; ++col)
|
||||
{
|
||||
if (!it->second[col].isNull() && it->second[col] != values_types[col]->getDefault())
|
||||
{
|
||||
erase = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (erase)
|
||||
it = merged_maps.erase(it);
|
||||
else
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
size_t size = merged_maps.size();
|
||||
|
||||
auto & to_tuple = assert_cast<ColumnTuple &>(to);
|
||||
auto & to_keys_arr = assert_cast<ColumnArray &>(to_tuple.getColumn(0));
|
||||
auto & to_keys_col = to_keys_arr.getData();
|
||||
|
||||
// Advance column offsets
|
||||
auto & to_keys_offsets = to_keys_arr.getOffsets();
|
||||
to_keys_offsets.push_back(to_keys_offsets.back() + size);
|
||||
to_keys_col.reserve(size);
|
||||
|
||||
for (size_t col = 0; col < num_columns; ++col)
|
||||
{
|
||||
auto & to_values_arr = assert_cast<ColumnArray &>(to_tuple.getColumn(col + 1));
|
||||
auto & to_values_offsets = to_values_arr.getOffsets();
|
||||
to_values_offsets.push_back(to_values_offsets.back() + size);
|
||||
to_values_arr.getData().reserve(size);
|
||||
}
|
||||
|
||||
// Write arrays of keys and values
|
||||
for (const auto & elem : merged_maps)
|
||||
{
|
||||
// Write array of keys into column
|
||||
to_keys_col.insert(elem.first);
|
||||
|
||||
// Write 0..n arrays of values
|
||||
for (size_t col = 0; col < num_columns; ++col)
|
||||
{
|
||||
auto & to_values_col = assert_cast<ColumnArray &>(to_tuple.getColumn(col + 1)).getData();
|
||||
if (elem.second[col].isNull())
|
||||
to_values_col.insertDefault();
|
||||
else
|
||||
to_values_col.insert(elem.second[col]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool keepKey(const T & key) const { return static_cast<const Derived &>(*this).keepKey(key); }
|
||||
String getName() const override { return Derived::getNameImpl(); }
|
||||
};
|
||||
|
||||
template <typename T, bool overflow, bool tuple_argument>
|
||||
class AggregateFunctionSumMap final :
|
||||
public AggregateFunctionMapBase<T, AggregateFunctionSumMap<T, overflow, tuple_argument>, FieldVisitorSum, overflow, tuple_argument, true>
|
||||
{
|
||||
private:
|
||||
using Self = AggregateFunctionSumMap<T, overflow, tuple_argument>;
|
||||
using Base = AggregateFunctionMapBase<T, Self, FieldVisitorSum, overflow, tuple_argument, true>;
|
||||
|
||||
public:
|
||||
AggregateFunctionSumMap(const DataTypePtr & keys_type_,
|
||||
DataTypes & values_types_, const DataTypes & argument_types_,
|
||||
const Array & params_)
|
||||
: Base{keys_type_, values_types_, argument_types_}
|
||||
{
|
||||
// The constructor accepts parameters to have a uniform interface with
|
||||
// sumMapFiltered, but this function doesn't have any parameters.
|
||||
assertNoParameters(getNameImpl(), params_);
|
||||
}
|
||||
|
||||
static String getNameImpl()
|
||||
{
|
||||
if constexpr (overflow)
|
||||
{
|
||||
return "sumMapWithOverflow";
|
||||
}
|
||||
else
|
||||
{
|
||||
return "sumMap";
|
||||
}
|
||||
}
|
||||
|
||||
bool keepKey(const T &) const { return true; }
|
||||
};
|
||||
|
||||
|
||||
template <typename T, bool overflow, bool tuple_argument>
|
||||
class AggregateFunctionSumMapFiltered final :
|
||||
public AggregateFunctionMapBase<T,
|
||||
AggregateFunctionSumMapFiltered<T, overflow, tuple_argument>,
|
||||
FieldVisitorSum,
|
||||
overflow,
|
||||
tuple_argument,
|
||||
true>
|
||||
{
|
||||
private:
|
||||
using Self = AggregateFunctionSumMapFiltered<T, overflow, tuple_argument>;
|
||||
using Base = AggregateFunctionMapBase<T, Self, FieldVisitorSum, overflow, tuple_argument, true>;
|
||||
|
||||
using ContainerT = std::unordered_set<T>;
|
||||
|
||||
ContainerT keys_to_keep;
|
||||
|
||||
public:
|
||||
AggregateFunctionSumMapFiltered(const DataTypePtr & keys_type_,
|
||||
const DataTypes & values_types_, const DataTypes & argument_types_,
|
||||
const Array & params_)
|
||||
: Base{keys_type_, values_types_, argument_types_}
|
||||
{
|
||||
if (params_.size() != 1)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
|
||||
"Aggregate function '{}' requires exactly one parameter "
|
||||
"of Array type", getNameImpl());
|
||||
|
||||
Array keys_to_keep_values;
|
||||
if (!params_.front().tryGet<Array>(keys_to_keep_values))
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Aggregate function {} requires an Array as a parameter",
|
||||
getNameImpl());
|
||||
|
||||
this->parameters = params_;
|
||||
|
||||
keys_to_keep.reserve(keys_to_keep_values.size());
|
||||
|
||||
for (const Field & f : keys_to_keep_values)
|
||||
keys_to_keep.emplace(f.safeGet<T>());
|
||||
}
|
||||
|
||||
static String getNameImpl()
|
||||
{
|
||||
if constexpr (overflow)
|
||||
{
|
||||
return "sumMapFilteredWithOverflow";
|
||||
}
|
||||
else
|
||||
{
|
||||
return "sumMapFiltered";
|
||||
}
|
||||
}
|
||||
|
||||
bool keepKey(const T & key) const { return keys_to_keep.count(key); }
|
||||
};
|
||||
|
||||
|
||||
/** Implements `Max` operation.
|
||||
* Returns true if changed
|
||||
*/
|
||||
class FieldVisitorMax : public StaticVisitor<bool>
|
||||
{
|
||||
private:
|
||||
const Field & rhs;
|
||||
|
||||
template <typename FieldType>
|
||||
bool compareImpl(FieldType & x) const
|
||||
{
|
||||
auto val = rhs.get<FieldType>();
|
||||
if (val > x)
|
||||
{
|
||||
x = val;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
public:
|
||||
explicit FieldVisitorMax(const Field & rhs_) : rhs(rhs_) {}
|
||||
|
||||
bool operator() (Null &) const
|
||||
{
|
||||
/// Do not update current value, skip nulls
|
||||
return false;
|
||||
}
|
||||
|
||||
bool operator() (AggregateFunctionStateData &) const { throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot compare AggregateFunctionStates"); }
|
||||
|
||||
bool operator() (Array & x) const { return compareImpl<Array>(x); }
|
||||
bool operator() (Tuple & x) const { return compareImpl<Tuple>(x); }
|
||||
template <typename T>
|
||||
bool operator() (DecimalField<T> & x) const { return compareImpl<DecimalField<T>>(x); }
|
||||
template <typename T>
|
||||
bool operator() (T & x) const { return compareImpl<T>(x); }
|
||||
};
|
||||
|
||||
/** Implements `Min` operation.
|
||||
* Returns true if changed
|
||||
*/
|
||||
class FieldVisitorMin : public StaticVisitor<bool>
|
||||
{
|
||||
private:
|
||||
const Field & rhs;
|
||||
|
||||
template <typename FieldType>
|
||||
bool compareImpl(FieldType & x) const
|
||||
{
|
||||
auto val = rhs.get<FieldType>();
|
||||
if (val < x)
|
||||
{
|
||||
x = val;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
public:
|
||||
explicit FieldVisitorMin(const Field & rhs_) : rhs(rhs_) {}
|
||||
|
||||
|
||||
bool operator() (Null &) const
|
||||
{
|
||||
/// Do not update current value, skip nulls
|
||||
return false;
|
||||
}
|
||||
|
||||
bool operator() (AggregateFunctionStateData &) const { throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot sum AggregateFunctionStates"); }
|
||||
|
||||
bool operator() (Array & x) const { return compareImpl<Array>(x); }
|
||||
bool operator() (Tuple & x) const { return compareImpl<Tuple>(x); }
|
||||
template <typename T>
|
||||
bool operator() (DecimalField<T> & x) const { return compareImpl<DecimalField<T>>(x); }
|
||||
template <typename T>
|
||||
bool operator() (T & x) const { return compareImpl<T>(x); }
|
||||
};
|
||||
|
||||
|
||||
template <typename T, bool tuple_argument>
|
||||
class AggregateFunctionMinMap final :
|
||||
public AggregateFunctionMapBase<T, AggregateFunctionMinMap<T, tuple_argument>, FieldVisitorMin, true, tuple_argument, false>
|
||||
{
|
||||
private:
|
||||
using Self = AggregateFunctionMinMap<T, tuple_argument>;
|
||||
using Base = AggregateFunctionMapBase<T, Self, FieldVisitorMin, true, tuple_argument, false>;
|
||||
|
||||
public:
|
||||
AggregateFunctionMinMap(const DataTypePtr & keys_type_,
|
||||
DataTypes & values_types_, const DataTypes & argument_types_,
|
||||
const Array & params_)
|
||||
: Base{keys_type_, values_types_, argument_types_}
|
||||
{
|
||||
// The constructor accepts parameters to have a uniform interface with
|
||||
// sumMapFiltered, but this function doesn't have any parameters.
|
||||
assertNoParameters(getNameImpl(), params_);
|
||||
}
|
||||
|
||||
static String getNameImpl() { return "minMap"; }
|
||||
|
||||
bool keepKey(const T &) const { return true; }
|
||||
};
|
||||
|
||||
template <typename T, bool tuple_argument>
|
||||
class AggregateFunctionMaxMap final :
|
||||
public AggregateFunctionMapBase<T, AggregateFunctionMaxMap<T, tuple_argument>, FieldVisitorMax, true, tuple_argument, false>
|
||||
{
|
||||
private:
|
||||
using Self = AggregateFunctionMaxMap<T, tuple_argument>;
|
||||
using Base = AggregateFunctionMapBase<T, Self, FieldVisitorMax, true, tuple_argument, false>;
|
||||
|
||||
public:
|
||||
AggregateFunctionMaxMap(const DataTypePtr & keys_type_,
|
||||
DataTypes & values_types_, const DataTypes & argument_types_,
|
||||
const Array & params_)
|
||||
: Base{keys_type_, values_types_, argument_types_}
|
||||
{
|
||||
// The constructor accepts parameters to have a uniform interface with
|
||||
// sumMapFiltered, but this function doesn't have any parameters.
|
||||
assertNoParameters(getNameImpl(), params_);
|
||||
}
|
||||
|
||||
static String getNameImpl() { return "maxMap"; }
|
||||
|
||||
bool keepKey(const T &) const { return true; }
|
||||
};
|
||||
|
||||
}
|
@ -1,5 +1,4 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionTopK.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <Common/FieldVisitorConvertToNumber.h>
|
||||
@ -7,6 +6,20 @@
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
#include <DataTypes/DataTypeIPv4andIPv6.h>
|
||||
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/ReadHelpersArena.h>
|
||||
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
|
||||
#include <Columns/ColumnArray.h>
|
||||
|
||||
#include <Common/SpaceSaving.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -25,6 +38,229 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
inline constexpr UInt64 TOP_K_MAX_SIZE = 0xFFFFFF;
|
||||
|
||||
template <typename T>
|
||||
struct AggregateFunctionTopKData
|
||||
{
|
||||
using Set = SpaceSaving<T, HashCRC32<T>>;
|
||||
|
||||
Set value;
|
||||
};
|
||||
|
||||
|
||||
template <typename T, bool is_weighted>
|
||||
class AggregateFunctionTopK
|
||||
: public IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>
|
||||
{
|
||||
protected:
|
||||
using State = AggregateFunctionTopKData<T>;
|
||||
UInt64 threshold;
|
||||
UInt64 reserved;
|
||||
|
||||
public:
|
||||
AggregateFunctionTopK(UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>(argument_types_, params, createResultType(argument_types_))
|
||||
, threshold(threshold_), reserved(load_factor * threshold)
|
||||
{}
|
||||
|
||||
AggregateFunctionTopK(UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params, const DataTypePtr & result_type_)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>(argument_types_, params, result_type_)
|
||||
, threshold(threshold_), reserved(load_factor * threshold)
|
||||
{}
|
||||
|
||||
String getName() const override { return is_weighted ? "topKWeighted" : "topK"; }
|
||||
|
||||
static DataTypePtr createResultType(const DataTypes & argument_types_)
|
||||
{
|
||||
return std::make_shared<DataTypeArray>(argument_types_[0]);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
auto & set = this->data(place).value;
|
||||
if (set.capacity() != reserved)
|
||||
set.resize(reserved);
|
||||
|
||||
if constexpr (is_weighted)
|
||||
set.insert(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num], columns[1]->getUInt(row_num));
|
||||
else
|
||||
set.insert(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
auto & set = this->data(place).value;
|
||||
if (set.capacity() != reserved)
|
||||
set.resize(reserved);
|
||||
set.merge(this->data(rhs).value);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).value.write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
auto & set = this->data(place).value;
|
||||
set.resize(reserved);
|
||||
set.read(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
ColumnArray & arr_to = assert_cast<ColumnArray &>(to);
|
||||
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
|
||||
|
||||
const typename State::Set & set = this->data(place).value;
|
||||
auto result_vec = set.topK(threshold);
|
||||
size_t size = result_vec.size();
|
||||
|
||||
offsets_to.push_back(offsets_to.back() + size);
|
||||
|
||||
typename ColumnVector<T>::Container & data_to = assert_cast<ColumnVector<T> &>(arr_to.getData()).getData();
|
||||
size_t old_size = data_to.size();
|
||||
data_to.resize(old_size + size);
|
||||
|
||||
size_t i = 0;
|
||||
for (auto it = result_vec.begin(); it != result_vec.end(); ++it, ++i)
|
||||
data_to[old_size + i] = it->key;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Generic implementation, it uses serialized representation as object descriptor.
|
||||
struct AggregateFunctionTopKGenericData
|
||||
{
|
||||
using Set = SpaceSaving<StringRef, StringRefHash>;
|
||||
|
||||
Set value;
|
||||
};
|
||||
|
||||
/** Template parameter with true value should be used for columns that store their elements in memory continuously.
|
||||
* For such columns topK() can be implemented more efficiently (especially for small numeric arrays).
|
||||
*/
|
||||
template <bool is_plain_column, bool is_weighted>
|
||||
class AggregateFunctionTopKGeneric
|
||||
: public IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData, AggregateFunctionTopKGeneric<is_plain_column, is_weighted>>
|
||||
{
|
||||
private:
|
||||
using State = AggregateFunctionTopKGenericData;
|
||||
|
||||
UInt64 threshold;
|
||||
UInt64 reserved;
|
||||
|
||||
static void deserializeAndInsert(StringRef str, IColumn & data_to);
|
||||
|
||||
public:
|
||||
AggregateFunctionTopKGeneric(
|
||||
UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData, AggregateFunctionTopKGeneric<is_plain_column, is_weighted>>(argument_types_, params, createResultType(argument_types_))
|
||||
, threshold(threshold_), reserved(load_factor * threshold) {}
|
||||
|
||||
String getName() const override { return is_weighted ? "topKWeighted" : "topK"; }
|
||||
|
||||
static DataTypePtr createResultType(const DataTypes & argument_types_)
|
||||
{
|
||||
return std::make_shared<DataTypeArray>(argument_types_[0]);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).value.write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
auto & set = this->data(place).value;
|
||||
set.clear();
|
||||
|
||||
// Specialized here because there's no deserialiser for StringRef
|
||||
size_t size = 0;
|
||||
readVarUInt(size, buf);
|
||||
if (unlikely(size > TOP_K_MAX_SIZE))
|
||||
throw Exception(
|
||||
ErrorCodes::ARGUMENT_OUT_OF_BOUND,
|
||||
"Too large size ({}) for aggregate function '{}' state (maximum is {})",
|
||||
size,
|
||||
getName(),
|
||||
TOP_K_MAX_SIZE);
|
||||
set.resize(size);
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
auto ref = readStringBinaryInto(*arena, buf);
|
||||
UInt64 count;
|
||||
UInt64 error;
|
||||
readVarUInt(count, buf);
|
||||
readVarUInt(error, buf);
|
||||
set.insert(ref, count, error);
|
||||
arena->rollback(ref.size);
|
||||
}
|
||||
|
||||
set.readAlphaMap(buf);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
auto & set = this->data(place).value;
|
||||
if (set.capacity() != reserved)
|
||||
set.resize(reserved);
|
||||
|
||||
if constexpr (is_plain_column)
|
||||
{
|
||||
if constexpr (is_weighted)
|
||||
set.insert(columns[0]->getDataAt(row_num), columns[1]->getUInt(row_num));
|
||||
else
|
||||
set.insert(columns[0]->getDataAt(row_num));
|
||||
}
|
||||
else
|
||||
{
|
||||
const char * begin = nullptr;
|
||||
StringRef str_serialized = columns[0]->serializeValueIntoArena(row_num, *arena, begin);
|
||||
if constexpr (is_weighted)
|
||||
set.insert(str_serialized, columns[1]->getUInt(row_num));
|
||||
else
|
||||
set.insert(str_serialized);
|
||||
arena->rollback(str_serialized.size);
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
auto & set = this->data(place).value;
|
||||
if (set.capacity() != reserved)
|
||||
set.resize(reserved);
|
||||
set.merge(this->data(rhs).value);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
ColumnArray & arr_to = assert_cast<ColumnArray &>(to);
|
||||
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
|
||||
IColumn & data_to = arr_to.getData();
|
||||
|
||||
auto result_vec = this->data(place).value.topK(threshold);
|
||||
offsets_to.push_back(offsets_to.back() + result_vec.size());
|
||||
|
||||
for (auto & elem : result_vec)
|
||||
{
|
||||
if constexpr (is_plain_column)
|
||||
data_to.insertData(elem.key.data, elem.key.size);
|
||||
else
|
||||
data_to.deserializeAndInsertFromArena(elem.key.data);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Substitute return type for Date and DateTime
|
||||
template <bool is_weighted>
|
||||
class AggregateFunctionTopKDate : public AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>
|
||||
|
@ -1,250 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/ReadHelpersArena.h>
|
||||
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
|
||||
#include <Columns/ColumnArray.h>
|
||||
|
||||
#include <Common/SpaceSaving.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
static inline constexpr UInt64 TOP_K_MAX_SIZE = 0xFFFFFF;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ARGUMENT_OUT_OF_BOUND;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct AggregateFunctionTopKData
|
||||
{
|
||||
using Set = SpaceSaving<T, HashCRC32<T>>;
|
||||
|
||||
Set value;
|
||||
};
|
||||
|
||||
|
||||
template <typename T, bool is_weighted>
|
||||
class AggregateFunctionTopK
|
||||
: public IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>
|
||||
{
|
||||
protected:
|
||||
using State = AggregateFunctionTopKData<T>;
|
||||
UInt64 threshold;
|
||||
UInt64 reserved;
|
||||
|
||||
public:
|
||||
AggregateFunctionTopK(UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>(argument_types_, params, createResultType(argument_types_))
|
||||
, threshold(threshold_), reserved(load_factor * threshold)
|
||||
{}
|
||||
|
||||
AggregateFunctionTopK(UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params, const DataTypePtr & result_type_)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>(argument_types_, params, result_type_)
|
||||
, threshold(threshold_), reserved(load_factor * threshold)
|
||||
{}
|
||||
|
||||
String getName() const override { return is_weighted ? "topKWeighted" : "topK"; }
|
||||
|
||||
static DataTypePtr createResultType(const DataTypes & argument_types_)
|
||||
{
|
||||
return std::make_shared<DataTypeArray>(argument_types_[0]);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
auto & set = this->data(place).value;
|
||||
if (set.capacity() != reserved)
|
||||
set.resize(reserved);
|
||||
|
||||
if constexpr (is_weighted)
|
||||
set.insert(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num], columns[1]->getUInt(row_num));
|
||||
else
|
||||
set.insert(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
auto & set = this->data(place).value;
|
||||
if (set.capacity() != reserved)
|
||||
set.resize(reserved);
|
||||
set.merge(this->data(rhs).value);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).value.write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
auto & set = this->data(place).value;
|
||||
set.resize(reserved);
|
||||
set.read(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
ColumnArray & arr_to = assert_cast<ColumnArray &>(to);
|
||||
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
|
||||
|
||||
const typename State::Set & set = this->data(place).value;
|
||||
auto result_vec = set.topK(threshold);
|
||||
size_t size = result_vec.size();
|
||||
|
||||
offsets_to.push_back(offsets_to.back() + size);
|
||||
|
||||
typename ColumnVector<T>::Container & data_to = assert_cast<ColumnVector<T> &>(arr_to.getData()).getData();
|
||||
size_t old_size = data_to.size();
|
||||
data_to.resize(old_size + size);
|
||||
|
||||
size_t i = 0;
|
||||
for (auto it = result_vec.begin(); it != result_vec.end(); ++it, ++i)
|
||||
data_to[old_size + i] = it->key;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Generic implementation, it uses serialized representation as object descriptor.
|
||||
struct AggregateFunctionTopKGenericData
|
||||
{
|
||||
using Set = SpaceSaving<StringRef, StringRefHash>;
|
||||
|
||||
Set value;
|
||||
};
|
||||
|
||||
/** Template parameter with true value should be used for columns that store their elements in memory continuously.
|
||||
* For such columns topK() can be implemented more efficiently (especially for small numeric arrays).
|
||||
*/
|
||||
template <bool is_plain_column, bool is_weighted>
|
||||
class AggregateFunctionTopKGeneric
|
||||
: public IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData, AggregateFunctionTopKGeneric<is_plain_column, is_weighted>>
|
||||
{
|
||||
private:
|
||||
using State = AggregateFunctionTopKGenericData;
|
||||
|
||||
UInt64 threshold;
|
||||
UInt64 reserved;
|
||||
|
||||
static void deserializeAndInsert(StringRef str, IColumn & data_to);
|
||||
|
||||
public:
|
||||
AggregateFunctionTopKGeneric(
|
||||
UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData, AggregateFunctionTopKGeneric<is_plain_column, is_weighted>>(argument_types_, params, createResultType(argument_types_))
|
||||
, threshold(threshold_), reserved(load_factor * threshold) {}
|
||||
|
||||
String getName() const override { return is_weighted ? "topKWeighted" : "topK"; }
|
||||
|
||||
static DataTypePtr createResultType(const DataTypes & argument_types_)
|
||||
{
|
||||
return std::make_shared<DataTypeArray>(argument_types_[0]);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).value.write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
auto & set = this->data(place).value;
|
||||
set.clear();
|
||||
|
||||
// Specialized here because there's no deserialiser for StringRef
|
||||
size_t size = 0;
|
||||
readVarUInt(size, buf);
|
||||
if (unlikely(size > TOP_K_MAX_SIZE))
|
||||
throw Exception(
|
||||
ErrorCodes::ARGUMENT_OUT_OF_BOUND,
|
||||
"Too large size ({}) for aggregate function '{}' state (maximum is {})",
|
||||
size,
|
||||
getName(),
|
||||
TOP_K_MAX_SIZE);
|
||||
set.resize(size);
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
auto ref = readStringBinaryInto(*arena, buf);
|
||||
UInt64 count;
|
||||
UInt64 error;
|
||||
readVarUInt(count, buf);
|
||||
readVarUInt(error, buf);
|
||||
set.insert(ref, count, error);
|
||||
arena->rollback(ref.size);
|
||||
}
|
||||
|
||||
set.readAlphaMap(buf);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
auto & set = this->data(place).value;
|
||||
if (set.capacity() != reserved)
|
||||
set.resize(reserved);
|
||||
|
||||
if constexpr (is_plain_column)
|
||||
{
|
||||
if constexpr (is_weighted)
|
||||
set.insert(columns[0]->getDataAt(row_num), columns[1]->getUInt(row_num));
|
||||
else
|
||||
set.insert(columns[0]->getDataAt(row_num));
|
||||
}
|
||||
else
|
||||
{
|
||||
const char * begin = nullptr;
|
||||
StringRef str_serialized = columns[0]->serializeValueIntoArena(row_num, *arena, begin);
|
||||
if constexpr (is_weighted)
|
||||
set.insert(str_serialized, columns[1]->getUInt(row_num));
|
||||
else
|
||||
set.insert(str_serialized);
|
||||
arena->rollback(str_serialized.size);
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
auto & set = this->data(place).value;
|
||||
if (set.capacity() != reserved)
|
||||
set.resize(reserved);
|
||||
set.merge(this->data(rhs).value);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
ColumnArray & arr_to = assert_cast<ColumnArray &>(to);
|
||||
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
|
||||
IColumn & data_to = arr_to.getData();
|
||||
|
||||
auto result_vec = this->data(place).value.topK(threshold);
|
||||
offsets_to.push_back(offsets_to.back() + result_vec.size());
|
||||
|
||||
for (auto & elem : result_vec)
|
||||
{
|
||||
if constexpr (is_plain_column)
|
||||
data_to.insertData(elem.key.data, elem.key.size);
|
||||
else
|
||||
data_to.deserializeAndInsertFromArena(elem.key.data);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -1,23 +1,8 @@
|
||||
#include <AggregateFunctions/AggregateFunctionUniqCombined.h>
|
||||
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
|
||||
#include <Common/FieldVisitorConvertToNumber.h>
|
||||
|
||||
#include <DataTypes/DataTypeDate.h>
|
||||
#include <DataTypes/DataTypeDate32.h>
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
#include <DataTypes/DataTypeIPv4andIPv6.h>
|
||||
|
||||
#include <functional>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
@ -26,118 +11,55 @@ namespace ErrorCodes
|
||||
|
||||
namespace
|
||||
{
|
||||
template <UInt8 K, typename HashValueType>
|
||||
struct WithK
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionUniqCombined(bool use_64_bit_hash,
|
||||
const std::string & name, const DataTypes & argument_types, const Array & params)
|
||||
{
|
||||
/// log2 of the number of cells in HyperLogLog.
|
||||
/// Reasonable default value, selected to be comparable in quality with "uniq" aggregate function.
|
||||
UInt8 precision = 17;
|
||||
|
||||
if (!params.empty())
|
||||
{
|
||||
template <typename T>
|
||||
using AggregateFunction = AggregateFunctionUniqCombined<T, K, HashValueType>;
|
||||
if (params.size() != 1)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires one parameter or less.",
|
||||
name);
|
||||
|
||||
template <bool is_exact, bool argument_is_tuple>
|
||||
using AggregateFunctionVariadic = AggregateFunctionUniqCombinedVariadic<is_exact, argument_is_tuple, K, HashValueType>;
|
||||
};
|
||||
|
||||
template <UInt8 K, typename HashValueType>
|
||||
AggregateFunctionPtr createAggregateFunctionWithK(const DataTypes & argument_types, const Array & params)
|
||||
{
|
||||
/// We use exact hash function if the arguments are not contiguous in memory, because only exact hash function has support for this case.
|
||||
bool use_exact_hash_function = !isAllArgumentsContiguousInMemory(argument_types);
|
||||
|
||||
if (argument_types.size() == 1)
|
||||
{
|
||||
const IDataType & argument_type = *argument_types[0];
|
||||
|
||||
AggregateFunctionPtr res(createWithNumericType<WithK<K, HashValueType>::template AggregateFunction>(*argument_types[0], argument_types, params));
|
||||
|
||||
WhichDataType which(argument_type);
|
||||
if (res)
|
||||
return res;
|
||||
else if (which.isDate())
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<DataTypeDate::FieldType>>(argument_types, params);
|
||||
else if (which.isDate32())
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<DataTypeDate32::FieldType>>(argument_types, params);
|
||||
else if (which.isDateTime())
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<DataTypeDateTime::FieldType>>(argument_types, params);
|
||||
else if (which.isStringOrFixedString())
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<String>>(argument_types, params);
|
||||
else if (which.isUUID())
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<DataTypeUUID::FieldType>>(argument_types, params);
|
||||
else if (which.isIPv4())
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<DataTypeIPv4::FieldType>>(argument_types, params);
|
||||
else if (which.isIPv6())
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<DataTypeIPv6::FieldType>>(argument_types, params);
|
||||
else if (which.isTuple())
|
||||
{
|
||||
if (use_exact_hash_function)
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunctionVariadic<true, true>>(argument_types, params);
|
||||
else
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunctionVariadic<false, true>>(argument_types, params);
|
||||
}
|
||||
}
|
||||
|
||||
/// "Variadic" method also works as a fallback generic case for a single argument.
|
||||
if (use_exact_hash_function)
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunctionVariadic<true, false>>(argument_types, params);
|
||||
else
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunctionVariadic<false, false>>(argument_types, params);
|
||||
UInt64 precision_param = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[0]);
|
||||
// This range is hardcoded below
|
||||
if (precision_param > 20 || precision_param < 12)
|
||||
throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Parameter for aggregate function {} is out of range: [12, 20].",
|
||||
name);
|
||||
precision = precision_param;
|
||||
}
|
||||
|
||||
template <UInt8 K>
|
||||
AggregateFunctionPtr createAggregateFunctionWithHashType(bool use_64_bit_hash, const DataTypes & argument_types, const Array & params)
|
||||
if (argument_types.empty())
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Incorrect number of arguments for aggregate function {}", name);
|
||||
|
||||
switch (precision) // NOLINT(bugprone-switch-missing-default-case)
|
||||
{
|
||||
if (use_64_bit_hash)
|
||||
return createAggregateFunctionWithK<K, UInt64>(argument_types, params);
|
||||
else
|
||||
return createAggregateFunctionWithK<K, UInt32>(argument_types, params);
|
||||
case 12:
|
||||
return createAggregateFunctionWithHashType<12>(use_64_bit_hash, argument_types, params);
|
||||
case 13:
|
||||
return createAggregateFunctionWithHashType<13>(use_64_bit_hash, argument_types, params);
|
||||
case 14:
|
||||
return createAggregateFunctionWithHashType<14>(use_64_bit_hash, argument_types, params);
|
||||
case 15:
|
||||
return createAggregateFunctionWithHashType<15>(use_64_bit_hash, argument_types, params);
|
||||
case 16:
|
||||
return createAggregateFunctionWithHashType<16>(use_64_bit_hash, argument_types, params);
|
||||
case 17:
|
||||
return createAggregateFunctionWithHashType<17>(use_64_bit_hash, argument_types, params);
|
||||
case 18:
|
||||
return createAggregateFunctionWithHashType<18>(use_64_bit_hash, argument_types, params);
|
||||
case 19:
|
||||
return createAggregateFunctionWithHashType<19>(use_64_bit_hash, argument_types, params);
|
||||
case 20:
|
||||
return createAggregateFunctionWithHashType<20>(use_64_bit_hash, argument_types, params);
|
||||
}
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionUniqCombined(bool use_64_bit_hash,
|
||||
const std::string & name, const DataTypes & argument_types, const Array & params)
|
||||
{
|
||||
/// log2 of the number of cells in HyperLogLog.
|
||||
/// Reasonable default value, selected to be comparable in quality with "uniq" aggregate function.
|
||||
UInt8 precision = 17;
|
||||
|
||||
if (!params.empty())
|
||||
{
|
||||
if (params.size() != 1)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires one parameter or less.",
|
||||
name);
|
||||
|
||||
UInt64 precision_param = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[0]);
|
||||
// This range is hardcoded below
|
||||
if (precision_param > 20 || precision_param < 12)
|
||||
throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Parameter for aggregate function {} is out of range: [12, 20].",
|
||||
name);
|
||||
precision = precision_param;
|
||||
}
|
||||
|
||||
if (argument_types.empty())
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Incorrect number of arguments for aggregate function {}", name);
|
||||
|
||||
switch (precision) // NOLINT(bugprone-switch-missing-default-case)
|
||||
{
|
||||
case 12:
|
||||
return createAggregateFunctionWithHashType<12>(use_64_bit_hash, argument_types, params);
|
||||
case 13:
|
||||
return createAggregateFunctionWithHashType<13>(use_64_bit_hash, argument_types, params);
|
||||
case 14:
|
||||
return createAggregateFunctionWithHashType<14>(use_64_bit_hash, argument_types, params);
|
||||
case 15:
|
||||
return createAggregateFunctionWithHashType<15>(use_64_bit_hash, argument_types, params);
|
||||
case 16:
|
||||
return createAggregateFunctionWithHashType<16>(use_64_bit_hash, argument_types, params);
|
||||
case 17:
|
||||
return createAggregateFunctionWithHashType<17>(use_64_bit_hash, argument_types, params);
|
||||
case 18:
|
||||
return createAggregateFunctionWithHashType<18>(use_64_bit_hash, argument_types, params);
|
||||
case 19:
|
||||
return createAggregateFunctionWithHashType<19>(use_64_bit_hash, argument_types, params);
|
||||
case 20:
|
||||
return createAggregateFunctionWithHashType<20>(use_64_bit_hash, argument_types, params);
|
||||
}
|
||||
|
||||
UNREACHABLE();
|
||||
}
|
||||
UNREACHABLE();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
|
||||
#include <Common/FieldVisitorConvertToNumber.h>
|
||||
|
||||
#include <DataTypes/DataTypeDate.h>
|
||||
#include <DataTypes/DataTypeDate32.h>
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
#include <DataTypes/DataTypeIPv4andIPv6.h>
|
||||
|
||||
#include <base/bit_cast.h>
|
||||
|
||||
#include <Common/CombinedCardinalityEstimator.h>
|
||||
@ -16,58 +26,15 @@
|
||||
#include <AggregateFunctions/UniqVariadicHash.h>
|
||||
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
|
||||
#include <functional>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
struct Settings;
|
||||
namespace detail
|
||||
{
|
||||
/** Hash function for uniqCombined/uniqCombined64 (based on Ret).
|
||||
*/
|
||||
template <typename T, typename Ret>
|
||||
struct AggregateFunctionUniqCombinedTraits
|
||||
{
|
||||
static Ret hash(T x)
|
||||
{
|
||||
if constexpr (sizeof(T) > sizeof(UInt64))
|
||||
return static_cast<Ret>(DefaultHash64<T>(x));
|
||||
else
|
||||
return static_cast<Ret>(intHash64(x));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Ret>
|
||||
struct AggregateFunctionUniqCombinedTraits<UInt128, Ret>
|
||||
{
|
||||
static Ret hash(UInt128 x)
|
||||
{
|
||||
return static_cast<Ret>(sipHash64(x));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Ret>
|
||||
struct AggregateFunctionUniqCombinedTraits<Float32, Ret>
|
||||
{
|
||||
static Ret hash(Float32 x)
|
||||
{
|
||||
UInt64 res = bit_cast<UInt64>(x);
|
||||
return static_cast<Ret>(intHash64(res));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Ret>
|
||||
struct AggregateFunctionUniqCombinedTraits<Float64, Ret>
|
||||
{
|
||||
static Ret hash(Float64 x)
|
||||
{
|
||||
UInt64 res = bit_cast<UInt64>(x);
|
||||
return static_cast<Ret>(intHash64(res));
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
// Unlike HashTableGrower always grows to power of 2.
|
||||
struct UniqCombinedHashTableGrower : public HashTableGrowerWithPrecalculation<>
|
||||
@ -75,55 +42,40 @@ struct UniqCombinedHashTableGrower : public HashTableGrowerWithPrecalculation<>
|
||||
void increaseSize() { increaseSizeDegree(1); }
|
||||
};
|
||||
|
||||
template <typename Key, UInt8 K>
|
||||
struct AggregateFunctionUniqCombinedDataWithKey
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename T, UInt8 K, typename HashValueType>
|
||||
struct AggregateFunctionUniqCombinedData
|
||||
{
|
||||
using Key = std::conditional_t<
|
||||
std::is_same_v<T, String> || std::is_same_v<T, IPv6>,
|
||||
UInt64,
|
||||
HashValueType>;
|
||||
|
||||
// TODO(ilezhankin): pre-generate values for |UniqCombinedBiasData|,
|
||||
// at the moment gen-bias-data.py script doesn't work.
|
||||
|
||||
// We want to migrate from |HashSet| to |HyperLogLogCounter| when the sizes in memory become almost equal.
|
||||
// The size per element in |HashSet| is sizeof(Key)*2 bytes, and the overall size of |HyperLogLogCounter| is 2^K * 6 bits.
|
||||
// For Key=UInt32 we can calculate: 2^X * 4 * 2 ≤ 2^(K-3) * 6 ⇒ X ≤ K-4.
|
||||
using Set = CombinedCardinalityEstimator<Key, HashSet<Key, TrivialHash, UniqCombinedHashTableGrower>, 16, K - 5 + (sizeof(Key) == sizeof(UInt32)), K, TrivialHash, Key>;
|
||||
|
||||
Set set;
|
||||
};
|
||||
|
||||
template <typename Key>
|
||||
struct AggregateFunctionUniqCombinedDataWithKey<Key, 17>
|
||||
{
|
||||
using Set = CombinedCardinalityEstimator<Key,
|
||||
/// Note: I don't recall what is special with '17' - probably it is one of the original functions that has to be compatible.
|
||||
using Set = CombinedCardinalityEstimator<
|
||||
Key,
|
||||
HashSet<Key, TrivialHash, UniqCombinedHashTableGrower>,
|
||||
16,
|
||||
12 + (sizeof(Key) == sizeof(UInt32)),
|
||||
17,
|
||||
K - 5 + (sizeof(Key) == sizeof(UInt32)),
|
||||
K,
|
||||
TrivialHash,
|
||||
Key,
|
||||
HyperLogLogBiasEstimator<UniqCombinedBiasData>,
|
||||
std::conditional_t<K == 17, HyperLogLogBiasEstimator<UniqCombinedBiasData>, TrivialBiasEstimator>,
|
||||
HyperLogLogMode::FullFeatured>;
|
||||
|
||||
Set set;
|
||||
};
|
||||
|
||||
|
||||
template <typename T, UInt8 K, typename HashValueType>
|
||||
struct AggregateFunctionUniqCombinedData : public AggregateFunctionUniqCombinedDataWithKey<HashValueType, K>
|
||||
{
|
||||
};
|
||||
|
||||
|
||||
/// For String keys, 64 bit hash is always used (both for uniqCombined and uniqCombined64),
|
||||
/// because of backwards compatibility (64 bit hash was already used for uniqCombined).
|
||||
template <UInt8 K, typename HashValueType>
|
||||
struct AggregateFunctionUniqCombinedData<String, K, HashValueType> : public AggregateFunctionUniqCombinedDataWithKey<UInt64 /*always*/, K>
|
||||
{
|
||||
};
|
||||
|
||||
template <UInt8 K, typename HashValueType>
|
||||
struct AggregateFunctionUniqCombinedData<IPv6, K, HashValueType> : public AggregateFunctionUniqCombinedDataWithKey<UInt64 /*always*/, K>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T, UInt8 K, typename HashValueType>
|
||||
class AggregateFunctionUniqCombined final
|
||||
: public IAggregateFunctionDataHelper<AggregateFunctionUniqCombinedData<T, K, HashValueType>, AggregateFunctionUniqCombined<T, K, HashValueType>>
|
||||
@ -153,7 +105,30 @@ public:
|
||||
else
|
||||
{
|
||||
const auto & value = assert_cast<const ColumnVector<T> &>(*columns[0]).getElement(row_num);
|
||||
this->data(place).set.insert(detail::AggregateFunctionUniqCombinedTraits<T, HashValueType>::hash(value));
|
||||
|
||||
HashValueType hash;
|
||||
|
||||
if constexpr (std::is_same_v<T, UInt128>)
|
||||
{
|
||||
/// This specialization exists due to historical circumstances.
|
||||
/// Initially UInt128 was introduced only for UUID, and then the other big-integer types were added.
|
||||
hash = static_cast<HashValueType>(sipHash64(value));
|
||||
}
|
||||
else if constexpr (std::is_floating_point_v<T>)
|
||||
{
|
||||
hash = static_cast<HashValueType>(intHash64(bit_cast<UInt64>(value)));
|
||||
}
|
||||
else if constexpr (sizeof(T) > sizeof(UInt64))
|
||||
{
|
||||
hash = static_cast<HashValueType>(DefaultHash64<T>(value));
|
||||
}
|
||||
else
|
||||
{
|
||||
/// This specialization exists also for compatibility with the initial implementation.
|
||||
hash = static_cast<HashValueType>(intHash64(value));
|
||||
}
|
||||
|
||||
this->data(place).set.insert(hash);
|
||||
}
|
||||
}
|
||||
|
||||
@ -237,4 +212,83 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <UInt8 K, typename HashValueType>
|
||||
struct WithK
|
||||
{
|
||||
template <typename T>
|
||||
using AggregateFunction = AggregateFunctionUniqCombined<T, K, HashValueType>;
|
||||
|
||||
template <bool is_exact, bool argument_is_tuple>
|
||||
using AggregateFunctionVariadic = AggregateFunctionUniqCombinedVariadic<is_exact, argument_is_tuple, K, HashValueType>;
|
||||
};
|
||||
|
||||
template <UInt8 K, typename HashValueType>
|
||||
AggregateFunctionPtr createAggregateFunctionWithK(const DataTypes & argument_types, const Array & params)
|
||||
{
|
||||
/// We use exact hash function if the arguments are not contiguous in memory, because only exact hash function has support for this case.
|
||||
bool use_exact_hash_function = !isAllArgumentsContiguousInMemory(argument_types);
|
||||
|
||||
if (argument_types.size() == 1)
|
||||
{
|
||||
const IDataType & argument_type = *argument_types[0];
|
||||
|
||||
AggregateFunctionPtr res(createWithNumericType<WithK<K, HashValueType>::template AggregateFunction>(*argument_types[0], argument_types, params));
|
||||
|
||||
WhichDataType which(argument_type);
|
||||
if (res)
|
||||
return res;
|
||||
else if (which.isDate())
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<DataTypeDate::FieldType>>(argument_types, params);
|
||||
else if (which.isDate32())
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<DataTypeDate32::FieldType>>(argument_types, params);
|
||||
else if (which.isDateTime())
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<DataTypeDateTime::FieldType>>(argument_types, params);
|
||||
else if (which.isStringOrFixedString())
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<String>>(argument_types, params);
|
||||
else if (which.isUUID())
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<DataTypeUUID::FieldType>>(argument_types, params);
|
||||
else if (which.isIPv4())
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<DataTypeIPv4::FieldType>>(argument_types, params);
|
||||
else if (which.isIPv6())
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<DataTypeIPv6::FieldType>>(argument_types, params);
|
||||
else if (which.isTuple())
|
||||
{
|
||||
if (use_exact_hash_function)
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunctionVariadic<true, true>>(argument_types, params);
|
||||
else
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunctionVariadic<false, true>>(argument_types, params);
|
||||
}
|
||||
}
|
||||
|
||||
/// "Variadic" method also works as a fallback generic case for a single argument.
|
||||
if (use_exact_hash_function)
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunctionVariadic<true, false>>(argument_types, params);
|
||||
else
|
||||
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunctionVariadic<false, false>>(argument_types, params);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <UInt8 K>
|
||||
AggregateFunctionPtr createAggregateFunctionWithHashType(bool use_64_bit_hash, const DataTypes & argument_types, const Array & params)
|
||||
{
|
||||
if (use_64_bit_hash)
|
||||
return createAggregateFunctionWithK<K, UInt64>(argument_types, params);
|
||||
else
|
||||
return createAggregateFunctionWithK<K, UInt32>(argument_types, params);
|
||||
}
|
||||
|
||||
/// Let's instantiate these templates in separate translation units,
|
||||
/// otherwise this translation unit becomes too large.
|
||||
extern template AggregateFunctionPtr createAggregateFunctionWithHashType<12>(bool use_64_bit_hash, const DataTypes & argument_types, const Array & params);
|
||||
extern template AggregateFunctionPtr createAggregateFunctionWithHashType<13>(bool use_64_bit_hash, const DataTypes & argument_types, const Array & params);
|
||||
extern template AggregateFunctionPtr createAggregateFunctionWithHashType<14>(bool use_64_bit_hash, const DataTypes & argument_types, const Array & params);
|
||||
extern template AggregateFunctionPtr createAggregateFunctionWithHashType<15>(bool use_64_bit_hash, const DataTypes & argument_types, const Array & params);
|
||||
extern template AggregateFunctionPtr createAggregateFunctionWithHashType<16>(bool use_64_bit_hash, const DataTypes & argument_types, const Array & params);
|
||||
extern template AggregateFunctionPtr createAggregateFunctionWithHashType<17>(bool use_64_bit_hash, const DataTypes & argument_types, const Array & params);
|
||||
extern template AggregateFunctionPtr createAggregateFunctionWithHashType<18>(bool use_64_bit_hash, const DataTypes & argument_types, const Array & params);
|
||||
extern template AggregateFunctionPtr createAggregateFunctionWithHashType<19>(bool use_64_bit_hash, const DataTypes & argument_types, const Array & params);
|
||||
extern template AggregateFunctionPtr createAggregateFunctionWithHashType<20>(bool use_64_bit_hash, const DataTypes & argument_types, const Array & params);
|
||||
|
||||
}
|
||||
|
@ -0,0 +1,6 @@
|
||||
#include <AggregateFunctions/AggregateFunctionUniqCombined.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
template AggregateFunctionPtr createAggregateFunctionWithHashType<12>(bool use_64_bit_hash, const DataTypes & argument_types, const Array & params);
|
||||
}
|
@ -0,0 +1,6 @@
|
||||
#include <AggregateFunctions/AggregateFunctionUniqCombined.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
template AggregateFunctionPtr createAggregateFunctionWithHashType<13>(bool use_64_bit_hash, const DataTypes & argument_types, const Array & params);
|
||||
}
|
@ -0,0 +1,6 @@
|
||||
#include <AggregateFunctions/AggregateFunctionUniqCombined.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
template AggregateFunctionPtr createAggregateFunctionWithHashType<14>(bool use_64_bit_hash, const DataTypes & argument_types, const Array & params);
|
||||
}
|
@ -0,0 +1,6 @@
|
||||
#include <AggregateFunctions/AggregateFunctionUniqCombined.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
template AggregateFunctionPtr createAggregateFunctionWithHashType<15>(bool use_64_bit_hash, const DataTypes & argument_types, const Array & params);
|
||||
}
|
@ -0,0 +1,6 @@
|
||||
#include <AggregateFunctions/AggregateFunctionUniqCombined.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
template AggregateFunctionPtr createAggregateFunctionWithHashType<16>(bool use_64_bit_hash, const DataTypes & argument_types, const Array & params);
|
||||
}
|
@ -0,0 +1,6 @@
|
||||
#include <AggregateFunctions/AggregateFunctionUniqCombined.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
template AggregateFunctionPtr createAggregateFunctionWithHashType<17>(bool use_64_bit_hash, const DataTypes & argument_types, const Array & params);
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user