diff --git a/src/Functions/FunctionsBitmap.h b/src/Functions/FunctionsBitmap.h index 93da4906658..601a7524213 100644 --- a/src/Functions/FunctionsBitmap.h +++ b/src/Functions/FunctionsBitmap.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -14,6 +15,7 @@ #include #include + // TODO include this last because of a broken roaring header. See the comment // inside. #include @@ -724,10 +726,11 @@ public: throw Exception( "First argument for function " + getName() + " must be a bitmap but it has type " + arguments[0]->getName() + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - const auto * arg_type1 = typeid_cast *>(arguments[1].get()); - if (!(arg_type1)) + + WhichDataType which(arguments[1].get()); + if (!(which.isUInt8() || which.isUInt16() || which.isUInt32() || which.isUInt64())) throw Exception( - "Second argument for function " + getName() + " must be UInt32 but it has type " + arguments[1]->getName() + ".", + "Second argument for function " + getName() + " must be UInt but it has type " + arguments[1]->getName() + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); return std::make_shared>(); @@ -765,27 +768,32 @@ private: { const IColumn * column_ptrs[2]; bool is_column_const[2]; - const PaddedPODArray * container0; - const PaddedPODArray * container1; - for (size_t i = 0; i < 2; ++i) - { - column_ptrs[i] = arguments[i].column.get(); - is_column_const[i] = isColumnConst(*column_ptrs[i]); - } + const PaddedPODArray * container0; + const PaddedPODArray * container1; + + column_ptrs[0] = arguments[0].column.get(); + is_column_const[0] = isColumnConst(*column_ptrs[0]); + if (is_column_const[0]) container0 = &typeid_cast(typeid_cast(column_ptrs[0])->getDataColumnPtr().get())->getData(); else container0 = &typeid_cast(column_ptrs[0])->getData(); + + // we can always cast the second column to ColumnUInt64 + auto super_type = std::make_shared(); + column_ptrs[1] = castColumn(arguments[1], super_type).get(); + is_column_const[1] = isColumnConst(*column_ptrs[1]); + if (is_column_const[1]) - container1 = &typeid_cast(typeid_cast(column_ptrs[1])->getDataColumnPtr().get())->getData(); + container1 = &typeid_cast(typeid_cast(column_ptrs[1])->getDataColumnPtr().get())->getData(); else - container1 = &typeid_cast(column_ptrs[1])->getData(); + container1 = &typeid_cast(column_ptrs[1])->getData(); for (size_t i = 0; i < input_rows_count; ++i) { const AggregateDataPtr data_ptr_0 = is_column_const[0] ? (*container0)[0] : (*container0)[i]; - const UInt32 data1 = is_column_const[1] ? (*container1)[0] : (*container1)[i]; + const UInt64 data1 = is_column_const[1] ? (*container1)[0] : (*container1)[i]; const AggregateFunctionGroupBitmapData & bitmap_data_0 = *reinterpret_cast *>(data_ptr_0); vec_to[i] = bitmap_data_0.rbs.rb_contains(data1); diff --git a/tests/queries/0_stateless/00974_bitmapContains_with_primary_key.reference b/tests/queries/0_stateless/00974_bitmapContains_with_primary_key.reference index d00491fd7e5..98fb6a68656 100644 --- a/tests/queries/0_stateless/00974_bitmapContains_with_primary_key.reference +++ b/tests/queries/0_stateless/00974_bitmapContains_with_primary_key.reference @@ -1 +1,4 @@ 1 +1 +1 +1 diff --git a/tests/queries/0_stateless/00974_bitmapContains_with_primary_key.sql b/tests/queries/0_stateless/00974_bitmapContains_with_primary_key.sql index 81dd7cab9f4..520b4a03057 100644 --- a/tests/queries/0_stateless/00974_bitmapContains_with_primary_key.sql +++ b/tests/queries/0_stateless/00974_bitmapContains_with_primary_key.sql @@ -1,5 +1,8 @@ DROP TABLE IF EXISTS test; CREATE TABLE test (num UInt64, str String) ENGINE = MergeTree ORDER BY num; INSERT INTO test (num) VALUES (1), (2), (10), (15), (23); +SELECT count(*) FROM test WHERE bitmapContains(bitmapBuild([1, 5, 7, 9]), toUInt8(num)); +SELECT count(*) FROM test WHERE bitmapContains(bitmapBuild([1, 5, 7, 9]), toUInt16(num)); SELECT count(*) FROM test WHERE bitmapContains(bitmapBuild([1, 5, 7, 9]), toUInt32(num)); +SELECT count(*) FROM test WHERE bitmapContains(bitmapBuild([1, 5, 7, 9]), toUInt64(num)); DROP TABLE test;