From 91c7ae83a30422c33ec814f7d340b2d854c90be6 Mon Sep 17 00:00:00 2001 From: chertus Date: Fri, 11 Oct 2019 20:56:26 +0300 Subject: [PATCH] asof join inequalities --- dbms/src/Interpreters/AnalyzedJoin.h | 5 ++ .../Interpreters/CollectJoinOnKeysVisitor.cpp | 23 +++++---- .../Interpreters/CollectJoinOnKeysVisitor.h | 8 ++- dbms/src/Interpreters/Join.cpp | 3 +- dbms/src/Interpreters/Join.h | 2 + dbms/src/Interpreters/RowRefs.cpp | 19 +++---- dbms/src/Interpreters/RowRefs.h | 51 +++++++++---------- dbms/src/Interpreters/asof.h | 46 +++++++++++++++++ .../0_stateless/00976_asof_join_on.reference | 22 ++++++++ .../0_stateless/00976_asof_join_on.sql | 14 +++-- 10 files changed, 141 insertions(+), 52 deletions(-) create mode 100644 dbms/src/Interpreters/asof.h diff --git a/dbms/src/Interpreters/AnalyzedJoin.h b/dbms/src/Interpreters/AnalyzedJoin.h index 9629547328d..991e5d2f395 100644 --- a/dbms/src/Interpreters/AnalyzedJoin.h +++ b/dbms/src/Interpreters/AnalyzedJoin.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -48,6 +49,7 @@ class AnalyzedJoin ASTs key_asts_left; ASTs key_asts_right; ASTTableJoin table_join; + ASOF::Inequality asof_inequality = ASOF::Inequality::GreaterOrEquals; /// All columns which can be read from joined table. Duplicating names are qualified. NamesAndTypesList columns_from_joined_table; @@ -100,6 +102,9 @@ public: void addJoinedColumn(const NameAndTypePair & joined_column); void addJoinedColumnsAndCorrectNullability(Block & sample_block) const; + void setAsofInequality(ASOF::Inequality inequality) { asof_inequality = inequality; } + ASOF::Inequality getAsofInequality() { return asof_inequality; } + ASTPtr leftKeysList() const; ASTPtr rightKeysList() const; /// For ON syntax only diff --git a/dbms/src/Interpreters/CollectJoinOnKeysVisitor.cpp b/dbms/src/Interpreters/CollectJoinOnKeysVisitor.cpp index 68e04b45d99..83d44629537 100644 --- a/dbms/src/Interpreters/CollectJoinOnKeysVisitor.cpp +++ b/dbms/src/Interpreters/CollectJoinOnKeysVisitor.cpp @@ -32,16 +32,24 @@ void CollectJoinOnKeysMatcher::Data::addJoinKeys(const ASTPtr & left_ast, const } void CollectJoinOnKeysMatcher::Data::addAsofJoinKeys(const ASTPtr & left_ast, const ASTPtr & right_ast, - const std::pair & table_no) + const std::pair & table_no, const ASOF::Inequality & inequality) { if (table_no.first == 1 || table_no.second == 2) { asof_left_key = left_ast->clone(); asof_right_key = right_ast->clone(); + analyzed_join.setAsofInequality(inequality); + return; + } + else if (table_no.first == 2 || table_no.second == 1) + { + asof_left_key = right_ast->clone(); + asof_right_key = left_ast->clone(); + analyzed_join.setAsofInequality(ASOF::reverseInequality(inequality)); return; } - throw Exception("ASOF JOIN for (left_table.x <= right_table.x) is not implemented", ErrorCodes::NOT_IMPLEMENTED); + throw Exception("ASOF JOIN requires keys inequality from different tables", ErrorCodes::NOT_IMPLEMENTED); } void CollectJoinOnKeysMatcher::Data::asofToJoinKeys() @@ -66,10 +74,9 @@ void CollectJoinOnKeysMatcher::visit(const ASTFunction & func, const ASTPtr & as return; } - bool less_or_equals = (func.name == "lessOrEquals"); - bool greater_or_equals = (func.name == "greaterOrEquals"); + ASOF::Inequality inequality = ASOF::getInequality(func.name); - if (data.is_asof && (less_or_equals || greater_or_equals)) + if (data.is_asof && (inequality != ASOF::Inequality::None)) { if (data.asof_left_key || data.asof_right_key) throwSyntaxException("ASOF JOIN expects exactly one inequality in ON section, unexpected " + queryToString(ast) + "."); @@ -78,11 +85,7 @@ void CollectJoinOnKeysMatcher::visit(const ASTFunction & func, const ASTPtr & as ASTPtr right = func.arguments->children.at(1); auto table_numbers = getTableNumbers(ast, left, right, data); - if (greater_or_equals) - data.addAsofJoinKeys(left, right, table_numbers); - else - data.addAsofJoinKeys(right, left, std::make_pair(table_numbers.second, table_numbers.first)); - + data.addAsofJoinKeys(left, right, table_numbers, inequality); return; } diff --git a/dbms/src/Interpreters/CollectJoinOnKeysVisitor.h b/dbms/src/Interpreters/CollectJoinOnKeysVisitor.h index 4d085dfcc31..0b4cb1fe857 100644 --- a/dbms/src/Interpreters/CollectJoinOnKeysVisitor.h +++ b/dbms/src/Interpreters/CollectJoinOnKeysVisitor.h @@ -12,6 +12,11 @@ namespace DB class ASTIdentifier; class AnalyzedJoin; +namespace ASOF +{ + enum class Inequality; +} + class CollectJoinOnKeysMatcher { public: @@ -29,7 +34,8 @@ public: bool has_some{false}; void addJoinKeys(const ASTPtr & left_ast, const ASTPtr & right_ast, const std::pair & table_no); - void addAsofJoinKeys(const ASTPtr & left_ast, const ASTPtr & right_ast, const std::pair & table_no); + void addAsofJoinKeys(const ASTPtr & left_ast, const ASTPtr & right_ast, const std::pair & table_no, + const ASOF::Inequality & asof_inequality); void asofToJoinKeys(); }; diff --git a/dbms/src/Interpreters/Join.cpp b/dbms/src/Interpreters/Join.cpp index d5381e1dc6d..18d9e0fdab9 100644 --- a/dbms/src/Interpreters/Join.cpp +++ b/dbms/src/Interpreters/Join.cpp @@ -70,6 +70,7 @@ Join::Join(std::shared_ptr table_join_, const Block & right_sample , nullable_right_side(table_join->forceNullableRight()) , nullable_left_side(table_join->forceNullableLeft()) , any_take_last_row(any_take_last_row_) + , asof_inequality(table_join->getAsofInequality()) , log(&Logger::get("Join")) { setSampleBlock(right_sample_block); @@ -635,7 +636,7 @@ std::unique_ptr NO_INLINE joinRightIndexedColumns( if constexpr (STRICTNESS == ASTTableJoin::Strictness::Asof) { - if (const RowRef * found = mapped.findAsof(join.getAsofType(), asof_column, i)) + if (const RowRef * found = mapped.findAsof(join.getAsofType(), join.getAsofInequality(), asof_column, i)) { filter[i] = 1; mapped.setUsed(); diff --git a/dbms/src/Interpreters/Join.h b/dbms/src/Interpreters/Join.h index 424512266fb..403621ccd75 100644 --- a/dbms/src/Interpreters/Join.h +++ b/dbms/src/Interpreters/Join.h @@ -166,6 +166,7 @@ public: ASTTableJoin::Kind getKind() const { return kind; } ASTTableJoin::Strictness getStrictness() const { return strictness; } AsofRowRefs::Type getAsofType() const { return *asof_type; } + ASOF::Inequality getAsofInequality() const { return asof_inequality; } bool anyTakeLastRow() const { return any_take_last_row; } /// Different types of keys for maps. @@ -305,6 +306,7 @@ private: Type type = Type::EMPTY; std::optional asof_type; + ASOF::Inequality asof_inequality; static Type chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_sizes); diff --git a/dbms/src/Interpreters/RowRefs.cpp b/dbms/src/Interpreters/RowRefs.cpp index 2ac61af7d9f..949bdd33096 100644 --- a/dbms/src/Interpreters/RowRefs.cpp +++ b/dbms/src/Interpreters/RowRefs.cpp @@ -58,26 +58,27 @@ void AsofRowRefs::insert(Type type, const IColumn * asof_column, const Block * b callWithType(type, call); } -const RowRef * AsofRowRefs::findAsof(Type type, const IColumn * asof_column, size_t row_num) const +const RowRef * AsofRowRefs::findAsof(Type type, ASOF::Inequality inequality, const IColumn * asof_column, size_t row_num) const { const RowRef * out = nullptr; + bool ascending = (inequality == ASOF::Inequality::Less) || (inequality == ASOF::Inequality::LessOrEquals); + bool is_strict = (inequality == ASOF::Inequality::Less) || (inequality == ASOF::Inequality::Greater); + auto call = [&](const auto & t) { using T = std::decay_t; - using LookupPtr = typename Entry::LookupPtr; + using EntryType = Entry; + using LookupPtr = typename EntryType::LookupPtr; auto * column = typeid_cast *>(asof_column); T key = column->getElement(row_num); auto & typed_lookup = std::get(lookups); - // The first thread that calls upper_bound ensures that the data is sorted - auto it = typed_lookup->upper_bound(Entry(key)); - - // cbegin() is safe to call now because the array is immutable after sorting - // hence the pointer to a entry can be returned - if (it != typed_lookup->cbegin()) - out = &((--it)->row_ref); + if (is_strict) + out = typed_lookup->upperBound(EntryType(key), ascending); + else + out = typed_lookup->lowerBound(EntryType(key), ascending); }; callWithType(type, call); diff --git a/dbms/src/Interpreters/RowRefs.h b/dbms/src/Interpreters/RowRefs.h index 03309831322..ea0a101f370 100644 --- a/dbms/src/Interpreters/RowRefs.h +++ b/dbms/src/Interpreters/RowRefs.h @@ -1,8 +1,8 @@ #pragma once #include -#include #include +#include #include #include @@ -144,34 +144,44 @@ public: array.push_back(std::forward(x), std::forward(allocator_params)...); } - // Transition into second stage, ensures that the vector is sorted - typename Base::const_iterator upper_bound(const TEntry & k) + const RowRef * upperBound(const TEntry & k, bool ascending) { - sort(); - return std::upper_bound(array.cbegin(), array.cend(), k); + sort(ascending); + auto it = std::upper_bound(array.cbegin(), array.cend(), k, (ascending ? less : less)); + if (it != array.cend()) + return &(it->row_ref); + return nullptr; } - // After ensuring that the vector is sorted by calling a lookup these are safe to call - typename Base::const_iterator cbegin() const { return array.cbegin(); } - typename Base::const_iterator cend() const { return array.cend(); } + const RowRef * lowerBound(const TEntry & k, bool ascending) + { + sort(ascending); + auto it = std::lower_bound(array.cbegin(), array.cend(), k, (ascending ? less : less)); + if (it != array.cend()) + return &(it->row_ref); + return nullptr; + } private: std::atomic sorted = false; Base array; mutable std::mutex lock; - struct RadixSortTraits : RadixSortNumTraits + template + static bool less(const TEntry & a, const TEntry & b) { - using Element = TEntry; - static TKey & extractKey(Element & elem) { return elem.asof_value; } - }; + if constexpr (ascending) + return a.asof_value < b.asof_value; + else + return a.asof_value > b.asof_value; + } // Double checked locking with SC atomics works in C++ // https://preshing.com/20130930/double-checked-locking-is-fixed-in-cpp11/ // The first thread that calls one of the lookup methods sorts the data // After calling the first lookup method it is no longer allowed to insert any data // the array becomes immutable - void sort() + void sort(bool ascending) { if (!sorted.load(std::memory_order_acquire)) { @@ -179,13 +189,7 @@ private: if (!sorted.load(std::memory_order_relaxed)) { if (!array.empty()) - { - /// TODO: It has been tested only for UInt32 yet. It needs to check UInt64, Float32/64. - if constexpr (std::is_same_v) - RadixSort::executeLSD(&array[0], array.size()); - else - std::sort(array.begin(), array.end()); - } + std::sort(array.begin(), array.end(), (ascending ? less : less)); sorted.store(true, std::memory_order_release); } @@ -206,11 +210,6 @@ public: Entry(T v) : asof_value(v) {} Entry(T v, RowRef rr) : asof_value(v), row_ref(rr) {} - - bool operator < (const Entry & o) const - { - return asof_value < o.asof_value; - } }; using Lookups = std::variant< @@ -236,7 +235,7 @@ public: void insert(Type type, const IColumn * asof_column, const Block * block, size_t row_num); // This will internally synchronize - const RowRef * findAsof(Type type, const IColumn * asof_column, size_t row_num) const; + const RowRef * findAsof(Type type, ASOF::Inequality inequality, const IColumn * asof_column, size_t row_num) const; private: // Lookups can be stored in a HashTable because it is memmovable diff --git a/dbms/src/Interpreters/asof.h b/dbms/src/Interpreters/asof.h new file mode 100644 index 00000000000..439bf4cc58c --- /dev/null +++ b/dbms/src/Interpreters/asof.h @@ -0,0 +1,46 @@ +#pragma once +#include + +namespace DB +{ +namespace ASOF +{ + +enum class Inequality +{ + None = 0, + Less, + Greater, + LessOrEquals, + GreaterOrEquals, +}; + +inline Inequality getInequality(const std::string & func_name) +{ + Inequality inequality{Inequality::None}; + if (func_name == "less") + inequality = Inequality::Less; + else if (func_name == "greater") + inequality = Inequality::Greater; + else if (func_name == "lessOrEquals") + inequality = Inequality::LessOrEquals; + else if (func_name == "greaterOrEquals") + inequality = Inequality::GreaterOrEquals; + return inequality; +} + +inline Inequality reverseInequality(Inequality inequality) +{ + if (inequality == Inequality::Less) + return Inequality::Greater; + else if (inequality == Inequality::Greater) + return Inequality::Less; + else if (inequality == Inequality::LessOrEquals) + return Inequality::GreaterOrEquals; + else if (inequality == Inequality::GreaterOrEquals) + return Inequality::LessOrEquals; + return Inequality::None; +} + +} +} diff --git a/dbms/tests/queries/0_stateless/00976_asof_join_on.reference b/dbms/tests/queries/0_stateless/00976_asof_join_on.reference index ffa8117cc75..4d1b1273363 100644 --- a/dbms/tests/queries/0_stateless/00976_asof_join_on.reference +++ b/dbms/tests/queries/0_stateless/00976_asof_join_on.reference @@ -11,3 +11,25 @@ 1 2 1 2 1 3 1 2 2 3 2 3 +- +1 1 1 2 +1 2 1 2 +1 3 1 4 +2 1 2 3 +2 2 2 3 +2 3 2 3 +- +1 1 1 2 +1 2 1 2 +1 3 1 4 +2 1 2 3 +2 2 2 3 +2 3 2 3 +- +1 3 1 2 +- +1 1 1 2 +1 2 1 4 +1 3 1 4 +2 1 2 3 +2 2 2 3 diff --git a/dbms/tests/queries/0_stateless/00976_asof_join_on.sql b/dbms/tests/queries/0_stateless/00976_asof_join_on.sql index 740287b7c30..ccecc0999c9 100644 --- a/dbms/tests/queries/0_stateless/00976_asof_join_on.sql +++ b/dbms/tests/queries/0_stateless/00976_asof_join_on.sql @@ -9,11 +9,15 @@ INSERT INTO B (b,t) VALUES (1,2),(1,4),(2,3); SELECT A.a, A.t, B.b, B.t FROM A ASOF LEFT JOIN B ON A.a == B.b AND A.t >= B.t ORDER BY (A.a, A.t); SELECT count() FROM A ASOF LEFT JOIN B ON A.a == B.b AND B.t <= A.t; -SELECT A.a, A.t, B.b, B.t FROM A ASOF INNER JOIN B ON B.t <= A.t AND A.a == B.b; -SELECT count() FROM A ASOF JOIN B ON A.a == B.b AND A.t <= B.t; -- { serverError 48 } -SELECT count() FROM A ASOF JOIN B ON A.a == B.b AND B.t >= A.t; -- { serverError 48 } -SELECT count() FROM A ASOF JOIN B ON A.a == B.b AND A.t > B.t; -- { serverError 403 } -SELECT count() FROM A ASOF JOIN B ON A.a == B.b AND A.t < B.t; -- { serverError 403 } +SELECT A.a, A.t, B.b, B.t FROM A ASOF INNER JOIN B ON B.t <= A.t AND A.a == B.b ORDER BY (A.a, A.t); +SELECT '-'; +SELECT A.a, A.t, B.b, B.t FROM A ASOF JOIN B ON A.a == B.b AND A.t <= B.t ORDER BY (A.a, A.t); +SELECT '-'; +SELECT A.a, A.t, B.b, B.t FROM A ASOF JOIN B ON A.a == B.b AND B.t >= A.t ORDER BY (A.a, A.t); +SELECT '-'; +SELECT A.a, A.t, B.b, B.t FROM A ASOF JOIN B ON A.a == B.b AND A.t > B.t ORDER BY (A.a, A.t); +SELECT '-'; +SELECT A.a, A.t, B.b, B.t FROM A ASOF JOIN B ON A.a == B.b AND A.t < B.t ORDER BY (A.a, A.t); SELECT count() FROM A ASOF JOIN B ON A.a == B.b AND A.t == B.t; -- { serverError 403 } SELECT count() FROM A ASOF JOIN B ON A.a == B.b AND A.t != B.t; -- { serverError 403 }