Annoy tests (#58)

* Added diskann to contrib

* Made cmake changes cleaner

* Implemented DiskANN aggregator

* Renamed index from Simple to DiskANN

* Implemented index aggregator

* Implemented index serialization

* Added condition to DiskANN index

* Implemented condition for diskann

* removed maybe_unused_attr

* Condition

* added some metrics and comparison operators

* Added common condition

* Added tests

* enchanced some tests

* Added inddex to tests

* Updated functional tests

* Added annoy tests

Co-authored-by: Danila Mishin <mishin.dk@phystech.edu>
Co-authored-by: Hakob Saghatelyan <sagatelyan.aa@phystech.edu>
This commit is contained in:
VVMak 2022-05-12 13:49:31 +03:00 committed by GitHub
parent e8c7278c0d
commit 6adbd516cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 19219 additions and 710 deletions

3
.gitmodules vendored
View File

@ -262,6 +262,9 @@
[submodule "contrib/minizip-ng"]
path = contrib/minizip-ng
url = https://github.com/zlib-ng/minizip-ng
[submodule "contrib/gperftools"]
path = contrib/gperftools
url = https://github.com/gperftools/gperftools.git
[submodule "contrib/spotify-annoy"]
path = contrib/spotify-annoy
url = https://github.com/spotify/annoy.git

View File

View File

@ -53,6 +53,8 @@ add_contrib (sparsehash-c11-cmake sparsehash-c11)
add_contrib (abseil-cpp-cmake abseil-cpp)
add_contrib (magic-enum-cmake magic_enum)
add_contrib (boost-cmake boost)
add_contrib (gperftools-cmake gperftools)
add_contrib (diskann-cmake diskann)
add_contrib (cctz-cmake cctz)
add_contrib (consistent-hashing)
add_contrib (dragonbox-cmake dragonbox)

1
contrib/diskann vendored Submodule

@ -0,0 +1 @@
Subproject commit 76722512bb2a6207e2f1c6b17d00e42a8a09d8c9

View File

@ -0,0 +1,68 @@
list(APPEND CMAKE_MODULE_PATH "${ClickHouse_SOURCE_DIR}/contrib/diskann-cmake/modules")
set(DISKANN_PROJECT_DIR "${ClickHouse_SOURCE_DIR}/contrib/diskann")
set(DISKANN_SOURCE_DIR "${DISKANN_PROJECT_DIR}/src")
set(DISKANN_SOURCES
${DISKANN_SOURCE_DIR}/ann_exception.cpp
${DISKANN_SOURCE_DIR}/aux_utils.cpp
${DISKANN_SOURCE_DIR}/index.cpp
${DISKANN_SOURCE_DIR}/linux_aligned_file_reader.cpp
${DISKANN_SOURCE_DIR}/math_utils.cpp
${DISKANN_SOURCE_DIR}/memory_mapper.cpp
${DISKANN_SOURCE_DIR}/partition_and_pq.cpp
${DISKANN_SOURCE_DIR}/pq_flash_index.cpp
${DISKANN_SOURCE_DIR}/logger.cpp
${DISKANN_SOURCE_DIR}/utils.cpp
)
add_library(_diskann ${DISKANN_SOURCES})
target_include_directories (_diskann SYSTEM PUBLIC
"${DISKANN_PROJECT_DIR}/include"
"${ClickHouse_SOURCE_DIR}"
)
target_compile_options(_diskann PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-mavx2 -mfma -mf16c -mpopcnt>)
# MKL dependency
find_package(MKL)
if(MKL_FOUND)
message("Found MKL! ${MKL_ROOT}/include ${mkl_core_file}")
target_include_directories(_diskann PUBLIC "${MKL_ROOT}/include")
set(LIBMKL_THREAD_PATH
"${ClickHouse_SOURCE_DIR}/contrib/diskann-cmake/mkl/libmkl_intel_thread.a"
)
set(LIBMKL_ILP64_PATH
"${ClickHouse_SOURCE_DIR}/contrib/diskann-cmake/mkl/libmkl_intel_ilp64.a"
)
set(LIBMKL_CORE_PATH
"${ClickHouse_SOURCE_DIR}/contrib/diskann-cmake/mkl/libmkl_core.a"
)
set (CMAKE_STATIC_LINKER_FLAGS "${CMAKE_STATIC_LINKER_FLAGS} ${LIBMKL_THREAD_PATH}" )
set (CMAKE_STATIC_LINKER_FLAGS "${CMAKE_STATIC_LINKER_FLAGS} ${LIBMKL_ILP64_PATH}" )
set (CMAKE_STATIC_LINKER_FLAGS "${CMAKE_STATIC_LINKER_FLAGS} ${LIBMKL_CORE_PATH}" )
else()
message(FATAL_ERROR "MKL not found!")
endif()
# OpenMP dependancy
find_package(OpenMP REQUIRED)
target_link_libraries(_diskann PRIVATE OpenMP::OpenMP_C)
target_link_libraries(_diskann PRIVATE OpenMP::OpenMP_CXX)
# Boost dependancy
target_link_libraries(_diskann PRIVATE boost::headers_only)
# gperftools dependancy
target_link_libraries(_diskann PRIVATE ch_contrib::gperftools)
# aiolib dependancy
target_include_directories(_diskann SYSTEM PUBLIC "${ClickHouse_SOURCE_DIR}/contrib/diskann-cmake/aio/include")
add_library(ch_contrib::diskann ALIAS _diskann)

View File

@ -0,0 +1,316 @@
/* /usr/include/libaio.h
*
* Copyright 2000,2001,2002 Red Hat, Inc.
*
* Written by Benjamin LaHaise <bcrl@redhat.com>
*
* libaio Linux async I/O interface
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*/
#ifndef __LIBAIO_H
#define __LIBAIO_H
#ifdef __cplusplus
extern "C" {
#endif
#include <sys/types.h>
#include <string.h>
#include <signal.h>
struct timespec;
struct sockaddr;
struct iovec;
typedef struct io_context *io_context_t;
typedef enum io_iocb_cmd {
IO_CMD_PREAD = 0,
IO_CMD_PWRITE = 1,
IO_CMD_FSYNC = 2,
IO_CMD_FDSYNC = 3,
IO_CMD_POLL = 5,
IO_CMD_NOOP = 6,
IO_CMD_PREADV = 7,
IO_CMD_PWRITEV = 8,
} io_iocb_cmd_t;
/* little endian, 32 bits */
#if defined(__i386__) || (defined(__arm__) && !defined(__ARMEB__)) || \
(defined(__sh__) && defined(__LITTLE_ENDIAN__)) || \
defined(__bfin__) || \
(defined(__MIPSEL__) && !defined(__mips64)) || \
defined(__cris__) || (defined(__riscv) && __riscv_xlen == 32) || \
(defined(__GNUC__) && defined(__BYTE_ORDER__) && \
__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ && __SIZEOF_LONG__ == 4)
#define PADDED(x, y) x; unsigned y
#define PADDEDptr(x, y) x; unsigned y
#define PADDEDul(x, y) unsigned long x; unsigned y
/* little endian, 64 bits */
#elif defined(__ia64__) || defined(__x86_64__) || defined(__alpha__) || \
(defined(__mips64) && defined(__MIPSEL__)) || \
(defined(__aarch64__) && defined(__AARCH64EL__)) || \
(defined(__riscv) && __riscv_xlen == 64) || \
(defined(__GNUC__) && defined(__BYTE_ORDER__) && \
__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ && __SIZEOF_LONG__ == 8)
#define PADDED(x, y) x, y
#define PADDEDptr(x, y) x
#define PADDEDul(x, y) unsigned long x
/* big endian, 64 bits */
#elif defined(__powerpc64__) || defined(__s390x__) || \
(defined(__hppa__) && defined(__arch64__)) || \
(defined(__sparc__) && defined(__arch64__)) || \
(defined(__mips64) && defined(__MIPSEB__)) || \
(defined(__aarch64__) && defined(__AARCH64EB__)) || \
(defined(__GNUC__) && defined(__BYTE_ORDER__) && \
__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ && __SIZEOF_LONG__ == 8)
#define PADDED(x, y) unsigned y; x
#define PADDEDptr(x,y) x
#define PADDEDul(x, y) unsigned long x
/* big endian, 32 bits */
#elif defined(__PPC__) || defined(__s390__) || \
(defined(__arm__) && defined(__ARMEB__)) || \
(defined(__sh__) && defined (__BIG_ENDIAN__)) || \
defined(__sparc__) || defined(__MIPSEB__) || defined(__m68k__) || \
defined(__hppa__) || defined(__frv__) || defined(__avr32__) || \
(defined(__GNUC__) && defined(__BYTE_ORDER__) && \
__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ && __SIZEOF_LONG__ == 4)
#define PADDED(x, y) unsigned y; x
#define PADDEDptr(x, y) unsigned y; x
#define PADDEDul(x, y) unsigned y; unsigned long x
#else
#error endian?
#endif
struct io_iocb_poll {
PADDED(int events, __pad1);
}; /* result code is the set of result flags or -'ve errno */
struct io_iocb_sockaddr {
struct sockaddr *addr;
int len;
}; /* result code is the length of the sockaddr, or -'ve errno */
struct io_iocb_common {
PADDEDptr(void *buf, __pad1);
PADDEDul(nbytes, __pad2);
long long offset;
long long __pad3;
unsigned flags;
unsigned resfd;
}; /* result code is the amount read or -'ve errno */
struct io_iocb_vector {
const struct iovec *vec;
int nr;
long long offset;
}; /* result code is the amount read or -'ve errno */
struct iocb {
PADDEDptr(void *data, __pad1); /* Return in the io completion event */
/* key: For use in identifying io requests */
/* aio_rw_flags: RWF_* flags (such as RWF_NOWAIT) */
PADDED(unsigned key, aio_rw_flags);
short aio_lio_opcode;
short aio_reqprio;
int aio_fildes;
union {
struct io_iocb_common c;
struct io_iocb_vector v;
struct io_iocb_poll poll;
struct io_iocb_sockaddr saddr;
} u;
};
struct io_event {
PADDEDptr(void *data, __pad1);
PADDEDptr(struct iocb *obj, __pad2);
PADDEDul(res, __pad3);
PADDEDul(res2, __pad4);
};
struct io_sigset {
unsigned long ss;
unsigned long ss_len;
};
struct io_sigset_compat {
PADDEDptr(unsigned long ss, __ss_pad);
unsigned long ss_len;
};
#undef PADDED
#undef PADDEDptr
#undef PADDEDul
typedef void (*io_callback_t)(io_context_t ctx, struct iocb *iocb, long res, long res2);
/* library wrappers */
extern int io_queue_init(int maxevents, io_context_t *ctxp);
/*extern int io_queue_grow(io_context_t ctx, int new_maxevents);*/
extern int io_queue_release(io_context_t ctx);
/*extern int io_queue_wait(io_context_t ctx, struct timespec *timeout);*/
extern int io_queue_run(io_context_t ctx);
/* Actual syscalls */
extern int io_setup(int maxevents, io_context_t *ctxp);
extern int io_destroy(io_context_t ctx);
extern int io_submit(io_context_t ctx, long nr, struct iocb *ios[]);
extern int io_cancel(io_context_t ctx, struct iocb *iocb, struct io_event *evt);
extern int io_getevents(io_context_t ctx, long min_nr, long nr, struct io_event *events, struct timespec *timeout);
extern int io_pgetevents(io_context_t ctx, long min_nr, long nr,
struct io_event *events, struct timespec *timeout,
sigset_t *sigmask);
static inline void io_set_callback(struct iocb *iocb, io_callback_t cb)
{
iocb->data = (void *)cb;
}
static inline void io_prep_pread(struct iocb *iocb, int fd, void *buf, size_t count, long long offset)
{
memset(iocb, 0, sizeof(*iocb));
iocb->aio_fildes = fd;
iocb->aio_lio_opcode = IO_CMD_PREAD;
iocb->aio_reqprio = 0;
iocb->u.c.buf = buf;
iocb->u.c.nbytes = count;
iocb->u.c.offset = offset;
}
static inline void io_prep_pwrite(struct iocb *iocb, int fd, void *buf, size_t count, long long offset)
{
memset(iocb, 0, sizeof(*iocb));
iocb->aio_fildes = fd;
iocb->aio_lio_opcode = IO_CMD_PWRITE;
iocb->aio_reqprio = 0;
iocb->u.c.buf = buf;
iocb->u.c.nbytes = count;
iocb->u.c.offset = offset;
}
static inline void io_prep_preadv(struct iocb *iocb, int fd, const struct iovec *iov, int iovcnt, long long offset)
{
memset(iocb, 0, sizeof(*iocb));
iocb->aio_fildes = fd;
iocb->aio_lio_opcode = IO_CMD_PREADV;
iocb->aio_reqprio = 0;
iocb->u.c.buf = (void *)iov;
iocb->u.c.nbytes = iovcnt;
iocb->u.c.offset = offset;
}
static inline void io_prep_pwritev(struct iocb *iocb, int fd, const struct iovec *iov, int iovcnt, long long offset)
{
memset(iocb, 0, sizeof(*iocb));
iocb->aio_fildes = fd;
iocb->aio_lio_opcode = IO_CMD_PWRITEV;
iocb->aio_reqprio = 0;
iocb->u.c.buf = (void *)iov;
iocb->u.c.nbytes = iovcnt;
iocb->u.c.offset = offset;
}
static inline void io_prep_preadv2(struct iocb *iocb, int fd, const struct iovec *iov, int iovcnt, long long offset, int flags)
{
memset(iocb, 0, sizeof(*iocb));
iocb->aio_fildes = fd;
iocb->aio_lio_opcode = IO_CMD_PREADV;
iocb->aio_reqprio = 0;
iocb->aio_rw_flags = flags;
iocb->u.c.buf = (void *)iov;
iocb->u.c.nbytes = iovcnt;
iocb->u.c.offset = offset;
}
static inline void io_prep_pwritev2(struct iocb *iocb, int fd, const struct iovec *iov, int iovcnt, long long offset, int flags)
{
memset(iocb, 0, sizeof(*iocb));
iocb->aio_fildes = fd;
iocb->aio_lio_opcode = IO_CMD_PWRITEV;
iocb->aio_reqprio = 0;
iocb->aio_rw_flags = flags;
iocb->u.c.buf = (void *)iov;
iocb->u.c.nbytes = iovcnt;
iocb->u.c.offset = offset;
}
static inline void io_prep_poll(struct iocb *iocb, int fd, int events)
{
memset(iocb, 0, sizeof(*iocb));
iocb->aio_fildes = fd;
iocb->aio_lio_opcode = IO_CMD_POLL;
iocb->aio_reqprio = 0;
iocb->u.poll.events = events;
}
static inline int io_poll(io_context_t ctx, struct iocb *iocb, io_callback_t cb, int fd, int events)
{
io_prep_poll(iocb, fd, events);
io_set_callback(iocb, cb);
return io_submit(ctx, 1, &iocb);
}
static inline void io_prep_fsync(struct iocb *iocb, int fd)
{
memset(iocb, 0, sizeof(*iocb));
iocb->aio_fildes = fd;
iocb->aio_lio_opcode = IO_CMD_FSYNC;
iocb->aio_reqprio = 0;
}
static inline int io_fsync(io_context_t ctx, struct iocb *iocb, io_callback_t cb, int fd)
{
io_prep_fsync(iocb, fd);
io_set_callback(iocb, cb);
return io_submit(ctx, 1, &iocb);
}
static inline void io_prep_fdsync(struct iocb *iocb, int fd)
{
memset(iocb, 0, sizeof(*iocb));
iocb->aio_fildes = fd;
iocb->aio_lio_opcode = IO_CMD_FDSYNC;
iocb->aio_reqprio = 0;
}
static inline int io_fdsync(io_context_t ctx, struct iocb *iocb, io_callback_t cb, int fd)
{
io_prep_fdsync(iocb, fd);
io_set_callback(iocb, cb);
return io_submit(ctx, 1, &iocb);
}
static inline void io_set_eventfd(struct iocb *iocb, int eventfd)
{
iocb->u.c.flags |= (1 << 0) /* IOCB_FLAG_RESFD */;
iocb->u.c.resfd = eventfd;
}
#ifdef __cplusplus
}
#endif
#endif /* __LIBAIO_H */

Binary file not shown.

View File

@ -0,0 +1,15 @@
find_path(LIBAIO_INCLUDE_DIR NAMES libaio.h)
mark_as_advanced(LIBAIO_INCLUDE_DIR)
find_library(LIBAIO_LIBRARY NAMES aio)
mark_as_advanced(LIBAIO_LIBRARY)
include(FindPackageHandleStandardArgs)
FIND_PACKAGE_HANDLE_STANDARD_ARGS(
LIBAIO
REQUIRED_VARS LIBAIO_LIBRARY LIBAIO_INCLUDE_DIR)
if(LIBAIO_FOUND)
set(LIBAIO_LIBRARIES ${LIBAIO_LIBRARY})
set(LIBAIO_INCLUDE_DIRS ${LIBAIO_INCLUDE_DIR})
endif()

View File

@ -0,0 +1,148 @@
# - Try to find Google performance tools (gperftools)
# Input variables:
# GPERFTOOLS_ROOT_DIR - The gperftools install directory;
# if not set the GPERFTOOLS_DIR environment variable will be used
# GPERFTOOLS_INCLUDE_DIR - The gperftools include directory
# GPERFTOOLS_LIBRARY - The gperftools library directory
# Components: profiler, and tcmalloc or tcmalloc_minimal
# Output variables:
# GPERFTOOLS_FOUND - System has gperftools
# GPERFTOOLS_INCLUDE_DIRS - The gperftools include directories
# GPERFTOOLS_LIBRARIES - The libraries needed to use gperftools
# GPERFTOOLS_VERSION - The version string for gperftools
include(FindPackageHandleStandardArgs)
if(NOT DEFINED GPERFTOOLS_FOUND)
# If not set already, set GPERFTOOLS_ROOT_DIR from environment
if (DEFINED ENV{GPERFTOOLS_DIR} AND NOT DEFINED GPERFTOOLS_ROOT_DIR)
set(GPERFTOOLS_ROOT_DIR $ENV{GPERFTOOLS_DIR})
endif()
# Check to see if libunwind is required
set(GPERFTOOLS_DISABLE_PROFILER FALSE)
if((";${Gperftools_FIND_COMPONENTS};" MATCHES ";profiler;") AND
(CMAKE_SYSTEM_NAME MATCHES "Linux" OR
CMAKE_SYSTEM_NAME MATCHES "BlueGeneQ" OR
CMAKE_SYSTEM_NAME MATCHES "BlueGeneP") AND
(CMAKE_SIZEOF_VOID_P EQUAL 8))
# Libunwind is required by profiler on this platform
if(Gperftools_FIND_REQUIRED_profiler OR Gperftools_FIND_REQUIRED_tcmalloc_and_profiler)
find_package(Libunwind 0.99 REQUIRED)
else()
find_package(Libunwind)
if(NOT LIBUNWIND_FOUND OR LIBUNWIND_VERSION VERSION_LESS 0.99)
set(GPERFTOOLS_DISABLE_PROFILER TRUE)
endif()
endif()
endif()
# Check for invalid components
foreach(_comp ${Gperftools_FIND_COMPONENTS})
if((NOT _comp STREQUAL "tcmalloc_and_profiler") AND
(NOT _comp STREQUAL "tcmalloc") AND
(NOT _comp STREQUAL "tcmalloc_minimal") AND
(NOT _comp STREQUAL "profiler"))
message(FATAL_ERROR "Invalid component specified for Gperftools: ${_comp}")
endif()
endforeach()
# Check for valid component combinations
if(";${Gperftools_FIND_COMPONENTS};" MATCHES ";tcmalloc_and_profiler;" AND
(";${Gperftools_FIND_COMPONENTS};" MATCHES ";tcmalloc;" OR
";${Gperftools_FIND_COMPONENTS};" MATCHES ";tcmalloc_minimal;" OR
";${Gperftools_FIND_COMPONENTS};" MATCHES ";profiler;"))
message("ERROR: Invalid component selection for Gperftools: ${Gperftools_FIND_COMPONENTS}")
message("ERROR: Gperftools cannot link both tcmalloc_and_profiler with the tcmalloc, tcmalloc_minimal, or profiler libraries")
message(FATAL_ERROR "Gperftools component list is invalid")
endif()
if(";${Gperftools_FIND_COMPONENTS};" MATCHES ";tcmalloc;" AND ";${Gperftools_FIND_COMPONENTS};" MATCHES ";tcmalloc_minimal;")
message("ERROR: Invalid component selection for Gperftools: ${Gperftools_FIND_COMPONENTS}")
message("ERROR: Gperftools cannot link both tcmalloc and tcmalloc_minimal")
message(FATAL_ERROR "Gperftools component list is invalid")
endif()
# Set default sarch paths for gperftools
if(GPERFTOOLS_ROOT_DIR)
set(GPERFTOOLS_INCLUDE_DIR ${GPERFTOOLS_ROOT_DIR}/include CACHE PATH "The include directory for gperftools")
if(CMAKE_SIZEOF_VOID_P EQUAL 8 AND CMAKE_SYSTEM_NAME STREQUAL "Linux")
set(GPERFTOOLS_LIBRARY ${GPERFTOOLS_ROOT_DIR}/lib64;${GPERFTOOLS_ROOT_DIR}/lib CACHE PATH "The library directory for gperftools")
else()
set(GPERFTOOLS_LIBRARY ${GPERFTOOLS_ROOT_DIR}/lib CACHE PATH "The library directory for gperftools")
endif()
endif()
find_path(GPERFTOOLS_INCLUDE_DIRS NAMES gperftools/malloc_extension.h
HINTS ${GPERFTOOLS_INCLUDE_DIR})
# Search for component libraries
foreach(_comp ${Gperftools_FIND_COMPONENTS})
find_library(GPERFTOOLS_${_comp}_LIBRARY ${_comp}
HINTS ${GPERFTOOLS_LIBRARY})
if(GPERFTOOLS_${_comp}_LIBRARY)
set(Gperftools_${_comp}_FOUND TRUE)
else()
set(Gperftools_${_comp}_FOUND FALSE)
endif()
# Exclude profiler from the found list if libunwind is required but not found
if(Gperftools_${_comp}_FOUND AND ${_comp} MATCHES "profiler" AND GPERFTOOLS_DISABLE_PROFILER)
set(Gperftools_${_comp}_FOUND FALSE)
set(GPERFTOOLS_${_comp}_LIBRARY "GPERFTOOLS_${_comp}_LIBRARY-NOTFOUND")
message("WARNING: Gperftools '${_comp}' requires libunwind 0.99 or later.")
message("WARNING: Gperftools '${_comp}' will be disabled.")
endif()
if(";${Gperftools_FIND_COMPONENTS};" MATCHES ";${_comp};" AND Gperftools_${_comp}_FOUND)
list(APPEND GPERFTOOLS_LIBRARIES "${GPERFTOOLS_${_comp}_LIBRARY}")
endif()
endforeach()
# Set gperftools libraries if not set based on component list
if(NOT GPERFTOOLS_LIBRARIES)
if(Gperftools_tcmalloc_and_profiler_FOUND)
set(GPERFTOOLS_LIBRARIES "${GPERFTOOLS_tcmalloc_and_profiler_LIBRARY}")
elseif(Gperftools_tcmalloc_FOUND AND GPERFTOOLS_profiler_FOUND)
set(GPERFTOOLS_LIBRARIES "${GPERFTOOLS_tcmalloc_LIBRARY}" "${GPERFTOOLS_profiler_LIBRARY}")
elseif(Gperftools_profiler_FOUND)
set(GPERFTOOLS_LIBRARIES "${GPERFTOOLS_profiler_LIBRARY}")
elseif(Gperftools_tcmalloc_FOUND)
set(GPERFTOOLS_LIBRARIES "${GPERFTOOLS_tcmalloc_LIBRARY}")
elseif(Gperftools_tcmalloc_minimal_FOUND)
set(GPERFTOOLS_LIBRARIES "${GPERFTOOLS_tcmalloc_minimal_LIBRARY}")
endif()
endif()
# handle the QUIETLY and REQUIRED arguments and set GPERFTOOLS_FOUND to TRUE
# if all listed variables are TRUE
find_package_handle_standard_args(Gperftools
FOUND_VAR GPERFTOOLS_FOUND
REQUIRED_VARS GPERFTOOLS_LIBRARIES GPERFTOOLS_INCLUDE_DIRS
HANDLE_COMPONENTS)
mark_as_advanced(GPERFTOOLS_INCLUDE_DIR GPERFTOOLS_LIBRARY
GPERFTOOLS_INCLUDE_DIRS GPERFTOOLS_LIBRARIES)
# Add linker flags that instruct the compiler to exclude built in memory
# allocation functions. This works for GNU, Intel, and Clang. Other compilers
# may need to be added in the future.
if(GPERFTOOLS_LIBRARIES MATCHES "tcmalloc")
if((CMAKE_CXX_COMPILER_ID MATCHES "GNU") OR
(CMAKE_CXX_COMPILER_ID MATCHES "AppleClang") OR
(CMAKE_CXX_COMPILER_ID MATCHES "Clang") OR
((CMAKE_CXX_COMPILER_ID MATCHES "Intel") AND (NOT CMAKE_CXX_PLATFORM_ID MATCHES "Windows")))
list(APPEND GPERFTOOLS_LIBRARIES "-fno-builtin-malloc"
"-fno-builtin-calloc" "-fno-builtin-realloc" "-fno-builtin-free")
endif()
endif()
# Add libunwind flags to gperftools if the profiler is being used
if(GPERFTOOLS_LIBRARIES MATCHES "profiler" AND LIBUNWIND_FOUND)
list(APPEND GPERFTOOLS_INCLUDE_DIRS "${LIBUNWIND_INCLUDE_DIR}")
list(APPEND GPERFTOOLS_LIBRARIES "${LIBUNWIND_LIBRARIES}")
endif()
unset(GPERFTOOLS_DISABLE_PROFILER)
endif()

1
contrib/gperftools vendored Submodule

@ -0,0 +1 @@
Subproject commit fe85bbdf4cb891a67a8e2109c1c22a33aa958c7e

View File

@ -0,0 +1,15 @@
set (SRCS
../gperftools/src/gperftools/heap-checker.h
../gperftools/src/gperftools/heap-profiler.h
../gperftools/src/gperftools/malloc_extension.h
../gperftools/src/gperftools/malloc_extension_c.h
../gperftools/src/gperftools/malloc_hook.h
../gperftools/src/gperftools/malloc_hook_c.h
../gperftools/src/gperftools/nallocx.h
../gperftools/src/gperftools/profiler.h
../gperftools/src/gperftools/stacktrace.h
)
add_library(_gperftools ${SRCS})
target_include_directories(_gperftools SYSTEM PUBLIC ../gperftools/src)
add_library(ch_contrib::gperftools ALIAS _gperftools)

View File

@ -490,6 +490,14 @@ if (TARGET ch_contrib::libpqxx)
dbms_target_link_libraries(PUBLIC ch_contrib::libpqxx)
endif()
if (TARGET ch_contrib::gperftools)
dbms_target_link_libraries(PUBLIC ch_contrib::gperftools)
endif ()
if (TARGET ch_contrib::diskann)
dbms_target_link_libraries(PUBLIC ch_contrib::diskann)
endif ()
if (TARGET ch_contrib::datasketches)
target_link_libraries (clickhouse_aggregate_functions PRIVATE ch_contrib::datasketches)
endif ()

View File

@ -1,534 +0,0 @@
#include <cstddef>
#include <optional>
#include <Parsers/ASTFunction.h>
#include "Core/Block.h"
#include "Core/Field.h"
#include "IO/ReadBuffer.h"
#include "Interpreters/Context_fwd.h"
#include "Parsers/ASTExpressionList.h"
#include "Parsers/ASTFunctionWithKeyValueArguments.h"
#include "Parsers/ASTIdentifier.h"
#include "Parsers/ASTIdentifier_fwd.h"
#include "Parsers/ASTLiteral.h"
#include "Parsers/ASTOrderByElement.h"
#include "Parsers/ASTSelectQuery.h"
#include "Parsers/ASTSetQuery.h"
#include "Parsers/ASTTablesInSelectQuery.h"
#include "Parsers/Access/ASTCreateUserQuery.h"
#include "Parsers/Access/ASTRolesOrUsersSet.h"
#include "Parsers/Access/ASTSettingsProfileElement.h"
#include "Parsers/IAST_fwd.h"
#include <Storages/MergeTree/CommonCondition.h>
#include <Storages/MergeTree/KeyCondition.h>
#include "Storages/SelectQueryInfo.h"
#include "base/logger_useful.h"
#include "base/types.h"
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
namespace Condition
{
CommonCondition::CommonCondition(const SelectQueryInfo & query_info,
ContextPtr context)
{
buildRPN(query_info, context);
index_is_useful = matchAllRPNS();
}
bool CommonCondition::alwaysUnknownOrTrue() const
{
return !index_is_useful;
}
float CommonCondition::getComparisonDistance() const
{
if (where_query_type)
{
return ann_expr->distance;
}
throw Exception(ErrorCodes::LOGICAL_ERROR, "Not supported method for this query type");
}
std::vector<float> CommonCondition::getTargetVector() const
{
return ann_expr->target;
}
String CommonCondition::getColumnName() const
{
return ann_expr->column_name;
}
String CommonCondition::getMetric() const
{
return ann_expr->metric_name;
}
size_t CommonCondition::getSpaceDim() const
{
return ann_expr->target.size();
}
float CommonCondition::getPForLpDistance() const
{
return ann_expr->p_for_lp_dist;
}
bool CommonCondition::queryHasWhereClause() const
{
return where_query_type;
}
bool CommonCondition::queryHasOrderByClause() const
{
return order_by_query_type && has_limit;
}
std::optional<UInt64> CommonCondition::getLimitLength() const
{
return has_limit ? std::optional<UInt64>(limit_expr->length) : std::nullopt;
}
String CommonCondition::getSettingsStr() const
{
return ann_index_params;
}
void CommonCondition::buildRPN(const SelectQueryInfo & query, ContextPtr context)
{
block_with_constants = KeyCondition::getBlockWithConstants(query.query, query.syntax_analyzer_result, context);
const auto & select = query.query->as<ASTSelectQuery &>();
if (select.prewhere())
{
traverseAST(select.prewhere(), rpn_prewhere_clause);
}
if (select.where())
{
traverseAST(select.where(), rpn_where_clause);
}
if (select.limitLength())
{
traverseAST(select.limitLength(), rpn_limit_clause);
}
if (select.settings())
{
parseSettings(select.settings());
}
if (select.orderBy())
{
if (const auto * expr_list = select.orderBy()->as<ASTExpressionList>())
{
if (const auto * order_by_element = expr_list->children.front()->as<ASTOrderByElement>())
{
traverseAST(order_by_element->children.front(), rpn_order_by_clause);
}
}
}
std::reverse(rpn_prewhere_clause.begin(), rpn_prewhere_clause.end());
std::reverse(rpn_where_clause.begin(), rpn_where_clause.end());
std::reverse(rpn_order_by_clause.begin(), rpn_order_by_clause.end());
}
void CommonCondition::traverseAST(const ASTPtr & node, RPN & rpn)
{
if (const auto * func = node->as<ASTFunction>())
{
const ASTs & args = func->arguments->children;
for (const auto& arg : args)
{
traverseAST(arg, rpn);
}
}
RPNElement element;
if (!traverseAtomAST(node, element))
{
element.function = RPNElement::FUNCTION_UNKNOWN;
}
rpn.emplace_back(std::move(element));
}
bool CommonCondition::traverseAtomAST(const ASTPtr & node, RPNElement & out)
{
if (const auto * order_by_element = node->as<ASTOrderByElement>())
{
out.function = RPNElement::FUNCTION_ORDER_BY_ELEMENT;
out.func_name = "order by elemnet";
return true;
}
if (const auto * function = node->as<ASTFunction>())
{
// Set the name
out.func_name = function->name;
// TODO: Add support for LpDistance
if (function->name == "L1Distance" ||
function->name == "L2Distance" ||
function->name == "LinfDistance" ||
function->name == "cosineDistance" ||
function->name == "dotProduct" ||
function->name == "LpDistance")
{
out.function = RPNElement::FUNCTION_DISTANCE;
}
else if (function->name == "tuple")
{
out.function = RPNElement::FUNCTION_TUPLE;
}
else if (function->name == "less" ||
function->name == "greater" ||
function->name == "lessOrEquals" ||
function->name == "greaterOrEquals")
{
out.function = RPNElement::FUNCTION_COMPARISON;
}
else
{
return false;
}
return true;
}
// Match identifier
else if (const auto * identifier = node->as<ASTIdentifier>())
{
out.function = RPNElement::FUNCTION_IDENTIFIER;
out.identifier.emplace(identifier->name());
out.func_name = "column identifier";
return true;
}
// Check if we have constants behind the node
{
Field const_value;
DataTypePtr const_type;
if (KeyCondition::getConstant(node, block_with_constants, const_value, const_type))
{
/// Check constant type (use Float64 because all Fields implementation contains Float64 (for Float32 too))
if (const_value.getType() == Field::Types::Float64)
{
out.function = RPNElement::FUNCTION_FLOAT_LITERAL;
out.float_literal.emplace(const_value.get<Float32>());
out.func_name = "Float literal";
return true;
}
if (const_value.getType() == Field::Types::UInt64)
{
out.function = RPNElement::FUNCTION_INT_LITERAL;
out.int_literal.emplace(const_value.get<UInt64>());
out.func_name = "Int literal";
return true;
}
if (const_value.getType() == Field::Types::Int64)
{
out.function = RPNElement::FUNCTION_INT_LITERAL;
out.int_literal.emplace(const_value.get<Int64>());
out.func_name = "Int literal";
return true;
}
if (const_value.getType() == Field::Types::String)
{
out.function = RPNElement::FUNCTION_STRING;
out.identifier.emplace(const_value.get<String>());
out.func_name = "setting string";
return true;
}
if (const_value.getType() == Field::Types::Tuple)
{
out.function = RPNElement::FUNCTION_LITERAL_TUPLE;
out.tuple_literal = const_value.get<Tuple>();
out.func_name = "Tuple literal";
return true;
}
}
}
return false;
}
bool CommonCondition::matchAllRPNS()
{
ANNExpression expr_prewhere;
ANNExpression expr_where;
ANNExpression expr_order_by;
LimitExpression expr_limit;
bool prewhere_is_valid = matchRPNWhere(rpn_prewhere_clause, expr_prewhere);
bool where_is_valid = matchRPNWhere(rpn_where_clause, expr_where);
bool limit_is_valid = matchRPNLimit(rpn_limit_clause, expr_limit);
bool order_by_is_valid = matchRPNOrderBy(rpn_order_by_clause, expr_order_by);
// Unxpected situation
if (prewhere_is_valid && where_is_valid)
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "Have both where and prewhere valid clauses - is not supported");
}
if (prewhere_is_valid || where_is_valid)
{
ann_expr = std::move(where_is_valid ? expr_where : expr_prewhere);
where_query_type = true;
}
if (order_by_is_valid)
{
ann_expr = std::move(expr_order_by);
order_by_query_type = true;
}
if (limit_is_valid)
{
limit_expr = std::move(expr_limit);
has_limit = true;
}
if (where_query_type && (has_limit && order_by_query_type))
{
throw Exception(ErrorCodes::LOGICAL_ERROR,
"The query with Valid Where Clause and valid OrderBy clause - is not supported");
}
return where_query_type || (has_limit && order_by_query_type);
}
bool CommonCondition::matchRPNLimit(RPN & rpn, LimitExpression & expr)
{
if (rpn.size() != 1)
{
return false;
}
if (rpn.front().function == RPNElement::FUNCTION_INT_LITERAL)
{
expr.length = rpn.front().int_literal.value();
return true;
}
return false;
}
void CommonCondition::parseSettings(const ASTPtr & node)
{
if (const auto * set = node->as<ASTSetQuery>())
{
for (const auto & change : set->changes)
{
if (change.name == "ann_index_params")
{
ann_index_params = change.value.get<String>();
return;
}
}
}
ann_index_params = "";
}
bool CommonCondition::matchRPNOrderBy(RPN & rpn, ANNExpression & expr)
{
if (rpn.size() < 3)
{
return false;
}
auto iter = rpn.begin();
auto end = rpn.end();
bool identifier_found = false;
return CommonCondition::matchMainParts(iter, end, expr, identifier_found);
}
bool CommonCondition::matchMainParts(RPN::iterator & iter, RPN::iterator & end,
ANNExpression & expr, bool & identifier_found)
{
if (iter->function != RPNElement::FUNCTION_DISTANCE)
{
return false;
}
expr.metric_name = iter->func_name;
++iter;
if (expr.metric_name == "LpDistance")
{
if (iter->function != RPNElement::FUNCTION_FLOAT_LITERAL &&
iter->function != RPNElement::FUNCTION_INT_LITERAL)
{
return false;
}
expr.p_for_lp_dist = getFloatOrIntLiteralOrPanic(iter);
++iter;
}
if (iter->function == RPNElement::FUNCTION_IDENTIFIER)
{
identifier_found = true;
expr.column_name = getIdentifierOrPanic(iter);
++iter;
}
if (iter->function == RPNElement::FUNCTION_TUPLE)
{
++iter;
}
if (iter->function == RPNElement::FUNCTION_LITERAL_TUPLE)
{
for (const auto & value : iter->tuple_literal.value())
{
expr.target.emplace_back(value.get<float>());
}
++iter;
}
while (iter != end)
{
if (iter->function == RPNElement::FUNCTION_FLOAT_LITERAL ||
iter->function == RPNElement::FUNCTION_INT_LITERAL)
{
expr.target.emplace_back(getFloatOrIntLiteralOrPanic(iter));
}
else if (iter->function == RPNElement::FUNCTION_IDENTIFIER)
{
if (identifier_found)
{
return false;
}
expr.column_name = getIdentifierOrPanic(iter);
identifier_found = true;
}
else
{
return false;
}
++iter;
}
return true;
}
bool CommonCondition::matchRPNWhere(RPN & rpn, ANNExpression & expr)
{
const size_t minimal_elemets_count = 6;// At least 6 AST nodes in querry
if (rpn.size() < minimal_elemets_count)
{
return false;
}
auto iter = rpn.begin();
bool identifier_found = false;
// Query starts from operator less
if (iter->function != RPNElement::FUNCTION_COMPARISON)
{
return false;
}
const bool greater_case = iter->func_name == "greater" || iter->func_name == "greaterOrEquals";
const bool less_case = iter->func_name == "less" || iter->func_name == "lessOrEquals";
++iter;
if (less_case)
{
if (iter->function != RPNElement::FUNCTION_FLOAT_LITERAL)
{
return false;
}
expr.distance = getFloatOrIntLiteralOrPanic(iter);
++iter;
}
else if (!greater_case)
{
return false;
}
auto end = rpn.end();
if (!matchMainParts(iter, end, expr, identifier_found))
{
return false;
}
// Final checks of correctness
if (!identifier_found || expr.target.empty())
{
return false;
}
if (greater_case)
{
if (expr.target.size() < 2)
{
return false;
}
expr.distance = expr.target.back();
expr.target.pop_back();
}
// Querry is ok
return true;
}
String CommonCondition::getIdentifierOrPanic(RPN::iterator& iter)
{
String identifier;
try
{
identifier = std::move(iter->identifier.value());
}
catch (...)
{
CommonCondition::panicIfWrongBuiltRPN();
}
return identifier;
}
float CommonCondition::getFloatOrIntLiteralOrPanic(RPN::iterator& iter)
{
if (iter->float_literal.has_value())
{
return iter->float_literal.value();
}
if (iter->int_literal.has_value())
{
return static_cast<float>(iter->int_literal.value());
}
CommonCondition::panicIfWrongBuiltRPN();
}
void CommonCondition::panicIfWrongBuiltRPN()
{
LOG_DEBUG(&Poco::Logger::get("CommonCondition"), "Wrong parsing of AST");
throw Exception(
"Wrong parsed AST in buildRPN\n", DB::ErrorCodes::LOGICAL_ERROR);
}
}
}

