diff --git a/src/AggregateFunctions/AggregateFunctionAnyHeavy.cpp b/src/AggregateFunctions/AggregateFunctionAnyHeavy.cpp index ffddd46f2e3..dbc5f9be72f 100644 --- a/src/AggregateFunctions/AggregateFunctionAnyHeavy.cpp +++ b/src/AggregateFunctions/AggregateFunctionAnyHeavy.cpp @@ -68,7 +68,10 @@ public: if (data().isEqualTo(to.data())) counter += to.counter; else if (!data().has() || counter < to.counter) + { data().set(to.data(), arena); + counter = to.counter - counter; + } else counter -= to.counter; } diff --git a/tests/integration/test_backward_compatibility/test_functions.py b/tests/integration/test_backward_compatibility/test_functions.py index 3231fb87f33..202a741bfb5 100644 --- a/tests/integration/test_backward_compatibility/test_functions.py +++ b/tests/integration/test_backward_compatibility/test_functions.py @@ -67,6 +67,11 @@ def test_aggregate_states(start_cluster): f"select hex(initializeAggregation('{function_name}State', 'foo'))" ).strip() + def get_final_value_unhex(node, function_name, value): + return node.query( + f"select finalizeAggregation(unhex('{value}')::AggregateFunction({function_name}, String))" + ).strip() + for aggregate_function in aggregate_functions: logging.info("Checking %s", aggregate_function) @@ -99,13 +104,39 @@ def test_aggregate_states(start_cluster): upstream_state = get_aggregate_state_hex(upstream, aggregate_function) if upstream_state != backward_state: - logging.info( - "Failed %s, %s (backward) != %s (upstream)", - aggregate_function, - backward_state, - upstream_state, - ) - failed += 1 + allowed_changes_if_result_is_the_same = ["anyHeavy"] + + if aggregate_function in allowed_changes_if_result_is_the_same: + backward_final_from_upstream = get_final_value_unhex( + backward, aggregate_function, upstream_state + ) + upstream_final_from_backward = get_final_value_unhex( + upstream, aggregate_function, backward_state + ) + + if backward_final_from_upstream == upstream_final_from_backward: + logging.info( + "OK %s (but different intermediate states)", aggregate_function + ) + passed += 1 + else: + logging.error( + "Failed %s, Intermediate: %s (backward) != %s (upstream). Final from intermediate: %s (backward from upstream state) != %s (upstream from backward state)", + aggregate_function, + backward_state, + upstream_state, + backward_final_from_upstream, + upstream_final_from_backward, + ) + failed += 1 + else: + logging.error( + "Failed %s, %s (backward) != %s (upstream)", + aggregate_function, + backward_state, + upstream_state, + ) + failed += 1 else: logging.info("OK %s", aggregate_function) passed += 1 diff --git a/tests/queries/0_stateless/03230_anyHeavy_merge.reference b/tests/queries/0_stateless/03230_anyHeavy_merge.reference new file mode 100644 index 00000000000..78981922613 --- /dev/null +++ b/tests/queries/0_stateless/03230_anyHeavy_merge.reference @@ -0,0 +1 @@ +a diff --git a/tests/queries/0_stateless/03230_anyHeavy_merge.sql b/tests/queries/0_stateless/03230_anyHeavy_merge.sql new file mode 100644 index 00000000000..5d4c0e55d0f --- /dev/null +++ b/tests/queries/0_stateless/03230_anyHeavy_merge.sql @@ -0,0 +1,4 @@ +DROP TABLE IF EXISTS t; +CREATE TABLE t (letter String) ENGINE=MergeTree order by () partition by letter; +INSERT INTO t VALUES ('a'), ('a'), ('a'), ('a'), ('b'), ('a'), ('a'), ('a'), ('a'), ('a'), ('a'), ('a'), ('a'), ('a'), ('a'), ('a'), ('c'); +SELECT anyHeavy(if(letter != 'b', letter, NULL)) FROM t;