From 844eb4ccdc87df7df95ce27229599e1227fcf609 Mon Sep 17 00:00:00 2001 From: Maksim Kita Date: Thu, 20 Jan 2022 11:16:18 +0000 Subject: [PATCH 1/2] RangeHashedDictionary handle invalid intervals --- src/Common/IntervalTree.h | 36 +++++++++------ src/Common/tests/gtest_interval_tree.cpp | 46 +++++++++++++++++++ src/Dictionaries/RangeHashedDictionary.cpp | 4 +- ...shed_dictionary_invalid_interval.reference | 5 ++ ...nge_hashed_dictionary_invalid_interval.sql | 36 +++++++++++++++ 5 files changed, 112 insertions(+), 15 deletions(-) create mode 100644 tests/queries/0_stateless/02179_range_hashed_dictionary_invalid_interval.reference create mode 100644 tests/queries/0_stateless/02179_range_hashed_dictionary_invalid_interval.sql diff --git a/src/Common/IntervalTree.h b/src/Common/IntervalTree.h index fd2fec528a4..dc2987247d8 100644 --- a/src/Common/IntervalTree.h +++ b/src/Common/IntervalTree.h @@ -10,6 +10,7 @@ namespace DB { /** Structure that holds closed interval with left and right. + * Interval left must be less than interval right. * Example: [1, 1] is valid interval, that contain point 1. */ template @@ -70,6 +71,9 @@ struct IntervalTreeVoidValue * Search for all intervals intersecting point has complexity O(log(n) + k), k is count of intervals that intersect point. * If we need to only check if there are some interval intersecting point such operation has complexity O(log(n)). * + * There is invariant that interval left must be less than interval right, otherwise such interval could not contain any point. + * If that invariant is broken, inserting such interval in IntervalTree will return false. + * * Explanation: * * IntervalTree structure is balanced tree. Each node contains: @@ -125,44 +129,48 @@ public: IntervalTree() { nodes.resize(1); } template , bool> = true> - void emplace(Interval interval) + ALWAYS_INLINE bool emplace(Interval interval) { assert(!tree_is_built); + if (unlikely(interval.left > interval.right)) + return false; + sorted_intervals.emplace_back(interval); increaseIntervalsSize(); + + return true; } template , bool> = true, typename... Args> - void emplace(Interval interval, Args &&... args) + ALWAYS_INLINE bool emplace(Interval interval, Args &&... args) { assert(!tree_is_built); + if (unlikely(interval.left > interval.right)) + return false; + sorted_intervals.emplace_back( std::piecewise_construct, std::forward_as_tuple(interval), std::forward_as_tuple(std::forward(args)...)); increaseIntervalsSize(); + + return true; } template , bool> = true> - void insert(Interval interval) + bool insert(Interval interval) { - assert(!tree_is_built); - sorted_intervals.emplace_back(interval); - increaseIntervalsSize(); + return emplace(interval); } template , bool> = true> - void insert(Interval interval, const Value & value) + bool insert(Interval interval, const Value & value) { - assert(!tree_is_built); - sorted_intervals.emplace_back(interval, value); - increaseIntervalsSize(); + return emplace(interval, value); } template , bool> = true> - void insert(Interval interval, Value && value) + bool insert(Interval interval, Value && value) { - assert(!tree_is_built); - sorted_intervals.emplace_back(interval, std::move(value)); - increaseIntervalsSize(); + return emplace(interval, std::move(value)); } /// Build tree, after that intervals cannot be inserted, and only search or iteration can be performed. diff --git a/src/Common/tests/gtest_interval_tree.cpp b/src/Common/tests/gtest_interval_tree.cpp index d9f19841b66..e99bfe83a98 100644 --- a/src/Common/tests/gtest_interval_tree.cpp +++ b/src/Common/tests/gtest_interval_tree.cpp @@ -309,6 +309,29 @@ TEST(IntervalTree, IntervalSetIterators) } } +TEST(IntervalTree, IntervalSetInvalidInterval) +{ + IntervalSet interval_set; + ASSERT_TRUE(!interval_set.insert(Int64Interval(10, 0))); + ASSERT_TRUE(!interval_set.insert(Int64Interval(15, 10))); + ASSERT_TRUE(interval_set.insert(Int64Interval(20, 25))); + + std::set expected; + expected.insert({20, 25}); + + auto actual = intervalSetFindIntervals(interval_set, 20); + + ASSERT_TRUE(actual == expected); + ASSERT_TRUE(interval_set.has(20)); + + interval_set.build(); + + actual = intervalSetFindIntervals(interval_set, 20); + + ASSERT_TRUE(actual == expected); + ASSERT_TRUE(interval_set.has(20)); +} + TEST(IntervalTree, IntervalMapBasic) { for (size_t intervals_size = 0; intervals_size < 120; ++intervals_size) @@ -538,3 +561,26 @@ TEST(IntervalTree, IntervalMapIterators) } } } + +TEST(IntervalTree, IntervalMapInvalidInterval) +{ + IntervalMap interval_map; + ASSERT_TRUE(!interval_map.insert(Int64Interval(10, 0), "Value")); + ASSERT_TRUE(!interval_map.insert(Int64Interval(15, 10), "Value")); + ASSERT_TRUE(interval_map.insert(Int64Interval(20, 25), "Value")); + + std::map expected; + expected.emplace(Int64Interval{20, 25}, "Value"); + + auto actual = intervalMapFindIntervals(interval_map, 20); + + ASSERT_TRUE(actual == expected); + ASSERT_TRUE(interval_map.has(20)); + + interval_map.build(); + + actual = intervalMapFindIntervals(interval_map, 20); + + ASSERT_TRUE(actual == expected); + ASSERT_TRUE(interval_map.has(20)); +} diff --git a/src/Dictionaries/RangeHashedDictionary.cpp b/src/Dictionaries/RangeHashedDictionary.cpp index 2d98583d4a3..657ef3706a1 100644 --- a/src/Dictionaries/RangeHashedDictionary.cpp +++ b/src/Dictionaries/RangeHashedDictionary.cpp @@ -537,7 +537,9 @@ void RangeHashedDictionary::blockToAttributes(const Block & if constexpr (std::is_same_v) key = copyStringInArena(string_arena, key); - setAttributeValue(attribute, key, RangeInterval{lower_bound, upper_bound}, attribute_column[key_index]); + if (likely(lower_bound < upper_bound)) + setAttributeValue(attribute, key, RangeInterval{lower_bound, upper_bound}, attribute_column[key_index]); + keys_extractor.rollbackCurrentKey(); } diff --git a/tests/queries/0_stateless/02179_range_hashed_dictionary_invalid_interval.reference b/tests/queries/0_stateless/02179_range_hashed_dictionary_invalid_interval.reference new file mode 100644 index 00000000000..d7753418087 --- /dev/null +++ b/tests/queries/0_stateless/02179_range_hashed_dictionary_invalid_interval.reference @@ -0,0 +1,5 @@ +Value +DefaultValue +1 +0 +0 15 20 Value diff --git a/tests/queries/0_stateless/02179_range_hashed_dictionary_invalid_interval.sql b/tests/queries/0_stateless/02179_range_hashed_dictionary_invalid_interval.sql new file mode 100644 index 00000000000..b68ee475273 --- /dev/null +++ b/tests/queries/0_stateless/02179_range_hashed_dictionary_invalid_interval.sql @@ -0,0 +1,36 @@ +DROP TABLE IF EXISTS 02179_test_table; +CREATE TABLE 02179_test_table +( + id UInt64, + value String, + start Int64, + end Int64 +) Engine = TinyLog; + +INSERT INTO 02179_test_table VALUES (0, 'Value', 10, 0); +INSERT INTO 02179_test_table VALUES (0, 'Value', 15, 10); +INSERT INTO 02179_test_table VALUES (0, 'Value', 15, 20); + +DROP DICTIONARY IF EXISTS 02179_test_dictionary; +CREATE DICTIONARY 02179_test_dictionary +( + id UInt64, + value String DEFAULT 'DefaultValue', + start Int64, + end Int64 +) PRIMARY KEY id +LAYOUT(RANGE_HASHED()) +SOURCE(CLICKHOUSE(TABLE '02179_test_table')) +RANGE(MIN start MAX end) +LIFETIME(0); + +SELECT dictGet('02179_test_dictionary', 'value', 0, 15); +SELECT dictGet('02179_test_dictionary', 'value', 0, 5); + +SELECT dictHas('02179_test_dictionary', 0, 15); +SELECT dictHas('02179_test_dictionary', 0, 5); + +SELECT * FROM 02179_test_dictionary; + +DROP DICTIONARY 02179_test_dictionary; +DROP TABLE 02179_test_table; From 97605b7c9caf1bb79189c33f5b1076cd6f35385e Mon Sep 17 00:00:00 2001 From: Maksim Kita Date: Thu, 20 Jan 2022 13:36:12 +0000 Subject: [PATCH 2/2] Fixed tests --- src/Dictionaries/RangeHashedDictionary.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Dictionaries/RangeHashedDictionary.cpp b/src/Dictionaries/RangeHashedDictionary.cpp index 657ef3706a1..d970498d97a 100644 --- a/src/Dictionaries/RangeHashedDictionary.cpp +++ b/src/Dictionaries/RangeHashedDictionary.cpp @@ -537,7 +537,7 @@ void RangeHashedDictionary::blockToAttributes(const Block & if constexpr (std::is_same_v) key = copyStringInArena(string_arena, key); - if (likely(lower_bound < upper_bound)) + if (likely(lower_bound <= upper_bound)) setAttributeValue(attribute, key, RangeInterval{lower_bound, upper_bound}, attribute_column[key_index]); keys_extractor.rollbackCurrentKey();