View File

@ -1,175 +0,0 @@
#pragma once
#include <Storages/MergeTree/KeyCondition.h>
#include "base/types.h"
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <variant>
#include <vector>
namespace DB
{
namespace Condition
{
class CommonCondition
{
public:
CommonCondition(const SelectQueryInfo & query_info,
ContextPtr context);
bool alwaysUnknownOrTrue() const;
float getComparisonDistance() const;
std::vector<float> getTargetVector() const;
size_t getSpaceDim() const;
String getColumnName() const;
String getMetric() const;
float getPForLpDistance() const;
bool queryHasOrderByClause() const;
bool queryHasWhereClause() const;
std::optional<UInt64> getLimitLength() const;
String getSettingsStr() const;
private:
// Type of the vector to use as a target in the distance function
using Target = std::vector<float>;
// Extracted data from the query like WHERE L2Distance(column_name, target) < distance
struct ANNExpression
{
Target target;
float distance = -1.0;
String metric_name = "Unknown"; // Metric name, maybe some Enum for all indices
String column_name = "Unknown"; // Coloumn name stored in IndexGranule
float p_for_lp_dist = -1.0; // The P parametr for Lp Distance
};
struct LimitExpression
{
Int64 length;
};
using ANNExprOpt = std::optional<ANNExpression>;
using LimitExprOpt = std::optional<LimitExpression>;
struct RPNElement
{
enum Function
{
// l2 dist
FUNCTION_DISTANCE,
//tuple(10, 15)
FUNCTION_TUPLE,
// Operator <, >
FUNCTION_COMPARISON,
// Numeric float value
FUNCTION_FLOAT_LITERAL,
// Numeric int value
FUNCTION_INT_LITERAL,
// Column identifier
FUNCTION_IDENTIFIER,
// Unknown, can be any value
FUNCTION_UNKNOWN,
FUNCTION_STRING,
FUNCTION_LITERAL_TUPLE,
FUNCTION_ORDER_BY_ELEMENT,
};
explicit RPNElement(Function function_ = FUNCTION_UNKNOWN)
: function(function_), func_name("Unknown"), float_literal(std::nullopt), identifier(std::nullopt) {}
Function function;
String func_name;
std::optional<float> float_literal;
std::optional<String> identifier;
std::optional<int64_t> int_literal{std::nullopt};
std::optional<Tuple> tuple_literal{std::nullopt};
UInt32 dim{0};
};
using RPN = std::vector<RPNElement>;
void buildRPN(const SelectQueryInfo & query, ContextPtr context);
// Util functions for the traversal of AST
void traverseAST(const ASTPtr & node, RPN & rpn);
// Return true if we can identify our node type
bool traverseAtomAST(const ASTPtr & node, RPNElement & out);
// Checks that at least one rpn is matching for index
// New RPNs for other query types can be added here
bool matchAllRPNS();
/* Returns true and stores ANNExpr if the query matches the template:
* WHERE DistFunc(column_name, tuple(float_1, float_2, ..., float_dim)) < float_literal */
static bool matchRPNWhere(RPN & rpn, ANNExpression & expr);
/* Returns true and stores OrderByExpr if the query has valid OrderBy section*/
static bool matchRPNOrderBy(RPN & rpn, ANNExpression & expr);
/* Returns true if we have valid limit clause in query*/
static bool matchRPNLimit(RPN & rpn, LimitExpression & expr);
/* Getting settings for ann_index_param */
void parseSettings(const ASTPtr & node);
/* Matches dist function, target vector, coloumn name */
static bool matchMainParts(RPN::iterator & iter, RPN::iterator & end, ANNExpression & expr, bool & identifier_found);
// Util methods
static void panicIfWrongBuiltRPN [[noreturn]] ();
static String getIdentifierOrPanic(RPN::iterator& iter);
static float getFloatOrIntLiteralOrPanic(RPN::iterator& iter);
// Here we store RPN-s for different types of Queries
RPN rpn_prewhere_clause;
RPN rpn_where_clause;
RPN rpn_limit_clause;
RPN rpn_order_by_clause;
Block block_with_constants;
ANNExprOpt ann_expr{std::nullopt};
LimitExprOpt limit_expr{std::nullopt};
String ann_index_params; // Empty string if no params
bool order_by_query_type{false};
bool where_query_type{false};
bool has_limit{false};
// true if we had extracted ANNExpression from query
bool index_is_useful{false};
};
}
}

