Merge pull request #24468 from kitaisreal/compile-expressions-cached-functions-with-context-fix

CompileExpression cached functions with context fix
This commit is contained in:
Maksim Kita 2021-05-26 10:52:22 +03:00 committed by GitHub
commit d56dc82784
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 72 additions and 50 deletions

View File

@ -463,12 +463,6 @@ struct ContextSharedPart
dictionaries_xmls.reset();
delete_system_logs = std::move(system_logs);
#if USE_EMBEDDED_COMPILER
if (auto * cache = CompiledExpressionCacheFactory::instance().tryGetCache())
cache->reset();
#endif
embedded_dictionaries.reset();
external_dictionaries_loader.reset();
models_repository_guard.reset();

View File

@ -28,7 +28,6 @@ namespace DB
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int NOT_IMPLEMENTED;
}
static CHJIT & getJITInstance()
@ -43,13 +42,36 @@ static Poco::Logger * getLogger()
return &logger;
}
class CompiledFunction
{
public:
CompiledFunction(void * compiled_function_, CHJIT::CompiledModuleInfo module_info_)
: compiled_function(compiled_function_)
, module_info(std::move(module_info_))
{}
void * getCompiledFunction() const { return compiled_function; }
~CompiledFunction()
{
getJITInstance().deleteCompiledModule(module_info);
}
private:
void * compiled_function;
CHJIT::CompiledModuleInfo module_info;
};
class LLVMExecutableFunction : public IExecutableFunction
{
public:
explicit LLVMExecutableFunction(const std::string & name_, JITCompiledFunction function_)
explicit LLVMExecutableFunction(const std::string & name_, std::shared_ptr<CompiledFunction> compiled_function_)
: name(name_)
, function(function_)
, compiled_function(compiled_function_)
{
}
@ -81,7 +103,9 @@ public:
}
columns[arguments.size()] = getColumnData(result_column.get());
function(input_rows_count, columns.data());
JITCompiledFunction jit_compiled_function_typed = reinterpret_cast<JITCompiledFunction>(compiled_function->getCompiledFunction());
jit_compiled_function_typed(input_rows_count, columns.data());
#if defined(MEMORY_SANITIZER)
/// Memory sanitizer don't know about stores from JIT-ed code.
@ -111,7 +135,7 @@ public:
private:
std::string name;
JITCompiledFunction function = nullptr;
std::shared_ptr<CompiledFunction> compiled_function;
};
class LLVMFunction : public IFunctionBase
@ -131,17 +155,13 @@ public:
else if (node.type == CompileDAG::CompileType::INPUT)
argument_types.emplace_back(node.result_type);
}
module_info = compileFunction(getJITInstance(), *this);
}
~LLVMFunction() override
void setCompiledFunction(std::shared_ptr<CompiledFunction> compiled_function_)
{
getJITInstance().deleteCompiledModule(module_info);
compiled_function = compiled_function_;
}
size_t getCompiledSize() const { return module_info.size; }
bool isCompilable() const override { return true; }
llvm::Value * compile(llvm::IRBuilderBase & builder, Values values) const override
@ -157,13 +177,10 @@ public:
ExecutableFunctionPtr prepare(const ColumnsWithTypeAndName &) const override
{
void * function = getJITInstance().findCompiledFunction(module_info, name);
if (!compiled_function)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Compiled function was not initialized {}", name);
if (!function)
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Cannot find compiled function {}", name);
JITCompiledFunction function_typed = reinterpret_cast<JITCompiledFunction>(function);
return std::make_unique<LLVMExecutableFunction>(name, function_typed);
return std::make_unique<LLVMExecutableFunction>(name, compiled_function);
}
bool isDeterministic() const override
@ -252,7 +269,7 @@ private:
CompileDAG dag;
DataTypes argument_types;
std::vector<FunctionBasePtr> nested_functions;
CHJIT::CompiledModuleInfo module_info;
std::shared_ptr<CompiledFunction> compiled_function;
};
static FunctionBasePtr compile(
@ -271,43 +288,42 @@ static FunctionBasePtr compile(
LOG_TRACE(getLogger(), "Try to compile expression {}", dag.dump());
FunctionBasePtr fn;
auto llvm_function = std::make_shared<LLVMFunction>(dag);
if (auto * compilation_cache = CompiledExpressionCacheFactory::instance().tryGetCache())
{
auto [compiled_function, was_inserted] = compilation_cache->getOrSet(hash_key, [&dag] ()
auto [compiled_function_cache_entry, was_inserted] = compilation_cache->getOrSet(hash_key, [&] ()
{
auto llvm_function = std::make_unique<LLVMFunction>(dag);
size_t compiled_size = llvm_function->getCompiledSize();
CHJIT::CompiledModuleInfo compiled_module_info = compileFunction(getJITInstance(), *llvm_function);
auto * compiled_jit_function = getJITInstance().findCompiledFunction(compiled_module_info, llvm_function->getName());
auto compiled_function = std::make_shared<CompiledFunction>(compiled_jit_function, compiled_module_info);
CompiledFunction function
{
.function = std::move(llvm_function),
.compiled_size = compiled_size
};
return std::make_shared<CompiledFunction>(function);
return std::make_shared<CompiledFunctionCacheEntry>(std::move(compiled_function), compiled_module_info.size);
});
if (was_inserted)
LOG_TRACE(getLogger(),
"Put compiled expression {} in cache used cache size {} total cache size {}",
compiled_function->function->getName(),
llvm_function->getName(),
compilation_cache->weight(),
compilation_cache->maxSize());
else
LOG_TRACE(getLogger(), "Get compiled expression {} from cache", compiled_function->function->getName());
LOG_TRACE(getLogger(), "Get compiled expression {} from cache", llvm_function->getName());
fn = compiled_function->function;
llvm_function->setCompiledFunction(compiled_function_cache_entry->getCompiledFunction());
}
else
{
fn = std::make_unique<LLVMFunction>(dag);
CHJIT::CompiledModuleInfo compiled_module_info = compileFunction(getJITInstance(), *llvm_function);
auto * compiled_jit_function = getJITInstance().findCompiledFunction(compiled_module_info, llvm_function->getName());
auto compiled_function = std::make_shared<CompiledFunction>(compiled_jit_function, compiled_module_info);
llvm_function->setCompiledFunction(compiled_function);
}
LOG_TRACE(getLogger(), "Use compiled expression {}", fn->getName());
LOG_TRACE(getLogger(), "Use compiled expression {}", llvm_function->getName());
return fn;
return llvm_function;
}
static bool isCompilableConstant(const ActionsDAG::Node & node)

View File

@ -5,35 +5,47 @@
#endif
#if USE_EMBEDDED_COMPILER
# include <Functions/IFunction.h>
# include <Common/LRUCache.h>
# include <Common/HashTable/Hash.h>
namespace DB
{
struct CompiledFunction
class CompiledFunction;
class CompiledFunctionCacheEntry
{
FunctionBasePtr function;
size_t compiled_size;
public:
CompiledFunctionCacheEntry(std::shared_ptr<CompiledFunction> compiled_function_, size_t compiled_function_size_)
: compiled_function(std::move(compiled_function_))
, compiled_function_size(compiled_function_size_)
{}
std::shared_ptr<CompiledFunction> getCompiledFunction() const { return compiled_function; }
size_t getCompiledFunctionSize() const { return compiled_function_size; }
private:
std::shared_ptr<CompiledFunction> compiled_function;
size_t compiled_function_size;
};
struct CompiledFunctionWeightFunction
{
size_t operator()(const CompiledFunction & compiled_function) const
size_t operator()(const CompiledFunctionCacheEntry & compiled_function) const
{
return compiled_function.compiled_size;
return compiled_function.getCompiledFunctionSize();
}
};
/** This child of LRUCache breaks one of it's invariants: total weight may be changed after insertion.
* We have to do so, because we don't known real memory consumption of generated LLVM code for every function.
*/
class CompiledExpressionCache : public LRUCache<UInt128, CompiledFunction, UInt128Hash, CompiledFunctionWeightFunction>
class CompiledExpressionCache : public LRUCache<UInt128, CompiledFunctionCacheEntry, UInt128Hash, CompiledFunctionWeightFunction>
{
public:
using Base = LRUCache<UInt128, CompiledFunction, UInt128Hash, CompiledFunctionWeightFunction>;
using Base = LRUCache<UInt128, CompiledFunctionCacheEntry, UInt128Hash, CompiledFunctionWeightFunction>;
using Base::Base;
};