Merge branch 'master' into refactor_databases_tmp_merge

This commit is contained in:
Alexander Tokmakov 2019-12-27 15:52:19 +03:00
commit 444167cc1d
156 changed files with 5654 additions and 1931 deletions

3
.gitmodules vendored
View File

@ -128,6 +128,9 @@
[submodule "contrib/icu"]
path = contrib/icu
url = https://github.com/unicode-org/icu.git
[submodule "contrib/flatbuffers"]
path = contrib/flatbuffers
url = https://github.com/google/flatbuffers.git
[submodule "contrib/libc-headers"]
path = contrib/libc-headers
url = https://github.com/ClickHouse-Extras/libc-headers.git

View File

@ -54,10 +54,12 @@ elseif(NOT MISSING_INTERNAL_PARQUET_LIBRARY AND NOT OS_FREEBSD)
endif()
if(${USE_STATIC_LIBRARIES})
set(FLATBUFFERS_LIBRARY flatbuffers)
set(ARROW_LIBRARY arrow_static)
set(PARQUET_LIBRARY parquet_static)
set(THRIFT_LIBRARY thrift_static)
else()
set(FLATBUFFERS_LIBRARY flatbuffers_shared)
set(ARROW_LIBRARY arrow_shared)
set(PARQUET_LIBRARY parquet_shared)
if(USE_INTERNAL_PARQUET_LIBRARY_NATIVE_CMAKE)
@ -74,7 +76,7 @@ endif()
endif()
if(USE_PARQUET)
message(STATUS "Using Parquet: ${ARROW_LIBRARY}:${ARROW_INCLUDE_DIR} ; ${PARQUET_LIBRARY}:${PARQUET_INCLUDE_DIR} ; ${THRIFT_LIBRARY}")
message(STATUS "Using Parquet: ${ARROW_LIBRARY}:${ARROW_INCLUDE_DIR} ; ${PARQUET_LIBRARY}:${PARQUET_INCLUDE_DIR} ; ${THRIFT_LIBRARY} ; ${FLATBUFFERS_LIBRARY}")
else()
message(STATUS "Building without Parquet support")
endif()

View File

@ -159,6 +159,8 @@ if (USE_INTERNAL_PARQUET_LIBRARY_NATIVE_CMAKE)
set (ARROW_PARQUET ON CACHE INTERNAL "")
set (ARROW_VERBOSE_THIRDPARTY_BUILD ON CACHE INTERNAL "")
set (ARROW_BUILD_SHARED 1 CACHE INTERNAL "")
set (ARROW_BUILD_UTILITIES OFF CACHE INTERNAL "")
set (ARROW_BUILD_INTEGRATION OFF CACHE INTERNAL "")
set (ARROW_BOOST_HEADER_ONLY ON CACHE INTERNAL "")
set (Boost_FOUND 1 CACHE INTERNAL "")
if (MAKE_STATIC_LIBRARIES)

2
contrib/arrow vendored

@ -1 +1 @@
Subproject commit 87ac6fddaf21d0b4ee8b8090533ff293db0da1b4
Subproject commit b789226ccb2124285792107c758bb3b40b3d082a

View File

@ -1,46 +1,48 @@
include(ExternalProject)
# === thrift
set(LIBRARY_DIR ${ClickHouse_SOURCE_DIR}/contrib/thrift/lib/cpp)
# contrib/thrift/lib/cpp/CMakeLists.txt
set(thriftcpp_SOURCES
${LIBRARY_DIR}/src/thrift/TApplicationException.cpp
${LIBRARY_DIR}/src/thrift/TOutput.cpp
${LIBRARY_DIR}/src/thrift/async/TAsyncChannel.cpp
${LIBRARY_DIR}/src/thrift/async/TAsyncProtocolProcessor.cpp
${LIBRARY_DIR}/src/thrift/async/TConcurrentClientSyncInfo.h
${LIBRARY_DIR}/src/thrift/async/TConcurrentClientSyncInfo.cpp
${LIBRARY_DIR}/src/thrift/concurrency/ThreadManager.cpp
${LIBRARY_DIR}/src/thrift/concurrency/TimerManager.cpp
${LIBRARY_DIR}/src/thrift/concurrency/Util.cpp
${LIBRARY_DIR}/src/thrift/processor/PeekProcessor.cpp
${LIBRARY_DIR}/src/thrift/protocol/TBase64Utils.cpp
${LIBRARY_DIR}/src/thrift/protocol/TDebugProtocol.cpp
${LIBRARY_DIR}/src/thrift/protocol/TJSONProtocol.cpp
${LIBRARY_DIR}/src/thrift/protocol/TMultiplexedProtocol.cpp
${LIBRARY_DIR}/src/thrift/protocol/TProtocol.cpp
${LIBRARY_DIR}/src/thrift/transport/TTransportException.cpp
${LIBRARY_DIR}/src/thrift/transport/TFDTransport.cpp
${LIBRARY_DIR}/src/thrift/transport/TSimpleFileTransport.cpp
${LIBRARY_DIR}/src/thrift/transport/THttpTransport.cpp
${LIBRARY_DIR}/src/thrift/transport/THttpClient.cpp
${LIBRARY_DIR}/src/thrift/transport/THttpServer.cpp
${LIBRARY_DIR}/src/thrift/transport/TSocket.cpp
${LIBRARY_DIR}/src/thrift/transport/TSocketPool.cpp
${LIBRARY_DIR}/src/thrift/transport/TServerSocket.cpp
${LIBRARY_DIR}/src/thrift/transport/TTransportUtils.cpp
${LIBRARY_DIR}/src/thrift/transport/TBufferTransports.cpp
${LIBRARY_DIR}/src/thrift/server/TConnectedClient.cpp
${LIBRARY_DIR}/src/thrift/server/TServerFramework.cpp
${LIBRARY_DIR}/src/thrift/server/TSimpleServer.cpp
${LIBRARY_DIR}/src/thrift/server/TThreadPoolServer.cpp
${LIBRARY_DIR}/src/thrift/server/TThreadedServer.cpp
)
set( thriftcpp_threads_SOURCES
${LIBRARY_DIR}/src/thrift/concurrency/ThreadFactory.cpp
${LIBRARY_DIR}/src/thrift/concurrency/Thread.cpp
${LIBRARY_DIR}/src/thrift/concurrency/Monitor.cpp
${LIBRARY_DIR}/src/thrift/concurrency/Mutex.cpp
)
${LIBRARY_DIR}/src/thrift/TApplicationException.cpp
${LIBRARY_DIR}/src/thrift/TOutput.cpp
${LIBRARY_DIR}/src/thrift/async/TAsyncChannel.cpp
${LIBRARY_DIR}/src/thrift/async/TAsyncProtocolProcessor.cpp
${LIBRARY_DIR}/src/thrift/async/TConcurrentClientSyncInfo.h
${LIBRARY_DIR}/src/thrift/async/TConcurrentClientSyncInfo.cpp
${LIBRARY_DIR}/src/thrift/concurrency/ThreadManager.cpp
${LIBRARY_DIR}/src/thrift/concurrency/TimerManager.cpp
${LIBRARY_DIR}/src/thrift/concurrency/Util.cpp
${LIBRARY_DIR}/src/thrift/processor/PeekProcessor.cpp
${LIBRARY_DIR}/src/thrift/protocol/TBase64Utils.cpp
${LIBRARY_DIR}/src/thrift/protocol/TDebugProtocol.cpp
${LIBRARY_DIR}/src/thrift/protocol/TJSONProtocol.cpp
${LIBRARY_DIR}/src/thrift/protocol/TMultiplexedProtocol.cpp
${LIBRARY_DIR}/src/thrift/protocol/TProtocol.cpp
${LIBRARY_DIR}/src/thrift/transport/TTransportException.cpp
${LIBRARY_DIR}/src/thrift/transport/TFDTransport.cpp
${LIBRARY_DIR}/src/thrift/transport/TSimpleFileTransport.cpp
${LIBRARY_DIR}/src/thrift/transport/THttpTransport.cpp
${LIBRARY_DIR}/src/thrift/transport/THttpClient.cpp
${LIBRARY_DIR}/src/thrift/transport/THttpServer.cpp
${LIBRARY_DIR}/src/thrift/transport/TSocket.cpp
${LIBRARY_DIR}/src/thrift/transport/TSocketPool.cpp
${LIBRARY_DIR}/src/thrift/transport/TServerSocket.cpp
${LIBRARY_DIR}/src/thrift/transport/TTransportUtils.cpp
${LIBRARY_DIR}/src/thrift/transport/TBufferTransports.cpp
${LIBRARY_DIR}/src/thrift/server/TConnectedClient.cpp
${LIBRARY_DIR}/src/thrift/server/TServerFramework.cpp
${LIBRARY_DIR}/src/thrift/server/TSimpleServer.cpp
${LIBRARY_DIR}/src/thrift/server/TThreadPoolServer.cpp
${LIBRARY_DIR}/src/thrift/server/TThreadedServer.cpp
)
set(thriftcpp_threads_SOURCES
${LIBRARY_DIR}/src/thrift/concurrency/ThreadFactory.cpp
${LIBRARY_DIR}/src/thrift/concurrency/Thread.cpp
${LIBRARY_DIR}/src/thrift/concurrency/Monitor.cpp
${LIBRARY_DIR}/src/thrift/concurrency/Mutex.cpp
)
add_library(${THRIFT_LIBRARY} ${thriftcpp_SOURCES} ${thriftcpp_threads_SOURCES})
set_target_properties(${THRIFT_LIBRARY} PROPERTIES CXX_STANDARD 14) # REMOVE after https://github.com/apache/thrift/pull/1641
target_include_directories(${THRIFT_LIBRARY} SYSTEM PUBLIC ${ClickHouse_SOURCE_DIR}/contrib/thrift/lib/cpp/src PRIVATE ${Boost_INCLUDE_DIRS})
@ -70,22 +72,94 @@ add_custom_command(OUTPUT orc_proto.pb.h orc_proto.pb.cc
--cpp_out="${CMAKE_CURRENT_BINARY_DIR}"
"${PROTO_DIR}/orc_proto.proto")
# arrow-cmake cmake file calling orc cmake subroutine which detects certain compiler features.
# === flatbuffers
##############################################################
# fbs - Step 1: build flatbuffers lib and flatc compiler
##############################################################
set(FLATBUFFERS_SRC_DIR ${ClickHouse_SOURCE_DIR}/contrib/flatbuffers)
set(FLATBUFFERS_BINARY_DIR ${ClickHouse_BINARY_DIR}/contrib/flatbuffers)
set(FLATBUFFERS_INCLUDE_DIR ${FLATBUFFERS_SRC_DIR}/include)
set(FLATBUFFERS_COMPILER "${FLATBUFFERS_BINARY_DIR}/flatc")
# set flatbuffers CMake options
if (${USE_STATIC_LIBRARIES})
set(FLATBUFFERS_BUILD_FLATLIB ON CACHE BOOL "Enable the build of the flatbuffers library")
set(FLATBUFFERS_BUILD_SHAREDLIB OFF CACHE BOOL "Disable the build of the flatbuffers shared library")
else ()
set(FLATBUFFERS_BUILD_SHAREDLIB ON CACHE BOOL "Enable the build of the flatbuffers shared library")
set(FLATBUFFERS_BUILD_FLATLIB OFF CACHE BOOL "Disable the build of the flatbuffers library")
endif ()
set(FLATBUFFERS_BUILD_FLATC ON CACHE BOOL "Build flatbuffers compiler")
set(FLATBUFFERS_BUILD_TESTS OFF CACHE BOOL "Skip flatbuffers tests")
add_subdirectory(${FLATBUFFERS_SRC_DIR} "${FLATBUFFERS_BINARY_DIR}")
###################################
# fbs - Step 2: compile *.fbs files
###################################
set(ARROW_IPC_SRC_DIR ${ARROW_SRC_DIR}/arrow/ipc)
set(ARROW_FORMAT_SRC_DIR ${ARROW_SRC_DIR}/../../format)
set(ARROW_GENERATED_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/arrow_gen_headers)
set(FLATBUFFERS_COMPILED_OUT_DIR ${ARROW_GENERATED_INCLUDE_DIR}/arrow/ipc)
set(FBS_OUTPUT_FILES
"${FLATBUFFERS_COMPILED_OUT_DIR}/File_generated.h"
"${FLATBUFFERS_COMPILED_OUT_DIR}/Message_generated.h"
"${FLATBUFFERS_COMPILED_OUT_DIR}/feather_generated.h"
"${FLATBUFFERS_COMPILED_OUT_DIR}/Schema_generated.h"
"${FLATBUFFERS_COMPILED_OUT_DIR}/SparseTensor_generated.h"
"${FLATBUFFERS_COMPILED_OUT_DIR}/Tensor_generated.h")
set(FBS_SRC
${ARROW_FORMAT_SRC_DIR}/Message.fbs
${ARROW_FORMAT_SRC_DIR}/File.fbs
${ARROW_FORMAT_SRC_DIR}/Schema.fbs
${ARROW_FORMAT_SRC_DIR}/Tensor.fbs
${ARROW_FORMAT_SRC_DIR}/SparseTensor.fbs
${ARROW_IPC_SRC_DIR}/feather.fbs)
foreach (FIL ${FBS_SRC})
get_filename_component(ABS_FIL ${FIL} ABSOLUTE)
list(APPEND ABS_FBS_SRC ${ABS_FIL})
endforeach ()
message(STATUS "FLATBUFFERS_LIBRARY: ${FLATBUFFERS_LIBRARY}, FLATBUFFERS_COMPILER: ${FLATBUFFERS_COMPILER}")
message(STATUS "FLATBUFFERS_COMPILED_OUT_DIR: ${FLATBUFFERS_COMPILED_OUT_DIR}")
message(STATUS "flatc: ${FLATBUFFERS_COMPILER} -c -o ${FLATBUFFERS_COMPILED_OUT_DIR}/ ${ABS_FBS_SRC}")
add_custom_command(OUTPUT ${FBS_OUTPUT_FILES}
COMMAND ${FLATBUFFERS_COMPILER}
-c
-o
${FLATBUFFERS_COMPILED_OUT_DIR}/
${ABS_FBS_SRC}
DEPENDS flatc ${ABS_FBS_SRC}
COMMENT "Running flatc compiler on ${ABS_FBS_SRC}"
VERBATIM)
add_custom_target(metadata_fbs DEPENDS ${FBS_OUTPUT_FILES})
add_dependencies(metadata_fbs flatc)
# arrow-cmake cmake file calling orc cmake subroutine which detects certain compiler features.
# Apple Clang compiler failed to compile this code without specifying c++11 standard.
# As result these compiler features detected as absent. In result it failed to compile orc itself.
# In orc makefile there is code that sets flags, but arrow-cmake ignores these flags.
if (CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
set (CXX11_FLAGS "-std=c++0x")
endif()
set(CXX11_FLAGS "-std=c++0x")
endif ()
include(${ClickHouse_SOURCE_DIR}/contrib/orc/cmake_modules/CheckSourceCompiles.cmake)
include(orc_check.cmake)
configure_file("${ORC_INCLUDE_DIR}/orc/orc-config.hh.in" "${ORC_BUILD_INCLUDE_DIR}/orc/orc-config.hh")
configure_file("${ORC_INCLUDE_DIR}/orc/orc-config.hh.in" "${ORC_BUILD_INCLUDE_DIR}/orc/orc-config.hh")
configure_file("${ORC_SOURCE_SRC_DIR}/Adaptor.hh.in" "${ORC_BUILD_INCLUDE_DIR}/Adaptor.hh")
set(ORC_SRCS
${ARROW_SRC_DIR}/arrow/adapters/orc/adapter.cc
${ARROW_SRC_DIR}/arrow/adapters/orc/adapter_util.cc
${ORC_SOURCE_SRC_DIR}/Exceptions.cc
${ORC_SOURCE_SRC_DIR}/OrcFile.cc
${ORC_SOURCE_SRC_DIR}/Reader.cc
@ -119,126 +193,160 @@ set(ORC_SRCS
# === arrow
set(LIBRARY_DIR ${ClickHouse_SOURCE_DIR}/contrib/arrow/cpp/src/arrow)
configure_file("${LIBRARY_DIR}/util/config.h.cmake" "${CMAKE_CURRENT_SOURCE_DIR}/cpp/src/arrow/util/config.h")
# arrow/cpp/src/arrow/CMakeLists.txt
set(ARROW_SRCS
${LIBRARY_DIR}/array.cc
${LIBRARY_DIR}/array.cc
${LIBRARY_DIR}/buffer.cc
${LIBRARY_DIR}/builder.cc
${LIBRARY_DIR}/compare.cc
${LIBRARY_DIR}/extension_type.cc
${LIBRARY_DIR}/memory_pool.cc
${LIBRARY_DIR}/pretty_print.cc
${LIBRARY_DIR}/record_batch.cc
${LIBRARY_DIR}/result.cc
${LIBRARY_DIR}/scalar.cc
${LIBRARY_DIR}/sparse_tensor.cc
${LIBRARY_DIR}/status.cc
${LIBRARY_DIR}/table_builder.cc
${LIBRARY_DIR}/table.cc
${LIBRARY_DIR}/tensor.cc
${LIBRARY_DIR}/type.cc
${LIBRARY_DIR}/visitor.cc
${LIBRARY_DIR}/builder.cc
${LIBRARY_DIR}/array/builder_adaptive.cc
${LIBRARY_DIR}/array/builder_base.cc
${LIBRARY_DIR}/array/builder_binary.cc
${LIBRARY_DIR}/array/builder_decimal.cc
${LIBRARY_DIR}/array/builder_dict.cc
${LIBRARY_DIR}/array/builder_nested.cc
${LIBRARY_DIR}/array/builder_primitive.cc
${LIBRARY_DIR}/array/builder_adaptive.cc
${LIBRARY_DIR}/array/builder_base.cc
${LIBRARY_DIR}/array/builder_binary.cc
${LIBRARY_DIR}/array/builder_decimal.cc
${LIBRARY_DIR}/array/builder_dict.cc
${LIBRARY_DIR}/array/builder_nested.cc
${LIBRARY_DIR}/array/builder_primitive.cc
${LIBRARY_DIR}/array/builder_union.cc
${LIBRARY_DIR}/array/concatenate.cc
${LIBRARY_DIR}/array/dict_internal.cc
${LIBRARY_DIR}/array/diff.cc
${LIBRARY_DIR}/buffer.cc
${LIBRARY_DIR}/compare.cc
${LIBRARY_DIR}/memory_pool.cc
${LIBRARY_DIR}/pretty_print.cc
${LIBRARY_DIR}/record_batch.cc
${LIBRARY_DIR}/status.cc
${LIBRARY_DIR}/table.cc
${LIBRARY_DIR}/table_builder.cc
${LIBRARY_DIR}/tensor.cc
${LIBRARY_DIR}/sparse_tensor.cc
${LIBRARY_DIR}/type.cc
${LIBRARY_DIR}/visitor.cc
${LIBRARY_DIR}/csv/converter.cc
${LIBRARY_DIR}/csv/chunker.cc
${LIBRARY_DIR}/csv/column_builder.cc
${LIBRARY_DIR}/csv/options.cc
${LIBRARY_DIR}/csv/parser.cc
${LIBRARY_DIR}/csv/reader.cc
${LIBRARY_DIR}/csv/converter.cc
${LIBRARY_DIR}/csv/chunker.cc
${LIBRARY_DIR}/csv/column-builder.cc
${LIBRARY_DIR}/csv/options.cc
${LIBRARY_DIR}/csv/parser.cc
${LIBRARY_DIR}/csv/reader.cc
${LIBRARY_DIR}/ipc/dictionary.cc
${LIBRARY_DIR}/ipc/feather.cc
${LIBRARY_DIR}/ipc/message.cc
${LIBRARY_DIR}/ipc/metadata_internal.cc
${LIBRARY_DIR}/ipc/options.cc
${LIBRARY_DIR}/ipc/reader.cc
${LIBRARY_DIR}/ipc/writer.cc
${LIBRARY_DIR}/io/buffered.cc
${LIBRARY_DIR}/io/compressed.cc
${LIBRARY_DIR}/io/file.cc
${LIBRARY_DIR}/io/interfaces.cc
${LIBRARY_DIR}/io/memory.cc
${LIBRARY_DIR}/io/readahead.cc
${LIBRARY_DIR}/io/buffered.cc
${LIBRARY_DIR}/io/compressed.cc
${LIBRARY_DIR}/io/file.cc
${LIBRARY_DIR}/io/interfaces.cc
${LIBRARY_DIR}/io/memory.cc
${LIBRARY_DIR}/io/readahead.cc
${LIBRARY_DIR}/io/slow.cc
${LIBRARY_DIR}/util/bit-util.cc
${LIBRARY_DIR}/util/compression.cc
${LIBRARY_DIR}/util/cpu-info.cc
${LIBRARY_DIR}/util/decimal.cc
${LIBRARY_DIR}/util/int-util.cc
${LIBRARY_DIR}/util/io-util.cc
${LIBRARY_DIR}/util/logging.cc
${LIBRARY_DIR}/util/key_value_metadata.cc
${LIBRARY_DIR}/util/task-group.cc
${LIBRARY_DIR}/util/thread-pool.cc
${LIBRARY_DIR}/util/trie.cc
${LIBRARY_DIR}/util/utf8.cc
${ORC_SRCS}
)
${LIBRARY_DIR}/util/basic_decimal.cc
${LIBRARY_DIR}/util/bit_util.cc
${LIBRARY_DIR}/util/compression.cc
${LIBRARY_DIR}/util/compression_lz4.cc
${LIBRARY_DIR}/util/compression_snappy.cc
${LIBRARY_DIR}/util/compression_zlib.cc
${LIBRARY_DIR}/util/compression_zstd.cc
${LIBRARY_DIR}/util/cpu_info.cc
${LIBRARY_DIR}/util/decimal.cc
${LIBRARY_DIR}/util/int_util.cc
${LIBRARY_DIR}/util/io_util.cc
${LIBRARY_DIR}/util/key_value_metadata.cc
${LIBRARY_DIR}/util/logging.cc
${LIBRARY_DIR}/util/memory.cc
${LIBRARY_DIR}/util/string_builder.cc
${LIBRARY_DIR}/util/string.cc
${LIBRARY_DIR}/util/task_group.cc
${LIBRARY_DIR}/util/thread_pool.cc
${LIBRARY_DIR}/util/trie.cc
${LIBRARY_DIR}/util/utf8.cc
${LIBRARY_DIR}/vendored/base64.cpp
${ORC_SRCS}
)
set(ARROW_SRCS ${ARROW_SRCS}
${LIBRARY_DIR}/compute/context.cc
${LIBRARY_DIR}/compute/kernels/boolean.cc
${LIBRARY_DIR}/compute/kernels/cast.cc
${LIBRARY_DIR}/compute/kernels/hash.cc
${LIBRARY_DIR}/compute/kernels/util-internal.cc
)
${LIBRARY_DIR}/compute/context.cc
${LIBRARY_DIR}/compute/kernels/boolean.cc
${LIBRARY_DIR}/compute/kernels/cast.cc
${LIBRARY_DIR}/compute/kernels/hash.cc
${LIBRARY_DIR}/compute/kernels/util_internal.cc
)
if (LZ4_INCLUDE_DIR AND LZ4_LIBRARY)
set(ARROW_WITH_LZ4 1)
endif()
endif ()
if(SNAPPY_INCLUDE_DIR AND SNAPPY_LIBRARY)
if (SNAPPY_INCLUDE_DIR AND SNAPPY_LIBRARY)
set(ARROW_WITH_SNAPPY 1)
endif()
endif ()
if(ZLIB_INCLUDE_DIR AND ZLIB_LIBRARIES)
if (ZLIB_INCLUDE_DIR AND ZLIB_LIBRARIES)
set(ARROW_WITH_ZLIB 1)
endif()
endif ()
if (ZSTD_INCLUDE_DIR AND ZSTD_LIBRARY)
set(ARROW_WITH_ZSTD 1)
endif()
endif ()
if (ARROW_WITH_LZ4)
add_definitions(-DARROW_WITH_LZ4)
SET(ARROW_SRCS ${LIBRARY_DIR}/util/compression_lz4.cc ${ARROW_SRCS})
endif()
add_definitions(-DARROW_WITH_LZ4)
SET(ARROW_SRCS ${LIBRARY_DIR}/util/compression_lz4.cc ${ARROW_SRCS})
endif ()
if (ARROW_WITH_SNAPPY)
add_definitions(-DARROW_WITH_SNAPPY)
SET(ARROW_SRCS ${LIBRARY_DIR}/util/compression_snappy.cc ${ARROW_SRCS})
endif()
add_definitions(-DARROW_WITH_SNAPPY)
SET(ARROW_SRCS ${LIBRARY_DIR}/util/compression_snappy.cc ${ARROW_SRCS})
endif ()
if (ARROW_WITH_ZLIB)
add_definitions(-DARROW_WITH_ZLIB)
SET(ARROW_SRCS ${LIBRARY_DIR}/util/compression_zlib.cc ${ARROW_SRCS})
endif()
add_definitions(-DARROW_WITH_ZLIB)
SET(ARROW_SRCS ${LIBRARY_DIR}/util/compression_zlib.cc ${ARROW_SRCS})
endif ()
if (ARROW_WITH_ZSTD)
add_definitions(-DARROW_WITH_ZSTD)
SET(ARROW_SRCS ${LIBRARY_DIR}/util/compression_zstd.cc ${ARROW_SRCS})
endif()
add_definitions(-DARROW_WITH_ZSTD)
SET(ARROW_SRCS ${LIBRARY_DIR}/util/compression_zstd.cc ${ARROW_SRCS})
endif ()
add_library(${ARROW_LIBRARY} ${ARROW_SRCS})
# Arrow dependencies
add_dependencies(${ARROW_LIBRARY} ${FLATBUFFERS_LIBRARY} metadata_fbs)
target_link_libraries(${ARROW_LIBRARY} PRIVATE boost_system_internal boost_filesystem_internal boost_regex_internal)
target_link_libraries(${ARROW_LIBRARY} PRIVATE ${FLATBUFFERS_LIBRARY})
if (USE_INTERNAL_PROTOBUF_LIBRARY)
add_dependencies(${ARROW_LIBRARY} protoc)
endif()
endif ()
target_include_directories(${ARROW_LIBRARY} SYSTEM PUBLIC ${ClickHouse_SOURCE_DIR}/contrib/arrow/cpp/src PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cpp/src ${Boost_INCLUDE_DIRS})
target_link_libraries(${ARROW_LIBRARY} PRIVATE ${DOUBLE_CONVERSION_LIBRARIES} ${Protobuf_LIBRARY})
if (ARROW_WITH_LZ4)
target_link_libraries(${ARROW_LIBRARY} PRIVATE ${LZ4_LIBRARY})
endif()
endif ()
if (ARROW_WITH_SNAPPY)
target_link_libraries(${ARROW_LIBRARY} PRIVATE ${SNAPPY_LIBRARY})
endif()
endif ()
if (ARROW_WITH_ZLIB)
target_link_libraries(${ARROW_LIBRARY} PRIVATE ${ZLIB_LIBRARIES})
endif()
endif ()
if (ARROW_WITH_ZSTD)
target_link_libraries(${ARROW_LIBRARY} PRIVATE ${ZSTD_LIBRARY})
endif()
endif ()
target_include_directories(${ARROW_LIBRARY} PRIVATE SYSTEM ${ORC_INCLUDE_DIR})
target_include_directories(${ARROW_LIBRARY} PRIVATE SYSTEM ${ORC_SOURCE_SRC_DIR})
@ -248,52 +356,56 @@ target_include_directories(${ARROW_LIBRARY} PRIVATE SYSTEM ${ORC_BUILD_SRC_DIR})
target_include_directories(${ARROW_LIBRARY} PRIVATE SYSTEM ${ORC_BUILD_INCLUDE_DIR})
target_include_directories(${ARROW_LIBRARY} PRIVATE SYSTEM ${ORC_ADDITION_SOURCE_DIR})
target_include_directories(${ARROW_LIBRARY} PRIVATE SYSTEM ${ARROW_SRC_DIR})
target_include_directories(${ARROW_LIBRARY} PRIVATE SYSTEM ${FLATBUFFERS_INCLUDE_DIR})
target_include_directories(${ARROW_LIBRARY} PRIVATE SYSTEM ${ARROW_GENERATED_INCLUDE_DIR})
# === parquet
set(LIBRARY_DIR ${ClickHouse_SOURCE_DIR}/contrib/arrow/cpp/src/parquet)
# arrow/cpp/src/parquet/CMakeLists.txt
set(PARQUET_SRCS
${LIBRARY_DIR}/arrow/reader.cc
${LIBRARY_DIR}/arrow/record_reader.cc
${LIBRARY_DIR}/arrow/schema.cc
${LIBRARY_DIR}/arrow/writer.cc
${LIBRARY_DIR}/bloom_filter.cc
${LIBRARY_DIR}/column_reader.cc
${LIBRARY_DIR}/column_scanner.cc
${LIBRARY_DIR}/column_writer.cc
${LIBRARY_DIR}/file_reader.cc
${LIBRARY_DIR}/file_writer.cc
${LIBRARY_DIR}/metadata.cc
${LIBRARY_DIR}/murmur3.cc
${LIBRARY_DIR}/printer.cc
${LIBRARY_DIR}/schema.cc
${LIBRARY_DIR}/statistics.cc
${LIBRARY_DIR}/types.cc
${LIBRARY_DIR}/util/comparison.cc
${LIBRARY_DIR}/util/memory.cc
)
${LIBRARY_DIR}/arrow/reader.cc
${LIBRARY_DIR}/arrow/reader_internal.cc
${LIBRARY_DIR}/arrow/schema.cc
${LIBRARY_DIR}/arrow/writer.cc
${LIBRARY_DIR}/bloom_filter.cc
${LIBRARY_DIR}/column_reader.cc
${LIBRARY_DIR}/column_scanner.cc
${LIBRARY_DIR}/column_writer.cc
${LIBRARY_DIR}/deprecated_io.cc
${LIBRARY_DIR}/encoding.cc
${LIBRARY_DIR}/file_reader.cc
${LIBRARY_DIR}/file_writer.cc
${LIBRARY_DIR}/metadata.cc
${LIBRARY_DIR}/murmur3.cc
${LIBRARY_DIR}/platform.cc
${LIBRARY_DIR}/printer.cc
${LIBRARY_DIR}/properties.cc
${LIBRARY_DIR}/schema.cc
${LIBRARY_DIR}/statistics.cc
${LIBRARY_DIR}/types.cc
)
#list(TRANSFORM PARQUET_SRCS PREPEND ${LIBRARY_DIR}/) # cmake 3.12
list(APPEND PARQUET_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/src/parquet/parquet_constants.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/src/parquet/parquet_types.cpp
)
${CMAKE_CURRENT_SOURCE_DIR}/cpp/src/parquet/parquet_constants.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/src/parquet/parquet_types.cpp
)
add_library(${PARQUET_LIBRARY} ${PARQUET_SRCS})
target_include_directories(${PARQUET_LIBRARY} SYSTEM PUBLIC ${ClickHouse_SOURCE_DIR}/contrib/arrow/cpp/src ${CMAKE_CURRENT_SOURCE_DIR}/cpp/src)
include(${ClickHouse_SOURCE_DIR}/contrib/thrift/build/cmake/ConfigureChecks.cmake) # makes config.h
target_link_libraries(${PARQUET_LIBRARY} PUBLIC ${ARROW_LIBRARY} PRIVATE ${THRIFT_LIBRARY} ${Boost_REGEX_LIBRARY})
target_include_directories(${PARQUET_LIBRARY} PRIVATE ${Boost_INCLUDE_DIRS})
if(SANITIZE STREQUAL "undefined")
if (SANITIZE STREQUAL "undefined")
target_compile_options(${PARQUET_LIBRARY} PRIVATE -fno-sanitize=undefined)
target_compile_options(${ARROW_LIBRARY} PRIVATE -fno-sanitize=undefined)
endif()
endif ()
# === tools
set(TOOLS_DIR ${ClickHouse_SOURCE_DIR}/contrib/arrow/cpp/tools/parquet)
set(PARQUET_TOOLS parquet-dump-schema parquet-reader parquet-scan)
foreach(TOOL ${PARQUET_TOOLS})
set(PARQUET_TOOLS parquet_dump_schema parquet_reader parquet_scan)
foreach (TOOL ${PARQUET_TOOLS})
add_executable(${TOOL} ${TOOLS_DIR}/${TOOL}.cc)
target_link_libraries(${TOOL} PRIVATE ${PARQUET_LIBRARY})
endforeach()
endforeach ()