View File

@ -0,0 +1,381 @@
#include <filesystem>
#include <memory>
#include <parameters.h>
#include <utils.h>
#include <Interpreters/ExpressionActions.h>
#include <Interpreters/ExpressionAnalyzer.h>
#include <Interpreters/TreeRewriter.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteBufferFromFile.h>
#include <IO/WriteHelpers.h>
#include "KeyCondition.h"
#include "Parsers/ASTIdentifier.h"
#include "Parsers/ASTSelectQuery.h"
#include "Parsers/IAST_fwd.h"
#include <Parsers/ASTFunction.h>
#include <Poco/Logger.h>
#include <Storages/MergeTree/MergeTreeIndexDiskANN.h>
#include <base/logger_useful.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
namespace detail {
void saveDataPoints(uint32_t dimensions, std::vector<DiskANNValue> datapoints, WriteBuffer & out) {
uint32_t num_of_points = static_cast<uint32_t>(datapoints.size()) / dimensions;
out.write(reinterpret_cast<const char*>(&num_of_points), sizeof(num_of_points));
out.write(reinterpret_cast<const char*>(&dimensions), sizeof(dimensions));
for (float data_point : datapoints) {
out.write(reinterpret_cast<const char*>(&data_point), sizeof(Float32));
}
LOG_DEBUG(&Poco::Logger::get("DiskANN"), "Saved {} points", num_of_points);
}
DiskANNIndexPtr constructIndexFromDatapoints(uint32_t dimensions, std::vector<DiskANNValue> datapoints) {
if (datapoints.empty()) {
throw Exception("Trying to construct index with no datapoints", ErrorCodes::LOGICAL_ERROR);
}
String datapoints_filename = "diskann_datapoints.bin";
WriteBufferFromFile write_buffer(datapoints_filename);
detail::saveDataPoints(dimensions, datapoints, write_buffer);
write_buffer.close();
return std::make_shared<DiskANNIndex>(
diskann::Metric::L2,
datapoints_filename.c_str()
);
}
}
MergeTreeIndexGranuleDiskANN::MergeTreeIndexGranuleDiskANN(const String & index_name_, const Block & index_sample_block_)
: index_name(index_name_)
, index_sample_block(index_sample_block_)
{}
MergeTreeIndexGranuleDiskANN::MergeTreeIndexGranuleDiskANN(
const String & index_name_,
const Block & index_sample_block_,
DiskANNIndexPtr base_index_,
uint32_t dimensions_,
std::vector<DiskANNValue> datapoints_
)
: dimensions(dimensions_)
, datapoints(std::move(datapoints_))
, index_name(index_name_)
, index_sample_block(index_sample_block_)
, base_index(base_index_)
{
}
uint64_t MergeTreeIndexGranuleDiskANN::calculateIndexSize() const {
uint64_t index_size = 0;
index_size += sizeof(uint64_t) + 2 * sizeof(unsigned);
std::cout << base_index->_nd << " " << base_index->_final_graph.size() << std::endl;
for (unsigned i = 0; i < base_index->_nd + base_index->_num_frozen_pts; i++) {
unsigned gk = static_cast<unsigned>(base_index->_final_graph[i].size());
index_size += sizeof(unsigned) + gk * sizeof(unsigned);
}
return index_size;
}
void MergeTreeIndexGranuleDiskANN::serializeBinary(WriteBuffer & out) const {
LOG_DEBUG(&Poco::Logger::get("DiskANN"), "Saving Vamana index: saving datapoints...");
if (!dimensions.has_value()) {
throw Exception("Dimensions parameter was not got, despite having data", ErrorCodes::LOGICAL_ERROR);
}
detail::saveDataPoints(dimensions.value(), datapoints, out);
LOG_DEBUG(&Poco::Logger::get("DiskANN"), "Datapoints saved.");
LOG_DEBUG(&Poco::Logger::get("DiskANN"), "Saving Vamana index itself...");
uint64_t total_gr_edges = 0;
uint64_t index_size = calculateIndexSize();
LOG_DEBUG(&Poco::Logger::get("DiskANN"), "Index size: {}", index_size);
out.write(reinterpret_cast<char*>(&index_size), sizeof(uint64_t));
out.write(reinterpret_cast<char*>(&base_index->_width), sizeof(unsigned));
out.write(reinterpret_cast<char*>(&base_index->_ep), sizeof(unsigned));
for (size_t i = 0; i < base_index->_nd + base_index->_num_frozen_pts; i++) {
unsigned gk = static_cast<unsigned>(base_index->_final_graph[i].size());
out.write(reinterpret_cast<char*>(&gk), sizeof(unsigned));
out.write(reinterpret_cast<char*>(base_index->_final_graph[i].data()), gk * sizeof(unsigned));
total_gr_edges += gk;
}
LOG_DEBUG(
&Poco::Logger::get("DiskANN"),
"Saving Vamana index done! Avg degree: {}",
(static_cast<float>(total_gr_edges)) / (static_cast<float>(base_index->_nd + base_index->_num_frozen_pts))
);
}
void MergeTreeIndexGranuleDiskANN::deserializeBinary(ReadBuffer & in, MergeTreeIndexVersion /*version*/) {
LOG_DEBUG(&Poco::Logger::get("DiskANN"), "Loading datapoints in deserialize...");
uint32_t num_of_points = 0;
uint32_t dims = 0;
in.read(reinterpret_cast<char*>(&num_of_points), sizeof(num_of_points));
in.read(reinterpret_cast<char*>(&dims), sizeof(dims));
LOG_DEBUG(&Poco::Logger::get("DiskANN"), "num_of_points={}, dims={}", num_of_points, dims);
dimensions = dims;
datapoints.resize(num_of_points * dims);
in.read(reinterpret_cast<char*>(datapoints.data()), sizeof(DiskANNValue) * num_of_points * dims);
if (num_of_points * dims != datapoints.size()) {
LOG_ERROR(
&Poco::Logger::get("DiskANN"),
"num_of_points * dims != datapoints.size(); {} * {} != {}.",
num_of_points, dims, datapoints.size());
throw Exception("Bad datapoints read", ErrorCodes::LOGICAL_ERROR);
}
base_index = detail::constructIndexFromDatapoints(dims, datapoints);
LOG_DEBUG(&Poco::Logger::get("DiskANN"), "Got datapoints: {}. Constructed the index object", datapoints.size());
LOG_DEBUG(&Poco::Logger::get("DiskANN"), "Loading Vamana index...");
uint64_t expected_file_size;
in.read(reinterpret_cast<char*>(&expected_file_size), sizeof(uint64_t));
LOG_DEBUG(&Poco::Logger::get("DiskANN"), "Expected index size: {}", expected_file_size);
in.read(reinterpret_cast<char*>(&base_index->_width), sizeof(unsigned));
in.read(reinterpret_cast<char*>(&base_index->_ep), sizeof(unsigned));
assert(base_index->_final_graph.empty());
size_t cc = 0;
unsigned nodes = 0;
while (!in.eof()) {
unsigned k;
in.read(reinterpret_cast<char*>(&k), sizeof(unsigned));
if (in.eof())
break;
cc += k;
++nodes;
std::vector<unsigned> tmp(k);
in.read(reinterpret_cast<char*>(tmp.data()), k * sizeof(unsigned));
base_index->_final_graph.emplace_back(tmp);
if (nodes >= datapoints.size() / dims) {
break;
}
}
assert(nodes == base_index->_final_graph.size());
if (base_index->_final_graph.size() != base_index->_nd) {
LOG_ERROR(
&Poco::Logger::get("DiskANN"), "Mismatch in "
"number of points. Graph has {} points and loaded dataset has {} points.",
base_index->_final_graph.size(), base_index->_nd
);
throw Exception("Number of points mismatch", ErrorCodes::LOGICAL_ERROR);
}
LOG_DEBUG(&Poco::Logger::get("DiskANN"), "..done. Index has {} nodes and {} out-edges", nodes, cc);
}
MergeTreeIndexAggregatorDiskANN::MergeTreeIndexAggregatorDiskANN(const String & index_name_, const Block & index_sample_block_)
: index_name(index_name_)
, index_sample_block(index_sample_block_)
{}
MergeTreeIndexGranulePtr MergeTreeIndexAggregatorDiskANN::getGranuleAndReset()
{
if (accumulated_data.empty()) {
return std::make_shared<MergeTreeIndexGranuleDiskANN>(index_name, index_sample_block);
}
if (!dimensions.has_value()) {
throw Exception("Dimensions parameter was not got, despite having data", ErrorCodes::LOGICAL_ERROR);
}
auto base_index = detail::constructIndexFromDatapoints(dimensions.value(), accumulated_data);
diskann::Parameters paras;
paras.Set<unsigned>("R", 100);
paras.Set<unsigned>("L", 150);
paras.Set<unsigned>("C", 750);
paras.Set<float>("alpha", 1.2);
paras.Set<bool>("saturate_graph", true);
paras.Set<unsigned>("num_threads", 1);
LOG_DEBUG(&Poco::Logger::get("DiskANN"), "Index parameters set");
LOG_DEBUG(&Poco::Logger::get("DiskANN"), "Starting to build DiskANN index");
base_index->build(paras);
LOG_DEBUG(&Poco::Logger::get("DiskANN"), "DiskANN index has been successfully built!");
return std::make_shared<MergeTreeIndexGranuleDiskANN>(index_name, index_sample_block, base_index, dimensions.value(), std::move(accumulated_data));
}
void MergeTreeIndexAggregatorDiskANN::flattenAccumulatedData(std::vector<std::vector<DiskANNValue>> data) {
if (data.empty()) {
throw Exception("Dimensionality must be possitive!", ErrorCodes::LOGICAL_ERROR);
}
dimensions = data.size();
accumulated_data.clear();
for (size_t current_element = 0; current_element < data[0].size(); ++current_element) {
for (size_t dim = 0; dim < dimensions; ++dim) {
accumulated_data.push_back(data[dim][current_element]);
}
}
LOG_DEBUG(&Poco::Logger::get("DiskANN"), "Flattened the data, size: {};", accumulated_data.size());
}
void MergeTreeIndexAggregatorDiskANN::update(const Block & block, size_t * pos, size_t limit)
{
if (*pos >= block.rows())
throw Exception(
"The provided position is not less than the number of block rows. Position: "
+ toString(*pos) + ", Block rows: " + toString(block.rows()) + ".", ErrorCodes::LOGICAL_ERROR);
size_t rows_read = std::min(limit, block.rows() - *pos);
if (index_sample_block.columns() > 1) {
throw Exception("Only one column is supported", ErrorCodes::LOGICAL_ERROR);
}
auto index_column_name = index_sample_block.getByPosition(0).name;
const auto & column = block.getByName(index_column_name).column->cut(*pos, rows_read);
std::vector<std::vector<Float32>> coords_vector;
const auto * vectors = typeid_cast<const ColumnTuple *>(column.get());
for (const auto & inner_column : vectors->getColumns()) {
const auto * coords = typeid_cast<const ColumnFloat32 *>(inner_column.get());
auto v = std::vector<Float32>(coords->getData().begin(), coords->getData().end());
coords_vector.push_back(std::move(v));
}
LOG_DEBUG(&Poco::Logger::get("DiskANN"), "Got data, dimensions: {};", coords_vector.size());
flattenAccumulatedData(std::move(coords_vector));
*pos += rows_read;
}
MergeTreeIndexConditionDiskANN::MergeTreeIndexConditionDiskANN(
const IndexDescription & /*index*/,
const SelectQueryInfo & query,
ContextPtr context)
: common_condition(query, context)
{
LOG_DEBUG(&Poco::Logger::get("DiskANN"), "Built DiskANN Condition");
}
bool MergeTreeIndexConditionDiskANN::alwaysUnknownOrTrue() const
{
return common_condition.alwaysUnknownOrTrue();
}
bool MergeTreeIndexConditionDiskANN::mayBeTrueOnGranule(MergeTreeIndexGranulePtr idx_granule) const
{
std::vector<float> target_vec = common_condition.getTargetVector();
float min_distance = common_condition.getComparisonDistance();
// Number of target vectors
size_t n = 5;
// Number of NN to search
size_t k = n;
// Will be populated by diskann
std::vector<float> distances(n);
std::vector<uint64_t> indicies(n);
std::vector<unsigned> init_ids{};
auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleDiskANN>(idx_granule);
auto disk_ann_index = std::dynamic_pointer_cast<DiskANNIndex>(granule->base_index);
target_vec.resize(ROUND_UP(target_vec.size(), 8));
LOG_DEBUG(&Poco::Logger::get("DiskANN"), "Searching for vector of dim {}", target_vec.size());
if (target_vec.empty()) {
return true;
}
disk_ann_index->search(target_vec.data(), k, n, init_ids, indicies.data(), distances.data());
float distance = *std::min_element(distances.begin(), distances.end());
LOG_DEBUG(&Poco::Logger::get("DiskANN"), "Maybe true on granule distances: {} <? {}", distance, min_distance);
/*
When using L2, DiskANN returns not the exact distance, but distance squared, that's why
we have to rise given distance to the power of 2.
Also, I don't know why, but DiskANN is not able to give precise answer to ANN task,
maybe it depends on hyperparameters and fine-tuning is needed. Nevertheless, temporary
ERROR_COEF is added to minimise the likelihood of false negative result
*/
/*
const static float ERROR_COEF = 10.f;
return distance < min_distance * min_distance * ERROR_COEF;
*/
return true;
}
MergeTreeIndexGranulePtr MergeTreeIndexDiskANN::createIndexGranule() const
{
return std::make_shared<MergeTreeIndexGranuleDiskANN>(index.name, index.sample_block);
}
MergeTreeIndexAggregatorPtr MergeTreeIndexDiskANN::createIndexAggregator() const
{
return std::make_shared<MergeTreeIndexAggregatorDiskANN>(index.name, index.sample_block);
}
MergeTreeIndexConditionPtr MergeTreeIndexDiskANN::createIndexCondition(
const SelectQueryInfo & query, ContextPtr context) const
{
return std::make_shared<MergeTreeIndexConditionDiskANN>(index, query, context);
};
MergeTreeIndexFormat MergeTreeIndexDiskANN::getDeserializedFormat(const DiskPtr disk, const std::string & relative_path_prefix) const
{
if (disk->exists(relative_path_prefix + ".idx2"))
return {2, ".idx2"};
else if (disk->exists(relative_path_prefix + ".idx"))
return {1, ".idx"};
return {0 /* unknown */, ""};
}
MergeTreeIndexPtr diskANNIndexCreator(
const IndexDescription & index)
{
return std::make_shared<MergeTreeIndexDiskANN>(index);
}
void diskANNIndexValidator(const IndexDescription & /* index */, bool /* attach */)
{}
}

