Add statistics cmsketch

This commit is contained in:
JackyWoo 2024-06-21 16:29:25 +08:00
parent a829444ff4
commit 15e20f56fa
13 changed files with 290 additions and 30 deletions

View File

@ -9,6 +9,7 @@ set(DATASKETCHES_LIBRARY theta)
add_library(_datasketches INTERFACE)
target_include_directories(_datasketches SYSTEM BEFORE INTERFACE
"${ClickHouse_SOURCE_DIR}/contrib/datasketches-cpp/common/include"
"${ClickHouse_SOURCE_DIR}/contrib/datasketches-cpp/count/include"
"${ClickHouse_SOURCE_DIR}/contrib/datasketches-cpp/theta/include")
add_library(ch_contrib::datasketches ALIAS _datasketches)

View File

@ -25,7 +25,7 @@ Also, they are replicated, syncing statistics metadata via ZooKeeper.
There is an example adding two statistics types to two columns:
```
ALTER TABLE t1 MODIFY STATISTICS c, d TYPE TDigest, Uniq;
ALTER TABLE t1 MODIFY STATISTICS c, d TYPE TDigest, Uniq, CMSketch;
```
:::note

View File

@ -546,6 +546,7 @@ endif()
if (TARGET ch_contrib::datasketches)
target_link_libraries (clickhouse_aggregate_functions PRIVATE ch_contrib::datasketches)
dbms_target_link_libraries(PRIVATE ch_contrib::datasketches)
endif ()
target_link_libraries (clickhouse_common_io PRIVATE ch_contrib::lz4)

View File

@ -0,0 +1,82 @@
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Storages/Statistics/CMSketchStatistics.h>
#if USE_DATASKETCHES
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_STATISTICS;
}
CMSketchStatistics::CMSketchStatistics(const SingleStatisticsDescription & stat_, DataTypePtr data_type_)
: IStatistics(stat_), data(CMSKETCH_HASH_COUNT, CMSKETCH_BUCKET_COUNT), data_type(data_type_)
{
}
Float64 CMSketchStatistics::estimateEqual(const Field & value) const
{
if (auto float_val = getFloat64(value))
return data.get_estimate(&float_val.value(), 8);
if (auto string_val = getString(value))
return data.get_estimate(string_val->data(), string_val->size());
UNREACHABLE();
}
void CMSketchStatistics::serialize(WriteBuffer & buf)
{
auto bytes = data.serialize();
writeIntBinary(static_cast<UInt64>(bytes.size()), buf);
buf.write(reinterpret_cast<const char *>(bytes.data()), bytes.size());
}
void CMSketchStatistics::deserialize(ReadBuffer & buf)
{
UInt64 size;
readIntBinary(size, buf);
String s;
s.reserve(size);
buf.readStrict(s.data(), size); /// Extra copy can be avoided by implementing count_min_sketch<Float64>::deserialize with ReadBuffer
auto read_sketch = datasketches::count_min_sketch<Float64>::deserialize(s.data(), size, datasketches::DEFAULT_SEED);
data.merge(read_sketch);
}
void CMSketchStatistics::update(const ColumnPtr & column)
{
size_t size = column->size();
for (size_t i = 0; i < size; ++i)
{
Field f;
column->get(i, f);
if (f.isNull())
continue;
if (auto float_val = getFloat64(f))
data.update(&float_val.value(), 8, 1.0);
else if (auto string_val = getString(f))
data.update(*string_val, 1.0);
}
}
void CMSketchValidator(const SingleStatisticsDescription &, DataTypePtr data_type)
{
data_type = removeNullable(data_type);
data_type = removeLowCardinalityAndNullable(data_type);
if (!data_type->isValueRepresentedByNumber() && !isStringOrFixedString(data_type))
throw Exception(ErrorCodes::ILLEGAL_STATISTICS, "Statistics of type 'cmsketch' does not support type {}", data_type->getName());
}
StatisticsPtr CMSketchCreator(const SingleStatisticsDescription & stat, DataTypePtr data_type)
{
return std::make_shared<CMSketchStatistics>(stat, data_type);
}
}
#endif

View File