View File

@ -0,0 +1,24 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#define ARROW_VERSION_MAJOR
#define ARROW_VERSION_MINOR
#define ARROW_VERSION_PATCH
#define ARROW_VERSION ((ARROW_VERSION_MAJOR * 1000) + ARROW_VERSION_MINOR) * 1000 + ARROW_VERSION_PATCH
/* #undef DOUBLE_CONVERSION_HAS_CASE_INSENSIBILITY */
/* #undef GRPCPP_PP_INCLUDE */

View File

@ -1,5 +1,5 @@
/**
* Autogenerated by Thrift Compiler (0.11.0)
* Autogenerated by Thrift Compiler (0.12.0)
*
* DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
* @generated

View File

@ -1,5 +1,5 @@
/**
* Autogenerated by Thrift Compiler (0.11.0)
* Autogenerated by Thrift Compiler (0.12.0)
*
* DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
* @generated

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,5 @@
/**
* Autogenerated by Thrift Compiler (0.11.0)
* Autogenerated by Thrift Compiler (0.12.0)
*
* DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
* @generated
@ -17,7 +17,7 @@
#include <thrift/stdcxx.h>
#include "parquet/util/windows_compatibility.h"
#include "parquet/windows_compatibility.h"
namespace parquet { namespace format {
@ -161,6 +161,8 @@ class MilliSeconds;
class MicroSeconds;
class NanoSeconds;
class TimeUnit;
class TimestampType;
@ -215,14 +217,14 @@ class OffsetIndex;
class ColumnIndex;
class FileMetaData;
class AesGcmV1;
class AesGcmCtrV1;
class EncryptionAlgorithm;
class FileMetaData;
class FileCryptoMetaData;
typedef struct _Statistics__isset {
@ -629,10 +631,42 @@ void swap(MicroSeconds &a, MicroSeconds &b);
std::ostream& operator<<(std::ostream& out, const MicroSeconds& obj);
class NanoSeconds : public virtual ::apache::thrift::TBase {
public:
NanoSeconds(const NanoSeconds&);
NanoSeconds& operator=(const NanoSeconds&);
NanoSeconds() {
}
virtual ~NanoSeconds() throw();
bool operator == (const NanoSeconds & /* rhs */) const
{
return true;
}
bool operator != (const NanoSeconds &rhs) const {
return !(*this == rhs);
}
bool operator < (const NanoSeconds & ) const;
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
virtual void printTo(std::ostream& out) const;
};
void swap(NanoSeconds &a, NanoSeconds &b);
std::ostream& operator<<(std::ostream& out, const NanoSeconds& obj);
typedef struct _TimeUnit__isset {
_TimeUnit__isset() : MILLIS(false), MICROS(false) {}
_TimeUnit__isset() : MILLIS(false), MICROS(false), NANOS(false) {}
bool MILLIS :1;
bool MICROS :1;
bool NANOS :1;
} _TimeUnit__isset;
class TimeUnit : public virtual ::apache::thrift::TBase {
@ -646,6 +680,7 @@ class TimeUnit : public virtual ::apache::thrift::TBase {
virtual ~TimeUnit() throw();
MilliSeconds MILLIS;
MicroSeconds MICROS;
NanoSeconds NANOS;
_TimeUnit__isset __isset;
@ -653,6 +688,8 @@ class TimeUnit : public virtual ::apache::thrift::TBase {
void __set_MICROS(const MicroSeconds& val);
void __set_NANOS(const NanoSeconds& val);
bool operator == (const TimeUnit & rhs) const
{
if (__isset.MILLIS != rhs.__isset.MILLIS)
@ -663,6 +700,10 @@ class TimeUnit : public virtual ::apache::thrift::TBase {
return false;
else if (__isset.MICROS && !(MICROS == rhs.MICROS))
return false;
if (__isset.NANOS != rhs.__isset.NANOS)
return false;
else if (__isset.NANOS && !(NANOS == rhs.NANOS))
return false;
return true;
}
bool operator != (const TimeUnit &rhs) const {
@ -867,7 +908,7 @@ void swap(BsonType &a, BsonType &b);
std::ostream& operator<<(std::ostream& out, const BsonType& obj);
typedef struct _LogicalType__isset {
_LogicalType__isset() : STRING(false), MAP(false), LIST(false), ENUM(false), DECIMAL(false), DATE(false), TIME(false), TIMESTAMP(false), INTEGER(false), UNKNOWN(false), JSON(false), BSON(false) {}
_LogicalType__isset() : STRING(false), MAP(false), LIST(false), ENUM(false), DECIMAL(false), DATE(false), TIME(false), TIMESTAMP(false), INTEGER(false), UNKNOWN(false), JSON(false), BSON(false), UUID(false) {}
bool STRING :1;
bool MAP :1;
bool LIST :1;
@ -880,6 +921,7 @@ typedef struct _LogicalType__isset {
bool UNKNOWN :1;
bool JSON :1;
bool BSON :1;
bool UUID :1;
} _LogicalType__isset;
class LogicalType : public virtual ::apache::thrift::TBase {
@ -903,6 +945,7 @@ class LogicalType : public virtual ::apache::thrift::TBase {
NullType UNKNOWN;
JsonType JSON;
BsonType BSON;
UUIDType UUID;
_LogicalType__isset __isset;
@ -930,6 +973,8 @@ class LogicalType : public virtual ::apache::thrift::TBase {
void __set_BSON(const BsonType& val);
void __set_UUID(const UUIDType& val);
bool operator == (const LogicalType & rhs) const
{
if (__isset.STRING != rhs.__isset.STRING)
@ -980,6 +1025,10 @@ class LogicalType : public virtual ::apache::thrift::TBase {
return false;
else if (__isset.BSON && !(BSON == rhs.BSON))
return false;
if (__isset.UUID != rhs.__isset.UUID)
return false;
else if (__isset.UUID && !(UUID == rhs.UUID))
return false;
return true;
}
bool operator != (const LogicalType &rhs) const {
@ -1722,8 +1771,8 @@ void swap(EncryptionWithFooterKey &a, EncryptionWithFooterKey &b);
std::ostream& operator<<(std::ostream& out, const EncryptionWithFooterKey& obj);
typedef struct _EncryptionWithColumnKey__isset {
_EncryptionWithColumnKey__isset() : column_key_metadata(false) {}
bool column_key_metadata :1;
_EncryptionWithColumnKey__isset() : key_metadata(false) {}
bool key_metadata :1;
} _EncryptionWithColumnKey__isset;
class EncryptionWithColumnKey : public virtual ::apache::thrift::TBase {
@ -1731,26 +1780,26 @@ class EncryptionWithColumnKey : public virtual ::apache::thrift::TBase {
EncryptionWithColumnKey(const EncryptionWithColumnKey&);
EncryptionWithColumnKey& operator=(const EncryptionWithColumnKey&);
EncryptionWithColumnKey() : column_key_metadata() {
EncryptionWithColumnKey() : key_metadata() {
}
virtual ~EncryptionWithColumnKey() throw();
std::vector<std::string> path_in_schema;
std::string column_key_metadata;
std::string key_metadata;
_EncryptionWithColumnKey__isset __isset;
void __set_path_in_schema(const std::vector<std::string> & val);
void __set_column_key_metadata(const std::string& val);
void __set_key_metadata(const std::string& val);
bool operator == (const EncryptionWithColumnKey & rhs) const
{
if (!(path_in_schema == rhs.path_in_schema))
return false;
if (__isset.column_key_metadata != rhs.__isset.column_key_metadata)
if (__isset.key_metadata != rhs.__isset.key_metadata)
return false;
else if (__isset.column_key_metadata && !(column_key_metadata == rhs.column_key_metadata))
else if (__isset.key_metadata && !(key_metadata == rhs.key_metadata))
return false;
return true;
}
@ -1823,14 +1872,15 @@ void swap(ColumnCryptoMetaData &a, ColumnCryptoMetaData &b);
std::ostream& operator<<(std::ostream& out, const ColumnCryptoMetaData& obj);
typedef struct _ColumnChunk__isset {
_ColumnChunk__isset() : file_path(false), meta_data(false), offset_index_offset(false), offset_index_length(false), column_index_offset(false), column_index_length(false), crypto_meta_data(false) {}
_ColumnChunk__isset() : file_path(false), meta_data(false), offset_index_offset(false), offset_index_length(false), column_index_offset(false), column_index_length(false), crypto_metadata(false), encrypted_column_metadata(false) {}
bool file_path :1;
bool meta_data :1;
bool offset_index_offset :1;
bool offset_index_length :1;
bool column_index_offset :1;
bool column_index_length :1;
bool crypto_meta_data :1;
bool crypto_metadata :1;
bool encrypted_column_metadata :1;
} _ColumnChunk__isset;
class ColumnChunk : public virtual ::apache::thrift::TBase {
@ -1838,7 +1888,7 @@ class ColumnChunk : public virtual ::apache::thrift::TBase {
ColumnChunk(const ColumnChunk&);
ColumnChunk& operator=(const ColumnChunk&);
ColumnChunk() : file_path(), file_offset(0), offset_index_offset(0), offset_index_length(0), column_index_offset(0), column_index_length(0) {
ColumnChunk() : file_path(), file_offset(0), offset_index_offset(0), offset_index_length(0), column_index_offset(0), column_index_length(0), encrypted_column_metadata() {
}
virtual ~ColumnChunk() throw();
@ -1849,7 +1899,8 @@ class ColumnChunk : public virtual ::apache::thrift::TBase {
int32_t offset_index_length;
int64_t column_index_offset;
int32_t column_index_length;
ColumnCryptoMetaData crypto_meta_data;
ColumnCryptoMetaData crypto_metadata;
std::string encrypted_column_metadata;
_ColumnChunk__isset __isset;
@ -1867,7 +1918,9 @@ class ColumnChunk : public virtual ::apache::thrift::TBase {
void __set_column_index_length(const int32_t val);
void __set_crypto_meta_data(const ColumnCryptoMetaData& val);
void __set_crypto_metadata(const ColumnCryptoMetaData& val);
void __set_encrypted_column_metadata(const std::string& val);
bool operator == (const ColumnChunk & rhs) const
{
@ -1897,9 +1950,13 @@ class ColumnChunk : public virtual ::apache::thrift::TBase {
return false;
else if (__isset.column_index_length && !(column_index_length == rhs.column_index_length))
return false;
if (__isset.crypto_meta_data != rhs.__isset.crypto_meta_data)
if (__isset.crypto_metadata != rhs.__isset.crypto_metadata)
return false;
else if (__isset.crypto_meta_data && !(crypto_meta_data == rhs.crypto_meta_data))
else if (__isset.crypto_metadata && !(crypto_metadata == rhs.crypto_metadata))
return false;
if (__isset.encrypted_column_metadata != rhs.__isset.encrypted_column_metadata)
return false;
else if (__isset.encrypted_column_metadata && !(encrypted_column_metadata == rhs.encrypted_column_metadata))
return false;
return true;
}
@ -1920,10 +1977,11 @@ void swap(ColumnChunk &a, ColumnChunk &b);
std::ostream& operator<<(std::ostream& out, const ColumnChunk& obj);
typedef struct _RowGroup__isset {
_RowGroup__isset() : sorting_columns(false), file_offset(false), total_compressed_size(false) {}
_RowGroup__isset() : sorting_columns(false), file_offset(false), total_compressed_size(false), ordinal(false) {}
bool sorting_columns :1;
bool file_offset :1;
bool total_compressed_size :1;
bool ordinal :1;
} _RowGroup__isset;
class RowGroup : public virtual ::apache::thrift::TBase {
@ -1931,7 +1989,7 @@ class RowGroup : public virtual ::apache::thrift::TBase {
RowGroup(const RowGroup&);
RowGroup& operator=(const RowGroup&);
RowGroup() : total_byte_size(0), num_rows(0), file_offset(0), total_compressed_size(0) {
RowGroup() : total_byte_size(0), num_rows(0), file_offset(0), total_compressed_size(0), ordinal(0) {
}
virtual ~RowGroup() throw();
@ -1941,6 +1999,7 @@ class RowGroup : public virtual ::apache::thrift::TBase {
std::vector<SortingColumn> sorting_columns;
int64_t file_offset;
int64_t total_compressed_size;
int16_t ordinal;
_RowGroup__isset __isset;
@ -1956,6 +2015,8 @@ class RowGroup : public virtual ::apache::thrift::TBase {
void __set_total_compressed_size(const int64_t val);
void __set_ordinal(const int16_t val);
bool operator == (const RowGroup & rhs) const
{
if (!(columns == rhs.columns))
@ -1976,6 +2037,10 @@ class RowGroup : public virtual ::apache::thrift::TBase {
return false;
else if (__isset.total_compressed_size && !(total_compressed_size == rhs.total_compressed_size))
return false;
if (__isset.ordinal != rhs.__isset.ordinal)
return false;
else if (__isset.ordinal && !(ordinal == rhs.ordinal))
return false;
return true;
}
bool operator != (const RowGroup &rhs) const {
@ -2215,90 +2280,11 @@ void swap(ColumnIndex &a, ColumnIndex &b);
std::ostream& operator<<(std::ostream& out, const ColumnIndex& obj);
typedef struct _FileMetaData__isset {
_FileMetaData__isset() : key_value_metadata(false), created_by(false), column_orders(false) {}
bool key_value_metadata :1;
bool created_by :1;
bool column_orders :1;
} _FileMetaData__isset;
class FileMetaData : public virtual ::apache::thrift::TBase {
public:
FileMetaData(const FileMetaData&);
FileMetaData& operator=(const FileMetaData&);
FileMetaData() : version(0), num_rows(0), created_by() {
}
virtual ~FileMetaData() throw();
int32_t version;
std::vector<SchemaElement> schema;
int64_t num_rows;
std::vector<RowGroup> row_groups;
std::vector<KeyValue> key_value_metadata;
std::string created_by;
std::vector<ColumnOrder> column_orders;
_FileMetaData__isset __isset;
void __set_version(const int32_t val);
void __set_schema(const std::vector<SchemaElement> & val);
void __set_num_rows(const int64_t val);
void __set_row_groups(const std::vector<RowGroup> & val);
void __set_key_value_metadata(const std::vector<KeyValue> & val);
void __set_created_by(const std::string& val);
void __set_column_orders(const std::vector<ColumnOrder> & val);
bool operator == (const FileMetaData & rhs) const
{
if (!(version == rhs.version))
return false;
if (!(schema == rhs.schema))
return false;
if (!(num_rows == rhs.num_rows))
return false;
if (!(row_groups == rhs.row_groups))
return false;
if (__isset.key_value_metadata != rhs.__isset.key_value_metadata)
return false;
else if (__isset.key_value_metadata && !(key_value_metadata == rhs.key_value_metadata))
return false;
if (__isset.created_by != rhs.__isset.created_by)
return false;
else if (__isset.created_by && !(created_by == rhs.created_by))
return false;
if (__isset.column_orders != rhs.__isset.column_orders)
return false;
else if (__isset.column_orders && !(column_orders == rhs.column_orders))
return false;
return true;
}
bool operator != (const FileMetaData &rhs) const {
return !(*this == rhs);
}
bool operator < (const FileMetaData & ) const;
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
virtual void printTo(std::ostream& out) const;
};
void swap(FileMetaData &a, FileMetaData &b);
std::ostream& operator<<(std::ostream& out, const FileMetaData& obj);
typedef struct _AesGcmV1__isset {
_AesGcmV1__isset() : aad_metadata(false), iv_prefix(false) {}
bool aad_metadata :1;
bool iv_prefix :1;
_AesGcmV1__isset() : aad_prefix(false), aad_file_unique(false), supply_aad_prefix(false) {}
bool aad_prefix :1;
bool aad_file_unique :1;
bool supply_aad_prefix :1;
} _AesGcmV1__isset;
class AesGcmV1 : public virtual ::apache::thrift::TBase {
@ -2306,28 +2292,35 @@ class AesGcmV1 : public virtual ::apache::thrift::TBase {
AesGcmV1(const AesGcmV1&);
AesGcmV1& operator=(const AesGcmV1&);
AesGcmV1() : aad_metadata(), iv_prefix() {
AesGcmV1() : aad_prefix(), aad_file_unique(), supply_aad_prefix(0) {
}
virtual ~AesGcmV1() throw();
std::string aad_metadata;
std::string iv_prefix;
std::string aad_prefix;
std::string aad_file_unique;
bool supply_aad_prefix;
_AesGcmV1__isset __isset;
void __set_aad_metadata(const std::string& val);
void __set_aad_prefix(const std::string& val);
void __set_iv_prefix(const std::string& val);
void __set_aad_file_unique(const std::string& val);
void __set_supply_aad_prefix(const bool val);
bool operator == (const AesGcmV1 & rhs) const
{
if (__isset.aad_metadata != rhs.__isset.aad_metadata)
if (__isset.aad_prefix != rhs.__isset.aad_prefix)
return false;
else if (__isset.aad_metadata && !(aad_metadata == rhs.aad_metadata))
else if (__isset.aad_prefix && !(aad_prefix == rhs.aad_prefix))
return false;
if (__isset.iv_prefix != rhs.__isset.iv_prefix)
if (__isset.aad_file_unique != rhs.__isset.aad_file_unique)
return false;
else if (__isset.iv_prefix && !(iv_prefix == rhs.iv_prefix))
else if (__isset.aad_file_unique && !(aad_file_unique == rhs.aad_file_unique))
return false;
if (__isset.supply_aad_prefix != rhs.__isset.supply_aad_prefix)
return false;
else if (__isset.supply_aad_prefix && !(supply_aad_prefix == rhs.supply_aad_prefix))
return false;
return true;
}
@ -2348,10 +2341,10 @@ void swap(AesGcmV1 &a, AesGcmV1 &b);
std::ostream& operator<<(std::ostream& out, const AesGcmV1& obj);
typedef struct _AesGcmCtrV1__isset {
_AesGcmCtrV1__isset() : aad_metadata(false), gcm_iv_prefix(false), ctr_iv_prefix(false) {}
bool aad_metadata :1;
bool gcm_iv_prefix :1;
bool ctr_iv_prefix :1;
_AesGcmCtrV1__isset() : aad_prefix(false), aad_file_unique(false), supply_aad_prefix(false) {}
bool aad_prefix :1;
bool aad_file_unique :1;
bool supply_aad_prefix :1;
} _AesGcmCtrV1__isset;
class AesGcmCtrV1 : public virtual ::apache::thrift::TBase {
@ -2359,35 +2352,35 @@ class AesGcmCtrV1 : public virtual ::apache::thrift::TBase {
AesGcmCtrV1(const AesGcmCtrV1&);
AesGcmCtrV1& operator=(const AesGcmCtrV1&);
AesGcmCtrV1() : aad_metadata(), gcm_iv_prefix(), ctr_iv_prefix() {
AesGcmCtrV1() : aad_prefix(), aad_file_unique(), supply_aad_prefix(0) {
}
virtual ~AesGcmCtrV1() throw();
std::string aad_metadata;
std::string gcm_iv_prefix;
std::string ctr_iv_prefix;
std::string aad_prefix;
std::string aad_file_unique;
bool supply_aad_prefix;
_AesGcmCtrV1__isset __isset;
void __set_aad_metadata(const std::string& val);
void __set_aad_prefix(const std::string& val);
void __set_gcm_iv_prefix(const std::string& val);
void __set_aad_file_unique(const std::string& val);
void __set_ctr_iv_prefix(const std::string& val);
void __set_supply_aad_prefix(const bool val);
bool operator == (const AesGcmCtrV1 & rhs) const
{
if (__isset.aad_metadata != rhs.__isset.aad_metadata)
if (__isset.aad_prefix != rhs.__isset.aad_prefix)
return false;
else if (__isset.aad_metadata && !(aad_metadata == rhs.aad_metadata))
else if (__isset.aad_prefix && !(aad_prefix == rhs.aad_prefix))
return false;
if (__isset.gcm_iv_prefix != rhs.__isset.gcm_iv_prefix)
if (__isset.aad_file_unique != rhs.__isset.aad_file_unique)
return false;
else if (__isset.gcm_iv_prefix && !(gcm_iv_prefix == rhs.gcm_iv_prefix))
else if (__isset.aad_file_unique && !(aad_file_unique == rhs.aad_file_unique))
return false;
if (__isset.ctr_iv_prefix != rhs.__isset.ctr_iv_prefix)
if (__isset.supply_aad_prefix != rhs.__isset.supply_aad_prefix)
return false;
else if (__isset.ctr_iv_prefix && !(ctr_iv_prefix == rhs.ctr_iv_prefix))
else if (__isset.supply_aad_prefix && !(supply_aad_prefix == rhs.supply_aad_prefix))
return false;
return true;
}
@ -2459,9 +2452,105 @@ void swap(EncryptionAlgorithm &a, EncryptionAlgorithm &b);
std::ostream& operator<<(std::ostream& out, const EncryptionAlgorithm& obj);
typedef struct _FileMetaData__isset {
_FileMetaData__isset() : key_value_metadata(false), created_by(false), column_orders(false), encryption_algorithm(false), footer_signing_key_metadata(false) {}
bool key_value_metadata :1;
bool created_by :1;
bool column_orders :1;
bool encryption_algorithm :1;
bool footer_signing_key_metadata :1;
} _FileMetaData__isset;
class FileMetaData : public virtual ::apache::thrift::TBase {
public:
FileMetaData(const FileMetaData&);
FileMetaData& operator=(const FileMetaData&);
FileMetaData() : version(0), num_rows(0), created_by(), footer_signing_key_metadata() {
}
virtual ~FileMetaData() throw();
int32_t version;
std::vector<SchemaElement> schema;
int64_t num_rows;
std::vector<RowGroup> row_groups;
std::vector<KeyValue> key_value_metadata;
std::string created_by;
std::vector<ColumnOrder> column_orders;
EncryptionAlgorithm encryption_algorithm;
std::string footer_signing_key_metadata;
_FileMetaData__isset __isset;
void __set_version(const int32_t val);
void __set_schema(const std::vector<SchemaElement> & val);
void __set_num_rows(const int64_t val);
void __set_row_groups(const std::vector<RowGroup> & val);
void __set_key_value_metadata(const std::vector<KeyValue> & val);
void __set_created_by(const std::string& val);
void __set_column_orders(const std::vector<ColumnOrder> & val);
void __set_encryption_algorithm(const EncryptionAlgorithm& val);
void __set_footer_signing_key_metadata(const std::string& val);
bool operator == (const FileMetaData & rhs) const
{
if (!(version == rhs.version))
return false;
if (!(schema == rhs.schema))
return false;
if (!(num_rows == rhs.num_rows))
return false;
if (!(row_groups == rhs.row_groups))
return false;
if (__isset.key_value_metadata != rhs.__isset.key_value_metadata)
return false;
else if (__isset.key_value_metadata && !(key_value_metadata == rhs.key_value_metadata))
return false;
if (__isset.created_by != rhs.__isset.created_by)
return false;
else if (__isset.created_by && !(created_by == rhs.created_by))
return false;
if (__isset.column_orders != rhs.__isset.column_orders)
return false;
else if (__isset.column_orders && !(column_orders == rhs.column_orders))
return false;
if (__isset.encryption_algorithm != rhs.__isset.encryption_algorithm)
return false;
else if (__isset.encryption_algorithm && !(encryption_algorithm == rhs.encryption_algorithm))
return false;
if (__isset.footer_signing_key_metadata != rhs.__isset.footer_signing_key_metadata)
return false;
else if (__isset.footer_signing_key_metadata && !(footer_signing_key_metadata == rhs.footer_signing_key_metadata))
return false;
return true;
}
bool operator != (const FileMetaData &rhs) const {
return !(*this == rhs);
}
bool operator < (const FileMetaData & ) const;
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
virtual void printTo(std::ostream& out) const;
};
void swap(FileMetaData &a, FileMetaData &b);
std::ostream& operator<<(std::ostream& out, const FileMetaData& obj);
typedef struct _FileCryptoMetaData__isset {
_FileCryptoMetaData__isset() : footer_key_metadata(false) {}
bool footer_key_metadata :1;
_FileCryptoMetaData__isset() : key_metadata(false) {}
bool key_metadata :1;
} _FileCryptoMetaData__isset;
class FileCryptoMetaData : public virtual ::apache::thrift::TBase {
@ -2469,36 +2558,26 @@ class FileCryptoMetaData : public virtual ::apache::thrift::TBase {
FileCryptoMetaData(const FileCryptoMetaData&);
FileCryptoMetaData& operator=(const FileCryptoMetaData&);
FileCryptoMetaData() : encrypted_footer(0), footer_key_metadata(), footer_offset(0) {
FileCryptoMetaData() : key_metadata() {
}
virtual ~FileCryptoMetaData() throw();
EncryptionAlgorithm encryption_algorithm;
bool encrypted_footer;
std::string footer_key_metadata;
int64_t footer_offset;
std::string key_metadata;
_FileCryptoMetaData__isset __isset;
void __set_encryption_algorithm(const EncryptionAlgorithm& val);
void __set_encrypted_footer(const bool val);
void __set_footer_key_metadata(const std::string& val);
void __set_footer_offset(const int64_t val);
void __set_key_metadata(const std::string& val);
bool operator == (const FileCryptoMetaData & rhs) const
{
if (!(encryption_algorithm == rhs.encryption_algorithm))
return false;
if (!(encrypted_footer == rhs.encrypted_footer))
if (__isset.key_metadata != rhs.__isset.key_metadata)
return false;
if (__isset.footer_key_metadata != rhs.__isset.footer_key_metadata)
return false;
else if (__isset.footer_key_metadata && !(footer_key_metadata == rhs.footer_key_metadata))
return false;
if (!(footer_offset == rhs.footer_offset))
else if (__isset.key_metadata && !(key_metadata == rhs.key_metadata))
return false;
return true;
}

1
contrib/flatbuffers vendored Submodule

@ -0,0 +1 @@
Subproject commit bf9eb67ab9371755c6bcece13cadc7693bcbf264

View File

@ -298,7 +298,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
global_context->shutdown();
LOG_DEBUG(log, "Shutted down storages.");
LOG_DEBUG(log, "Shut down storages.");
/** Explicitly destroy Context. It is more convenient than in destructor of Server, because logger is still available.
* At this moment, no one could own shared part of Context.

View File

@ -84,25 +84,17 @@
<!-- Quota for user. -->
<quota>default</quota>
<!-- For testing the table filters -->
<databases>
<!-- Example of row level security policy. -->
<!-- <databases>
<test>
<!-- Simple expression filter -->
<filtered_table1>
<filter>a = 1</filter>
</filtered_table1>
<!-- Complex expression filter -->
<filtered_table2>
<filter>a + b &lt; 1 or c - d &gt; 5</filter>
</filtered_table2>
<!-- Filter with ALIAS column -->
<filtered_table3>
<filter>c = 1</filter>
</filtered_table3>
</test>
</databases>
</databases> -->
</default>
<!-- Example of user with readonly access. -->

View File

@ -3,6 +3,7 @@
#include <Access/MemoryAccessStorage.h>
#include <Access/UsersConfigAccessStorage.h>
#include <Access/QuotaContextFactory.h>
#include <Access/RowPolicyContextFactory.h>
namespace DB
@ -21,7 +22,8 @@ namespace
AccessControlManager::AccessControlManager()
: MultipleAccessStorage(createStorages()),
quota_context_factory(std::make_unique<QuotaContextFactory>(*this))
quota_context_factory(std::make_unique<QuotaContextFactory>(*this)),
row_policy_context_factory(std::make_unique<RowPolicyContextFactory>(*this))
{
}
@ -49,4 +51,11 @@ std::vector<QuotaUsageInfo> AccessControlManager::getQuotaUsageInfo() const
{
return quota_context_factory->getUsageInfo();
}
std::shared_ptr<RowPolicyContext> AccessControlManager::getRowPolicyContext(const String & user_name) const
{
return row_policy_context_factory->createContext(user_name);
}
}

View File

@ -22,6 +22,8 @@ namespace DB
class QuotaContext;
class QuotaContextFactory;
struct QuotaUsageInfo;
class RowPolicyContext;
class RowPolicyContextFactory;
/// Manages access control entities.
@ -38,8 +40,11 @@ public:
std::vector<QuotaUsageInfo> getQuotaUsageInfo() const;
std::shared_ptr<RowPolicyContext> getRowPolicyContext(const String & user_name) const;
private:
std::unique_ptr<QuotaContextFactory> quota_context_factory;
std::unique_ptr<RowPolicyContextFactory> row_policy_context_factory;
};
}

View File

@ -1,5 +1,6 @@
#include <Access/IAccessEntity.h>
#include <Access/Quota.h>
#include <Access/RowPolicy.h>
#include <common/demangle.h>
@ -9,6 +10,8 @@ String IAccessEntity::getTypeName(std::type_index type)
{
if (type == typeid(Quota))
return "Quota";
if (type == typeid(RowPolicy))
return "Row policy";
return demangle(type.name());
}

View File

@ -0,0 +1,111 @@
#include <Access/RowPolicy.h>
#include <Interpreters/Context.h>
#include <Common/quoteString.h>
#include <boost/range/algorithm/equal.hpp>
namespace DB
{
namespace
{
void generateFullNameImpl(const String & database_, const String & table_name_, const String & policy_name_, String & full_name_)
{
full_name_.clear();
full_name_.reserve(database_.length() + table_name_.length() + policy_name_.length() + 6);
full_name_ += backQuoteIfNeed(policy_name_);
full_name_ += " ON ";
if (!database_.empty())
{
full_name_ += backQuoteIfNeed(database_);
full_name_ += '.';
}
full_name_ += backQuoteIfNeed(table_name_);
}
}
String RowPolicy::FullNameParts::getFullName() const
{
String full_name;
generateFullNameImpl(database, table_name, policy_name, full_name);
return full_name;
}
String RowPolicy::FullNameParts::getFullName(const Context & context) const
{
String full_name;
generateFullNameImpl(database.empty() ? context.getCurrentDatabase() : database, table_name, policy_name, full_name);
return full_name;
}
void RowPolicy::setDatabase(const String & database_)
{
database = database_;
generateFullNameImpl(database, table_name, policy_name, full_name);
}
void RowPolicy::setTableName(const String & table_name_)
{
table_name = table_name_;
generateFullNameImpl(database, table_name, policy_name, full_name);
}
void RowPolicy::setName(const String & policy_name_)
{
policy_name = policy_name_;
generateFullNameImpl(database, table_name, policy_name, full_name);
}
void RowPolicy::setFullName(const String & database_, const String & table_name_, const String & policy_name_)
{
database = database_;
table_name = table_name_;
policy_name = policy_name_;
generateFullNameImpl(database, table_name, policy_name, full_name);
}
bool RowPolicy::equal(const IAccessEntity & other) const
{
if (!IAccessEntity::equal(other))
return false;
const auto & other_policy = typeid_cast<const RowPolicy &>(other);
return (database == other_policy.database) && (table_name == other_policy.table_name) && (policy_name == other_policy.policy_name)
&& boost::range::equal(conditions, other_policy.conditions) && restrictive == other_policy.restrictive
&& (roles == other_policy.roles) && (all_roles == other_policy.all_roles) && (except_roles == other_policy.except_roles);
}
const char * RowPolicy::conditionIndexToString(ConditionIndex index)
{
switch (index)
{
case SELECT_FILTER: return "SELECT_FILTER";
case INSERT_CHECK: return "INSERT_CHECK";
case UPDATE_FILTER: return "UPDATE_FILTER";
case UPDATE_CHECK: return "UPDATE_CHECK";
case DELETE_FILTER: return "DELETE_FILTER";
}
__builtin_unreachable();
}
const char * RowPolicy::conditionIndexToColumnName(ConditionIndex index)
{
switch (index)
{
case SELECT_FILTER: return "select_filter";
case INSERT_CHECK: return "insert_check";
case UPDATE_FILTER: return "update_filter";
case UPDATE_CHECK: return "update_check";
case DELETE_FILTER: return "delete_filter";
}
__builtin_unreachable();
}
}

View File

@ -0,0 +1,81 @@
#pragma once
#include <Access/IAccessEntity.h>
namespace DB
{
class Context;
/** Represents a row level security policy for a table.
*/
struct RowPolicy : public IAccessEntity
{
void setDatabase(const String & database_);
void setTableName(const String & table_name_);
void setName(const String & policy_name_) override;
void setFullName(const String & database_, const String & table_name_, const String & policy_name_);
String getDatabase() const { return database; }
String getTableName() const { return table_name; }
String getName() const override { return policy_name; }
struct FullNameParts
{
String database;
String table_name;
String policy_name;
String getFullName() const;
String getFullName(const Context & context) const;
};
/// Filter is a SQL conditional expression used to figure out which rows should be visible
/// for user or available for modification. If the expression returns NULL or false for some rows
/// those rows are silently suppressed.
/// Check is a SQL condition expression used to check whether a row can be written into
/// the table. If the expression returns NULL or false an exception is thrown.
/// If a conditional expression here is empty it means no filtering is applied.
enum ConditionIndex
{
SELECT_FILTER,
INSERT_CHECK,
UPDATE_FILTER,
UPDATE_CHECK,
DELETE_FILTER,
};
static constexpr size_t MAX_CONDITION_INDEX = 5;
static const char * conditionIndexToString(ConditionIndex index);
static const char * conditionIndexToColumnName(ConditionIndex index);
String conditions[MAX_CONDITION_INDEX];
/// Sets that the policy is permissive.
/// A row is only accessible if at least one of the permissive policies passes,
/// in addition to all the restrictive policies.
void setPermissive(bool permissive_ = true) { setRestrictive(!permissive_); }
bool isPermissive() const { return !isRestrictive(); }
/// Sets that the policy is restrictive.
/// A row is only accessible if at least one of the permissive policies passes,
/// in addition to all the restrictive policies.
void setRestrictive(bool restrictive_ = true) { restrictive = restrictive_; }
bool isRestrictive() const { return restrictive; }
bool equal(const IAccessEntity & other) const override;
std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<RowPolicy>(); }
/// Which roles or users should use this quota.
Strings roles;
bool all_roles = false;
Strings except_roles;
private:
String database;
String table_name;
String policy_name;
bool restrictive = false;
};
using RowPolicyPtr = std::shared_ptr<const RowPolicy>;
}

View File

@ -0,0 +1,59 @@
#include <Access/RowPolicyContext.h>
#include <boost/range/adaptor/map.hpp>
#include <boost/range/algorithm/copy.hpp>
namespace DB
{
size_t RowPolicyContext::Hash::operator()(const DatabaseAndTableNameRef & database_and_table_name) const
{
return std::hash<StringRef>{}(database_and_table_name.first) - std::hash<StringRef>{}(database_and_table_name.second);
}
RowPolicyContext::RowPolicyContext()
: atomic_map_of_mixed_conditions(std::make_shared<MapOfMixedConditions>())
{
}
RowPolicyContext::~RowPolicyContext() = default;
RowPolicyContext::RowPolicyContext(const String & user_name_)
: user_name(user_name_)
{}
ASTPtr RowPolicyContext::getCondition(const String & database, const String & table_name, ConditionIndex index) const
{
/// We don't lock `mutex` here.
auto map_of_mixed_conditions = std::atomic_load(&atomic_map_of_mixed_conditions);
auto it = map_of_mixed_conditions->find({database, table_name});
if (it == map_of_mixed_conditions->end())
return {};
return it->second.mixed_conditions[index];
}
std::vector<UUID> RowPolicyContext::getCurrentPolicyIDs() const
{
/// We don't lock `mutex` here.
auto map_of_mixed_conditions = std::atomic_load(&atomic_map_of_mixed_conditions);
std::vector<UUID> policy_ids;
for (const auto & mixed_conditions : *map_of_mixed_conditions | boost::adaptors::map_values)
boost::range::copy(mixed_conditions.policy_ids, std::back_inserter(policy_ids));
return policy_ids;
}
std::vector<UUID> RowPolicyContext::getCurrentPolicyIDs(const String & database, const String & table_name) const
{
/// We don't lock `mutex` here.
auto map_of_mixed_conditions = std::atomic_load(&atomic_map_of_mixed_conditions);
auto it = map_of_mixed_conditions->find({database, table_name});
if (it == map_of_mixed_conditions->end())
return {};
return it->second.policy_ids;
}
}

View File

@ -0,0 +1,66 @@
#pragma once
#include <Access/RowPolicy.h>
#include <Core/Types.h>
#include <Core/UUID.h>
#include <common/StringRef.h>
#include <memory>
#include <unordered_map>
namespace DB
{
class IAST;
using ASTPtr = std::shared_ptr<IAST>;
/// Provides fast access to row policies' conditions for a specific user and tables.
class RowPolicyContext
{
public:
/// Default constructor makes a row policy usage context which restricts nothing.
RowPolicyContext();
~RowPolicyContext();
using ConditionIndex = RowPolicy::ConditionIndex;
/// Returns prepared filter for a specific table and operations.
/// The function can return nullptr, that means there is no filters applied.
/// The returned filter can be a combination of the filters defined by multiple row policies.
ASTPtr getCondition(const String & database, const String & table_name, ConditionIndex index) const;
/// Returns IDs of all the policies used by the current user.
std::vector<UUID> getCurrentPolicyIDs() const;
/// Returns IDs of the policies used by a concrete table.
std::vector<UUID> getCurrentPolicyIDs(const String & database, const String & table_name) const;
private:
friend class RowPolicyContextFactory;
friend struct ext::shared_ptr_helper<RowPolicyContext>;
RowPolicyContext(const String & user_name_); /// RowPolicyContext should be created by RowPolicyContextFactory.
using DatabaseAndTableName = std::pair<String, String>;
using DatabaseAndTableNameRef = std::pair<StringRef, StringRef>;
struct Hash
{
size_t operator()(const DatabaseAndTableNameRef & database_and_table_name) const;
};
static constexpr size_t MAX_CONDITION_INDEX = RowPolicy::MAX_CONDITION_INDEX;
using ParsedConditions = std::array<ASTPtr, MAX_CONDITION_INDEX>;
struct MixedConditions
{
std::unique_ptr<DatabaseAndTableName> database_and_table_name_keeper;
ParsedConditions mixed_conditions;
std::vector<UUID> policy_ids;
};
using MapOfMixedConditions = std::unordered_map<DatabaseAndTableNameRef, MixedConditions, Hash>;
const String user_name;
std::shared_ptr<const MapOfMixedConditions> atomic_map_of_mixed_conditions; /// Changed atomically, not protected by `mutex`.
};
using RowPolicyContextPtr = std::shared_ptr<RowPolicyContext>;
}

View File

@ -0,0 +1,314 @@
#include <Access/RowPolicyContextFactory.h>
#include <Access/RowPolicyContext.h>
#include <Access/AccessControlManager.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ExpressionListParsers.h>
#include <Parsers/parseQuery.h>
#include <Common/Exception.h>
#include <Common/quoteString.h>
#include <ext/range.h>
#include <boost/range/algorithm/copy.hpp>
#include <boost/range/algorithm_ext/erase.hpp>
namespace DB
{
namespace
{
bool tryGetLiteralBool(const IAST & ast, bool & value)
{
try
{
if (const ASTLiteral * literal = ast.as<ASTLiteral>())
{
value = !literal->value.isNull() && applyVisitor(FieldVisitorConvertToNumber<bool>(), literal->value);
return true;
}
return false;
}
catch (...)
{
return false;
}
}
ASTPtr applyFunctionAND(ASTs arguments)
{
bool const_arguments = true;
boost::range::remove_erase_if(arguments, [&](const ASTPtr & argument) -> bool
{
bool b;
if (!tryGetLiteralBool(*argument, b))
return false;
const_arguments &= b;
return true;
});
if (!const_arguments)
return std::make_shared<ASTLiteral>(Field{UInt8(0)});
if (arguments.empty())
return std::make_shared<ASTLiteral>(Field{UInt8(1)});
if (arguments.size() == 1)
return arguments[0];
auto function = std::make_shared<ASTFunction>();
auto exp_list = std::make_shared<ASTExpressionList>();
function->name = "and";
function->arguments = exp_list;
function->children.push_back(exp_list);
exp_list->children = std::move(arguments);
return function;
}
ASTPtr applyFunctionOR(ASTs arguments)
{
bool const_arguments = false;
boost::range::remove_erase_if(arguments, [&](const ASTPtr & argument) -> bool
{
bool b;
if (!tryGetLiteralBool(*argument, b))
return false;
const_arguments |= b;
return true;
});
if (const_arguments)
return std::make_shared<ASTLiteral>(Field{UInt8(1)});
if (arguments.empty())
return std::make_shared<ASTLiteral>(Field{UInt8(0)});
if (arguments.size() == 1)
return arguments[0];
auto function = std::make_shared<ASTFunction>();
auto exp_list = std::make_shared<ASTExpressionList>();
function->name = "or";
function->arguments = exp_list;
function->children.push_back(exp_list);
exp_list->children = std::move(arguments);
return function;
}
using ConditionIndex = RowPolicy::ConditionIndex;
static constexpr size_t MAX_CONDITION_INDEX = RowPolicy::MAX_CONDITION_INDEX;
/// Accumulates conditions from multiple row policies and joins them using the AND logical operation.
class ConditionsMixer
{
public:
void add(const ASTPtr & condition, bool is_restrictive)
{
if (!condition)
return;
if (is_restrictive)
restrictions.push_back(condition);
else
permissions.push_back(condition);
}
ASTPtr getResult() &&
{
/// Process permissive conditions.
if (!permissions.empty())
restrictions.push_back(applyFunctionOR(std::move(permissions)));
/// Process restrictive conditions.
if (!restrictions.empty())
return applyFunctionAND(std::move(restrictions));
return nullptr;
}
private:
ASTs permissions;
ASTs restrictions;
};
}
void RowPolicyContextFactory::PolicyInfo::setPolicy(const RowPolicyPtr & policy_)
{
policy = policy_;
boost::range::copy(policy->roles, std::inserter(roles, roles.end()));
all_roles = policy->all_roles;
boost::range::copy(policy->except_roles, std::inserter(except_roles, except_roles.end()));
for (auto index : ext::range_with_static_cast<ConditionIndex>(0, MAX_CONDITION_INDEX))
{
const String & condition = policy->conditions[index];
auto previous_range = std::pair(std::begin(policy->conditions), std::begin(policy->conditions) + index);
auto previous_it = std::find(previous_range.first, previous_range.second, condition);
if (previous_it != previous_range.second)
{
/// The condition is already parsed before.
parsed_conditions[index] = parsed_conditions[previous_it - previous_range.first];
}
else
{
/// Try to parse the condition.
try
{
ParserExpression parser;
parsed_conditions[index] = parseQuery(parser, condition, 0);
}
catch (...)
{
tryLogCurrentException(
&Poco::Logger::get("RowPolicy"),
String("Could not parse the condition ") + RowPolicy::conditionIndexToString(index) + " of row policy "
+ backQuote(policy->getFullName()));
}
}
}
}
bool RowPolicyContextFactory::PolicyInfo::canUseWithContext(const RowPolicyContext & context) const
{
if (roles.count(context.user_name))
return true;
if (all_roles && !except_roles.count(context.user_name))
return true;
return false;
}
RowPolicyContextFactory::RowPolicyContextFactory(const AccessControlManager & access_control_manager_)
: access_control_manager(access_control_manager_)
{
}
RowPolicyContextFactory::~RowPolicyContextFactory() = default;
RowPolicyContextPtr RowPolicyContextFactory::createContext(const String & user_name)
{
std::lock_guard lock{mutex};
ensureAllRowPoliciesRead();
auto context = ext::shared_ptr_helper<RowPolicyContext>::create(user_name);
contexts.push_back(context);
mixConditionsForContext(*context);
return context;
}
void RowPolicyContextFactory::ensureAllRowPoliciesRead()
{
/// `mutex` is already locked.
if (all_policies_read)
return;
all_policies_read = true;
subscription = access_control_manager.subscribeForChanges<RowPolicy>(
[&](const UUID & id, const AccessEntityPtr & entity)
{
if (entity)
rowPolicyAddedOrChanged(id, typeid_cast<RowPolicyPtr>(entity));
else
rowPolicyRemoved(id);
});
for (const UUID & id : access_control_manager.findAll<RowPolicy>())
{
auto quota = access_control_manager.tryRead<RowPolicy>(id);
if (quota)
all_policies.emplace(id, PolicyInfo(quota));
}
}
void RowPolicyContextFactory::rowPolicyAddedOrChanged(const UUID & policy_id, const RowPolicyPtr & new_policy)
{
std::lock_guard lock{mutex};
auto it = all_policies.find(policy_id);
if (it == all_policies.end())
{
it = all_policies.emplace(policy_id, PolicyInfo(new_policy)).first;
}
else
{
if (it->second.policy == new_policy)
return;
}
auto & info = it->second;
info.setPolicy(new_policy);
mixConditionsForAllContexts();
}
void RowPolicyContextFactory::rowPolicyRemoved(const UUID & policy_id)
{
std::lock_guard lock{mutex};
all_policies.erase(policy_id);
mixConditionsForAllContexts();
}
void RowPolicyContextFactory::mixConditionsForAllContexts()
{
/// `mutex` is already locked.
boost::range::remove_erase_if(
contexts,
[&](const std::weak_ptr<RowPolicyContext> & weak)
{
auto context = weak.lock();
if (!context)
return true; // remove from the `contexts` list.
mixConditionsForContext(*context);
return false; // keep in the `contexts` list.
});
}
void RowPolicyContextFactory::mixConditionsForContext(RowPolicyContext & context)
{
/// `mutex` is already locked.
struct Mixers
{
ConditionsMixer mixers[MAX_CONDITION_INDEX];
std::vector<UUID> policy_ids;
};
using MapOfMixedConditions = RowPolicyContext::MapOfMixedConditions;
using DatabaseAndTableName = RowPolicyContext::DatabaseAndTableName;
using DatabaseAndTableNameRef = RowPolicyContext::DatabaseAndTableNameRef;
using Hash = RowPolicyContext::Hash;
std::unordered_map<DatabaseAndTableName, Mixers, Hash> map_of_mixers;
for (const auto & [policy_id, info] : all_policies)
{
if (info.canUseWithContext(context))
{
const auto & policy = *info.policy;
auto & mixers = map_of_mixers[std::pair{policy.getDatabase(), policy.getTableName()}];
mixers.policy_ids.push_back(policy_id);
for (auto index : ext::range(0, MAX_CONDITION_INDEX))
mixers.mixers[index].add(info.parsed_conditions[index], policy.isRestrictive());
}
}
auto map_of_mixed_conditions = std::make_shared<MapOfMixedConditions>();
for (auto & [database_and_table_name, mixers] : map_of_mixers)
{
auto database_and_table_name_keeper = std::make_unique<DatabaseAndTableName>();
database_and_table_name_keeper->first = database_and_table_name.first;
database_and_table_name_keeper->second = database_and_table_name.second;
auto & mixed_conditions = (*map_of_mixed_conditions)[DatabaseAndTableNameRef{database_and_table_name_keeper->first,
database_and_table_name_keeper->second}];
mixed_conditions.database_and_table_name_keeper = std::move(database_and_table_name_keeper);
mixed_conditions.policy_ids = std::move(mixers.policy_ids);
for (auto index : ext::range(0, MAX_CONDITION_INDEX))
mixed_conditions.mixed_conditions[index] = std::move(mixers.mixers[index]).getResult();
}
std::atomic_store(&context.atomic_map_of_mixed_conditions, std::shared_ptr<const MapOfMixedConditions>{map_of_mixed_conditions});
}
}

View File

@ -0,0 +1,54 @@
#pragma once
#include <Access/RowPolicyContext.h>
#include <Access/IAccessStorage.h>
#include <mutex>
#include <unordered_map>
#include <unordered_set>
namespace DB
{
class AccessControlManager;
/// Stores read and parsed row policies.
class RowPolicyContextFactory
{
public:
RowPolicyContextFactory(const AccessControlManager & access_control_manager_);
~RowPolicyContextFactory();
RowPolicyContextPtr createContext(const String & user_name);
private:
using ParsedConditions = RowPolicyContext::ParsedConditions;
struct PolicyInfo
{
PolicyInfo(const RowPolicyPtr & policy_) { setPolicy(policy_); }
void setPolicy(const RowPolicyPtr & policy_);
bool canUseWithContext(const RowPolicyContext & context) const;
RowPolicyPtr policy;
std::unordered_set<String> roles;
bool all_roles = false;
std::unordered_set<String> except_roles;
ParsedConditions parsed_conditions;
};
void ensureAllRowPoliciesRead();
void rowPolicyAddedOrChanged(const UUID & policy_id, const RowPolicyPtr & new_policy);
void rowPolicyRemoved(const UUID & policy_id);
void mixConditionsForAllContexts();
void mixConditionsForContext(RowPolicyContext & context);
const AccessControlManager & access_control_manager;
std::unordered_map<UUID, PolicyInfo> all_policies;
bool all_policies_read = false;
IAccessStorage::SubscriptionPtr subscription;
std::vector<std::weak_ptr<RowPolicyContext>> contexts;
std::mutex mutex;
};
}

View File

@ -1,5 +1,6 @@
#include <Access/UsersConfigAccessStorage.h>
#include <Access/Quota.h>
#include <Access/RowPolicy.h>
#include <Common/StringUtils/StringUtils.h>
#include <Common/quoteString.h>
#include <Poco/Util/AbstractConfiguration.h>
@ -15,6 +16,8 @@ namespace
{
if (type == typeid(Quota))
return 'Q';
if (type == typeid(RowPolicy))
return 'P';
return 0;
}
@ -112,6 +115,57 @@ namespace
}
return quotas;
}
std::vector<AccessEntityPtr> parseRowPolicies(const Poco::Util::AbstractConfiguration & config, Poco::Logger * log)
{
std::vector<AccessEntityPtr> policies;
Poco::Util::AbstractConfiguration::Keys user_names;
config.keys("users", user_names);
for (const String & user_name : user_names)
{
const String databases_config = "users." + user_name + ".databases";
if (config.has(databases_config))
{
Poco::Util::AbstractConfiguration::Keys databases;
config.keys(databases_config, databases);
/// Read tables within databases
for (const String & database : databases)
{
const String database_config = databases_config + "." + database;
Poco::Util::AbstractConfiguration::Keys table_names;
config.keys(database_config, table_names);
/// Read table properties
for (const String & table_name : table_names)
{
const auto filter_config = database_config + "." + table_name + ".filter";
if (config.has(filter_config))
{
try
{
auto policy = std::make_shared<RowPolicy>();
policy->setFullName(database, table_name, user_name);
policy->conditions[RowPolicy::SELECT_FILTER] = config.getString(filter_config);
policy->roles.push_back(user_name);
policies.push_back(policy);
}
catch (...)
{
tryLogCurrentException(
log,
"Could not parse row policy " + backQuote(user_name) + " on table " + backQuoteIfNeed(database) + "."
+ backQuoteIfNeed(table_name));
}
}
}
}
}
}
return policies;
}
}
@ -128,6 +182,8 @@ void UsersConfigAccessStorage::loadFromConfig(const Poco::Util::AbstractConfigur
std::vector<std::pair<UUID, AccessEntityPtr>> all_entities;
for (const auto & entity : parseQuotas(config, getLogger()))
all_entities.emplace_back(generateID(*entity), entity);
for (const auto & entity : parseRowPolicies(config, getLogger()))
all_entities.emplace_back(generateID(*entity), entity);
memory_storage.setAll(all_entities);
}

View File

@ -24,11 +24,16 @@ struct AggregateFunctionArgMinMaxData
ResultData result; // the argument at which the minimum/maximum value is reached.
ValueData value; // value for which the minimum/maximum is calculated.
static bool allocatesMemoryInArena()
{
return ResultData::allocatesMemoryInArena() || ValueData::allocatesMemoryInArena();
}
};
/// Returns the first arg value found for the minimum/maximum value. Example: argMax(arg, value).
template <typename Data, bool AllocatesMemoryInArena>
class AggregateFunctionArgMinMax final : public IAggregateFunctionDataHelper<Data, AggregateFunctionArgMinMax<Data, AllocatesMemoryInArena>>
template <typename Data>
class AggregateFunctionArgMinMax final : public IAggregateFunctionDataHelper<Data, AggregateFunctionArgMinMax<Data>>
{
private:
const DataTypePtr & type_res;
@ -36,7 +41,7 @@ private:
public:
AggregateFunctionArgMinMax(const DataTypePtr & type_res_, const DataTypePtr & type_val_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionArgMinMax<Data, AllocatesMemoryInArena>>({type_res_, type_val_}, {}),
: IAggregateFunctionDataHelper<Data, AggregateFunctionArgMinMax<Data>>({type_res_, type_val_}, {}),
type_res(this->argument_types[0]), type_val(this->argument_types[1])
{
if (!type_val->isComparable())
@ -77,7 +82,7 @@ public:
bool allocatesMemoryInArena() const override
{
return AllocatesMemoryInArena;
return Data::allocatesMemoryInArena();
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override

View File

@ -166,6 +166,11 @@ public:
{
return has() && assert_cast<const ColVecType &>(column).getData()[row_num] == value;
}
static bool allocatesMemoryInArena()
{
return false;
}
};
@ -384,6 +389,11 @@ public:
{
return has() && assert_cast<const ColumnString &>(column).getDataAtWithTerminatingZero(row_num) == getStringRef();
}
static bool allocatesMemoryInArena()
{
return true;
}
};
static_assert(
@ -555,6 +565,11 @@ public:
{
return has() && to.value == value;
}
static bool allocatesMemoryInArena()
{
return false;
}
};
@ -675,15 +690,15 @@ struct AggregateFunctionAnyHeavyData : Data
};
template <typename Data, bool use_arena>
class AggregateFunctionsSingleValue final : public IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data, use_arena>>
template <typename Data>
class AggregateFunctionsSingleValue final : public IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data>>
{
private:
DataTypePtr & type;
public:
AggregateFunctionsSingleValue(const DataTypePtr & type_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data, use_arena>>({type_}, {})
: IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data>>({type_}, {})
, type(this->argument_types[0])
{
if (StringRef(Data::name()) == StringRef("min")
@ -724,7 +739,7 @@ public:
bool allocatesMemoryInArena() const override
{
return use_arena;
return Data::allocatesMemoryInArena();
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override

View File

@ -13,8 +13,8 @@
namespace DB
{
/// min, max, any, anyLast
template <template <typename, bool> class AggregateFunctionTemplate, template <typename> class Data>
/// min, max, any, anyLast, anyHeavy, etc...
template <template <typename> class AggregateFunctionTemplate, template <typename> class Data>
static IAggregateFunction * createAggregateFunctionSingleValue(const String & name, const DataTypes & argument_types, const Array & parameters)
{
assertNoParameters(name, parameters);
@ -24,26 +24,26 @@ static IAggregateFunction * createAggregateFunctionSingleValue(const String & na
WhichDataType which(argument_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<TYPE>>, false>(argument_type);
if (which.idx == TypeIndex::TYPE) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<TYPE>>>(argument_type);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Date)
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<DataTypeDate::FieldType>>, false>(argument_type);
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<DataTypeDate::FieldType>>>(argument_type);
if (which.idx == TypeIndex::DateTime)
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<DataTypeDateTime::FieldType>>, false>(argument_type);
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<DataTypeDateTime::FieldType>>>(argument_type);
if (which.idx == TypeIndex::DateTime64)
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<DateTime64>>, false>(argument_type);
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<DateTime64>>>(argument_type);
if (which.idx == TypeIndex::Decimal32)
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Decimal32>>, false>(argument_type);
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Decimal32>>>(argument_type);
if (which.idx == TypeIndex::Decimal64)
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Decimal64>>, false>(argument_type);
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Decimal64>>>(argument_type);
if (which.idx == TypeIndex::Decimal128)
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Decimal128>>, false>(argument_type);
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Decimal128>>>(argument_type);
if (which.idx == TypeIndex::String)
return new AggregateFunctionTemplate<Data<SingleValueDataString>, true>(argument_type);
return new AggregateFunctionTemplate<Data<SingleValueDataString>>(argument_type);
return new AggregateFunctionTemplate<Data<SingleValueDataGeneric>, false>(argument_type);
return new AggregateFunctionTemplate<Data<SingleValueDataGeneric>>(argument_type);
}
@ -52,28 +52,29 @@ template <template <typename> class MinMaxData, typename ResData>
static IAggregateFunction * createAggregateFunctionArgMinMaxSecond(const DataTypePtr & res_type, const DataTypePtr & val_type)
{
WhichDataType which(val_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<TYPE>>>, false>(res_type, val_type);
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<TYPE>>>>(res_type, val_type); \
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Date)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<DataTypeDate::FieldType>>>, false>(res_type, val_type);
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<DataTypeDate::FieldType>>>>(res_type, val_type);
if (which.idx == TypeIndex::DateTime)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<DataTypeDateTime::FieldType>>>, false>(res_type, val_type);
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<DataTypeDateTime::FieldType>>>>(res_type, val_type);
if (which.idx == TypeIndex::DateTime64)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<DateTime64>>>, false>(res_type, val_type);
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<DateTime64>>>>(res_type, val_type);
if (which.idx == TypeIndex::Decimal32)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<Decimal32>>>, false>(res_type, val_type);
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<Decimal32>>>>(res_type, val_type);
if (which.idx == TypeIndex::Decimal64)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<Decimal64>>>, false>(res_type, val_type);
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<Decimal64>>>>(res_type, val_type);
if (which.idx == TypeIndex::Decimal128)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<Decimal128>>>, false>(res_type, val_type);
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<Decimal128>>>>(res_type, val_type);
if (which.idx == TypeIndex::String)
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataString>>, true>(res_type, val_type);
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataString>>>(res_type, val_type);
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataGeneric>>, false>(res_type, val_type);
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataGeneric>>>(res_type, val_type);
}
template <template <typename> class MinMaxData>

View File

@ -339,7 +339,7 @@ bool DataTypeAggregateFunction::equals(const IDataType & rhs) const
}
static DataTypePtr create(const ASTPtr & arguments)
static DataTypePtr create(const String & /*type_name*/, const ASTPtr & arguments)
{
String function_name;
AggregateFunctionPtr function;

View File

@ -508,7 +508,7 @@ size_t DataTypeArray::getNumberOfDimensions() const
}
static DataTypePtr create(const ASTPtr & arguments)
static DataTypePtr create(const String & /*type_name*/, const ASTPtr & arguments)
{
if (!arguments || arguments->children.size() != 1)
throw Exception("Array data type family must have exactly one argument - type of elements", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);

View File

@ -102,16 +102,16 @@ public:
void registerDataTypeDomainIPv4AndIPv6(DataTypeFactory & factory)
{
factory.registerSimpleDataTypeCustom("IPv4", []
factory.registerSimpleDataTypeCustom("IPv4", [&](const String & /*type_name*/)
{
return std::make_pair(DataTypeFactory::instance().get("UInt32"),
std::make_unique<DataTypeCustomDesc>(std::make_unique<DataTypeCustomFixedName>("IPv4"), std::make_unique<DataTypeCustomIPv4Serialization>()));
});
factory.registerSimpleDataTypeCustom("IPv6", []
factory.registerSimpleDataTypeCustom("IPv6", [&](const String & /*type_name*/)
{
return std::make_pair(DataTypeFactory::instance().get("FixedString(16)"),
std::make_unique<DataTypeCustomDesc>(std::make_unique<DataTypeCustomFixedName>("IPv6"), std::make_unique<DataTypeCustomIPv6Serialization>()));
std::make_unique<DataTypeCustomDesc>(std::make_unique<DataTypeCustomFixedName>("IPv6"), std::make_unique<DataTypeCustomIPv6Serialization>()));
});
}

View File

@ -58,7 +58,7 @@ String DataTypeCustomSimpleAggregateFunction::getName() const
}
static std::pair<DataTypePtr, DataTypeCustomDescPtr> create(const ASTPtr & arguments)
static std::pair<DataTypePtr, DataTypeCustomDescPtr> create(const String & /*type_name*/, const ASTPtr & arguments)
{
String function_name;
AggregateFunctionPtr function;

View File

@ -113,7 +113,12 @@ bool DataTypeDate::equals(const IDataType & rhs) const
void registerDataTypeDate(DataTypeFactory & factory)
{
factory.registerSimpleDataType("Date", [] { return DataTypePtr(std::make_shared<DataTypeDate>()); }, DataTypeFactory::CaseInsensitive);
const auto & creator = [&](const String & /*type_name*/)
{
return DataTypePtr(std::make_shared<DataTypeDate>());
};
factory.registerSimpleDataType("Date", creator, DataTypeFactory::CaseInsensitive);
}
}

View File

@ -43,8 +43,8 @@ TimezoneMixin::TimezoneMixin(const String & time_zone_name)
utc_time_zone(DateLUT::instance("UTC"))
{}
DataTypeDateTime::DataTypeDateTime(const String & time_zone_name)
: TimezoneMixin(time_zone_name)
DataTypeDateTime::DataTypeDateTime(const String & time_zone_name, const String & type_name_)
: TimezoneMixin(time_zone_name), type_name(type_name_)
{
}
@ -55,10 +55,10 @@ DataTypeDateTime::DataTypeDateTime(const TimezoneMixin & time_zone_)
String DataTypeDateTime::doGetName() const
{
if (!has_explicit_time_zone)
return "DateTime";
return type_name;
WriteBufferFromOwnString out;
out << "DateTime(" << quote << time_zone.getTimeZone() << ")";
out << type_name << "(" << quote << time_zone.getTimeZone() << ")";
return out.str();
}
@ -194,10 +194,10 @@ namespace ErrorCodes
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
static DataTypePtr create(const ASTPtr & arguments)
static DataTypePtr create(const String & type_name, const ASTPtr & arguments)
{
if (!arguments)
return std::make_shared<DataTypeDateTime>();
return std::make_shared<DataTypeDateTime>("", type_name);
if (arguments->children.size() != 1)
throw Exception("DateTime data type can optionally have only one argument - time zone name", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
@ -206,7 +206,7 @@ static DataTypePtr create(const ASTPtr & arguments)
if (!arg || arg->value.getType() != Field::Types::String)
throw Exception("Parameter for DateTime data type must be string literal", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<DataTypeDateTime>(arg->value.get<String>());
return std::make_shared<DataTypeDateTime>(arg->value.get<String>(), type_name);
}
void registerDataTypeDateTime(DataTypeFactory & factory)

View File

@ -49,7 +49,7 @@ protected:
class DataTypeDateTime final : public DataTypeNumberBase<UInt32>, public TimezoneMixin
{
public:
explicit DataTypeDateTime(const String & time_zone_name = "");
explicit DataTypeDateTime(const String & time_zone_name = "", const String & type_name_ = "DateTime");
explicit DataTypeDateTime(const TimezoneMixin & time_zone);
static constexpr auto family_name = "DateTime";
@ -75,6 +75,8 @@ public:
bool canBeInsideNullable() const override { return true; }
bool equals(const IDataType & rhs) const override;
private:
const String type_name;
};
}

View File

@ -233,7 +233,7 @@ getArgument(const ASTPtr & arguments, size_t argument_index, const char * argume
return argument->value.get<NearestResultType>();
}
static DataTypePtr create64(const ASTPtr & arguments)
static DataTypePtr create64(const String & /*type_name*/, const ASTPtr & arguments)
{
if (!arguments || arguments->size() == 0)
return std::make_shared<DataTypeDateTime64>(DataTypeDateTime64::default_scale);

View File

@ -195,7 +195,7 @@ const DecimalType<U> decimalResultType(const DataTypeNumber<T> &, const DecimalT
}
template <template <typename> typename DecimalType>
DataTypePtr createDecimal(UInt64 precision_value, UInt64 scale_value)
DataTypePtr createDecimal(UInt64 precision_value, UInt64 scale_value, const String & type_name = "Decimal", bool only_scale = false)
{
if (precision_value < DecimalUtils::minPrecision() || precision_value > DecimalUtils::maxPrecision<Decimal128>())
throw Exception("Wrong precision", ErrorCodes::ARGUMENT_OUT_OF_BOUND);
@ -204,10 +204,10 @@ DataTypePtr createDecimal(UInt64 precision_value, UInt64 scale_value)
throw Exception("Negative scales and scales larger than precision are not supported", ErrorCodes::ARGUMENT_OUT_OF_BOUND);
if (precision_value <= DecimalUtils::maxPrecision<Decimal32>())
return std::make_shared<DecimalType<Decimal32>>(precision_value, scale_value);
return std::make_shared<DecimalType<Decimal32>>(precision_value, scale_value, type_name, only_scale);
else if (precision_value <= DecimalUtils::maxPrecision<Decimal64>())
return std::make_shared<DecimalType<Decimal64>>(precision_value, scale_value);
return std::make_shared<DecimalType<Decimal128>>(precision_value, scale_value);
return std::make_shared<DecimalType<Decimal64>>(precision_value, scale_value, type_name, only_scale);
return std::make_shared<DecimalType<Decimal128>>(precision_value, scale_value, type_name, only_scale);
}
}

View File

@ -364,7 +364,7 @@ static void checkASTStructure(const ASTPtr & child)
}
template <typename DataTypeEnum>
static DataTypePtr createExact(const ASTPtr & arguments)
static DataTypePtr createExact(const String & /*type_name*/, const ASTPtr & arguments)
{
if (!arguments || arguments->children.empty())
throw Exception("Enum data type cannot be empty", ErrorCodes::EMPTY_DATA_PASSED);
@ -403,7 +403,7 @@ static DataTypePtr createExact(const ASTPtr & arguments)
return std::make_shared<DataTypeEnum>(values);
}
static DataTypePtr create(const ASTPtr & arguments)
static DataTypePtr create(const String & type_name, const ASTPtr & arguments)
{
if (!arguments || arguments->children.empty())
throw Exception("Enum data type cannot be empty", ErrorCodes::EMPTY_DATA_PASSED);
@ -424,10 +424,10 @@ static DataTypePtr create(const ASTPtr & arguments)
Int64 value = value_literal->value.get<Int64>();
if (value > std::numeric_limits<Int8>::max() || value < std::numeric_limits<Int8>::min())
return createExact<DataTypeEnum16>(arguments);
return createExact<DataTypeEnum16>(type_name, arguments);
}
return createExact<DataTypeEnum8>(arguments);
return createExact<DataTypeEnum8>(type_name, arguments);
}
void registerDataTypeEnum(DataTypeFactory & factory)

View File

@ -74,7 +74,7 @@ DataTypePtr DataTypeFactory::get(const String & family_name_param, const ASTPtr
return get("LowCardinality", low_cardinality_params);
}
return findCreatorByName(family_name)(parameters);
return findCreatorByName(family_name)(family_name_param, parameters);
}
@ -107,30 +107,30 @@ void DataTypeFactory::registerSimpleDataType(const String & name, SimpleCreator
throw Exception("DataTypeFactory: the data type " + name + " has been provided "
" a null constructor", ErrorCodes::LOGICAL_ERROR);
registerDataType(name, [name, creator](const ASTPtr & ast)
registerDataType(name, [name, creator](const String & type_name, const ASTPtr & ast)
{
if (ast)
throw Exception("Data type " + name + " cannot have arguments", ErrorCodes::DATA_TYPE_CANNOT_HAVE_ARGUMENTS);
return creator();
return creator(type_name);
}, case_sensitiveness);
}
void DataTypeFactory::registerDataTypeCustom(const String & family_name, CreatorWithCustom creator, CaseSensitiveness case_sensitiveness)
{
registerDataType(family_name, [creator](const ASTPtr & ast)
registerDataType(family_name, [creator](const String & type_name, const ASTPtr & ast)
{
auto res = creator(ast);
auto res = creator(type_name, ast);
res.first->setCustomization(std::move(res.second));
return res.first;
}, case_sensitiveness);
}
void DataTypeFactory::registerSimpleDataTypeCustom(const String &name, SimpleCreatorWithCustom creator, CaseSensitiveness case_sensitiveness)
void DataTypeFactory::registerSimpleDataTypeCustom(const String & name, SimpleCreatorWithCustom creator, CaseSensitiveness case_sensitiveness)
{
registerDataTypeCustom(name, [creator](const ASTPtr & /*ast*/)
registerDataTypeCustom(name, [creator](const String & type_name, const ASTPtr & /*ast*/)
{
return creator();
return creator(type_name);
}, case_sensitiveness);
}

View File

@ -16,16 +16,15 @@ namespace DB
class IDataType;
using DataTypePtr = std::shared_ptr<const IDataType>;
/** Creates a data type by name of data type family and parameters.
*/
class DataTypeFactory final : private boost::noncopyable, public IFactoryWithAliases<std::function<DataTypePtr(const ASTPtr & parameters)>>
class DataTypeFactory final : private boost::noncopyable, public IFactoryWithAliases<std::function<DataTypePtr(const String & type_name, const ASTPtr & parameters)>>
{
private:
using SimpleCreator = std::function<DataTypePtr()>;
using SimpleCreator = std::function<DataTypePtr(const String & type_name)>;
using DataTypesDictionary = std::unordered_map<String, Creator>;
using CreatorWithCustom = std::function<std::pair<DataTypePtr,DataTypeCustomDescPtr>(const ASTPtr & parameters)>;
using SimpleCreatorWithCustom = std::function<std::pair<DataTypePtr,DataTypeCustomDescPtr>()>;
using CreatorWithCustom = std::function<std::pair<DataTypePtr, DataTypeCustomDescPtr>(const String & type_name, const ASTPtr & parameters)>;
using SimpleCreatorWithCustom = std::function<std::pair<DataTypePtr, DataTypeCustomDescPtr>(const String & type_name)>;
public:
static DataTypeFactory & instance();

View File

@ -34,7 +34,7 @@ namespace ErrorCodes
std::string DataTypeFixedString::doGetName() const
{
return "FixedString(" + toString(n) + ")";
return type_name + "(" + toString(n) + ")";
}
@ -279,7 +279,7 @@ bool DataTypeFixedString::equals(const IDataType & rhs) const
}
static DataTypePtr create(const ASTPtr & arguments)
static DataTypePtr create(const String & type_name, const ASTPtr & arguments)
{
if (!arguments || arguments->children.size() != 1)
throw Exception("FixedString data type family must have exactly one argument - size in bytes", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
@ -288,7 +288,7 @@ static DataTypePtr create(const ASTPtr & arguments)
if (!argument || argument->value.getType() != Field::Types::UInt64 || argument->value.get<UInt64>() == 0)
throw Exception("FixedString data type family must have a number (positive integer) as its argument", ErrorCodes::UNEXPECTED_AST_STRUCTURE);
return std::make_shared<DataTypeFixedString>(argument->value.get<UInt64>());
return std::make_shared<DataTypeFixedString>(argument->value.get<UInt64>(), type_name);
}

View File

@ -22,7 +22,7 @@ private:
public:
static constexpr bool is_parametric = true;
DataTypeFixedString(size_t n_) : n(n_)
DataTypeFixedString(size_t n_, const String & type_name_ = "FixedString") : n(n_), type_name(type_name_)
{
if (n == 0)
throw Exception("FixedString size must be positive", ErrorCodes::ARGUMENT_OUT_OF_BOUND);
@ -85,6 +85,9 @@ public:
bool isCategorial() const override { return true; }
bool canBeInsideNullable() const override { return true; }
bool canBeInsideLowCardinality() const override { return true; }
private:
const String type_name;
};
}

View File

@ -10,17 +10,22 @@ bool DataTypeInterval::equals(const IDataType & rhs) const
return typeid(rhs) == typeid(*this) && kind == static_cast<const DataTypeInterval &>(rhs).kind;
}
template <IntervalKind::Kind kind>
static DataTypePtr create(const String & /*type_name*/)
{
return DataTypePtr(std::make_shared<DataTypeInterval>(kind));
}
void registerDataTypeInterval(DataTypeFactory & factory)
{
factory.registerSimpleDataType("IntervalSecond", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Second)); });
factory.registerSimpleDataType("IntervalMinute", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Minute)); });
factory.registerSimpleDataType("IntervalHour", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Hour)); });
factory.registerSimpleDataType("IntervalDay", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Day)); });
factory.registerSimpleDataType("IntervalWeek", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Week)); });
factory.registerSimpleDataType("IntervalMonth", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Month)); });
factory.registerSimpleDataType("IntervalQuarter", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Quarter)); });
factory.registerSimpleDataType("IntervalYear", [] { return DataTypePtr(std::make_shared<DataTypeInterval>(IntervalKind::Year)); });
factory.registerSimpleDataType("IntervalSecond", create<IntervalKind::Second>);
factory.registerSimpleDataType("IntervalMinute", create<IntervalKind::Minute>);
factory.registerSimpleDataType("IntervalHour", create<IntervalKind::Hour>);
factory.registerSimpleDataType("IntervalDay", create<IntervalKind::Day>);
factory.registerSimpleDataType("IntervalWeek", create<IntervalKind::Week>);
factory.registerSimpleDataType("IntervalMonth", create<IntervalKind::Month>);
factory.registerSimpleDataType("IntervalQuarter", create<IntervalKind::Quarter>);
factory.registerSimpleDataType("IntervalYear", create<IntervalKind::Year>);
}
}

View File

@ -949,7 +949,7 @@ bool DataTypeLowCardinality::equals(const IDataType & rhs) const
}
static DataTypePtr create(const ASTPtr & arguments)
static DataTypePtr create(const String & /*type_name*/, const ASTPtr & arguments)
{
if (!arguments || arguments->children.size() != 1)
throw Exception("LowCardinality data type family must have single argument - type of elements",

View File

@ -38,7 +38,9 @@ bool DataTypeNothing::equals(const IDataType & rhs) const
void registerDataTypeNothing(DataTypeFactory & factory)
{
factory.registerSimpleDataType("Nothing", [] { return DataTypePtr(std::make_shared<DataTypeNothing>()); });
const auto & creator = [&](const String & /*type_name*/) { return DataTypePtr(std::make_shared<DataTypeNothing>()); };
factory.registerSimpleDataType("Nothing", creator);
}
}

View File

@ -505,7 +505,7 @@ bool DataTypeNullable::equals(const IDataType & rhs) const
}
static DataTypePtr create(const ASTPtr & arguments)
static DataTypePtr create(const String & /*type_name*/, const ASTPtr & arguments)
{
if (!arguments || arguments->children.size() != 1)
throw Exception("Nullable data type family must have exactly one argument - nested type", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);

View File

@ -369,7 +369,7 @@ bool DataTypeString::equals(const IDataType & rhs) const
void registerDataTypeString(DataTypeFactory & factory)
{
auto creator = static_cast<DataTypePtr(*)()>([] { return DataTypePtr(std::make_shared<DataTypeString>()); });
const auto & creator = [&] (const String & type_name) { return std::make_shared<DataTypeString>(type_name); };
factory.registerSimpleDataType("String", creator);

View File

@ -14,6 +14,10 @@ public:
using FieldType = String;
static constexpr bool is_parametric = false;
DataTypeString(const String & type_name_ = "String") : type_name(type_name_) {}
String doGetName() const override { return type_name; }
const char * getFamilyName() const override
{
return "String";
@ -63,6 +67,9 @@ public:
bool isCategorial() const override { return true; }
bool canBeInsideNullable() const override { return true; }
bool canBeInsideLowCardinality() const override { return true; }
private:
const String type_name;
};
}

View File

@ -529,7 +529,7 @@ size_t DataTypeTuple::getSizeOfValueInMemory() const
}
static DataTypePtr create(const ASTPtr & arguments)
static DataTypePtr create(const String & /*type_name*/, const ASTPtr & arguments)
{
if (!arguments || arguments->children.empty())
throw Exception("Tuple cannot be empty", ErrorCodes::EMPTY_DATA_PASSED);
@ -568,7 +568,7 @@ void registerDataTypeTuple(DataTypeFactory & factory)
void registerDataTypeNested(DataTypeFactory & factory)
{
/// Nested(...) data type is just a sugar for Array(Tuple(...))
factory.registerDataType("Nested", [&factory](const ASTPtr & arguments)
factory.registerDataType("Nested", [&factory](const String & /*type_name*/, const ASTPtr & arguments)
{
return std::make_shared<DataTypeArray>(factory.get("Tuple", arguments));
});

View File

@ -106,7 +106,9 @@ bool DataTypeUUID::equals(const IDataType & rhs) const
void registerDataTypeUUID(DataTypeFactory & factory)
{
factory.registerSimpleDataType("UUID", [] { return DataTypePtr(std::make_shared<DataTypeUUID>()); });
const auto & creator = [&] (const String & /*type_name*/) { return std::make_shared<DataTypeUUID>(); };
factory.registerSimpleDataType("UUID", creator);
}
}

View File

@ -14,6 +14,8 @@
#include <Parsers/IAST.h>
#include <type_traits>
#include "DataTypesDecimal.h"
namespace DB
{
@ -31,7 +33,12 @@ template <typename T>
std::string DataTypeDecimal<T>::doGetName() const
{
std::stringstream ss;
ss << "Decimal(" << this->precision << ", " << this->scale << ")";
ss << type_name << "(";
if (!only_scale)
ss << this->precision << ", ";
ss << this->scale << ")";
return ss.str();
}
@ -135,8 +142,14 @@ void DataTypeDecimal<T>::deserializeProtobuf(IColumn & column, ProtobufReader &
container.back() = decimal;
}
template<typename T>
DataTypeDecimal<T>::DataTypeDecimal(UInt32 precision_, UInt32 scale_, const String & type_name_, bool only_scale_)
: Base(precision_, scale_), type_name(type_name_), only_scale(only_scale_)
{
}
static DataTypePtr create(const ASTPtr & arguments)
static DataTypePtr create(const String & type_name, const ASTPtr & arguments)
{
if (!arguments || arguments->children.size() != 2)
throw Exception("Decimal data type family must have exactly two arguments: precision and scale",
@ -152,11 +165,11 @@ static DataTypePtr create(const ASTPtr & arguments)
UInt64 precision_value = precision->value.get<UInt64>();
UInt64 scale_value = scale->value.get<UInt64>();
return createDecimal<DataTypeDecimal>(precision_value, scale_value);
return createDecimal<DataTypeDecimal>(precision_value, scale_value, type_name);
}
template <typename T>
static DataTypePtr createExact(const ASTPtr & arguments)
static DataTypePtr createExact(const String & type_name, const ASTPtr & arguments)
{
if (!arguments || arguments->children.size() != 1)
throw Exception("Decimal data type family must have exactly two arguments: precision and scale",
@ -170,7 +183,7 @@ static DataTypePtr createExact(const ASTPtr & arguments)
UInt64 precision = DecimalUtils::maxPrecision<T>();
UInt64 scale = scale_arg->value.get<UInt64>();
return createDecimal<DataTypeDecimal>(precision, scale);
return createDecimal<DataTypeDecimal>(precision, scale, type_name, true);
}
void registerDataTypeDecimal(DataTypeFactory & factory)

View File

@ -35,6 +35,8 @@ public:
using typename Base::ColumnType;
using Base::Base;
DataTypeDecimal(UInt32 precision_, UInt32 scale_, const String & type_name_ = "Decimal", bool only_scale_ = false);
static constexpr auto family_name = "Decimal";
const char * getFamilyName() const override { return family_name; }
@ -57,6 +59,12 @@ public:
static void readText(T & x, ReadBuffer & istr, UInt32 precision_, UInt32 scale_, bool csv = false);
static bool tryReadText(T & x, ReadBuffer & istr, UInt32 precision_, UInt32 scale_);
private:
/// The name of data type how the user specified it. A single data type may be referenced by various synonims.
const String type_name;
/// If the user specified it only with scale parameter but without precision.
bool only_scale = false;
};
template <typename T>

View File

@ -5,20 +5,26 @@
namespace DB
{
template <typename NumberType>
static DataTypePtr create(const String & type_name)
{
return DataTypePtr(std::make_shared<NumberType>(type_name));
}
void registerDataTypeNumbers(DataTypeFactory & factory)
{
factory.registerSimpleDataType("UInt8", [] { return DataTypePtr(std::make_shared<DataTypeUInt8>()); });
factory.registerSimpleDataType("UInt16", [] { return DataTypePtr(std::make_shared<DataTypeUInt16>()); });
factory.registerSimpleDataType("UInt32", [] { return DataTypePtr(std::make_shared<DataTypeUInt32>()); });
factory.registerSimpleDataType("UInt64", [] { return DataTypePtr(std::make_shared<DataTypeUInt64>()); });
factory.registerSimpleDataType("UInt8", create<DataTypeUInt8>);
factory.registerSimpleDataType("UInt16", create<DataTypeUInt16>);
factory.registerSimpleDataType("UInt32", create<DataTypeUInt32>);
factory.registerSimpleDataType("UInt64", create<DataTypeUInt64>);
factory.registerSimpleDataType("Int8", [] { return DataTypePtr(std::make_shared<DataTypeInt8>()); });
factory.registerSimpleDataType("Int16", [] { return DataTypePtr(std::make_shared<DataTypeInt16>()); });
factory.registerSimpleDataType("Int32", [] { return DataTypePtr(std::make_shared<DataTypeInt32>()); });
factory.registerSimpleDataType("Int64", [] { return DataTypePtr(std::make_shared<DataTypeInt64>()); });
factory.registerSimpleDataType("Int8", create<DataTypeInt8>);
factory.registerSimpleDataType("Int16", create<DataTypeInt16>);
factory.registerSimpleDataType("Int32", create<DataTypeInt32>);
factory.registerSimpleDataType("Int64", create<DataTypeInt64>);
factory.registerSimpleDataType("Float32", [] { return DataTypePtr(std::make_shared<DataTypeFloat32>()); });
factory.registerSimpleDataType("Float64", [] { return DataTypePtr(std::make_shared<DataTypeFloat64>()); });
factory.registerSimpleDataType("Float32", create<DataTypeFloat32>);
factory.registerSimpleDataType("Float64", create<DataTypeFloat64>);
/// These synonyms are added for compatibility.

View File

@ -25,6 +25,13 @@ class DataTypeNumber final : public DataTypeNumberBase<T>
using PromotedType = DataTypeNumber<NearestFieldType<T>>;
return std::make_shared<PromotedType>();
}
public:
DataTypeNumber(const String & type_name_ = TypeName<T>::get()) : type_name(type_name_) {}
String doGetName() const override { return type_name; }
private:
const String type_name;
};
using DataTypeUInt8 = DataTypeNumber<UInt8>;

View File

@ -534,6 +534,7 @@ struct WhichDataType
inline bool isDate(const DataTypePtr & data_type) { return WhichDataType(data_type).isDate(); }
inline bool isDateOrDateTime(const DataTypePtr & data_type) { return WhichDataType(data_type).isDateOrDateTime(); }
inline bool isDateTime(const DataTypePtr & data_type) { return WhichDataType(data_type).isDateTime(); }
inline bool isDateTime64(const DataTypePtr & data_type) { return WhichDataType(data_type).isDateTime64(); }
inline bool isEnum(const DataTypePtr & data_type) { return WhichDataType(data_type).isEnum(); }
inline bool isDecimal(const DataTypePtr & data_type) { return WhichDataType(data_type).isDecimal(); }

View File

@ -19,6 +19,7 @@ namespace ErrorCodes
extern const int ILLEGAL_COLUMN;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int SIZES_OF_ARRAYS_DOESNT_MATCH;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
const ColumnConst * checkAndGetColumnConstStringOrFixedString(const IColumn * column)
@ -124,9 +125,9 @@ namespace
void validateArgumentsImpl(const IFunction & func,
const ColumnsWithTypeAndName & arguments,
size_t argument_offset,
const FunctionArgumentTypeValidators & validators)
const FunctionArgumentDescriptors & descriptors)
{
for (size_t i = 0; i < validators.size(); ++i)
for (size_t i = 0; i < descriptors.size(); ++i)
{
const auto argument_index = i + argument_offset;
if (argument_index >= arguments.size())
@ -135,24 +136,36 @@ void validateArgumentsImpl(const IFunction & func,
}
const auto & arg = arguments[i + argument_offset];
const auto validator = validators[i];
if (!validator.validator_func(*arg.type))
throw Exception("Illegal type " + arg.type->getName() +
" of " + std::to_string(i) +
" argument of function " + func.getName() +
" expected " + validator.expected_type_description,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const auto descriptor = descriptors[i];
if (int errorCode = descriptor.isValid(arg.type, arg.column); errorCode != 0)
throw Exception("Illegal type of argument #" + std::to_string(i)
+ (descriptor.argument_name ? " '" + std::string(descriptor.argument_name) + "'" : String{})
+ " of function " + func.getName()
+ (descriptor.expected_type_description ? String(", expected ") + descriptor.expected_type_description : String{})
+ (arg.type ? ", got " + arg.type->getName() : String{}),
errorCode);
}
}
}
int FunctionArgumentDescriptor::isValid(const DataTypePtr & data_type, const ColumnPtr & column) const
{
if (type_validator_func && (data_type == nullptr || type_validator_func(*data_type) == false))
return ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT;
if (column_validator_func && (column == nullptr || column_validator_func(*column) == false))
return ErrorCodes::ILLEGAL_COLUMN;
return 0;
}
void validateFunctionArgumentTypes(const IFunction & func,
const ColumnsWithTypeAndName & arguments,
const FunctionArgumentTypeValidators & mandatory_args,
const FunctionArgumentTypeValidators & optional_args)
const FunctionArgumentDescriptors & mandatory_args,
const FunctionArgumentDescriptors & optional_args)
{
if (arguments.size() < mandatory_args.size())
if (arguments.size() < mandatory_args.size() || arguments.size() > mandatory_args.size() + optional_args.size())
{
auto joinArgumentTypes = [](const auto & args, const String sep = ", ") -> String
{
@ -160,8 +173,13 @@ void validateFunctionArgumentTypes(const IFunction & func,
for (const auto & a : args)
{
using A = std::decay_t<decltype(a)>;
if constexpr (std::is_same_v<A, FunctionArgumentTypeValidator>)
result += a.expected_type_description;
if constexpr (std::is_same_v<A, FunctionArgumentDescriptor>)
{
if (a.argument_name)
result += "'" + std::string(a.argument_name) + "' : ";
if (a.expected_type_description)
result += a.expected_type_description;
}
else if constexpr (std::is_same_v<A, ColumnWithTypeAndName>)
result += a.type->getName();
@ -174,10 +192,14 @@ void validateFunctionArgumentTypes(const IFunction & func,
return result;
};
throw Exception("Incorrect number of arguments of function " + func.getName()
+ " provided " + std::to_string(arguments.size()) + " (" + joinArgumentTypes(arguments) + ")"
+ " expected " + std::to_string(mandatory_args.size()) + (optional_args.size() ? " or " + std::to_string(mandatory_args.size() + optional_args.size()) : "")
+ " (" + joinArgumentTypes(mandatory_args) + (optional_args.size() ? ", [" + joinArgumentTypes(mandatory_args) + "]" : "") + ")",
throw Exception("Incorrect number of arguments for function " + func.getName()
+ " provided " + std::to_string(arguments.size())
+ (arguments.size() ? " (" + joinArgumentTypes(arguments) + ")" : String{})
+ ", expected " + std::to_string(mandatory_args.size())
+ (optional_args.size() ? " to " + std::to_string(mandatory_args.size() + optional_args.size()) : "")
+ " (" + joinArgumentTypes(mandatory_args)
+ (optional_args.size() ? ", [" + joinArgumentTypes(optional_args) + "]" : "")
+ ")",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
}

View File

@ -90,21 +90,46 @@ void validateArgumentType(const IFunction & func, const DataTypes & arguments,
size_t argument_index, bool (* validator_func)(const IDataType &),
const char * expected_type_description);
// Simple validator that is used in conjunction with validateFunctionArgumentTypes() to check if function arguments are as expected.
struct FunctionArgumentTypeValidator
/** Simple validator that is used in conjunction with validateFunctionArgumentTypes() to check if function arguments are as expected
*
* Also it is used to generate function description when arguments do not match expected ones.
* Any field can be null:
* `argument_name` - if not null, reported via type check errors.
* `expected_type_description` - if not null, reported via type check errors.
* `type_validator_func` - if not null, used to validate data type of function argument.
* `column_validator_func` - if not null, used to validate column of function argument.
*/
struct FunctionArgumentDescriptor
{
bool (* validator_func)(const IDataType &);
const char * argument_name;
bool (* type_validator_func)(const IDataType &);
bool (* column_validator_func)(const IColumn &);
const char * expected_type_description;
/** Validate argument type and column.
*
* Returns non-zero error code if:
* Validator != nullptr && (Value == nullptr || Validator(*Value) == false)
* For:
* Validator is either `type_validator_func` or `column_validator_func`
* Value is either `data_type` or `column` respectively.
* ILLEGAL_TYPE_OF_ARGUMENT if type validation fails
*
*/
int isValid(const DataTypePtr & data_type, const ColumnPtr & column) const;
};
using FunctionArgumentTypeValidators = std::vector<FunctionArgumentTypeValidator>;
using FunctionArgumentDescriptors = std::vector<FunctionArgumentDescriptor>;
/** Validate that function arguments match specification.
*
* Designed to simplify argument validation
* for functions with variable arguments (e.g. depending on result type or other trait).
* first, checks that mandatory args present and have valid type.
* second, checks optional arguents types, skipping ones that are missing.
* Designed to simplify argument validation for functions with variable arguments
* (e.g. depending on result type or other trait).
* First, checks that number of arguments is as expected (including optional arguments).
* Second, checks that mandatory args present and have valid type.
* Third, checks optional arguents types, skipping ones that are missing.
*
* Please note that if you have several optional arguments, like f([a, b, c]),
* only these calls are considered valid:
@ -113,11 +138,13 @@ using FunctionArgumentTypeValidators = std::vector<FunctionArgumentTypeValidator
* f(a, b, c)
*
* But NOT these: f(a, c), f(b, c)
* In other words you can't skip
* In other words you can't omit middle optional arguments (just like in regular C++).
*
* If any mandatory arg is missing, throw an exception, with explicit description of expected arguments.
*/
void validateFunctionArgumentTypes(const IFunction & func, const ColumnsWithTypeAndName & arguments, const FunctionArgumentTypeValidators & mandatory_args, const FunctionArgumentTypeValidators & optional_args = {});
void validateFunctionArgumentTypes(const IFunction & func, const ColumnsWithTypeAndName & arguments,
const FunctionArgumentDescriptors & mandatory_args,
const FunctionArgumentDescriptors & optional_args = {});
/// Checks if a list of array columns have equal offsets. Return a pair of nested columns and offsets if true, otherwise throw.
std::pair<std::vector<const IColumn *>, const ColumnArray::Offset *>

View File

@ -918,16 +918,25 @@ public:
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
FunctionArgumentTypeValidators mandatory_args = {{[](const auto &) {return true;}, "ANY TYPE"}};
FunctionArgumentTypeValidators optional_args;
FunctionArgumentDescriptors mandatory_args = {{"Value", nullptr, nullptr, nullptr}};
FunctionArgumentDescriptors optional_args;
if constexpr (to_decimal || to_datetime64)
{
mandatory_args.push_back(FunctionArgumentTypeValidator{&isNativeInteger, "Integer"}); // scale
mandatory_args.push_back({"scale", &isNativeInteger, &isColumnConst, "const Integer"});
}
else
// toString(DateTime or DateTime64, [timezone: String])
if ((std::is_same_v<Name, NameToString> && arguments.size() > 0 && (isDateTime64(arguments[0].type) || isDateTime(arguments[0].type)))
// toUnixTimestamp(value[, timezone : String])
|| std::is_same_v<Name, NameToUnixTimestamp>
// toDate(value[, timezone : String])
|| std::is_same_v<ToDataType, DataTypeDate> // TODO: shall we allow timestamp argument for toDate? DateTime knows nothing about timezones and this arument is ignored below.
// toDateTime(value[, timezone: String])
|| std::is_same_v<ToDataType, DataTypeDateTime>
// toDateTime64(value, scale : Integer[, timezone: String])
|| std::is_same_v<ToDataType, DataTypeDateTime64>)
{
optional_args.push_back(FunctionArgumentTypeValidator{&isString, "String"}); // timezone
optional_args.push_back({"timezone", &isString, &isColumnConst, "const String"});
}
validateFunctionArgumentTypes(*this, arguments, mandatory_args, optional_args);
@ -938,8 +947,8 @@ public:
}
else if constexpr (to_decimal)
{
if (!arguments[1].column)
throw Exception("Second argument for function " + getName() + " must be constant", ErrorCodes::ILLEGAL_COLUMN);
// if (!arguments[1].column)
// throw Exception("Second argument for function " + getName() + " must be constant", ErrorCodes::ILLEGAL_COLUMN);
UInt64 scale = extractToDecimalScale(arguments[1]);

View File

@ -0,0 +1,225 @@
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeUUID.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnTuple.h>
#include <Interpreters/Context.h>
#include <Access/RowPolicyContext.h>
#include <Access/AccessControlManager.h>
#include <ext/range.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
/// The currentRowPolicies() function can be called with 0..2 arguments:
/// currentRowPolicies() returns array of tuples (database, table_name, row_policy_name) for all the row policies applied for the current user;
/// currentRowPolicies(table_name) is equivalent to currentRowPolicies(currentDatabase(), table_name);
/// currentRowPolicies(database, table_name) returns array of names of the row policies applied to a specific table and for the current user.
class FunctionCurrentRowPolicies : public IFunction
{
public:
static constexpr auto name = "currentRowPolicies";
static FunctionPtr create(const Context & context_) { return std::make_shared<FunctionCurrentRowPolicies>(context_); }
explicit FunctionCurrentRowPolicies(const Context & context_) : context(context_) {}
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 0; }
bool isVariadic() const override { return true; }
void checkNumberOfArgumentsIfVariadic(size_t number_of_arguments) const override
{
if (number_of_arguments > 2)
throw Exception("Number of arguments for function " + String(name) + " doesn't match: passed "
+ toString(number_of_arguments) + ", should be 0..2",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
}
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (arguments.empty())
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeTuple>(
DataTypes{std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>()}));
else
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeString>());
}
bool isDeterministic() const override { return false; }
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result_pos, size_t input_rows_count) override
{
if (arguments.empty())
{
auto database_column = ColumnString::create();
auto table_name_column = ColumnString::create();
auto policy_name_column = ColumnString::create();
for (const auto & policy_id : context.getRowPolicy()->getCurrentPolicyIDs())
{
const auto policy = context.getAccessControlManager().tryRead<RowPolicy>(policy_id);
if (policy)
{
const String database = policy->getDatabase();
const String table_name = policy->getTableName();
const String policy_name = policy->getName();
database_column->insertData(database.data(), database.length());
table_name_column->insertData(table_name.data(), table_name.length());
policy_name_column->insertData(policy_name.data(), policy_name.length());
}
}
auto offset_column = ColumnArray::ColumnOffsets::create();
offset_column->insertValue(policy_name_column->size());
block.getByPosition(result_pos).column = ColumnConst::create(
ColumnArray::create(
ColumnTuple::create(Columns{std::move(database_column), std::move(table_name_column), std::move(policy_name_column)}),
std::move(offset_column)),
input_rows_count);
return;
}
const IColumn * database_column = nullptr;
if (arguments.size() == 2)
{
const auto & database_column_with_type = block.getByPosition(arguments[0]);
if (!isStringOrFixedString(database_column_with_type.type))
throw Exception{"The first argument of function " + String(name)
+ " should be a string containing database name, illegal type: "
+ database_column_with_type.type->getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
database_column = database_column_with_type.column.get();
}
const auto & table_name_column_with_type = block.getByPosition(arguments[arguments.size() - 1]);
if (!isStringOrFixedString(table_name_column_with_type.type))
throw Exception{"The" + String(database_column ? " last" : "") + " argument of function " + String(name)
+ " should be a string containing table name, illegal type: " + table_name_column_with_type.type->getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
const IColumn * table_name_column = table_name_column_with_type.column.get();
auto policy_name_column = ColumnString::create();
auto offset_column = ColumnArray::ColumnOffsets::create();
for (const auto i : ext::range(0, input_rows_count))
{
String database = database_column ? database_column->getDataAt(i).toString() : context.getCurrentDatabase();
String table_name = table_name_column->getDataAt(i).toString();
for (const auto & policy_id : context.getRowPolicy()->getCurrentPolicyIDs(database, table_name))
{
const auto policy = context.getAccessControlManager().tryRead<RowPolicy>(policy_id);
if (policy)
{
const String policy_name = policy->getName();
policy_name_column->insertData(policy_name.data(), policy_name.length());
}
}
offset_column->insertValue(policy_name_column->size());
}
block.getByPosition(result_pos).column = ColumnArray::create(std::move(policy_name_column), std::move(offset_column));
}
private:
const Context & context;
};
/// The currentRowPolicyIDs() function can be called with 0..2 arguments:
/// currentRowPolicyIDs() returns array of IDs of all the row policies applied for the current user;
/// currentRowPolicyIDs(table_name) is equivalent to currentRowPolicyIDs(currentDatabase(), table_name);
/// currentRowPolicyIDs(database, table_name) returns array of IDs of the row policies applied to a specific table and for the current user.
class FunctionCurrentRowPolicyIDs : public IFunction
{
public:
static constexpr auto name = "currentRowPolicyIDs";
static FunctionPtr create(const Context & context_) { return std::make_shared<FunctionCurrentRowPolicyIDs>(context_); }
explicit FunctionCurrentRowPolicyIDs(const Context & context_) : context(context_) {}
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 0; }
bool isVariadic() const override { return true; }
void checkNumberOfArgumentsIfVariadic(size_t number_of_arguments) const override
{
if (number_of_arguments > 2)
throw Exception("Number of arguments for function " + String(name) + " doesn't match: passed "
+ toString(number_of_arguments) + ", should be 0..2",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
}
DataTypePtr getReturnTypeImpl(const DataTypes & /* arguments */) const override
{
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeUUID>());
}
bool isDeterministic() const override { return false; }
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result_pos, size_t input_rows_count) override
{
if (arguments.empty())
{
auto policy_id_column = ColumnVector<UInt128>::create();
for (const auto & policy_id : context.getRowPolicy()->getCurrentPolicyIDs())
policy_id_column->insertValue(policy_id);
auto offset_column = ColumnArray::ColumnOffsets::create();
offset_column->insertValue(policy_id_column->size());
block.getByPosition(result_pos).column
= ColumnConst::create(ColumnArray::create(std::move(policy_id_column), std::move(offset_column)), input_rows_count);
return;
}
const IColumn * database_column = nullptr;
if (arguments.size() == 2)
{
const auto & database_column_with_type = block.getByPosition(arguments[0]);
if (!isStringOrFixedString(database_column_with_type.type))
throw Exception{"The first argument of function " + String(name)
+ " should be a string containing database name, illegal type: "
+ database_column_with_type.type->getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
database_column = database_column_with_type.column.get();
}
const auto & table_name_column_with_type = block.getByPosition(arguments[arguments.size() - 1]);
if (!isStringOrFixedString(table_name_column_with_type.type))
throw Exception{"The" + String(database_column ? " last" : "") + " argument of function " + String(name)
+ " should be a string containing table name, illegal type: " + table_name_column_with_type.type->getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
const IColumn * table_name_column = table_name_column_with_type.column.get();
auto policy_id_column = ColumnVector<UInt128>::create();
auto offset_column = ColumnArray::ColumnOffsets::create();
for (const auto i : ext::range(0, input_rows_count))
{
String database = database_column ? database_column->getDataAt(i).toString() : context.getCurrentDatabase();
String table_name = table_name_column->getDataAt(i).toString();
for (const auto & policy_id : context.getRowPolicy()->getCurrentPolicyIDs(database, table_name))
policy_id_column->insertValue(policy_id);
offset_column->insertValue(policy_id_column->size());
}
block.getByPosition(result_pos).column = ColumnArray::create(std::move(policy_id_column), std::move(offset_column));
}
private:
const Context & context;
};
void registerFunctionCurrentRowPolicies(FunctionFactory & factory)
{
factory.registerFunction<FunctionCurrentRowPolicies>();
factory.registerFunction<FunctionCurrentRowPolicyIDs>();
}
}

View File

@ -9,6 +9,7 @@ class FunctionFactory;
void registerFunctionCurrentDatabase(FunctionFactory &);
void registerFunctionCurrentUser(FunctionFactory &);
void registerFunctionCurrentQuota(FunctionFactory &);
void registerFunctionCurrentRowPolicies(FunctionFactory &);
void registerFunctionHostName(FunctionFactory &);
void registerFunctionFQDN(FunctionFactory &);
void registerFunctionVisibleWidth(FunctionFactory &);

View File

@ -8,6 +8,7 @@ void registerFunctionsMiscellaneous(FunctionFactory & factory)
registerFunctionCurrentDatabase(factory);
registerFunctionCurrentUser(factory);
registerFunctionCurrentQuota(factory);
registerFunctionCurrentRowPolicies(factory);
registerFunctionHostName(factory);
registerFunctionFQDN(factory);
registerFunctionVisibleWidth(factory);

View File

@ -28,6 +28,7 @@
#include <Access/AccessControlManager.h>
#include <Access/SettingsConstraints.h>
#include <Access/QuotaContext.h>
#include <Access/RowPolicyContext.h>
#include <Interpreters/ExpressionJIT.h>
#include <Interpreters/UsersManager.h>
#include <Dictionaries/Embedded/GeoDictionariesLoader.h>
@ -333,6 +334,7 @@ Context Context::createGlobal()
{
Context res;
res.quota = std::make_shared<QuotaContext>();
res.row_policy = std::make_shared<RowPolicyContext>();
res.shared = std::make_shared<ContextShared>();
return res;
}
@ -625,6 +627,13 @@ void Context::checkQuotaManagementIsAllowed()
"User " + client_info.current_user + " doesn't have enough privileges to manage quotas", ErrorCodes::NOT_ENOUGH_PRIVILEGES);
}
void Context::checkRowPolicyManagementIsAllowed()
{
if (!is_row_policy_management_allowed)
throw Exception(
"User " + client_info.current_user + " doesn't have enough privileges to manage row policies", ErrorCodes::NOT_ENOUGH_PRIVILEGES);
}
void Context::setUsersConfig(const ConfigurationPtr & config)
{
auto lock = getLock();
@ -639,34 +648,6 @@ ConfigurationPtr Context::getUsersConfig()
return shared->users_config;
}
bool Context::hasUserProperty(const String & database, const String & table, const String & name) const
{
auto lock = getLock();
// No user - no properties.
if (client_info.current_user.empty())
return false;
const auto & props = shared->users_manager->getUser(client_info.current_user)->table_props;
auto db = props.find(database);
if (db == props.end())
return false;
auto table_props = db->second.find(table);
if (table_props == db->second.end())
return false;
return !!table_props->second.count(name);
}
const String & Context::getUserProperty(const String & database, const String & table, const String & name) const
{
auto lock = getLock();
const auto & props = shared->users_manager->getUser(client_info.current_user)->table_props;
return props.at(database).at(table).at(name);
}
void Context::calculateUserSettings()
{
auto lock = getLock();
@ -691,6 +672,8 @@ void Context::calculateUserSettings()
quota = getAccessControlManager().createQuotaContext(
client_info.current_user, client_info.current_address.host(), client_info.quota_key);
is_quota_management_allowed = user->is_quota_management_allowed;
row_policy = getAccessControlManager().getRowPolicyContext(client_info.current_user);
is_row_policy_management_allowed = user->is_row_policy_management_allowed;
}

View File

@ -45,6 +45,7 @@ namespace DB
struct ContextShared;
class Context;
class QuotaContext;
class RowPolicyContext;
class EmbeddedDictionaries;
class ExternalDictionariesLoader;
class ExternalModelsLoader;
@ -140,6 +141,8 @@ private:
std::shared_ptr<QuotaContext> quota; /// Current quota. By default - empty quota, that have no limits.
bool is_quota_management_allowed = false; /// Whether the current user is allowed to manage quotas via SQL commands.
std::shared_ptr<RowPolicyContext> row_policy;
bool is_row_policy_management_allowed = false; /// Whether the current user is allowed to manage row policies via SQL commands.
String current_database;
Settings settings; /// Setting for query execution.
std::shared_ptr<const SettingsConstraints> settings_constraints;
@ -210,6 +213,8 @@ public:
const AccessControlManager & getAccessControlManager() const;
std::shared_ptr<QuotaContext> getQuota() const { return quota; }
void checkQuotaManagementIsAllowed();
std::shared_ptr<RowPolicyContext> getRowPolicy() const { return row_policy; }
void checkRowPolicyManagementIsAllowed();
/** Take the list of users, quotas and configuration profiles from this config.
* The list of users is completely replaced.
@ -218,10 +223,6 @@ public:
void setUsersConfig(const ConfigurationPtr & config);
ConfigurationPtr getUsersConfig();
// User property is a key-value pair from the configuration entry: users.<username>.databases.<db_name>.<table_name>.<key_name>
bool hasUserProperty(const String & database, const String & table, const String & name) const;
const String & getUserProperty(const String & database, const String & table, const String & name) const;
/// Must be called before getClientInfo.
void setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address, const String & quota_key);

View File

@ -0,0 +1,93 @@
#include <Interpreters/InterpreterCreateRowPolicyQuery.h>
#include <Parsers/ASTCreateRowPolicyQuery.h>
#include <Parsers/ASTRoleList.h>
#include <Parsers/formatAST.h>
#include <Interpreters/Context.h>
#include <Access/AccessControlManager.h>
#include <boost/range/algorithm/sort.hpp>
namespace DB
{
BlockIO InterpreterCreateRowPolicyQuery::execute()
{
context.checkRowPolicyManagementIsAllowed();
const auto & query = query_ptr->as<const ASTCreateRowPolicyQuery &>();
auto & access_control = context.getAccessControlManager();
if (query.alter)
{
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
{
auto updated_policy = typeid_cast<std::shared_ptr<RowPolicy>>(entity->clone());
updateRowPolicyFromQuery(*updated_policy, query);
return updated_policy;
};
String full_name = query.name_parts.getFullName(context);
if (query.if_exists)
{
if (auto id = access_control.find<RowPolicy>(full_name))
access_control.tryUpdate(*id, update_func);
}
else
access_control.update(access_control.getID<RowPolicy>(full_name), update_func);
}
else
{
auto new_policy = std::make_shared<RowPolicy>();
updateRowPolicyFromQuery(*new_policy, query);
if (query.if_not_exists)
access_control.tryInsert(new_policy);
else if (query.or_replace)
access_control.insertOrReplace(new_policy);
else
access_control.insert(new_policy);
}
return {};
}
void InterpreterCreateRowPolicyQuery::updateRowPolicyFromQuery(RowPolicy & policy, const ASTCreateRowPolicyQuery & query)
{
if (query.alter)
{
if (!query.new_policy_name.empty())
policy.setName(query.new_policy_name);
}
else
{
policy.setDatabase(query.name_parts.database.empty() ? context.getCurrentDatabase() : query.name_parts.database);
policy.setTableName(query.name_parts.table_name);
policy.setName(query.name_parts.policy_name);
}
if (query.is_restrictive)
policy.setRestrictive(*query.is_restrictive);
for (const auto & [index, condition] : query.conditions)
policy.conditions[index] = condition ? serializeAST(*condition) : String{};
if (query.roles)
{
const auto & query_roles = *query.roles;
/// We keep `roles` sorted.
policy.roles = query_roles.roles;
if (query_roles.current_user)
policy.roles.push_back(context.getClientInfo().current_user);
boost::range::sort(policy.roles);
policy.roles.erase(std::unique(policy.roles.begin(), policy.roles.end()), policy.roles.end());
policy.all_roles = query_roles.all_roles;
/// We keep `except_roles` sorted.
policy.except_roles = query_roles.except_roles;
if (query_roles.except_current_user)
policy.except_roles.push_back(context.getClientInfo().current_user);
boost::range::sort(policy.except_roles);
policy.except_roles.erase(std::unique(policy.except_roles.begin(), policy.except_roles.end()), policy.except_roles.end());
}
}
}

View File

@ -0,0 +1,26 @@
#pragma once
#include <Interpreters/IInterpreter.h>
#include <Parsers/IAST_fwd.h>
namespace DB
{
class ASTCreateRowPolicyQuery;
struct RowPolicy;
class InterpreterCreateRowPolicyQuery : public IInterpreter
{
public:
InterpreterCreateRowPolicyQuery(const ASTPtr & query_ptr_, Context & context_) : query_ptr(query_ptr_), context(context_) {}
BlockIO execute() override;
private:
void updateRowPolicyFromQuery(RowPolicy & policy, const ASTCreateRowPolicyQuery & query);
ASTPtr query_ptr;
Context & context;
};
}

View File

@ -3,6 +3,8 @@
#include <Interpreters/Context.h>
#include <Access/AccessControlManager.h>
#include <Access/Quota.h>
#include <Access/RowPolicy.h>
#include <boost/range/algorithm/transform.hpp>
namespace DB
@ -24,6 +26,19 @@ BlockIO InterpreterDropAccessEntityQuery::execute()
access_control.remove(access_control.getIDs<Quota>(query.names));
return {};
}
case Kind::ROW_POLICY:
{
context.checkRowPolicyManagementIsAllowed();
Strings full_names;
boost::range::transform(
query.row_policies_names, std::back_inserter(full_names),
[this](const RowPolicy::FullNameParts & row_policy_name) { return row_policy_name.getFullName(context); });
if (query.if_exists)
access_control.tryRemove(access_control.find<RowPolicy>(full_names));
else
access_control.remove(access_control.getIDs<RowPolicy>(full_names));
return {};
}
}
__builtin_unreachable();

View File

@ -2,6 +2,7 @@
#include <Parsers/ASTCheckQuery.h>
#include <Parsers/ASTCreateQuery.h>
#include <Parsers/ASTCreateQuotaQuery.h>
#include <Parsers/ASTCreateRowPolicyQuery.h>
#include <Parsers/ASTDropAccessEntityQuery.h>
#include <Parsers/ASTDropQuery.h>
#include <Parsers/ASTInsertQuery.h>
@ -14,6 +15,7 @@
#include <Parsers/ASTShowCreateAccessEntityQuery.h>
#include <Parsers/ASTShowProcesslistQuery.h>
#include <Parsers/ASTShowQuotasQuery.h>
#include <Parsers/ASTShowRowPoliciesQuery.h>
#include <Parsers/ASTShowTablesQuery.h>
#include <Parsers/ASTUseQuery.h>
#include <Parsers/ASTExplainQuery.h>
@ -24,6 +26,7 @@
#include <Interpreters/InterpreterCheckQuery.h>
#include <Interpreters/InterpreterCreateQuery.h>
#include <Interpreters/InterpreterCreateQuotaQuery.h>
#include <Interpreters/InterpreterCreateRowPolicyQuery.h>
#include <Interpreters/InterpreterDescribeQuery.h>
#include <Interpreters/InterpreterExplainQuery.h>
#include <Interpreters/InterpreterDropAccessEntityQuery.h>
@ -41,6 +44,7 @@
#include <Interpreters/InterpreterShowCreateQuery.h>
#include <Interpreters/InterpreterShowProcesslistQuery.h>
#include <Interpreters/InterpreterShowQuotasQuery.h>
#include <Interpreters/InterpreterShowRowPoliciesQuery.h>
#include <Interpreters/InterpreterShowTablesQuery.h>
#include <Interpreters/InterpreterSystemQuery.h>
#include <Interpreters/InterpreterUseQuery.h>
@ -199,6 +203,10 @@ std::unique_ptr<IInterpreter> InterpreterFactory::get(ASTPtr & query, Context &
{
return std::make_unique<InterpreterCreateQuotaQuery>(query, context);
}
else if (query->as<ASTCreateRowPolicyQuery>())
{
return std::make_unique<InterpreterCreateRowPolicyQuery>(query, context);
}
else if (query->as<ASTDropAccessEntityQuery>())
{
return std::make_unique<InterpreterDropAccessEntityQuery>(query, context);
@ -211,6 +219,10 @@ std::unique_ptr<IInterpreter> InterpreterFactory::get(ASTPtr & query, Context &
{
return std::make_unique<InterpreterShowQuotasQuery>(query, context);
}
else if (query->as<ASTShowRowPoliciesQuery>())
{
return std::make_unique<InterpreterShowRowPoliciesQuery>(query, context);
}
else
throw Exception("Unknown type of query: " + query->getID(), ErrorCodes::UNKNOWN_TYPE_OF_QUERY);
}

View File

@ -38,6 +38,8 @@
#include <Parsers/ExpressionListParsers.h>
#include <Parsers/parseQuery.h>
#include <Access/RowPolicyContext.h>
#include <Interpreters/InterpreterSelectQuery.h>
#include <Interpreters/InterpreterSelectWithUnionQuery.h>
#include <Interpreters/InterpreterSetQuery.h>
@ -118,11 +120,10 @@ namespace
{
/// Assumes `storage` is set and the table filter (row-level security) is not empty.
String generateFilterActions(ExpressionActionsPtr & actions, const StoragePtr & storage, const Context & context, const Names & prerequisite_columns = {})
String generateFilterActions(ExpressionActionsPtr & actions, const Context & context, const StoragePtr & storage, const ASTPtr & row_policy_filter, const Names & prerequisite_columns = {})
{
const auto & db_name = storage->getDatabaseName();
const auto & table_name = storage->getTableName();
const auto & filter_str = context.getUserProperty(db_name, table_name, "filter");
/// TODO: implement some AST builders for this kind of stuff
ASTPtr query_ast = std::make_shared<ASTSelectQuery>();
@ -131,18 +132,15 @@ String generateFilterActions(ExpressionActionsPtr & actions, const StoragePtr &
select_ast->setExpression(ASTSelectQuery::Expression::SELECT, std::make_shared<ASTExpressionList>());
auto expr_list = select_ast->select();
auto parseExpression = [] (const String & expr)
{
ParserExpression expr_parser;
return parseQuery(expr_parser, expr, 0);
};
// The first column is our filter expression.
expr_list->children.push_back(parseExpression(filter_str));
expr_list->children.push_back(row_policy_filter);
/// Keep columns that are required after the filter actions.
for (const auto & column_str : prerequisite_columns)
expr_list->children.push_back(parseExpression(column_str));
{
ParserExpression expr_parser;
expr_list->children.push_back(parseQuery(expr_parser, column_str, 0));
}
select_ast->setExpression(ASTSelectQuery::Expression::TABLES, std::make_shared<ASTTablesInSelectQuery>());
auto tables = select_ast->tables();
@ -378,10 +376,11 @@ InterpreterSelectQuery::InterpreterSelectQuery(
source_header = storage->getSampleBlockForColumns(required_columns);
/// Fix source_header for filter actions.
if (context->hasUserProperty(storage->getDatabaseName(), storage->getTableName(), "filter"))
auto row_policy_filter = context->getRowPolicy()->getCondition(storage->getDatabaseName(), storage->getTableName(), RowPolicy::SELECT_FILTER);
if (row_policy_filter)
{
filter_info = std::make_shared<FilterInfo>();
filter_info->column_name = generateFilterActions(filter_info->actions, storage, *context, required_columns);
filter_info->column_name = generateFilterActions(filter_info->actions, *context, storage, row_policy_filter, required_columns);
source_header = storage->getSampleBlockForColumns(filter_info->actions->getRequiredColumns());
}
}
@ -502,7 +501,7 @@ Block InterpreterSelectQuery::getSampleBlockImpl()
/// PREWHERE optimization.
/// Turn off, if the table filter (row-level security) is applied.
if (storage && !context->hasUserProperty(storage->getDatabaseName(), storage->getTableName(), "filter"))
if (storage && !context->getRowPolicy()->getCondition(storage->getDatabaseName(), storage->getTableName(), RowPolicy::SELECT_FILTER))
{
query_analyzer->makeSetsForIndex(query.where());
query_analyzer->makeSetsForIndex(query.prewhere());
@ -1363,11 +1362,12 @@ void InterpreterSelectQuery::executeFetchColumns(
if (storage)
{
/// Append columns from the table filter to required
if (context->hasUserProperty(storage->getDatabaseName(), storage->getTableName(), "filter"))
auto row_policy_filter = context->getRowPolicy()->getCondition(storage->getDatabaseName(), storage->getTableName(), RowPolicy::SELECT_FILTER);
if (row_policy_filter)
{
auto initial_required_columns = required_columns;
ExpressionActionsPtr actions;
generateFilterActions(actions, storage, *context, initial_required_columns);
generateFilterActions(actions, *context, storage, row_policy_filter, initial_required_columns);
auto required_columns_from_filter = actions->getRequiredColumns();
for (const auto & column : required_columns_from_filter)

View File

@ -1,9 +1,12 @@
#include <Interpreters/InterpreterShowCreateAccessEntityQuery.h>
#include <Interpreters/Context.h>
#include <Parsers/ASTCreateQuotaQuery.h>
#include <Parsers/ASTCreateRowPolicyQuery.h>
#include <Parsers/ASTShowCreateAccessEntityQuery.h>
#include <Parsers/ASTRoleList.h>
#include <Parsers/ExpressionListParsers.h>
#include <Parsers/formatAST.h>
#include <Parsers/parseQuery.h>
#include <Access/AccessControlManager.h>
#include <Access/QuotaContext.h>
#include <Columns/ColumnString.h>
@ -28,7 +31,7 @@ BlockInputStreamPtr InterpreterShowCreateAccessEntityQuery::executeImpl()
const auto & show_query = query_ptr->as<ASTShowCreateAccessEntityQuery &>();
/// Build a create query.
ASTPtr create_query = getCreateQuotaQuery(show_query);
ASTPtr create_query = getCreateQuery(show_query);
/// Build the result column.
std::stringstream create_query_ss;
@ -49,6 +52,18 @@ BlockInputStreamPtr InterpreterShowCreateAccessEntityQuery::executeImpl()
}
ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuery(const ASTShowCreateAccessEntityQuery & show_query) const
{
using Kind = ASTShowCreateAccessEntityQuery::Kind;
switch (show_query.kind)
{
case Kind::QUOTA: return getCreateQuotaQuery(show_query);
case Kind::ROW_POLICY: return getCreateRowPolicyQuery(show_query);
}
__builtin_unreachable();
}
ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuotaQuery(const ASTShowCreateAccessEntityQuery & show_query) const
{
auto & access_control = context.getAccessControlManager();
@ -86,4 +101,38 @@ ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuotaQuery(const ASTShow
return create_query;
}
ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateRowPolicyQuery(const ASTShowCreateAccessEntityQuery & show_query) const
{
auto & access_control = context.getAccessControlManager();
RowPolicyPtr policy = access_control.read<RowPolicy>(show_query.row_policy_name.getFullName(context));
auto create_query = std::make_shared<ASTCreateRowPolicyQuery>();
create_query->name_parts = RowPolicy::FullNameParts{policy->getDatabase(), policy->getTableName(), policy->getName()};
if (policy->isRestrictive())
create_query->is_restrictive = policy->isRestrictive();
for (auto index : ext::range_with_static_cast<RowPolicy::ConditionIndex>(RowPolicy::MAX_CONDITION_INDEX))
{
const auto & condition = policy->conditions[index];
if (!condition.empty())
{
ParserExpression parser;
ASTPtr expr = parseQuery(parser, condition, 0);
create_query->conditions.push_back(std::pair{index, expr});
}
}
if (!policy->roles.empty() || policy->all_roles)
{
auto create_query_roles = std::make_shared<ASTRoleList>();
create_query_roles->roles = policy->roles;
create_query_roles->all_roles = policy->all_roles;
create_query_roles->except_roles = policy->except_roles;
create_query->roles = std::move(create_query_roles);
}
return create_query;
}
}

View File

@ -28,7 +28,9 @@ private:
const Context & context;
BlockInputStreamPtr executeImpl();
ASTPtr getCreateQuery(const ASTShowCreateAccessEntityQuery & show_query) const;
ASTPtr getCreateQuotaQuery(const ASTShowCreateAccessEntityQuery & show_query) const;
ASTPtr getCreateRowPolicyQuery(const ASTShowCreateAccessEntityQuery & show_query) const;
};

View File

@ -0,0 +1,68 @@
#include <Interpreters/InterpreterShowRowPoliciesQuery.h>
#include <Parsers/ASTShowRowPoliciesQuery.h>
#include <Parsers/formatAST.h>
#include <Interpreters/executeQuery.h>
#include <Common/quoteString.h>
#include <Interpreters/Context.h>
#include <ext/range.h>
namespace DB
{
InterpreterShowRowPoliciesQuery::InterpreterShowRowPoliciesQuery(const ASTPtr & query_ptr_, Context & context_)
: query_ptr(query_ptr_), context(context_)
{
}
BlockIO InterpreterShowRowPoliciesQuery::execute()
{
return executeQuery(getRewrittenQuery(), context, true);
}
String InterpreterShowRowPoliciesQuery::getRewrittenQuery() const
{
const auto & query = query_ptr->as<ASTShowRowPoliciesQuery &>();
const String & table_name = query.table_name;
String database;
if (!table_name.empty())
{
database = query.database;
if (database.empty())
database = context.getCurrentDatabase();
}
String filter;
if (query.current)
{
if (table_name.empty())
filter = "has(currentRowPolicyIDs(), id)";
else
filter = "has(currentRowPolicyIDs(" + quoteString(database) + ", " + quoteString(table_name) + "), id)";
}
else
{
if (!table_name.empty())
filter = "database = " + quoteString(database) + " AND table = " + quoteString(table_name);
}
String expr = table_name.empty() ? "full_name" : "name";
return "SELECT " + expr + " AS " + backQuote(getResultDescription()) + " from system.row_policies"
+ (filter.empty() ? "" : " WHERE " + filter) + " ORDER BY " + expr;
}
String InterpreterShowRowPoliciesQuery::getResultDescription() const
{
std::stringstream ss;
formatAST(*query_ptr, ss, false, true);
String desc = ss.str();
String prefix = "SHOW ";
if (startsWith(desc, prefix))
desc = desc.substr(prefix.length()); /// `desc` always starts with "SHOW ", so we can trim this prefix.
return desc;
}
}

View File

@ -0,0 +1,25 @@
#pragma once
#include <Interpreters/IInterpreter.h>
#include <Parsers/IAST_fwd.h>
namespace DB
{
class Context;
class InterpreterShowRowPoliciesQuery : public IInterpreter
{
public:
InterpreterShowRowPoliciesQuery(const ASTPtr & query_ptr_, Context & context_);
BlockIO execute() override;
private:
String getRewrittenQuery() const;
String getResultDescription() const;
ASTPtr query_ptr;
Context & context;
};
}

View File

@ -13,6 +13,7 @@
#include <Interpreters/InterpreterDropQuery.h>
#include <Interpreters/InterpreterCreateQuery.h>
#include <Interpreters/QueryLog.h>
#include <Interpreters/DDLWorker.h>
#include <Interpreters/PartLog.h>
#include <Interpreters/QueryThreadLog.h>
#include <Interpreters/TraceLog.h>
@ -101,14 +102,14 @@ void startStopAction(Context & context, ASTSystemQuery & query, StorageActionBlo
auto manager = context.getActionLocksManager();
manager->cleanExpired();
if (!query.target_table.empty())
if (!query.table.empty())
{
String database = !query.target_database.empty() ? query.target_database : context.getCurrentDatabase();
String database = !query.database.empty() ? query.database : context.getCurrentDatabase();
if (start)
manager->remove(database, query.target_table, action_type);
manager->remove(database, query.table, action_type);
else
manager->add(database, query.target_table, action_type);
manager->add(database, query.table, action_type);
}
else
{
@ -131,6 +132,9 @@ BlockIO InterpreterSystemQuery::execute()
{
auto & query = query_ptr->as<ASTSystemQuery &>();
if (!query.cluster.empty())
return executeDDLQueryOnCluster(query_ptr, context, {query.database});
using Type = ASTSystemQuery::Type;
/// Use global context with fresh system profile settings
@ -138,11 +142,11 @@ BlockIO InterpreterSystemQuery::execute()
system_context.setSetting("profile", context.getSystemProfileName());
/// Make canonical query for simpler processing
if (!query.target_table.empty() && query.target_database.empty())
query.target_database = context.getCurrentDatabase();
if (!query.table.empty() && query.database.empty())
query.database = context.getCurrentDatabase();
if (!query.target_dictionary.empty() && !query.target_database.empty())
query.target_dictionary = query.target_database + "." + query.target_dictionary;
if (!query.target_dictionary.empty() && !query.database.empty())
query.target_dictionary = query.database + "." + query.target_dictionary;
switch (query.type)
{
@ -237,8 +241,8 @@ BlockIO InterpreterSystemQuery::execute()
restartReplicas(system_context);
break;
case Type::RESTART_REPLICA:
if (!tryRestartReplica(query.target_database, query.target_table, system_context))
throw Exception("There is no " + query.target_database + "." + query.target_table + " replicated table",
if (!tryRestartReplica(query.database, query.table, system_context))
throw Exception("There is no " + query.database + "." + query.table + " replicated table",
ErrorCodes::BAD_ARGUMENTS);
break;
case Type::FLUSH_LOGS:
@ -338,8 +342,8 @@ void InterpreterSystemQuery::restartReplicas(Context & system_context)
void InterpreterSystemQuery::syncReplica(ASTSystemQuery & query)
{
String database_name = !query.target_database.empty() ? query.target_database : context.getCurrentDatabase();
const String & table_name = query.target_table;
String database_name = !query.database.empty() ? query.database : context.getCurrentDatabase();
const String & table_name = query.table;
StoragePtr table = context.getTable(database_name, table_name);
@ -361,8 +365,8 @@ void InterpreterSystemQuery::syncReplica(ASTSystemQuery & query)
void InterpreterSystemQuery::flushDistributed(ASTSystemQuery & query)
{
String database_name = !query.target_database.empty() ? query.target_database : context.getCurrentDatabase();
String & table_name = query.target_table;
String database_name = !query.database.empty() ? query.database : context.getCurrentDatabase();
String & table_name = query.table;
if (auto storage_distributed = dynamic_cast<StorageDistributed *>(context.getTable(database_name, table_name).get()))
storage_distributed->flushClusterNodesAllData();

View File

@ -5,6 +5,7 @@
#include <IO/ReadHelpers.h>
#include <Interpreters/Users.h>
#include <common/logger_useful.h>
#include <Poco/MD5Engine.h>
namespace DB
@ -102,36 +103,10 @@ User::User(const String & name_, const String & config_elem, const Poco::Util::A
}
}
/// Read properties per "database.table"
/// Only tables are expected to have properties, so that all the keys inside "database" are table names.
const auto config_databases = config_elem + ".databases";
if (config.has(config_databases))
{
Poco::Util::AbstractConfiguration::Keys database_names;
config.keys(config_databases, database_names);
/// Read tables within databases
for (const auto & database : database_names)
{
const auto config_database = config_databases + "." + database;
Poco::Util::AbstractConfiguration::Keys table_names;
config.keys(config_database, table_names);
/// Read table properties
for (const auto & table : table_names)
{
const auto config_filter = config_database + "." + table + ".filter";
if (config.has(config_filter))
{
const auto filter_query = config.getString(config_filter);
table_props[database][table]["filter"] = filter_query;
}
}
}
}
if (config.has(config_elem + ".allow_quota_management"))
is_quota_management_allowed = config.getBool(config_elem + ".allow_quota_management");
if (config.has(config_elem + ".allow_row_policy_management"))
is_row_policy_management_allowed = config.getBool(config_elem + ".allow_row_policy_management");
}
}

View File

@ -1,12 +1,13 @@
#pragma once
#include <Core/Types.h>
#include <Core/UUID.h>
#include <Access/Authentication.h>
#include <Access/AllowedClientHosts.h>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace Poco
@ -41,13 +42,8 @@ struct User
using DictionarySet = std::unordered_set<std::string>;
std::optional<DictionarySet> dictionaries;
/// Table properties.
using PropertyMap = std::unordered_map<std::string /* name */, std::string /* value */>;
using TableMap = std::unordered_map<std::string /* table */, PropertyMap /* properties */>;
using DatabaseMap = std::unordered_map<std::string /* database */, TableMap /* tables */>;
DatabaseMap table_props;
bool is_quota_management_allowed = false;
bool is_row_policy_management_allowed = false;
User(const String & name_, const String & config_elem, const Poco::Util::AbstractConfiguration & config);
};

View File

@ -0,0 +1,164 @@
#include <Parsers/ASTCreateRowPolicyQuery.h>
#include <Parsers/ASTRoleList.h>
#include <Parsers/formatAST.h>
#include <Common/quoteString.h>
#include <boost/range/algorithm/transform.hpp>
#include <sstream>
namespace DB
{
namespace
{
using ConditionIndex = RowPolicy::ConditionIndex;
void formatRenameTo(const String & new_policy_name, const IAST::FormatSettings & settings)
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " RENAME TO " << (settings.hilite ? IAST::hilite_none : "")
<< backQuote(new_policy_name);
}
void formatIsRestrictive(bool is_restrictive, const IAST::FormatSettings & settings)
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " AS " << (is_restrictive ? "RESTRICTIVE" : "PERMISSIVE")
<< (settings.hilite ? IAST::hilite_none : "");
}
void formatConditionalExpression(const ASTPtr & expr, const IAST::FormatSettings & settings)
{
if (!expr)
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " NONE" << (settings.hilite ? IAST::hilite_none : "");
return;
}
expr->format(settings);
}
std::vector<std::pair<ConditionIndex, String>>
conditionalExpressionsToStrings(const std::vector<std::pair<ConditionIndex, ASTPtr>> & exprs, const IAST::FormatSettings & settings)
{
std::vector<std::pair<ConditionIndex, String>> result;
std::stringstream ss;
IAST::FormatSettings temp_settings(ss, settings);
boost::range::transform(exprs, std::back_inserter(result), [&](const std::pair<ConditionIndex, ASTPtr> & in)
{
formatConditionalExpression(in.second, temp_settings);
auto out = std::pair{in.first, ss.str()};
ss.str("");
return out;
});
return result;
}
void formatConditions(const char * op, const std::optional<String> & filter, const std::optional<String> & check, bool alter, const IAST::FormatSettings & settings)
{
if (op)
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " FOR" << (settings.hilite ? IAST::hilite_none : "");
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << ' ' << op << (settings.hilite ? IAST::hilite_none : "");
}
if (filter)
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " USING " << (settings.hilite ? IAST::hilite_none : "") << *filter;
if (check && (alter || (check != filter)))
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " WITH CHECK " << (settings.hilite ? IAST::hilite_none : "") << *check;
}
void formatMultipleConditions(const std::vector<std::pair<ConditionIndex, ASTPtr>> & conditions, bool alter, const IAST::FormatSettings & settings)
{
std::optional<String> scond[RowPolicy::MAX_CONDITION_INDEX];
for (const auto & [index, scondition] : conditionalExpressionsToStrings(conditions, settings))
scond[index] = scondition;
if ((scond[RowPolicy::SELECT_FILTER] == scond[RowPolicy::UPDATE_FILTER])
&& (scond[RowPolicy::UPDATE_FILTER] == scond[RowPolicy::DELETE_FILTER])
&& (scond[RowPolicy::INSERT_CHECK] == scond[RowPolicy::UPDATE_CHECK])
&& (scond[RowPolicy::SELECT_FILTER] || scond[RowPolicy::INSERT_CHECK]))
{
formatConditions(nullptr, scond[RowPolicy::SELECT_FILTER], scond[RowPolicy::INSERT_CHECK], alter, settings);
return;
}
bool need_comma = false;
if (scond[RowPolicy::SELECT_FILTER])
{
if (std::exchange(need_comma, true))
settings.ostr << ',';
formatConditions("SELECT", scond[RowPolicy::SELECT_FILTER], {}, alter, settings);
}
if (scond[RowPolicy::INSERT_CHECK])
{
if (std::exchange(need_comma, true))
settings.ostr << ',';
formatConditions("INSERT", {}, scond[RowPolicy::INSERT_CHECK], alter, settings);
}
if (scond[RowPolicy::UPDATE_FILTER] || scond[RowPolicy::UPDATE_CHECK])
{
if (std::exchange(need_comma, true))
settings.ostr << ',';
formatConditions("UPDATE", scond[RowPolicy::UPDATE_FILTER], scond[RowPolicy::UPDATE_CHECK], alter, settings);
}
if (scond[RowPolicy::DELETE_FILTER])
{
if (std::exchange(need_comma, true))
settings.ostr << ',';
formatConditions("DELETE", scond[RowPolicy::DELETE_FILTER], {}, alter, settings);
}
}
void formatRoles(const ASTRoleList & roles, const IAST::FormatSettings & settings)
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " TO " << (settings.hilite ? IAST::hilite_none : "");
roles.format(settings);
}
}
String ASTCreateRowPolicyQuery::getID(char) const
{
return "CREATE POLICY or ALTER POLICY query";
}
ASTPtr ASTCreateRowPolicyQuery::clone() const
{
return std::make_shared<ASTCreateRowPolicyQuery>(*this);
}
void ASTCreateRowPolicyQuery::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const
{
settings.ostr << (settings.hilite ? hilite_keyword : "") << (alter ? "ALTER POLICY" : "CREATE POLICY")
<< (settings.hilite ? hilite_none : "");
if (if_exists)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " IF EXISTS" << (settings.hilite ? hilite_none : "");
else if (if_not_exists)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " IF NOT EXISTS" << (settings.hilite ? hilite_none : "");
else if (or_replace)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " OR REPLACE" << (settings.hilite ? hilite_none : "");
const String & database = name_parts.database;
const String & table_name = name_parts.table_name;
const String & policy_name = name_parts.policy_name;
settings.ostr << " " << backQuoteIfNeed(policy_name) << (settings.hilite ? hilite_keyword : "") << " ON "
<< (settings.hilite ? hilite_none : "") << (database.empty() ? String{} : backQuoteIfNeed(database) + ".") << table_name;
if (!new_policy_name.empty())
formatRenameTo(new_policy_name, settings);
if (is_restrictive)
formatIsRestrictive(*is_restrictive, settings);
formatMultipleConditions(conditions, alter, settings);
if (roles)
formatRoles(*roles, settings);
}
}

View File

@ -0,0 +1,50 @@
#pragma once
#include <Parsers/IAST.h>
#include <Access/RowPolicy.h>
#include <utility>
#include <vector>
namespace DB
{
class ASTRoleList;
/** CREATE [ROW] POLICY [IF NOT EXISTS | OR REPLACE] name ON [database.]table
* [AS {PERMISSIVE | RESTRICTIVE}]
* [FOR {SELECT | INSERT | UPDATE | DELETE | ALL}]
* [USING condition]
* [WITH CHECK condition] [,...]
* [TO {role [,...] | ALL | ALL EXCEPT role [,...]}]
*
* ALTER [ROW] POLICY [IF EXISTS] name ON [database.]table
* [RENAME TO new_name]
* [AS {PERMISSIVE | RESTRICTIVE}]
* [FOR {SELECT | INSERT | UPDATE | DELETE | ALL}]
* [USING {condition | NONE}]
* [WITH CHECK {condition | NONE}] [,...]
* [TO {role [,...] | ALL | ALL EXCEPT role [,...]}]
*/
class ASTCreateRowPolicyQuery : public IAST
{
public:
bool alter = false;
bool if_exists = false;
bool if_not_exists = false;
bool or_replace = false;
RowPolicy::FullNameParts name_parts;
String new_policy_name;
std::optional<bool> is_restrictive;
using ConditionIndex = RowPolicy::ConditionIndex;
std::vector<std::pair<ConditionIndex, ASTPtr>> conditions;
std::shared_ptr<ASTRoleList> roles;
String getID(char) const override;
ASTPtr clone() const override;
void formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override;
};
}

View File

@ -13,6 +13,7 @@ namespace
switch (kind)
{
case Kind::QUOTA: return "QUOTA";
case Kind::ROW_POLICY: return "POLICY";
}
__builtin_unreachable();
}
@ -44,13 +45,32 @@ void ASTDropAccessEntityQuery::formatImpl(const FormatSettings & settings, Forma
<< (if_exists ? " IF EXISTS" : "")
<< (settings.hilite ? hilite_none : "");
bool need_comma = false;
for (const auto & name : names)
if (kind == Kind::ROW_POLICY)
{
if (need_comma)
settings.ostr << ',';
need_comma = true;
settings.ostr << ' ' << backQuoteIfNeed(name);
bool need_comma = false;
for (const auto & row_policy_name : row_policies_names)
{
if (need_comma)
settings.ostr << ',';
need_comma = true;
const String & database = row_policy_name.database;
const String & table_name = row_policy_name.table_name;
const String & policy_name = row_policy_name.policy_name;
settings.ostr << ' ' << backQuoteIfNeed(policy_name) << (settings.hilite ? hilite_keyword : "") << " ON "
<< (settings.hilite ? hilite_none : "") << (database.empty() ? String{} : backQuoteIfNeed(database) + ".")
<< backQuoteIfNeed(table_name);
}
}
else
{
bool need_comma = false;
for (const auto & name : names)
{
if (need_comma)
settings.ostr << ',';
need_comma = true;
settings.ostr << ' ' << backQuoteIfNeed(name);
}
}
}
}

View File

@ -1,12 +1,14 @@
#pragma once
#include <Parsers/IAST.h>
#include <Access/RowPolicy.h>
namespace DB
{
/** DROP QUOTA [IF EXISTS] name [,...]
* DROP [ROW] POLICY [IF EXISTS] name [,...] ON [database.]table [,...]
*/
class ASTDropAccessEntityQuery : public IAST
{
@ -14,11 +16,13 @@ public:
enum class Kind
{
QUOTA,
ROW_POLICY,
};
const Kind kind;
const char * const keyword;
bool if_exists = false;
Strings names;
std::vector<RowPolicy::FullNameParts> row_policies_names;
ASTDropAccessEntityQuery(Kind kind_);
String getID(char) const override;

View File

@ -13,6 +13,7 @@ namespace
switch (kind)
{
case Kind::QUOTA: return "QUOTA";
case Kind::ROW_POLICY: return "POLICY";
}
__builtin_unreachable();
}
@ -43,7 +44,16 @@ void ASTShowCreateAccessEntityQuery::formatQueryImpl(const FormatSettings & sett
<< "SHOW CREATE " << keyword
<< (settings.hilite ? hilite_none : "");
if (current_quota)
if (kind == Kind::ROW_POLICY)
{
const String & database = row_policy_name.database;
const String & table_name = row_policy_name.table_name;
const String & policy_name = row_policy_name.policy_name;
settings.ostr << ' ' << backQuoteIfNeed(policy_name) << (settings.hilite ? hilite_keyword : "") << " ON "
<< (settings.hilite ? hilite_none : "") << (database.empty() ? String{} : backQuoteIfNeed(database) + ".")
<< backQuoteIfNeed(table_name);
}
else if ((kind == Kind::QUOTA) && current_quota)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " CURRENT" << (settings.hilite ? hilite_none : "");
else
settings.ostr << " " << backQuoteIfNeed(name);

View File

@ -1,11 +1,13 @@
#pragma once
#include <Parsers/ASTQueryWithOutput.h>
#include <Access/RowPolicy.h>
namespace DB
{
/** SHOW CREATE QUOTA [name | CURRENT]
* SHOW CREATE [ROW] POLICY name ON [database.]table
*/
class ASTShowCreateAccessEntityQuery : public ASTQueryWithOutput
{
@ -13,11 +15,13 @@ public:
enum class Kind
{
QUOTA,
ROW_POLICY,
};
const Kind kind;
const char * const keyword;
String name;
bool current_quota = false;
RowPolicy::FullNameParts row_policy_name;
ASTShowCreateAccessEntityQuery(Kind kind_);
String getID(char) const override;

View File

@ -0,0 +1,22 @@
#include <Parsers/ASTShowRowPoliciesQuery.h>
#include <Common/quoteString.h>
namespace DB
{
void ASTShowRowPoliciesQuery::formatQueryImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const
{
settings.ostr << (settings.hilite ? hilite_keyword : "") << "SHOW POLICIES" << (settings.hilite ? hilite_none : "");
if (current)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " CURRENT" << (settings.hilite ? hilite_none : "");
if (!table_name.empty())
{
settings.ostr << (settings.hilite ? hilite_keyword : "") << " ON " << (settings.hilite ? hilite_none : "");
if (!database.empty())
settings.ostr << backQuoteIfNeed(database) << ".";
settings.ostr << backQuoteIfNeed(table_name);
}
}
}

View File

@ -0,0 +1,23 @@
#pragma once
#include <Parsers/ASTQueryWithOutput.h>
namespace DB
{
/// SHOW [ROW] POLICIES [CURRENT] [ON [database.]table]
class ASTShowRowPoliciesQuery : public ASTQueryWithOutput
{
public:
bool current = false;
String database;
String table_name;
String getID(char) const override { return "SHOW POLICIES query"; }
ASTPtr clone() const override { return std::make_shared<ASTShowRowPoliciesQuery>(*this); }
protected:
void formatQueryImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override;
};
}

View File

@ -96,27 +96,30 @@ void ASTSystemQuery::formatImpl(const FormatSettings & settings, FormatState &,
auto print_database_table = [&]
{
settings.ostr << " ";
if (!target_database.empty())
if (!database.empty())
{
settings.ostr << (settings.hilite ? hilite_identifier : "") << backQuoteIfNeed(target_database)
settings.ostr << (settings.hilite ? hilite_identifier : "") << backQuoteIfNeed(database)
<< (settings.hilite ? hilite_none : "") << ".";
}
settings.ostr << (settings.hilite ? hilite_identifier : "") << backQuoteIfNeed(target_table)
settings.ostr << (settings.hilite ? hilite_identifier : "") << backQuoteIfNeed(table)
<< (settings.hilite ? hilite_none : "");
};
auto print_database_dictionary = [&]
{
settings.ostr << " ";
if (!target_database.empty())
if (!database.empty())
{
settings.ostr << (settings.hilite ? hilite_identifier : "") << backQuoteIfNeed(target_database)
settings.ostr << (settings.hilite ? hilite_identifier : "") << backQuoteIfNeed(database)
<< (settings.hilite ? hilite_none : "") << ".";
}
settings.ostr << (settings.hilite ? hilite_identifier : "") << backQuoteIfNeed(target_dictionary)
<< (settings.hilite ? hilite_none : "");
};
if (!cluster.empty())
formatOnCluster(settings);
if ( type == Type::STOP_MERGES
|| type == Type::START_MERGES
|| type == Type::STOP_TTL_MERGES
@ -132,7 +135,7 @@ void ASTSystemQuery::formatImpl(const FormatSettings & settings, FormatState &,
|| type == Type::STOP_DISTRIBUTED_SENDS
|| type == Type::START_DISTRIBUTED_SENDS)
{
if (!target_table.empty())
if (!table.empty())
print_database_table();
}
else if (type == Type::RESTART_REPLICA || type == Type::SYNC_REPLICA || type == Type::FLUSH_DISTRIBUTED)

View File

@ -1,13 +1,14 @@
#pragma once
#include "config_core.h"
#include <Parsers/ASTQueryWithOnCluster.h>
#include <Parsers/IAST.h>
namespace DB
{
class ASTSystemQuery : public IAST
class ASTSystemQuery : public IAST, public ASTQueryWithOnCluster
{
public:
@ -55,13 +56,18 @@ public:
Type type = Type::UNKNOWN;
String target_dictionary;
String target_database;
String target_table;
String database;
String table;
String getID(char) const override { return "SYSTEM query"; }
ASTPtr clone() const override { return std::make_shared<ASTSystemQuery>(*this); }
ASTPtr getRewrittenASTWithoutOnCluster(const std::string & new_database) const override
{
return removeOnCluster<ASTSystemQuery>(clone(), new_database);
}
protected:
void formatImpl(const FormatSettings & settings, FormatState & state, FormatStateStacked frame) const override;

View File

@ -0,0 +1,261 @@
#include <Parsers/ParserCreateRowPolicyQuery.h>
#include <Parsers/ASTCreateRowPolicyQuery.h>
#include <Access/RowPolicy.h>
#include <Parsers/ParserRoleList.h>
#include <Parsers/ASTRoleList.h>
#include <Parsers/parseIdentifierOrStringLiteral.h>
#include <Parsers/parseDatabaseAndTableName.h>
#include <Parsers/ExpressionListParsers.h>
#include <Parsers/ASTLiteral.h>
namespace DB
{
namespace ErrorCodes
{
extern const int SYNTAX_ERROR;
}
namespace
{
using ConditionIndex = RowPolicy::ConditionIndex;
bool parseRenameTo(IParserBase::Pos & pos, Expected & expected, String & new_policy_name, bool alter)
{
return IParserBase::wrapParseImpl(pos, [&]
{
if (!new_policy_name.empty() || !alter)
return false;
if (!ParserKeyword{"RENAME TO"}.ignore(pos, expected))
return false;
return parseIdentifierOrStringLiteral(pos, expected, new_policy_name);
});
}
bool parseIsRestrictive(IParserBase::Pos & pos, Expected & expected, std::optional<bool> & is_restrictive)
{
return IParserBase::wrapParseImpl(pos, [&]
{
if (is_restrictive)
return false;
if (!ParserKeyword{"AS"}.ignore(pos, expected))
return false;
if (ParserKeyword{"RESTRICTIVE"}.ignore(pos, expected))
is_restrictive = true;
else if (ParserKeyword{"PERMISSIVE"}.ignore(pos, expected))
is_restrictive = false;
else
return false;
return true;
});
}
bool parseConditionalExpression(IParserBase::Pos & pos, Expected & expected, std::optional<ASTPtr> & expr)
{
if (ParserKeyword("NONE").ignore(pos, expected))
{
expr = nullptr;
return true;
}
ParserExpression parser;
ASTPtr x;
if (parser.parse(pos, x, expected))
{
expr = x;
return true;
}
expr.reset();
return false;
}
bool parseConditions(IParserBase::Pos & pos, Expected & expected, std::vector<std::pair<ConditionIndex, ASTPtr>> & conditions, bool alter)
{
return IParserBase::wrapParseImpl(pos, [&]
{
static constexpr char select_op[] = "SELECT";
static constexpr char insert_op[] = "INSERT";
static constexpr char update_op[] = "UPDATE";
static constexpr char delete_op[] = "DELETE";
std::vector<const char *> ops;
bool keyword_for = false;
if (ParserKeyword{"FOR"}.ignore(pos, expected))
{
keyword_for = true;
do
{
if (ParserKeyword{"SELECT"}.ignore(pos, expected))
ops.push_back(select_op);
else if (ParserKeyword{"INSERT"}.ignore(pos, expected))
ops.push_back(insert_op);
else if (ParserKeyword{"UPDATE"}.ignore(pos, expected))
ops.push_back(update_op);
else if (ParserKeyword{"DELETE"}.ignore(pos, expected))
ops.push_back(delete_op);
else if (ParserKeyword{"ALL"}.ignore(pos, expected))
{
}
else
return false;
}
while (ParserToken{TokenType::Comma}.ignore(pos, expected));
}
if (ops.empty())
{
ops.push_back(select_op);
ops.push_back(insert_op);
ops.push_back(update_op);
ops.push_back(delete_op);
}
std::optional<ASTPtr> filter;
std::optional<ASTPtr> check;
bool keyword_using = false, keyword_with_check = false;
if (ParserKeyword{"USING"}.ignore(pos, expected))
{
keyword_using = true;
if (!parseConditionalExpression(pos, expected, filter))
return false;
}
if (ParserKeyword{"WITH CHECK"}.ignore(pos, expected))
{
keyword_with_check = true;
if (!parseConditionalExpression(pos, expected, check))
return false;
}
if (!keyword_for && !keyword_using && !keyword_with_check)
return false;
if (filter && !check && !alter)
check = filter;
auto set_condition = [&](ConditionIndex index, const ASTPtr & condition)
{
auto it = std::find_if(conditions.begin(), conditions.end(), [index](const std::pair<ConditionIndex, ASTPtr> & element)
{
return element.first == index;
});
if (it == conditions.end())
it = conditions.insert(conditions.end(), std::pair<ConditionIndex, ASTPtr>{index, nullptr});
it->second = condition;
};
for (const auto & op : ops)
{
if ((op == select_op) && filter)
set_condition(RowPolicy::SELECT_FILTER, *filter);
else if ((op == insert_op) && check)
set_condition(RowPolicy::INSERT_CHECK, *check);
else if (op == update_op)
{
if (filter)
set_condition(RowPolicy::UPDATE_FILTER, *filter);
if (check)
set_condition(RowPolicy::UPDATE_CHECK, *check);
}
else if ((op == delete_op) && filter)
set_condition(RowPolicy::DELETE_FILTER, *filter);
else
__builtin_unreachable();
}
return true;
});
}
bool parseMultipleConditions(IParserBase::Pos & pos, Expected & expected, std::vector<std::pair<ConditionIndex, ASTPtr>> & conditions, bool alter)
{
return IParserBase::wrapParseImpl(pos, [&]
{
do
{
if (!parseConditions(pos, expected, conditions, alter))
return false;
}
while (ParserToken{TokenType::Comma}.ignore(pos, expected));
return true;
});
}
bool parseRoles(IParserBase::Pos & pos, Expected & expected, std::shared_ptr<ASTRoleList> & roles)
{
return IParserBase::wrapParseImpl(pos, [&]
{
ASTPtr node;
if (roles || !ParserKeyword{"TO"}.ignore(pos, expected) || !ParserRoleList{}.parse(pos, node, expected))
return false;
roles = std::static_pointer_cast<ASTRoleList>(node);
return true;
});
}
}
bool ParserCreateRowPolicyQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
bool alter;
if (ParserKeyword{"CREATE POLICY"}.ignore(pos, expected) || ParserKeyword{"CREATE ROW POLICY"}.ignore(pos, expected))
alter = false;
else if (ParserKeyword{"ALTER POLICY"}.ignore(pos, expected) || ParserKeyword{"ALTER ROW POLICY"}.ignore(pos, expected))
alter = true;
else
return false;
bool if_exists = false;
bool if_not_exists = false;
bool or_replace = false;
if (alter)
{
if (ParserKeyword{"IF EXISTS"}.ignore(pos, expected))
if_exists = true;
}
else
{
if (ParserKeyword{"IF NOT EXISTS"}.ignore(pos, expected))
if_not_exists = true;
else if (ParserKeyword{"OR REPLACE"}.ignore(pos, expected))
or_replace = true;
}
RowPolicy::FullNameParts name_parts;
String & database = name_parts.database;
String & table_name = name_parts.table_name;
String & policy_name = name_parts.policy_name;
if (!parseIdentifierOrStringLiteral(pos, expected, policy_name) || !ParserKeyword{"ON"}.ignore(pos, expected)
|| !parseDatabaseAndTableName(pos, expected, database, table_name))
return false;
String new_policy_name;
std::optional<bool> is_restrictive;
std::vector<std::pair<ConditionIndex, ASTPtr>> conditions;
std::shared_ptr<ASTRoleList> roles;
while (parseRenameTo(pos, expected, new_policy_name, alter) || parseIsRestrictive(pos, expected, is_restrictive)
|| parseMultipleConditions(pos, expected, conditions, alter) || parseRoles(pos, expected, roles))
;
auto query = std::make_shared<ASTCreateRowPolicyQuery>();
node = query;
query->alter = alter;
query->if_exists = if_exists;
query->if_not_exists = if_not_exists;
query->or_replace = or_replace;
query->name_parts = std::move(name_parts);
query->new_policy_name = std::move(new_policy_name);
query->is_restrictive = is_restrictive;
query->conditions = std::move(conditions);
query->roles = std::move(roles);
return true;
}
}

View File

@ -0,0 +1,30 @@
#pragma once
#include <Parsers/IParserBase.h>
namespace DB
{
/** Parses queries like
* CREATE [ROW] POLICY [IF NOT EXISTS | OR REPLACE] name ON [database.]table
* [AS {PERMISSIVE | RESTRICTIVE}]
* [FOR {SELECT | INSERT | UPDATE | DELETE | ALL}]
* [USING condition]
* [WITH CHECK condition] [,...]
* [TO {role [,...] | ALL | ALL EXCEPT role [,...]}]
*
* ALTER [ROW] POLICY [IF EXISTS] name ON [database.]table
* [RENAME TO new_name]
* [AS {PERMISSIVE | RESTRICTIVE}]
* [FOR {SELECT | INSERT | UPDATE | DELETE | ALL}]
* [USING {condition | NONE}]
* [WITH CHECK {condition | NONE}] [,...]
* [TO {role [,...] | ALL | ALL EXCEPT role [,...]}]
*/
class ParserCreateRowPolicyQuery : public IParserBase
{
protected:
const char * getName() const override { return "CREATE ROW POLICY or ALTER ROW POLICY query"; }
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override;
};
}

View File

@ -2,11 +2,30 @@
#include <Parsers/ASTDropAccessEntityQuery.h>
#include <Parsers/CommonParsers.h>
#include <Parsers/parseIdentifierOrStringLiteral.h>
#include <Parsers/parseDatabaseAndTableName.h>
#include <Access/Quota.h>
namespace DB
{
namespace
{
bool parseNames(IParserBase::Pos & pos, Expected & expected, Strings & names)
{
do
{
String name;
if (!parseIdentifierOrStringLiteral(pos, expected, name))
return false;
names.push_back(std::move(name));
}
while (ParserToken{TokenType::Comma}.ignore(pos, expected));
return true;
}
}
bool ParserDropAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
if (!ParserKeyword{"DROP"}.ignore(pos, expected))
@ -16,6 +35,8 @@ bool ParserDropAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expected &
Kind kind;
if (ParserKeyword{"QUOTA"}.ignore(pos, expected))
kind = Kind::QUOTA;
else if (ParserKeyword{"POLICY"}.ignore(pos, expected) || ParserKeyword{"ROW POLICY"}.ignore(pos, expected))
kind = Kind::ROW_POLICY;
else
return false;
@ -24,21 +45,35 @@ bool ParserDropAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expected &
if_exists = true;
Strings names;
do
{
String name;
if (!parseIdentifierOrStringLiteral(pos, expected, name))
return false;
std::vector<RowPolicy::FullNameParts> row_policies_names;
names.push_back(std::move(name));
if (kind == Kind::ROW_POLICY)
{
do
{
Strings policy_names;
if (!parseNames(pos, expected, policy_names))
return false;
String database, table_name;
if (!ParserKeyword{"ON"}.ignore(pos, expected) || !parseDatabaseAndTableName(pos, expected, database, table_name))
return false;
for (const String & policy_name : policy_names)
row_policies_names.push_back({database, table_name, policy_name});
}
while (ParserToken{TokenType::Comma}.ignore(pos, expected));
}
else
{
if (!parseNames(pos, expected, names))
return false;
}
while (ParserToken{TokenType::Comma}.ignore(pos, expected));
auto query = std::make_shared<ASTDropAccessEntityQuery>(kind);
node = query;
query->if_exists = if_exists;
query->names = std::move(names);
query->row_policies_names = std::move(row_policies_names);
return true;
}

View File

@ -10,6 +10,7 @@
#include <Parsers/ParserAlterQuery.h>
#include <Parsers/ParserSystemQuery.h>
#include <Parsers/ParserCreateQuotaQuery.h>
#include <Parsers/ParserCreateRowPolicyQuery.h>
#include <Parsers/ParserDropAccessEntityQuery.h>
@ -25,6 +26,7 @@ bool ParserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
ParserSetQuery set_p;
ParserSystemQuery system_p;
ParserCreateQuotaQuery create_quota_p;
ParserCreateRowPolicyQuery create_row_policy_p;
ParserDropAccessEntityQuery drop_access_entity_p;
bool res = query_with_output_p.parse(pos, node, expected)
@ -33,6 +35,7 @@ bool ParserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
|| set_p.parse(pos, node, expected)
|| system_p.parse(pos, node, expected)
|| create_quota_p.parse(pos, node, expected)
|| create_row_policy_p.parse(pos, node, expected)
|| drop_access_entity_p.parse(pos, node, expected);
return res;

View File

@ -16,6 +16,7 @@
#include <Parsers/ASTExplainQuery.h>
#include <Parsers/ParserShowCreateAccessEntityQuery.h>
#include <Parsers/ParserShowQuotasQuery.h>
#include <Parsers/ParserShowRowPoliciesQuery.h>
namespace DB
@ -38,6 +39,7 @@ bool ParserQueryWithOutput::parseImpl(Pos & pos, ASTPtr & node, Expected & expec
ParserWatchQuery watch_p;
ParserShowCreateAccessEntityQuery show_create_access_entity_p;
ParserShowQuotasQuery show_quotas_p;
ParserShowRowPoliciesQuery show_row_policies_p;
ASTPtr query;
@ -66,7 +68,8 @@ bool ParserQueryWithOutput::parseImpl(Pos & pos, ASTPtr & node, Expected & expec
|| kill_query_p.parse(pos, query, expected)
|| optimize_p.parse(pos, query, expected)
|| watch_p.parse(pos, query, expected)
|| show_quotas_p.parse(pos, query, expected);
|| show_quotas_p.parse(pos, query, expected)
|| show_row_policies_p.parse(pos, query, expected);
if (!parsed)
return false;

View File

@ -2,6 +2,8 @@
#include <Parsers/ASTShowCreateAccessEntityQuery.h>
#include <Parsers/CommonParsers.h>
#include <Parsers/parseIdentifierOrStringLiteral.h>
#include <Parsers/parseDatabaseAndTableName.h>
#include <assert.h>
namespace DB
@ -15,25 +17,41 @@ bool ParserShowCreateAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expe
Kind kind;
if (ParserKeyword{"QUOTA"}.ignore(pos, expected))
kind = Kind::QUOTA;
else if (ParserKeyword{"POLICY"}.ignore(pos, expected) || ParserKeyword{"ROW POLICY"}.ignore(pos, expected))
kind = Kind::ROW_POLICY;
else
return false;
String name;
bool current_quota = false;
RowPolicy::FullNameParts row_policy_name;
if ((kind == Kind::QUOTA) && ParserKeyword{"CURRENT"}.ignore(pos, expected))
if (kind == Kind::ROW_POLICY)
{
/// SHOW CREATE QUOTA CURRENT
current_quota = true;
}
else if (parseIdentifierOrStringLiteral(pos, expected, name))
{
/// SHOW CREATE QUOTA name
String & database = row_policy_name.database;
String & table_name = row_policy_name.table_name;
String & policy_name = row_policy_name.policy_name;
if (!parseIdentifierOrStringLiteral(pos, expected, policy_name) || !ParserKeyword{"ON"}.ignore(pos, expected)
|| !parseDatabaseAndTableName(pos, expected, database, table_name))
return false;
}
else
{
/// SHOW CREATE QUOTA
current_quota = true;
assert(kind == Kind::QUOTA);
if (ParserKeyword{"CURRENT"}.ignore(pos, expected))
{
/// SHOW CREATE QUOTA CURRENT
current_quota = true;
}
else if (parseIdentifierOrStringLiteral(pos, expected, name))
{
/// SHOW CREATE QUOTA name
}
else
{
/// SHOW CREATE QUOTA
current_quota = true;
}
}
auto query = std::make_shared<ASTShowCreateAccessEntityQuery>(kind);
@ -41,6 +59,7 @@ bool ParserShowCreateAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expe
query->name = std::move(name);
query->current_quota = current_quota;
query->row_policy_name = std::move(row_policy_name);
return true;
}

View File

@ -0,0 +1,40 @@
#include <Parsers/ParserShowRowPoliciesQuery.h>
#include <Parsers/ASTShowRowPoliciesQuery.h>
#include <Parsers/CommonParsers.h>
#include <Parsers/parseDatabaseAndTableName.h>
namespace DB
{
namespace
{
bool parseONDatabaseAndTableName(IParserBase::Pos & pos, Expected & expected, String & database, String & table_name)
{
return IParserBase::wrapParseImpl(pos, [&]
{
database.clear();
table_name.clear();
return ParserKeyword{"ON"}.ignore(pos, expected) && parseDatabaseAndTableName(pos, expected, database, table_name);
});
}
}
bool ParserShowRowPoliciesQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
if (!ParserKeyword{"SHOW POLICIES"}.ignore(pos, expected) && !ParserKeyword{"SHOW ROW POLICIES"}.ignore(pos, expected))
return false;
bool current = ParserKeyword{"CURRENT"}.ignore(pos, expected);
String database, table_name;
parseONDatabaseAndTableName(pos, expected, database, table_name);
auto query = std::make_shared<ASTShowRowPoliciesQuery>();
query->current = current;
query->database = std::move(database);
query->table_name = std::move(table_name);
node = query;
return true;
}
}

View File

@ -0,0 +1,17 @@
#pragma once
#include <Parsers/IParserBase.h>
namespace DB
{
/** Parses queries like
* SHOW [ROW] POLICIES [CURRENT] [ON [database.]table]
*/
class ParserShowRowPoliciesQuery : public IParserBase
{
protected:
const char * getName() const override { return "SHOW POLICIES query"; }
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override;
};
}

View File

@ -43,10 +43,17 @@ bool ParserSystemQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Expected &
{
case Type::RELOAD_DICTIONARY:
{
String cluster_str;
if (ParserKeyword{"ON"}.ignore(pos, expected))
{
if (!ASTQueryWithOnCluster::parse(pos, cluster_str, expected))
return false;
}
res->cluster = cluster_str;
ASTPtr ast;
if (ParserStringLiteral{}.parse(pos, ast, expected))
res->target_dictionary = ast->as<ASTLiteral &>().value.safeGet<String>();
else if (!parseDatabaseAndTableName(pos, expected, res->target_database, res->target_dictionary))
else if (!parseDatabaseAndTableName(pos, expected, res->database, res->target_dictionary))
return false;
break;
}
@ -54,7 +61,7 @@ bool ParserSystemQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Expected &
case Type::RESTART_REPLICA:
case Type::SYNC_REPLICA:
case Type::FLUSH_DISTRIBUTED:
if (!parseDatabaseAndTableName(pos, expected, res->target_database, res->target_table))
if (!parseDatabaseAndTableName(pos, expected, res->database, res->table))
return false;
break;
@ -72,7 +79,7 @@ bool ParserSystemQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Expected &
case Type::START_REPLICATION_QUEUES:
case Type::STOP_DISTRIBUTED_SENDS:
case Type::START_DISTRIBUTED_SENDS:
parseDatabaseAndTableName(pos, expected, res->target_database, res->target_table);
parseDatabaseAndTableName(pos, expected, res->database, res->table);
break;
default:

View File

@ -27,7 +27,7 @@ public:
/** In some usecase (hello Kafka) we need to read a lot of tiny streams in exactly the same format.
* The recreating of parser for each small stream takes too long, so we introduce a method
* resetParser() which allow to reset the state of parser to continure reading of
* resetParser() which allow to reset the state of parser to continue reading of
* source stream w/o recreating that.
* That should be called after current buffer was fully read.
*/

View File

@ -61,14 +61,14 @@ namespace DB
/// Inserts numeric data right into internal column data to reduce an overhead
template <typename NumericType, typename VectorType = ColumnVector<NumericType>>
static void fillColumnWithNumericData(std::shared_ptr<arrow::Column> & arrow_column, MutableColumnPtr & internal_column)
static void fillColumnWithNumericData(std::shared_ptr<arrow::ChunkedArray> & arrow_column, MutableColumnPtr & internal_column)
{
auto & column_data = static_cast<VectorType &>(*internal_column).getData();
column_data.reserve(arrow_column->length());
for (size_t chunk_i = 0, num_chunks = static_cast<size_t>(arrow_column->data()->num_chunks()); chunk_i < num_chunks; ++chunk_i)
for (size_t chunk_i = 0, num_chunks = static_cast<size_t>(arrow_column->num_chunks()); chunk_i < num_chunks; ++chunk_i)
{
std::shared_ptr<arrow::Array> chunk = arrow_column->data()->chunk(chunk_i);
std::shared_ptr<arrow::Array> chunk = arrow_column->chunk(chunk_i);
/// buffers[0] is a null bitmap and buffers[1] are actual values
std::shared_ptr<arrow::Buffer> buffer = chunk->data()->buffers[1];
@ -80,15 +80,15 @@ namespace DB
/// Inserts chars and offsets right into internal column data to reduce an overhead.
/// Internal offsets are shifted by one to the right in comparison with Arrow ones. So the last offset should map to the end of all chars.
/// Also internal strings are null terminated.
static void fillColumnWithStringData(std::shared_ptr<arrow::Column> & arrow_column, MutableColumnPtr & internal_column)
static void fillColumnWithStringData(std::shared_ptr<arrow::ChunkedArray> & arrow_column, MutableColumnPtr & internal_column)
{
PaddedPODArray<UInt8> & column_chars_t = assert_cast<ColumnString &>(*internal_column).getChars();
PaddedPODArray<UInt64> & column_offsets = assert_cast<ColumnString &>(*internal_column).getOffsets();
size_t chars_t_size = 0;
for (size_t chunk_i = 0, num_chunks = static_cast<size_t>(arrow_column->data()->num_chunks()); chunk_i < num_chunks; ++chunk_i)
for (size_t chunk_i = 0, num_chunks = static_cast<size_t>(arrow_column->num_chunks()); chunk_i < num_chunks; ++chunk_i)
{
arrow::BinaryArray & chunk = static_cast<arrow::BinaryArray &>(*(arrow_column->data()->chunk(chunk_i)));
arrow::BinaryArray & chunk = static_cast<arrow::BinaryArray &>(*(arrow_column->chunk(chunk_i)));
const size_t chunk_length = chunk.length();
chars_t_size += chunk.value_offset(chunk_length - 1) + chunk.value_length(chunk_length - 1);
@ -98,9 +98,9 @@ namespace DB
column_chars_t.reserve(chars_t_size);
column_offsets.reserve(arrow_column->length());
for (size_t chunk_i = 0, num_chunks = static_cast<size_t>(arrow_column->data()->num_chunks()); chunk_i < num_chunks; ++chunk_i)
for (size_t chunk_i = 0, num_chunks = static_cast<size_t>(arrow_column->num_chunks()); chunk_i < num_chunks; ++chunk_i)
{
arrow::BinaryArray & chunk = static_cast<arrow::BinaryArray &>(*(arrow_column->data()->chunk(chunk_i)));
arrow::BinaryArray & chunk = static_cast<arrow::BinaryArray &>(*(arrow_column->chunk(chunk_i)));
std::shared_ptr<arrow::Buffer> buffer = chunk.value_data();
const size_t chunk_length = chunk.length();
@ -118,14 +118,14 @@ namespace DB
}
}
static void fillColumnWithBooleanData(std::shared_ptr<arrow::Column> & arrow_column, MutableColumnPtr & internal_column)
static void fillColumnWithBooleanData(std::shared_ptr<arrow::ChunkedArray> & arrow_column, MutableColumnPtr & internal_column)
{
auto & column_data = assert_cast<ColumnVector<UInt8> &>(*internal_column).getData();
column_data.reserve(arrow_column->length());
for (size_t chunk_i = 0, num_chunks = static_cast<size_t>(arrow_column->data()->num_chunks()); chunk_i < num_chunks; ++chunk_i)
for (size_t chunk_i = 0, num_chunks = static_cast<size_t>(arrow_column->num_chunks()); chunk_i < num_chunks; ++chunk_i)
{
arrow::BooleanArray & chunk = static_cast<arrow::BooleanArray &>(*(arrow_column->data()->chunk(chunk_i)));
arrow::BooleanArray & chunk = static_cast<arrow::BooleanArray &>(*(arrow_column->chunk(chunk_i)));
/// buffers[0] is a null bitmap and buffers[1] are actual values
std::shared_ptr<arrow::Buffer> buffer = chunk.data()->buffers[1];
@ -135,14 +135,14 @@ namespace DB
}
/// Arrow stores Parquet::DATE in Int32, while ClickHouse stores Date in UInt16. Therefore, it should be checked before saving
static void fillColumnWithDate32Data(std::shared_ptr<arrow::Column> & arrow_column, MutableColumnPtr & internal_column)
static void fillColumnWithDate32Data(std::shared_ptr<arrow::ChunkedArray> & arrow_column, MutableColumnPtr & internal_column)
{
PaddedPODArray<UInt16> & column_data = assert_cast<ColumnVector<UInt16> &>(*internal_column).getData();
column_data.reserve(arrow_column->length());
for (size_t chunk_i = 0, num_chunks = static_cast<size_t>(arrow_column->data()->num_chunks()); chunk_i < num_chunks; ++chunk_i)
for (size_t chunk_i = 0, num_chunks = static_cast<size_t>(arrow_column->num_chunks()); chunk_i < num_chunks; ++chunk_i)
{
arrow::Date32Array & chunk = static_cast<arrow::Date32Array &>(*(arrow_column->data()->chunk(chunk_i)));
arrow::Date32Array & chunk = static_cast<arrow::Date32Array &>(*(arrow_column->chunk(chunk_i)));
for (size_t value_i = 0, length = static_cast<size_t>(chunk.length()); value_i < length; ++value_i)
{
@ -150,7 +150,7 @@ namespace DB
if (days_num > DATE_LUT_MAX_DAY_NUM)
{
// TODO: will it rollback correctly?
throw Exception{"Input value " + std::to_string(days_num) + " of a column \"" + arrow_column->name()
throw Exception{"Input value " + std::to_string(days_num) + " of a column \"" + internal_column->getName()
+ "\" is greater than "
"max allowed Date value, which is "
+ std::to_string(DATE_LUT_MAX_DAY_NUM),
@ -163,14 +163,14 @@ namespace DB
}
/// Arrow stores Parquet::DATETIME in Int64, while ClickHouse stores DateTime in UInt32. Therefore, it should be checked before saving
static void fillColumnWithDate64Data(std::shared_ptr<arrow::Column> & arrow_column, MutableColumnPtr & internal_column)
static void fillColumnWithDate64Data(std::shared_ptr<arrow::ChunkedArray> & arrow_column, MutableColumnPtr & internal_column)
{
auto & column_data = assert_cast<ColumnVector<UInt32> &>(*internal_column).getData();
column_data.reserve(arrow_column->length());
for (size_t chunk_i = 0, num_chunks = static_cast<size_t>(arrow_column->data()->num_chunks()); chunk_i < num_chunks; ++chunk_i)
for (size_t chunk_i = 0, num_chunks = static_cast<size_t>(arrow_column->num_chunks()); chunk_i < num_chunks; ++chunk_i)
{
auto & chunk = static_cast<arrow::Date64Array &>(*(arrow_column->data()->chunk(chunk_i)));
auto & chunk = static_cast<arrow::Date64Array &>(*(arrow_column->chunk(chunk_i)));
for (size_t value_i = 0, length = static_cast<size_t>(chunk.length()); value_i < length; ++value_i)
{
auto timestamp = static_cast<UInt32>(chunk.Value(value_i) / 1000); // Always? in ms
@ -179,14 +179,14 @@ namespace DB
}
}
static void fillColumnWithTimestampData(std::shared_ptr<arrow::Column> & arrow_column, MutableColumnPtr & internal_column)
static void fillColumnWithTimestampData(std::shared_ptr<arrow::ChunkedArray> & arrow_column, MutableColumnPtr & internal_column)
{
auto & column_data = assert_cast<ColumnVector<UInt32> &>(*internal_column).getData();
column_data.reserve(arrow_column->length());
for (size_t chunk_i = 0, num_chunks = static_cast<size_t>(arrow_column->data()->num_chunks()); chunk_i < num_chunks; ++chunk_i)
for (size_t chunk_i = 0, num_chunks = static_cast<size_t>(arrow_column->num_chunks()); chunk_i < num_chunks; ++chunk_i)
{
auto & chunk = static_cast<arrow::TimestampArray &>(*(arrow_column->data()->chunk(chunk_i)));
auto & chunk = static_cast<arrow::TimestampArray &>(*(arrow_column->chunk(chunk_i)));
const auto & type = static_cast<const ::arrow::TimestampType &>(*chunk.type());
UInt32 divide = 1;
@ -215,15 +215,15 @@ namespace DB
}
}
static void fillColumnWithDecimalData(std::shared_ptr<arrow::Column> & arrow_column, MutableColumnPtr & internal_column)
static void fillColumnWithDecimalData(std::shared_ptr<arrow::ChunkedArray> & arrow_column, MutableColumnPtr & internal_column)
{
auto & column = assert_cast<ColumnDecimal<Decimal128> &>(*internal_column);
auto & column_data = column.getData();
column_data.reserve(arrow_column->length());
for (size_t chunk_i = 0, num_chunks = static_cast<size_t>(arrow_column->data()->num_chunks()); chunk_i < num_chunks; ++chunk_i)
for (size_t chunk_i = 0, num_chunks = static_cast<size_t>(arrow_column->num_chunks()); chunk_i < num_chunks; ++chunk_i)
{
auto & chunk = static_cast<arrow::DecimalArray &>(*(arrow_column->data()->chunk(chunk_i)));
auto & chunk = static_cast<arrow::DecimalArray &>(*(arrow_column->chunk(chunk_i)));
for (size_t value_i = 0, length = static_cast<size_t>(chunk.length()); value_i < length; ++value_i)
{
column_data.emplace_back(chunk.IsNull(value_i) ? Decimal128(0) : *reinterpret_cast<const Decimal128 *>(chunk.Value(value_i))); // TODO: copy column
@ -232,14 +232,14 @@ namespace DB
}
/// Creates a null bytemap from arrow's null bitmap
static void fillByteMapFromArrowColumn(std::shared_ptr<arrow::Column> & arrow_column, MutableColumnPtr & bytemap)
static void fillByteMapFromArrowColumn(std::shared_ptr<arrow::ChunkedArray> & arrow_column, MutableColumnPtr & bytemap)
{
PaddedPODArray<UInt8> & bytemap_data = assert_cast<ColumnVector<UInt8> &>(*bytemap).getData();
bytemap_data.reserve(arrow_column->length());
for (size_t chunk_i = 0; chunk_i != static_cast<size_t>(arrow_column->data()->num_chunks()); ++chunk_i)
for (size_t chunk_i = 0; chunk_i != static_cast<size_t>(arrow_column->num_chunks()); ++chunk_i)
{
std::shared_ptr<arrow::Array> chunk = arrow_column->data()->chunk(chunk_i);
std::shared_ptr<arrow::Array> chunk = arrow_column->chunk(chunk_i);
for (size_t value_i = 0; value_i != static_cast<size_t>(chunk->length()); ++value_i)
bytemap_data.emplace_back(chunk->IsNull(value_i));
@ -255,7 +255,7 @@ namespace DB
columns_list.reserve(header.rows());
using NameToColumnPtr = std::unordered_map<std::string, std::shared_ptr<arrow::Column>>;
using NameToColumnPtr = std::unordered_map<std::string, std::shared_ptr<arrow::ChunkedArray>>;
if (!read_status.ok())
throw Exception{"Error while reading " + format_name + " data: " + read_status.ToString(),
ErrorCodes::CANNOT_READ_ALL_DATA};
@ -270,10 +270,10 @@ namespace DB
++row_group_current;
NameToColumnPtr name_to_column_ptr;
for (size_t i = 0, num_columns = static_cast<size_t>(table->num_columns()); i < num_columns; ++i)
for (const auto& column_name : table->ColumnNames())
{
std::shared_ptr<arrow::Column> arrow_column = table->column(i);
name_to_column_ptr[arrow_column->name()] = arrow_column;
std::shared_ptr<arrow::ChunkedArray> arrow_column = table->GetColumnByName(column_name);
name_to_column_ptr[column_name] = arrow_column;
}
for (size_t column_i = 0, columns = header.columns(); column_i < columns; ++column_i)
@ -285,7 +285,7 @@ namespace DB
throw Exception{"Column \"" + header_column.name + "\" is not presented in input data",
ErrorCodes::THERE_IS_NO_COLUMN};
std::shared_ptr<arrow::Column> arrow_column = name_to_column_ptr[header_column.name];
std::shared_ptr<arrow::ChunkedArray> arrow_column = name_to_column_ptr[header_column.name];
arrow::Type::type arrow_type = arrow_column->type()->id();
// TODO: check if a column is const?
@ -313,7 +313,7 @@ namespace DB
}
else
{
throw Exception{"The type \"" + arrow_column->type()->name() + "\" of an input column \"" + arrow_column->name()
throw Exception{"The type \"" + arrow_column->type()->name() + "\" of an input column \"" + header_column.name
+ "\" is not supported for conversion from a " + format_name + " data format",
ErrorCodes::CANNOT_CONVERT_TYPE};
}
@ -374,7 +374,7 @@ namespace DB
throw Exception
{
"Unsupported " + format_name + " type \"" + arrow_column->type()->name() + "\" of an input column \""
+ arrow_column->name() + "\"",
+ header_column.name + "\"",
ErrorCodes::UNKNOWN_TYPE
};
}

View File

@ -45,9 +45,11 @@ namespace DB
buffer = std::make_unique<arrow::Buffer>(file_data);
// TODO: maybe use parquet::RandomAccessSource?
auto reader = parquet::ParquetFileReader::Open(std::make_shared<::arrow::io::BufferReader>(*buffer));
file_reader = std::make_unique<parquet::arrow::FileReader>(::arrow::default_memory_pool(),
std::move(reader));
auto status = parquet::arrow::FileReader::Make(
::arrow::default_memory_pool(),
parquet::ParquetFileReader::Open(std::make_shared<::arrow::io::BufferReader>(*buffer)),
&file_reader);
row_group_total = file_reader->num_row_groups();
row_group_current = 0;
}

View File

@ -21,9 +21,10 @@
#include <arrow/api.h>
#include <arrow/io/api.h>
#include <arrow/util/decimal.h>
#include <arrow/util/memory.h>
#include <parquet/arrow/writer.h>
#include <parquet/exception.h>
#include <parquet/util/memory.h>
#include <parquet/deprecated_io.h>
namespace DB
@ -238,22 +239,39 @@ static const PaddedPODArray<UInt8> * extractNullBytemapPtr(ColumnPtr column)
}
class OstreamOutputStream : public parquet::OutputStream
class OstreamOutputStream : public arrow::io::OutputStream
{
public:
explicit OstreamOutputStream(WriteBuffer & ostr_) : ostr(ostr_) {}
virtual ~OstreamOutputStream() {}
virtual void Close() {}
virtual int64_t Tell() { return total_length; }
virtual void Write(const uint8_t * data, int64_t length)
explicit OstreamOutputStream(WriteBuffer & ostr_) : ostr(ostr_) { is_open = true; }
~OstreamOutputStream() override {}
// FileInterface
::arrow::Status Close() override
{
is_open = false;
return ::arrow::Status::OK();
}
::arrow::Status Tell(int64_t* position) const override
{
*position = total_length;
return ::arrow::Status::OK();
}
bool closed() const override { return !is_open; }
// Writable
::arrow::Status Write(const void* data, int64_t length) override
{
ostr.write(reinterpret_cast<const char *>(data), length);
total_length += length;
return ::arrow::Status::OK();
}
private:
WriteBuffer & ostr;
int64_t total_length = 0;
bool is_open = false;
PARQUET_DISALLOW_COPY_AND_ASSIGN(OstreamOutputStream);
};
@ -396,7 +414,6 @@ void ParquetBlockOutputFormat::consume(Chunk chunk)
arrow::default_memory_pool(),
sink,
props, /*parquet::default_writer_properties(),*/
parquet::arrow::default_arrow_writer_properties(),
&file_writer);
if (!status.ok())
throw Exception{"Error while opening a table: " + status.ToString(), ErrorCodes::UNKNOWN_EXCEPTION};

View File

@ -51,11 +51,12 @@
#include <algorithm>
#include <iomanip>
#include <optional>
#include <set>
#include <thread>
#include <typeinfo>
#include <typeindex>
#include <optional>
#include <unordered_set>
namespace ProfileEvents
@ -637,11 +638,22 @@ void MergeTreeData::setTTLExpressions(const ColumnsDescription::ColumnTTLs & new
else
{
auto new_ttl_entry = create_ttl_entry(ttl_element.children[0]);
new_ttl_entry.entry_ast = ttl_element_ptr;
new_ttl_entry.destination_type = ttl_element.destination_type;
new_ttl_entry.destination_name = ttl_element.destination_name;
if (!new_ttl_entry.getDestination(getStoragePolicy()))
{
String message;
if (new_ttl_entry.destination_type == PartDestinationType::DISK)
message = "No such disk " + backQuote(new_ttl_entry.destination_name) + " for given storage policy.";
else
message = "No such volume " + backQuote(new_ttl_entry.destination_name) + " for given storage policy.";
throw Exception(message, ErrorCodes::BAD_TTL_EXPRESSION);
}
if (!only_check)
{
new_ttl_entry.entry_ast = ttl_element_ptr;
new_ttl_entry.destination_type = ttl_element.destination_type;
new_ttl_entry.destination_name = ttl_element.destination_name;
move_ttl_entries.emplace_back(std::move(new_ttl_entry));
}
}
@ -791,6 +803,27 @@ void MergeTreeData::loadDataParts(bool skip_sanity_checks)
auto disks = storage_policy->getDisks();
if (getStoragePolicy()->getName() != "default")
{
/// Check extra parts at different disks, in order to not allow to miss data parts at undefined disks.
std::unordered_set<String> defined_disk_names;
for (const auto & disk_ptr : disks)
defined_disk_names.insert(disk_ptr->getName());
for (auto & [disk_name, disk_ptr] : global_context.getDiskSelector().getDisksMap())
{
if (defined_disk_names.count(disk_name) == 0 && Poco::File(getFullPathOnDisk(disk_ptr)).exists())
{
for (Poco::DirectoryIterator it(getFullPathOnDisk(disk_ptr)); it != end; ++it)
{
MergeTreePartInfo part_info;
if (MergeTreePartInfo::tryParsePartName(it.name(), &part_info, format_version))
throw Exception("Part " + backQuote(it.name()) + " was found on disk " + backQuote(disk_name) + " which is not defined in the storage policy", ErrorCodes::UNKNOWN_DISK);
}
}
}
}
/// Reversed order to load part from low priority disks firstly.
/// Used for keep part on low priority disk if duplication found
for (auto disk_it = disks.rbegin(); disk_it != disks.rend(); ++disk_it)

View File

@ -70,7 +70,7 @@ void StorageJoin::truncate(const ASTPtr &, const Context &, TableStructureWriteL
HashJoinPtr StorageJoin::getJoin(std::shared_ptr<AnalyzedJoin> analyzed_join) const
{
if (!(kind == analyzed_join->kind() && strictness == analyzed_join->strictness()))
if (kind != analyzed_join->kind() || strictness != analyzed_join->strictness())
throw Exception("Table " + table_name + " has incompatible type of JOIN.", ErrorCodes::INCOMPATIBLE_TYPE_OF_JOIN);
/// TODO: check key columns
@ -96,58 +96,14 @@ void registerStorageJoin(StorageFactory & factory)
ASTs & engine_args = args.engine_args;
if (engine_args.size() < 3)
throw Exception(
"Storage Join requires at least 3 parameters: Join(ANY|ALL, LEFT|INNER, keys...).",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
auto opt_strictness_id = tryGetIdentifierName(engine_args[0]);
if (!opt_strictness_id)
throw Exception("First parameter of storage Join must be ANY or ALL (without quotes).", ErrorCodes::BAD_ARGUMENTS);
const String strictness_str = Poco::toLower(*opt_strictness_id);
ASTTableJoin::Strictness strictness;
if (strictness_str == "any")
strictness = ASTTableJoin::Strictness::RightAny;
else if (strictness_str == "all")
strictness = ASTTableJoin::Strictness::All;
else
throw Exception("First parameter of storage Join must be ANY or ALL (without quotes).", ErrorCodes::BAD_ARGUMENTS);
auto opt_kind_id = tryGetIdentifierName(engine_args[1]);
if (!opt_kind_id)
throw Exception("Second parameter of storage Join must be LEFT or INNER (without quotes).", ErrorCodes::BAD_ARGUMENTS);
const String kind_str = Poco::toLower(*opt_kind_id);
ASTTableJoin::Kind kind;
if (kind_str == "left")
kind = ASTTableJoin::Kind::Left;
else if (kind_str == "inner")
kind = ASTTableJoin::Kind::Inner;
else if (kind_str == "right")
kind = ASTTableJoin::Kind::Right;
else if (kind_str == "full")
kind = ASTTableJoin::Kind::Full;
else
throw Exception("Second parameter of storage Join must be LEFT or INNER or RIGHT or FULL (without quotes).", ErrorCodes::BAD_ARGUMENTS);
Names key_names;
key_names.reserve(engine_args.size() - 2);
for (size_t i = 2, size = engine_args.size(); i < size; ++i)
{
auto opt_key = tryGetIdentifierName(engine_args[i]);
if (!opt_key)
throw Exception("Parameter №" + toString(i + 1) + " of storage Join don't look like column name.", ErrorCodes::BAD_ARGUMENTS);
key_names.push_back(*opt_key);
}
auto & settings = args.context.getSettingsRef();
auto join_use_nulls = settings.join_use_nulls;
auto max_rows_in_join = settings.max_rows_in_join;
auto max_bytes_in_join = settings.max_bytes_in_join;
auto join_overflow_mode = settings.join_overflow_mode;
auto join_any_take_last_row = settings.join_any_take_last_row;
auto old_any_join = settings.any_join_distinct_right_table_keys;
if (args.storage_def && args.storage_def->settings)
{
@ -163,6 +119,8 @@ void registerStorageJoin(StorageFactory & factory)
join_overflow_mode.set(setting.value);
else if (setting.name == "join_any_take_last_row")
join_any_take_last_row.set(setting.value);
else if (setting.name == "any_join_distinct_right_table_keys")
old_any_join.set(setting.value);
else
throw Exception(
"Unknown setting " + setting.name + " for storage " + args.engine_name,
@ -170,6 +128,68 @@ void registerStorageJoin(StorageFactory & factory)
}
}
if (engine_args.size() < 3)
throw Exception(
"Storage Join requires at least 3 parameters: Join(ANY|ALL|SEMI|ANTI, LEFT|INNER|RIGHT, keys...).",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
ASTTableJoin::Strictness strictness = ASTTableJoin::Strictness::Unspecified;
ASTTableJoin::Kind kind = ASTTableJoin::Kind::Comma;
if (auto opt_strictness_id = tryGetIdentifierName(engine_args[0]))
{
const String strictness_str = Poco::toLower(*opt_strictness_id);
if (strictness_str == "any" || strictness_str == "\'any\'")
{
if (old_any_join)
strictness = ASTTableJoin::Strictness::RightAny;
else
strictness = ASTTableJoin::Strictness::Any;
}
else if (strictness_str == "all" || strictness_str == "\'all\'")
strictness = ASTTableJoin::Strictness::All;
else if (strictness_str == "semi" || strictness_str == "\'semi\'")
strictness = ASTTableJoin::Strictness::Semi;
else if (strictness_str == "anti" || strictness_str == "\'anti\'")
strictness = ASTTableJoin::Strictness::Anti;
}
if (strictness == ASTTableJoin::Strictness::Unspecified)
throw Exception("First parameter of storage Join must be ANY or ALL or SEMI or ANTI.", ErrorCodes::BAD_ARGUMENTS);
if (auto opt_kind_id = tryGetIdentifierName(engine_args[1]))
{
const String kind_str = Poco::toLower(*opt_kind_id);
if (kind_str == "left" || kind_str == "\'left\'")
kind = ASTTableJoin::Kind::Left;
else if (kind_str == "inner" || kind_str == "\'inner\'")
kind = ASTTableJoin::Kind::Inner;
else if (kind_str == "right" || kind_str == "\'right\'")
kind = ASTTableJoin::Kind::Right;
else if (kind_str == "full" || kind_str == "\'full\'")
{
if (strictness == ASTTableJoin::Strictness::Any)
strictness = ASTTableJoin::Strictness::RightAny;
kind = ASTTableJoin::Kind::Full;
}
}
if (kind == ASTTableJoin::Kind::Comma)
throw Exception("Second parameter of storage Join must be LEFT or INNER or RIGHT or FULL.", ErrorCodes::BAD_ARGUMENTS);
Names key_names;
key_names.reserve(engine_args.size() - 2);
for (size_t i = 2, size = engine_args.size(); i < size; ++i)
{
auto opt_key = tryGetIdentifierName(engine_args[i]);
if (!opt_key)
throw Exception("Parameter №" + toString(i + 1) + " of storage Join don't look like column name.", ErrorCodes::BAD_ARGUMENTS);
key_names.push_back(*opt_key);
}
return StorageJoin::create(
args.relative_data_path,
args.database_name,
@ -246,8 +266,8 @@ protected:
Block block;
if (!joinDispatch(parent.kind, parent.strictness, parent.data->maps,
[&](auto, auto strictness, auto & map) { block = createBlock<strictness>(map); }))
throw Exception("Logical error: unknown JOIN strictness (must be ANY or ALL)", ErrorCodes::LOGICAL_ERROR);
[&](auto kind, auto strictness, auto & map) { block = createBlock<kind, strictness>(map); }))
throw Exception("Logical error: unknown JOIN strictness", ErrorCodes::LOGICAL_ERROR);
return block;
}
@ -265,7 +285,7 @@ private:
std::unique_ptr<void, std::function<void(void *)>> position; /// type erasure
template <ASTTableJoin::Strictness STRICTNESS, typename Maps>
template <ASTTableJoin::Kind KIND, ASTTableJoin::Strictness STRICTNESS, typename Maps>
Block createBlock(const Maps & maps)
{
for (size_t i = 0; i < sample_block.columns(); ++i)
@ -292,7 +312,7 @@ private:
{
#define M(TYPE) \
case Join::Type::TYPE: \
rows_added = fillColumns<STRICTNESS>(*maps.TYPE); \
rows_added = fillColumns<KIND, STRICTNESS>(*maps.TYPE); \
break;
APPLY_FOR_JOIN_VARIANTS_LIMITED(M)
#undef M
@ -323,8 +343,7 @@ private:
return res;
}
template <ASTTableJoin::Strictness STRICTNESS, typename Map>
template <ASTTableJoin::Kind KIND, ASTTableJoin::Strictness STRICTNESS, typename Map>
size_t fillColumns(const Map & map)
{
size_t rows_added = 0;
@ -341,34 +360,35 @@ private:
{
if constexpr (STRICTNESS == ASTTableJoin::Strictness::RightAny)
{
for (size_t j = 0; j < columns.size(); ++j)
if (j == key_pos)
columns[j]->insertData(rawData(it->getKey()), rawSize(it->getKey()));
else
columns[j]->insertFrom(*it->getMapped().block->getByPosition(column_indices[j]).column.get(), it->getMapped().row_num);
++rows_added;
fillOne<Map>(columns, column_indices, it, key_pos, rows_added);
}
else if constexpr (STRICTNESS == ASTTableJoin::Strictness::All)
{
fillAll<Map>(columns, column_indices, it, key_pos, rows_added);
}
else if constexpr (STRICTNESS == ASTTableJoin::Strictness::Any)
{
throw Exception("New ANY join storage is not implemented yet (set any_join_distinct_right_table_keys=1 to use old one)",
ErrorCodes::NOT_IMPLEMENTED);
if constexpr (KIND == ASTTableJoin::Kind::Left || KIND == ASTTableJoin::Kind::Inner)
fillOne<Map>(columns, column_indices, it, key_pos, rows_added);
else if constexpr (KIND == ASTTableJoin::Kind::Right)
fillAll<Map>(columns, column_indices, it, key_pos, rows_added);
}
else if constexpr (STRICTNESS == ASTTableJoin::Strictness::Asof ||
STRICTNESS == ASTTableJoin::Strictness::Semi ||
STRICTNESS == ASTTableJoin::Strictness::Anti)
else if constexpr (STRICTNESS == ASTTableJoin::Strictness::Semi)
{
throw Exception("ASOF|SEMI|ANTI join storage is not implemented yet", ErrorCodes::NOT_IMPLEMENTED);
if constexpr (KIND == ASTTableJoin::Kind::Left)
fillOne<Map>(columns, column_indices, it, key_pos, rows_added);
else if constexpr (KIND == ASTTableJoin::Kind::Right)
fillAll<Map>(columns, column_indices, it, key_pos, rows_added);
}
else if constexpr (STRICTNESS == ASTTableJoin::Strictness::Anti)
{
if constexpr (KIND == ASTTableJoin::Kind::Left)
fillOne<Map>(columns, column_indices, it, key_pos, rows_added);
else if constexpr (KIND == ASTTableJoin::Kind::Right)
fillAll<Map>(columns, column_indices, it, key_pos, rows_added);
}
else
for (auto ref_it = it->getMapped().begin(); ref_it.ok(); ++ref_it)
{
for (size_t j = 0; j < columns.size(); ++j)
if (j == key_pos)
columns[j]->insertData(rawData(it->getKey()), rawSize(it->getKey()));
else
columns[j]->insertFrom(*ref_it->block->getByPosition(column_indices[j]).column.get(), ref_it->row_num);
++rows_added;
}
throw Exception("This JOIN is not implemented yet", ErrorCodes::NOT_IMPLEMENTED);
if (rows_added >= max_block_size)
{
@ -379,6 +399,33 @@ private:
return rows_added;
}
template <typename Map>
static void fillOne(MutableColumns & columns, const ColumnNumbers & column_indices, typename Map::const_iterator & it,
const std::optional<size_t> & key_pos, size_t & rows_added)
{
for (size_t j = 0; j < columns.size(); ++j)
if (j == key_pos)
columns[j]->insertData(rawData(it->getKey()), rawSize(it->getKey()));
else
columns[j]->insertFrom(*it->getMapped().block->getByPosition(column_indices[j]).column.get(), it->getMapped().row_num);
++rows_added;
}
template <typename Map>
static void fillAll(MutableColumns & columns, const ColumnNumbers & column_indices, typename Map::const_iterator & it,
const std::optional<size_t> & key_pos, size_t & rows_added)
{
for (auto ref_it = it->getMapped().begin(); ref_it.ok(); ++ref_it)
{
for (size_t j = 0; j < columns.size(); ++j)
if (j == key_pos)
columns[j]->insertData(rawData(it->getKey()), rawSize(it->getKey()));
else
columns[j]->insertFrom(*ref_it->block->getByPosition(column_indices[j]).column.get(), ref_it->row_num);
++rows_added;
}
}
};

View File

@ -196,7 +196,7 @@ BlockInputStreams IStorageURLBase::read(const Names & column_names,
context,
max_block_size,
ConnectionTimeouts::getHTTPTimeouts(context),
IStorage::chooseCompressionMethod(request_uri.toString(), compression_method));
IStorage::chooseCompressionMethod(request_uri.getPath(), compression_method));
auto column_defaults = getColumns().getDefaults();
if (column_defaults.empty())

View File

@ -0,0 +1,59 @@
#include <Storages/System/StorageSystemRowPolicies.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeUUID.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeNullable.h>
#include <Interpreters/Context.h>
#include <Access/AccessControlManager.h>
#include <Access/RowPolicy.h>
#include <ext/range.h>
namespace DB
{
NamesAndTypesList StorageSystemRowPolicies::getNamesAndTypes()
{
NamesAndTypesList names_and_types{
{"database", std::make_shared<DataTypeString>()},
{"table", std::make_shared<DataTypeString>()},
{"name", std::make_shared<DataTypeString>()},
{"full_name", std::make_shared<DataTypeString>()},
{"id", std::make_shared<DataTypeUUID>()},
{"source", std::make_shared<DataTypeString>()},
{"restrictive", std::make_shared<DataTypeUInt8>()},
};
for (auto index : ext::range_with_static_cast<RowPolicy::ConditionIndex>(RowPolicy::MAX_CONDITION_INDEX))
names_and_types.push_back({RowPolicy::conditionIndexToColumnName(index), std::make_shared<DataTypeString>()});
return names_and_types;
}
void StorageSystemRowPolicies::fillData(MutableColumns & res_columns, const Context & context, const SelectQueryInfo &) const
{
const auto & access_control = context.getAccessControlManager();
std::vector<UUID> ids = access_control.findAll<RowPolicy>();
for (const auto & id : ids)
{
auto policy = access_control.tryRead<RowPolicy>(id);
if (!policy)
continue;
const auto * storage = access_control.findStorage(id);
size_t i = 0;
res_columns[i++]->insert(policy->getDatabase());
res_columns[i++]->insert(policy->getTableName());
res_columns[i++]->insert(policy->getName());
res_columns[i++]->insert(policy->getFullName());
res_columns[i++]->insert(id);
res_columns[i++]->insert(storage ? storage->getStorageName() : "");
res_columns[i++]->insert(policy->isRestrictive());
for (auto index : ext::range(RowPolicy::MAX_CONDITION_INDEX))
res_columns[i++]->insert(policy->conditions[index]);
}
}
}

Some files were not shown because too many files have changed in this diff Show More