View File

@ -0,0 +1,111 @@
#pragma once
#include <Storages/MergeTree/MergeTreeIndices.h>
#include <Storages/MergeTree/MergeTreeData.h>
#include <Storages/MergeTree/KeyCondition.h>
#include <Storages/MergeTree/CommonCondition.h>
#include <memory>
#include <random>
#include <string_view>
#include "IO/WriteBuffer.h"
#include "index.h"
namespace DB
{
using DiskANNIndex = diskann::Index<Float32>;
using DiskANNIndexPtr = std::shared_ptr<DiskANNIndex>;
// !TODO: Working only with Float32 type
using DiskANNValue = Float32;
struct MergeTreeIndexGranuleDiskANN final : public IMergeTreeIndexGranule
{
MergeTreeIndexGranuleDiskANN(const String & index_name_, const Block & index_sample_block_);
MergeTreeIndexGranuleDiskANN(
const String & index_name_, const Block & index_sample_block_,
DiskANNIndexPtr base_index_, uint32_t dimensions, std::vector<DiskANNValue> datapoints
);
~MergeTreeIndexGranuleDiskANN() override = default;
void serializeBinary(WriteBuffer & out) const override;
uint64_t calculateIndexSize() const;
void deserializeBinary(ReadBuffer & in, MergeTreeIndexVersion version) override;
bool empty() const override { return false; }
std::optional<uint32_t> dimensions;
std::vector<DiskANNValue> datapoints;
String index_name;
Block index_sample_block;
DiskANNIndexPtr base_index;
};
struct MergeTreeIndexAggregatorDiskANN final : IMergeTreeIndexAggregator
{
MergeTreeIndexAggregatorDiskANN(const String & index_name_, const Block & index_sample_block);
~MergeTreeIndexAggregatorDiskANN() override = default;
bool empty() const override { return accumulated_data.empty(); }
MergeTreeIndexGranulePtr getGranuleAndReset() override;
void update(const Block & block, size_t * pos, size_t limit) override;
private:
void flattenAccumulatedData(std::vector<std::vector<DiskANNValue>> data);
private:
String index_name;
Block index_sample_block;
std::optional<uint32_t> dimensions;
std::vector<DiskANNValue> accumulated_data;
};
class MergeTreeIndexConditionDiskANN final : public IMergeTreeIndexCondition
{
public:
MergeTreeIndexConditionDiskANN(
const IndexDescription & index,
const SelectQueryInfo & query,
ContextPtr context
);
bool alwaysUnknownOrTrue() const override;
bool mayBeTrueOnGranule(MergeTreeIndexGranulePtr idx_granule) const override;
~MergeTreeIndexConditionDiskANN() override = default;
private:
Condition::Common::CommonCondition common_condition;
};
class MergeTreeIndexDiskANN : public IMergeTreeIndex
{
public:
explicit MergeTreeIndexDiskANN(const IndexDescription & index_)
: IMergeTreeIndex(index_)
{}
~MergeTreeIndexDiskANN() override = default;
MergeTreeIndexGranulePtr createIndexGranule() const override;
MergeTreeIndexAggregatorPtr createIndexAggregator() const override;
MergeTreeIndexConditionPtr createIndexCondition(
const SelectQueryInfo & query, ContextPtr context) const override;
bool mayBenefitFromIndexForIn(const ASTPtr & /*node*/) const override { return true; }
const char* getSerializedFileExtension() const override { return ".idx2"; }
MergeTreeIndexFormat getDeserializedFormat(const DiskPtr disk, const std::string & path_prefix) const override;
};
}