@ -0,0 +1,39 @@
#pragma once
#if USE_DATASKETCHES
#include <Storages/Statistics/Statistics.h>
#include <count_min.hpp>
#include <Common/Allocator.h>
namespace DB
{
/// CMSketchStatistics is used to estimate expression like col = 'value' or col in ('v1', 'v2').
class CMSketchStatistics : public IStatistics
{
public:
explicit CMSketchStatistics(const SingleStatisticsDescription & stat_, DataTypePtr data_type_);
Float64 estimateEqual(const Field & value) const;
void serialize(WriteBuffer & buf) override;
void deserialize(ReadBuffer & buf) override;
void update(const ColumnPtr & column) override;
private:
static constexpr size_t CMSKETCH_HASH_COUNT = 8;
static constexpr size_t CMSKETCH_BUCKET_COUNT = 2048;
datasketches::count_min_sketch<Float64> data;
DataTypePtr data_type;
};
StatisticsPtr CMSketchCreator(const SingleStatisticsDescription & stat, DataTypePtr);
void CMSketchValidator(const SingleStatisticsDescription &, DataTypePtr data_type);
}
#endif

View File

@ -35,11 +35,14 @@ Float64 ConditionSelectivityEstimator::ColumnSelectivityEstimator::estimateGreat
return rows - estimateLess(val, rows);
}
Float64 ConditionSelectivityEstimator::ColumnSelectivityEstimator::estimateEqual(Float64 val, Float64 rows) const
Float64 ConditionSelectivityEstimator::ColumnSelectivityEstimator::estimateEqual(Field val, Float64 rows) const
{
auto float_val = getFloat64(val);
if (part_statistics.empty())
{
if (val < - threshold || val > threshold)
if (!float_val)
return default_unknown_cond_factor * rows;
else if (float_val.value() < - threshold || float_val.value() > threshold)
return default_normal_cond_factor * rows;
else
return default_good_cond_factor * rows;
@ -87,7 +90,7 @@ static std::pair<String, Int32> tryToExtractSingleColumn(const RPNBuilderTreeNod
return result;
}
std::pair<String, Float64> ConditionSelectivityEstimator::extractBinaryOp(const RPNBuilderTreeNode & node, const String & column_name) const
std::pair<String, Field> ConditionSelectivityEstimator::extractBinaryOp(const RPNBuilderTreeNode & node, const String & column_name) const
{
if (!node.isFunction())
return {};
@ -123,18 +126,7 @@ std::pair<String, Float64> ConditionSelectivityEstimator::extractBinaryOp(const
DataTypePtr output_type;
if (!constant_node->tryGetConstant(output_value, output_type))
return {};
const auto type = output_value.getType();
Float64 value;
if (type == Field::Types::Int64)
value = output_value.get<Int64>();
else if (type == Field::Types::UInt64)
value = output_value.get<UInt64>();
else if (type == Field::Types::Float64)
value = output_value.get<Float64>();
else
return {};
return std::make_pair(function_name, value);
return std::make_pair(function_name, output_value);
}
Float64 ConditionSelectivityEstimator::estimateRowCount(const RPNBuilderTreeNode & node) const
@ -142,7 +134,7 @@ Float64 ConditionSelectivityEstimator::estimateRowCount(const RPNBuilderTreeNode
auto result = tryToExtractSingleColumn(node);
if (result.second != 1)
{
return default_unknown_cond_factor;
return default_unknown_cond_factor * total_rows;
}
String col = result.first;
auto it = column_estimators.find(col);
@ -152,19 +144,16 @@ Float64 ConditionSelectivityEstimator::estimateRowCount(const RPNBuilderTreeNode
bool dummy = total_rows == 0;
ColumnSelectivityEstimator estimator;
if (it != column_estimators.end())
{
estimator = it->second;
}
else
{
dummy = true;
}
auto [op, val] = extractBinaryOp(node, col);
auto float_val = getFloat64(val);
if (op == "equals")
{
if (dummy)
{
if (val < - threshold || val > threshold)
if (!float_val || (float_val < - threshold || float_val > threshold))
return default_normal_cond_factor * total_rows;
else
return default_good_cond_factor * total_rows;
@ -175,13 +164,13 @@ Float64 ConditionSelectivityEstimator::estimateRowCount(const RPNBuilderTreeNode
{
if (dummy)
return default_normal_cond_factor * total_rows;
return estimator.estimateLess(val, total_rows);
return estimator.estimateLess(float_val.value(), total_rows);
}
else if (op == "greater" || op == "greaterOrEquals")
{
if (dummy)
return default_normal_cond_factor * total_rows;
return estimator.estimateGreater(val, total_rows);
return estimator.estimateGreater(float_val.value(), total_rows);
}
else
return default_unknown_cond_factor * total_rows;

View File

@ -1,6 +1,7 @@
#pragma once
#include <Storages/Statistics/Statistics.h>
#include <Core/Field.h>
namespace DB
{
@ -24,7 +25,7 @@ private:
Float64 estimateGreater(Float64 val, Float64 rows) const;
Float64 estimateEqual(Float64 val, Float64 rows) const;
Float64 estimateEqual(Field val, Float64 rows) const;
};
static constexpr auto default_good_cond_factor = 0.1;
@ -37,7 +38,7 @@ private:
UInt64 total_rows = 0;
std::set<String> part_names;
std::map<String, ColumnSelectivityEstimator> column_estimators;
std::pair<String, Float64> extractBinaryOp(const RPNBuilderTreeNode & node, const String & column_name) const;
std::pair<String, Field> extractBinaryOp(const RPNBuilderTreeNode & node, const String & column_name) const;
public:
/// TODO: Support the condition consists of CNF/DNF like (cond1 and cond2) or (cond3) ...

View File

@ -5,6 +5,7 @@
#include <Storages/Statistics/ConditionSelectivityEstimator.h>
#include <Storages/Statistics/TDigestStatistics.h>
#include <Storages/Statistics/UniqStatistics.h>
#include <Storages/Statistics/CMSketchStatistics.h>
#include <Storages/StatisticsDescription.h>
#include <Storages/ColumnsDescription.h>
#include <IO/ReadHelpers.h>
@ -26,6 +27,28 @@ enum StatisticsFileVersion : UInt16
V0 = 0,
};
std::optional<Float64> getFloat64(const Field & f)
{
const auto type = f.getType();
Float64 value;
if (type == Field::Types::Int64)
value = f.get<Int64>();
else if (type == Field::Types::UInt64)
value = f.get<UInt64>();
else if (type == Field::Types::Float64)
value = f.get<Float64>();
else
return {};
return value;
}
std::optional<String> getString(const Field & f)
{
if (f.getType() == Field::Types::String)
return f.get<String>();
return {};
}
IStatistics::IStatistics(const SingleStatisticsDescription & stat_) : stat(stat_) {}
ColumnStatistics::ColumnStatistics(const ColumnStatisticsDescription & stats_desc_)
@ -54,9 +77,10 @@ Float64 ColumnStatistics::estimateGreater(Float64 val) const
return rows - estimateLess(val);
}
Float64 ColumnStatistics::estimateEqual(Float64 val) const
Float64 ColumnStatistics::estimateEqual(Field val) const
{
if (stats.contains(StatisticsType::Uniq) && stats.contains(StatisticsType::TDigest))
auto float_val = getFloat64(val);
if (float_val && stats.contains(StatisticsType::Uniq) && stats.contains(StatisticsType::TDigest))
{
auto uniq_static = std::static_pointer_cast<UniqStatistics>(stats.at(StatisticsType::Uniq));
/// 2048 is the default number of buckets in TDigest. In this case, TDigest stores exactly one value (with many rows)
@ -64,9 +88,16 @@ Float64 ColumnStatistics::estimateEqual(Float64 val) const
if (uniq_static->getCardinality() < 2048)
{
auto tdigest_static = std::static_pointer_cast<TDigestStatistics>(stats.at(StatisticsType::TDigest));
return tdigest_static->estimateEqual(val);
return tdigest_static->estimateEqual(float_val.value());
}
}
#if USE_DATASKETCHES
if (stats.contains(StatisticsType::CMSketch))
{
auto cmsketch_static = std::static_pointer_cast<CMSketchStatistics>(stats.at(StatisticsType::CMSketch));
return cmsketch_static->estimateEqual(val);
}
#endif
if (val < - ConditionSelectivityEstimator::threshold || val > ConditionSelectivityEstimator::threshold)
return rows * ConditionSelectivityEstimator::default_normal_cond_factor;
else
@ -145,6 +176,10 @@ MergeTreeStatisticsFactory::MergeTreeStatisticsFactory()
registerCreator(StatisticsType::Uniq, UniqCreator);
registerValidator(StatisticsType::TDigest, TDigestValidator);
registerValidator(StatisticsType::Uniq, UniqValidator);
#if USE_DATASKETCHES
registerCreator(StatisticsType::CMSketch, CMSketchCreator);
registerValidator(StatisticsType::CMSketch, CMSketchValidator);
#endif
}
MergeTreeStatisticsFactory & MergeTreeStatisticsFactory::instance()

View File

@ -7,6 +7,7 @@
#include <Common/logger_useful.h>
#include <IO/ReadBuffer.h>
#include <IO/WriteBuffer.h>
#include <Core/Field.h>
#include <Storages/StatisticsDescription.h>
@ -58,7 +59,7 @@ public:
Float64 estimateGreater(Float64 val) const;
Float64 estimateEqual(Float64 val) const;
Float64 estimateEqual(Field val) const;
private:
@ -100,4 +101,6 @@ private:
Validators validators;
};
std::optional<Float64> getFloat64(const Field & f);
std::optional<String> getString(const Field & f);
}

View File

@ -54,6 +54,8 @@ static StatisticsType stringToStatisticsType(String type)
return StatisticsType::TDigest;
if (type == "uniq")
return StatisticsType::Uniq;
if (type == "cmsketch")
return StatisticsType::CMSketch;
throw Exception(ErrorCodes::INCORRECT_QUERY, "Unknown statistics type: {}. Supported statistics types are `tdigest` and `uniq`.", type);
}
@ -65,6 +67,8 @@ String SingleStatisticsDescription::getTypeName() const
return "TDigest";
case StatisticsType::Uniq:
return "Uniq";
case StatisticsType::CMSketch:
return "CMSketch";
default:
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unknown statistics type: {}. Supported statistics types are `tdigest` and `uniq`.", type);
}

View File

@ -13,6 +13,7 @@ enum class StatisticsType : UInt8
{
TDigest = 0,
Uniq = 1,
CMSketch = 2,
Max = 63,
};

View File

@ -0,0 +1,26 @@
CREATE TABLE default.t1\n(\n `a` String STATISTICS(cmsketch),\n `b` Int64 STATISTICS(cmsketch),\n `c` UInt64 STATISTICS(cmsketch),\n `pk` String\n)\nENGINE = MergeTree\nORDER BY pk\nSETTINGS min_bytes_for_wide_part = 0, index_granularity = 8192
Prewhere info
Prewhere filter
Prewhere filter column: and(equals(a, \'0\'_String), equals(b, 0), equals(c, 0)) (removed)
Prewhere info
Prewhere filter
Prewhere filter column: and(equals(a, \'0\'_String), equals(c, 0), greater(b, 0)) (removed)
After drop statistics for a
Prewhere info
Prewhere filter
Prewhere filter column: and(equals(b, 0), equals(c, 0), equals(a, \'0\'_String)) (removed)
Prewhere info
Prewhere filter
Prewhere filter column: and(equals(c, 0), equals(a, \'0\'_String), greater(b, 0)) (removed)
LowCardinality
Prewhere info
Prewhere filter
Prewhere filter column: and(equals(a, \'0\'_String), equals(b, 0), equals(c, 0)) (removed)
Nullable
Prewhere info
Prewhere filter
Prewhere filter column: and(equals(a, \'0\'_String), equals(b, 0), equals(c, 0)) (removed)
LowCardinality(Nullable)
Prewhere info
Prewhere filter
Prewhere filter column: and(equals(a, \'0\'_String), equals(b, 0), equals(c, 0)) (removed)

View File

@ -0,0 +1,78 @@
-- Tags: no-fasttest
DROP TABLE IF EXISTS t1;
SET allow_experimental_statistics = 1;
SET allow_statistics_optimize = 1;
CREATE TABLE t1
(
a String STATISTICS(cmsketch),
b Int64 STATISTICS(cmsketch),
c UInt64 STATISTICS(cmsketch),
pk String,
) Engine = MergeTree() ORDER BY pk
SETTINGS min_bytes_for_wide_part = 0;
SHOW CREATE TABLE t1;
INSERT INTO t1 select toString(number % 1000), number % 100, number % 10, generateUUIDv4() FROM system.numbers LIMIT 10000;
SELECT replaceRegexpAll(explain, '__table1.|_UInt8|_Int8', '') FROM (EXPLAIN actions=1 SELECT count(*) FROM t1 WHERE c = 0 and b = 0 and a = '0') WHERE explain LIKE '%Prewhere%' OR explain LIKE '%Filter column%';
SELECT replaceRegexpAll(explain, '__table1.|_UInt8|_Int8', '') FROM (EXPLAIN actions=1 SELECT count(*) FROM t1 WHERE c = 0 and b > 0 and a = '0') WHERE explain LIKE '%Prewhere%' OR explain LIKE '%Filter column%';
ALTER TABLE t1 DROP STATISTICS a;
SELECT 'After drop statistics for a';
SELECT replaceRegexpAll(explain, '__table1.|_UInt8|_Int8', '') FROM (EXPLAIN actions=1 SELECT count(*) FROM t1 WHERE c = 0 and b = 0 and a = '0') WHERE explain LIKE '%Prewhere%' OR explain LIKE '%Filter column%';
SELECT replaceRegexpAll(explain, '__table1.|_UInt8|_Int8', '') FROM (EXPLAIN actions=1 SELECT count(*) FROM t1 WHERE c = 0 and b > 0 and a = '0') WHERE explain LIKE '%Prewhere%' OR explain LIKE '%Filter column%';
DROP TABLE IF EXISTS t1;
DROP TABLE IF EXISTS t2;
SET allow_suspicious_low_cardinality_types=1;
CREATE TABLE t2
(
a LowCardinality(String) STATISTICS(cmsketch),
b Int64 STATISTICS(cmsketch),
c UInt64 STATISTICS(cmsketch),
pk String,
) Engine = MergeTree() ORDER BY pk
SETTINGS min_bytes_for_wide_part = 0;
INSERT INTO t2 select toString(number % 1000), number % 100, number % 10, generateUUIDv4() FROM system.numbers LIMIT 10000;
SELECT 'LowCardinality';
SELECT replaceRegexpAll(explain, '__table1.|_UInt8|_Int8', '') FROM (EXPLAIN actions=1 SELECT count(*) FROM t2 WHERE c = 0 and b = 0 and a = '0') WHERE explain LIKE '%Prewhere%' OR explain LIKE '%Filter column%';
DROP TABLE IF EXISTS t2;
DROP TABLE IF EXISTS t3;
CREATE TABLE t3
(
a Nullable(String) STATISTICS(cmsketch),
b Int64 STATISTICS(cmsketch),
c UInt64 STATISTICS(cmsketch),
pk String,
) Engine = MergeTree() ORDER BY pk
SETTINGS min_bytes_for_wide_part = 0;
INSERT INTO t3 select toString(number % 1000), number % 100, number % 10, generateUUIDv4() FROM system.numbers LIMIT 10000;
SELECT 'Nullable';
SELECT replaceRegexpAll(explain, '__table1.|_UInt8|_Int8', '') FROM (EXPLAIN actions=1 SELECT count(*) FROM t3 WHERE c = 0 and b = 0 and a = '0') WHERE explain LIKE '%Prewhere%' OR explain LIKE '%Filter column%';
DROP TABLE IF EXISTS t3;
DROP TABLE IF EXISTS t4;
CREATE TABLE t4
(
a LowCardinality(Nullable(String)) STATISTICS(cmsketch),
b Int64 STATISTICS(cmsketch),
c UInt64 STATISTICS(cmsketch),
pk String,
) Engine = MergeTree() ORDER BY pk
SETTINGS min_bytes_for_wide_part = 0;
INSERT INTO t4 select toString(number % 1000), number % 100, number % 10, generateUUIDv4() FROM system.numbers LIMIT 10000;
SELECT 'LowCardinality(Nullable)';
SELECT replaceRegexpAll(explain, '__table1.|_UInt8|_Int8', '') FROM (EXPLAIN actions=1 SELECT count(*) FROM t4 WHERE c = 0 and b = 0 and a = '0') WHERE explain LIKE '%Prewhere%' OR explain LIKE '%Filter column%';
DROP TABLE IF EXISTS t4;