upd FullSortingJoin.AsofGeneratedTestData

This commit is contained in:
vdimir 2023-09-21 11:38:37 +00:00
parent 6330b466aa
commit da4f355561
No known key found for this signature in database
GPG Key ID: 6EE4CE2BEDC51862

View File

@ -208,6 +208,14 @@ Block executePipeline(QueryPipeline && pipeline)
return concatenateBlocks(result_blocks);
}
template <typename T>
void checkColumn(const typename ColumnVector<T>::Container & expected, const Block & block, const std::string & name)
{
const auto & actual = assert_cast<const ColumnVector<T> *>(block.getByName(name).column.get())->getData();
EXPECT_EQ(actual.size(), expected.size());
ASSERT_EQ(actual, expected);
}
TEST(FullSortingJoin, Asof)
try
{
@ -306,18 +314,27 @@ catch (Exception & e)
TEST(FullSortingJoin, AsofGeneratedTestData)
try
{
std::vector<JoinKind> join_kinds = {JoinKind::Inner, JoinKind::Left};
auto join_kind = join_kinds[std::uniform_int_distribution<size_t>(0, join_kinds.size() - 1)(rng)];
std::vector<ASOFJoinInequality> asof_inequalities = {
ASOFJoinInequality::Less, ASOFJoinInequality::LessOrEquals,
// ASOFJoinInequality::Greater, ASOFJoinInequality::GreaterOrEquals,
};
auto asof_inequality = asof_inequalities[std::uniform_int_distribution<size_t>(0, asof_inequalities.size() - 1)(rng)];
auto left_source_builder = SourceChunksBuilder({
{std::make_shared<DataTypeUInt64>(), "k1"},
{std::make_shared<DataTypeString>(), "k2"},
{std::make_shared<DataTypeUInt64>(), "t"},
{std::make_shared<DataTypeUInt64>(), "attr"},
{std::make_shared<DataTypeInt64>(), "attr"},
});
auto right_source_builder = SourceChunksBuilder({
{std::make_shared<DataTypeUInt64>(), "k1"},
{std::make_shared<DataTypeString>(), "k2"},
{std::make_shared<DataTypeUInt64>(), "t"},
{std::make_shared<DataTypeUInt64>(), "attr"},
{std::make_shared<DataTypeInt64>(), "attr"},
});
/// uniform_int_distribution to have 0.0 and 1.0 probabilities
@ -333,14 +350,14 @@ try
k2 = new_k2;
};
ColumnUInt64::Container expected;
ColumnInt64::Container expected;
UInt64 k1 = 0;
String k2 = "asdfg";
auto key_num_total = std::uniform_int_distribution<>(1, 1000)(rng);
for (size_t key_num = 0; key_num < key_num_total; ++key_num)
{
UInt64 left_t = 0;
Int64 left_t = 0;
size_t num_left_rows = std::uniform_int_distribution<>(1, 100)(rng);
for (size_t i = 0; i < num_left_rows; ++i)
{
@ -351,11 +368,22 @@ try
auto num_matches = 1 + std::poisson_distribution<>(4)(rng);
size_t right_t = left_t;
auto right_t = left_t;
for (size_t j = 0; j < num_matches; ++j)
{
right_t += std::uniform_int_distribution<>(0, 3)(rng);
right_source_builder.addRow({k1, k2, right_t, j == 0 ? 100 * left_t : 0});
int min_step = 1;
if (asof_inequality == ASOFJoinInequality::LessOrEquals || asof_inequality == ASOFJoinInequality::GreaterOrEquals)
min_step = 0;
right_t += std::uniform_int_distribution<>(min_step, 3)(rng);
bool is_match = false;
if (asof_inequality == ASOFJoinInequality::LessOrEquals || asof_inequality == ASOFJoinInequality::Less)
is_match = j == 0;
else if (asof_inequality == ASOFJoinInequality::GreaterOrEquals || asof_inequality == ASOFJoinInequality::Greater)
is_match = j == num_matches - 1;
right_source_builder.addRow({k1, k2, right_t, is_match ? 100 * left_t : -1});
}
/// next left_t should be greater than right_t not to match with previous rows
left_t = right_t;
@ -366,7 +394,10 @@ try
for (size_t i = 0; i < num_left_rows; ++i)
{
left_t += std::uniform_int_distribution<>(1, 10)(rng);
left_source_builder.addRow({k1, k2, left_t, 10 * left_t});
left_source_builder.addRow({k1, k2, left_t, -10 * left_t});
if (join_kind == JoinKind::Left)
expected.push_back(-10 * left_t);
}
get_next_key(k1, k2);
@ -375,13 +406,14 @@ try
Block result_block = executePipeline(buildJoinPipeline(
left_source_builder.build(), right_source_builder.build(),
/* key_length = */ 3,
JoinKind::Inner, JoinStrictness::Asof, ASOFJoinInequality::LessOrEquals));
join_kind, JoinStrictness::Asof, asof_inequality));
ASSERT_EQ(assert_cast<const ColumnUInt64 *>(block.getByName("t1.attr").column.get())->getData(), expected);
checkColumn<Int64>(expected, result_block, "t1.attr");
for (auto & e : expected)
e = 10 * e;
ASSERT_EQ(assert_cast<const ColumnUInt64 *>(block.getByName("t2.attr").column.get())->getData(), expected);
e = e < 0 ? 0 : 10 * e; /// non matched rows from left table have negative attr
checkColumn<Int64>(expected, result_block, "t2.attr");
}
catch (Exception & e) {
std::cout << e.getStackTraceString() << std::endl;