View File

@ -84,6 +84,9 @@ void MergeTreeIndexFactory::validate(const IndexDescription & index, bool attach
MergeTreeIndexFactory::MergeTreeIndexFactory()
{
registerCreator("diskann", diskANNIndexCreator);
registerValidator("diskann", diskANNIndexValidator);
registerCreator("minmax", minmaxIndexCreator);
registerValidator("minmax", minmaxIndexValidator);

View File

@ -208,6 +208,9 @@ private:
Validators validators;
};
MergeTreeIndexPtr diskANNIndexCreator(const IndexDescription & index);
void diskANNIndexValidator(const IndexDescription & index, bool attach);
MergeTreeIndexPtr minmaxIndexCreator(const IndexDescription & index);
void minmaxIndexValidator(const IndexDescription & index, bool attach);

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,113 @@
import pytest
from helpers.cluster import ClickHouseCluster
from helpers.client import QueryRuntimeException
from helpers.test_tools import assert_eq_with_retry
cluster = ClickHouseCluster(__file__)
node = cluster.add_instance("node")
@pytest.fixture(scope="module")
def start_cluster():
try:
cluster.start()
yield cluster
finally:
cluster.shutdown()
# big boi
embedding = (
0.06024,0.01686,0.005325,-0.009544,0.004116,-0.01059,-0.002037,0.0376,
-0.02249,0.02478,0.01913,0.0183,0.04434,-0.0577,0.05887,0.0167,0.07904,
0.0228,0.005283,-0.002245,-0.0008516,-0.02882,-0.0198,0.02747,-0.002686,
-0.00536,0.01132,-0.00846,-0.01865,-0.015526,0.0487,0.02243,-0.02321,
0.03735,-0.05173,-0.0252,-0.00805,0.003412,0.007347,0.02107,-0.04788,
-0.00452,-0.012054,-0.03558,0.03452,-0.0546,-0.03079,0.007328,0.01706,
-0.002962,0.00498,0.003387,0.01302,-0.02638,-0.02522,-0.02483,0.10315,
-0.02751,-0.001883,0.002201,0.04135,0.0008435,-0.01659,0.00714,-0.01947,
-0.05148,-0.04926,0.0841,-0.02086,0.02118,-0.04044,-0.003202,-0.00819,
-0.00533,0.0203,0.03656,-0.01473,0.01258,0.000288,-0.0363,0.005714,0.0348,
-0.01167,0.02007,0.002924,0.0451,-0.001959,-0.008896,0.01941,-0.001381,
0.01935,-0.02231,-0.708,0.02777,0.0373,-0.003986,0.00877,-0.001664,-0.05478,
0.02263,-0.02516,-0.02475,0.0544,-0.00831,0.03105,0.006763,-0.0461,0.013016,
-0.031,-0.03116,0.0461,0.0337,-0.0005875,0.00904,0.02382,-0.01491,0.01015,
0.00158,0.05084,-0.015305,-0.00772,0.0362,0.0334,0.01004,0.01723,0.03235,
0.02106,0.04752,-0.03262,0.01752,-0.00647,0.02199,-0.0373,0.09296,-0.02605,
-0.00794,0.0574,-0.02608,0.01345,-0.011566,0.01906,-0.03476,-0.00737,
-0.002981,-0.01936,0.008125,-0.00888,0.079,-0.00604,-0.002335,0.0366,
-0.02048,0.0789,-0.03958,-0.01469,0.02269,-0.013565,0.028,-0.03482,
-0.0006204,-0.0421,0.003387,-0.01933,-0.05478,0.01213,-0.0411,-0.02231,
-0.00218,0.01604,0.01663,0.001544,-0.0342,0.002089,-0.02547,-0.01056,
-0.06604,-0.006977,-0.010376,0.073,-0.01637,-0.02267,-0.0318,0.012085,
0.003386,0.01008,-0.008766,0.0008154,-0.00428,-0.01167,-0.0083,-0.009926,
-0.00000775,0.01651,-0.001198,-0.01701,0.01174,0.001672,-0.05197,-0.02385,
0.01889,0.0348,-0.02522,0.0536,-0.02635,0.02159,0.00805,-0.01534,-0.05832,
0.02022,0.01154,-0.033,0.1327,-0.04324,0.02092,-0.00683,-0.02008,0.008736,
-0.01342,0.0803,0.006012,0.11017,0.02551,0.01124,0.02985,0.01203,0.01788,
0.003204,0.0402,0.05145,-0.004402,-0.0342,0.01015,0.003265,-0.02858,0.01515,
-0.02948,-0.00684,-0.01243,0.002657,0.01743,0.00338,-0.00969,-0.0235,
-0.04434,0.03458,-0.05057,-0.02028,-0.0113,-0.01645,0.02066,0.02963,0.0312,
0.02374,-0.00571,0.01656,0.0004306,0.01631,-0.00118,-0.06586,-0.0466,
0.02402,0.0010395,-0.0394,0.002724,0.06775,0.00805,0.02908,-0.002623,
0.02457,0.04343,-0.02614,0.001141,-0.0151,0.03436,0.02481,0.002968,-0.0231,
0.005814,-0.01952,0.0003521,-0.00462,0.01246,0.002914,0.006153,0.004726,
-0.008766,0.06186,-0.04285,-0.02795,-0.0195,0.02283,-0.01532,-0.03906,
0.002748,0.01968,0.01927,-0.02249,0.00863,0.00987,-0.04395,-0.07904,-0.0738,
0.05325,0.01749,-0.002647,0.004536,-0.01665,-0.004314,-0.001041,0.00579,
-0.00928,-0.001073,0.09283,-0.02007,0.00432,0.02092,0.03033,-0.0007863,
-0.0231,-0.1035,0.00817,0.1106,-0.005802,0.01897,0.004032,-0.03586,0.01208,
-0.06464,-0.01122,0.05148,-0.0217,-0.01566,-0.003944,-0.001542,-0.02379,
0.03598,-0.009705,0.03702,0.0321,-0.01825,0.01926,0.02225,-0.02588,0.01026,
-0.010605,0.05063,0.04077,0.005386,0.001807,0.04764,-0.0485,0.02492,0.04214,
-0.02666,-0.00834,0.01569,-0.02435,0.03268,0.01855,0.0464,0.074,-0.0323,
0.02478,0.02812,-0.05862,-0.01484,-0.02225,0.07074,-0.083,0.00886,0.002829,
-0.0373,0.0133,0.02077,0.00789,-0.006886,0.09766,0.01718,-0.01507,0.004738,
-0.03513,-0.02435,-0.02284,0.007626,0.006992,0.02643,0.0242,-0.0139,
-0.01314,0.02092,-0.0473,-0.02531,0.02649,-0.03482,0.01254,-0.01962,
-0.006905,-0.03497,-0.03674,-0.09265,-0.01799,0.01627,0.0000971,-0.0395,
-0.0337,-0.01645,-0.01813,-0.0163,0.03937,0.01613,0.0967,0.03467,0.008644,
-0.004112,-0.003628,0.02989,-0.00684,-0.001299,0.002989,-0.0436,0.008316,
0.02017,0.00948,-0.03998,-0.05066,0.02573,-0.005447,-0.02568,-0.0223,
-0.02321,0.0241,0.005386,0.0535,-0.05148,0.01955,-0.00326,-0.005287,-0.0311,
0.01846,-0.009895,0.02252,0.0754,0.02232,-0.000737,-0.03012,0.01865,
-0.03506,0.012535,-0.00781,-0.01258,0.000363,-0.00882,0.03604,0.02089,
-0.02872,0.04346,0.01015,0.02193,0.00512,0.01068,-0.01743,0.0000452,
0.02278,0.01685,0.01034,0.03096,0.00968,0.007385,0.0209,-0.015114,0.04517,
0.0466,-0.003426,-0.0418,-0.00539,-0.01247,-0.02144,-0.006763,-0.02197,
0.001221,-0.00834,-0.00472,-0.02126,0.01529,0.02715,0.005226,-0.01617,
0.05203,0.0003045,-0.02583,-0.04303,-0.01749,-0.02094,-0.0336,0.04303,
0.0175,-0.02184,0.02324,-0.01173,-0.001993,-0.006622,-0.02744,0.009125,
-0.00701,-0.04028,0.01695,-0.03084,0.02217,0.00815,0.0363,-0.00158,-0.02916
)
def test_condition(start_cluster):
node.query(
f"""
CREATE TABLE annoy_test (
id UInt32,
url String,
embedding Tuple({"Float64, " * (len(embedding) - 1)}Float64)
) Engine=MergeTree ORDER BY id;
"""
)
assert len(embedding) == 512, "Laion is 512 dim embedding"
node.query(
"""
INSERT INTO annoy_test FROM INFILE './test_annoy_index/laion_cut.csv';
"""
)
select_query = f"""
select url from annoy_test where
L2Distance(embedding, {str(embedding)}) < 0.0001
"""
image_url = "http://media.vector4free.com/mini/pixel77-free-vector-bird-1004-270x200.jpg"
assert node.query(select_query).split("\n")[0] == image_url, "Didn't find correct nearest picture"

View File

@ -0,0 +1 @@
"rows_read": 1000,

View File

@ -0,0 +1,30 @@
#!/usr/bin/env bash
CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
# shellcheck source=../shell_config.sh
. "$CURDIR"/../shell_config.sh
$CLICKHOUSE_CLIENT --query="DROP TABLE IF EXISTS t_annoy_test;"
$CLICKHOUSE_CLIENT -n --query="
CREATE TABLE t_annoy_test
(
id Int64,
number Tuple(Float32, Float32, Float32),
INDEX x (number) TYPE annoy GRANULARITY 1
) ENGINE = MergeTree()
ORDER BY id
"
$CLICKHOUSE_CLIENT --query="
INSERT INTO t_annoy_test SELECT
number AS id,
(toFloat32(number), toFloat32(number), toFloat32(number))
FROM system.numbers
LIMIT 1000;"
# simple select
$CLICKHOUSE_CLIENT --query="SELECT * from t_annoy_test FORMAT JSON" | grep "rows_read"
$CLICKHOUSE_CLIENT --query="DROP TABLE t_annoy_test;"