Add bitmapContains support for all UInt types

This commit is contained in:
sundy-li 2021-01-06 13:20:01 +00:00
parent 67e7e6b235
commit 9c3c1d13ab
3 changed files with 27 additions and 13 deletions

View File

@ -6,6 +6,7 @@
#include <Columns/ColumnConst.h> #include <Columns/ColumnConst.h>
#include <Columns/ColumnVector.h> #include <Columns/ColumnVector.h>
#include <Columns/ColumnsNumber.h> #include <Columns/ColumnsNumber.h>
#include <Interpreters/castColumn.h>
#include <DataTypes/DataTypeAggregateFunction.h> #include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeArray.h> #include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypesNumber.h> #include <DataTypes/DataTypesNumber.h>
@ -14,6 +15,7 @@
#include <Common/typeid_cast.h> #include <Common/typeid_cast.h>
#include <Common/assert_cast.h> #include <Common/assert_cast.h>
// TODO include this last because of a broken roaring header. See the comment // TODO include this last because of a broken roaring header. See the comment
// inside. // inside.
#include <AggregateFunctions/AggregateFunctionGroupBitmapData.h> #include <AggregateFunctions/AggregateFunctionGroupBitmapData.h>
@ -724,10 +726,11 @@ public:
throw Exception( throw Exception(
"First argument for function " + getName() + " must be a bitmap but it has type " + arguments[0]->getName() + ".", "First argument for function " + getName() + " must be a bitmap but it has type " + arguments[0]->getName() + ".",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const auto * arg_type1 = typeid_cast<const DataTypeNumber<UInt32> *>(arguments[1].get());
if (!(arg_type1)) WhichDataType which(arguments[1].get());
if (!(which.isUInt8() || which.isUInt16() || which.isUInt32() || which.isUInt64()))
throw Exception( throw Exception(
"Second argument for function " + getName() + " must be UInt32 but it has type " + arguments[1]->getName() + ".", "Second argument for function " + getName() + " must be UInt but it has type " + arguments[1]->getName() + ".",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<DataTypeNumber<UInt8>>(); return std::make_shared<DataTypeNumber<UInt8>>();
@ -765,27 +768,32 @@ private:
{ {
const IColumn * column_ptrs[2]; const IColumn * column_ptrs[2];
bool is_column_const[2]; bool is_column_const[2];
const PaddedPODArray<AggregateDataPtr> * container0;
const PaddedPODArray<UInt32> * container1;
for (size_t i = 0; i < 2; ++i) const PaddedPODArray<AggregateDataPtr> * container0;
{ const PaddedPODArray<UInt64> * container1;
column_ptrs[i] = arguments[i].column.get();
is_column_const[i] = isColumnConst(*column_ptrs[i]); column_ptrs[0] = arguments[0].column.get();
} is_column_const[0] = isColumnConst(*column_ptrs[0]);
if (is_column_const[0]) if (is_column_const[0])
container0 = &typeid_cast<const ColumnAggregateFunction*>(typeid_cast<const ColumnConst*>(column_ptrs[0])->getDataColumnPtr().get())->getData(); container0 = &typeid_cast<const ColumnAggregateFunction*>(typeid_cast<const ColumnConst*>(column_ptrs[0])->getDataColumnPtr().get())->getData();
else else
container0 = &typeid_cast<const ColumnAggregateFunction*>(column_ptrs[0])->getData(); container0 = &typeid_cast<const ColumnAggregateFunction*>(column_ptrs[0])->getData();
// we can always cast the second column to ColumnUInt64
auto super_type = std::make_shared<DataTypeUInt64>();
column_ptrs[1] = castColumn(arguments[1], super_type).get();
is_column_const[1] = isColumnConst(*column_ptrs[1]);
if (is_column_const[1]) if (is_column_const[1])
container1 = &typeid_cast<const ColumnUInt32*>(typeid_cast<const ColumnConst*>(column_ptrs[1])->getDataColumnPtr().get())->getData(); container1 = &typeid_cast<const ColumnUInt64*>(typeid_cast<const ColumnConst*>(column_ptrs[1])->getDataColumnPtr().get())->getData();
else else
container1 = &typeid_cast<const ColumnUInt32*>(column_ptrs[1])->getData(); container1 = &typeid_cast<const ColumnUInt64*>(column_ptrs[1])->getData();
for (size_t i = 0; i < input_rows_count; ++i) for (size_t i = 0; i < input_rows_count; ++i)
{ {
const AggregateDataPtr data_ptr_0 = is_column_const[0] ? (*container0)[0] : (*container0)[i]; const AggregateDataPtr data_ptr_0 = is_column_const[0] ? (*container0)[0] : (*container0)[i];
const UInt32 data1 = is_column_const[1] ? (*container1)[0] : (*container1)[i]; const UInt64 data1 = is_column_const[1] ? (*container1)[0] : (*container1)[i];
const AggregateFunctionGroupBitmapData<T> & bitmap_data_0 const AggregateFunctionGroupBitmapData<T> & bitmap_data_0
= *reinterpret_cast<const AggregateFunctionGroupBitmapData<T> *>(data_ptr_0); = *reinterpret_cast<const AggregateFunctionGroupBitmapData<T> *>(data_ptr_0);
vec_to[i] = bitmap_data_0.rbs.rb_contains(data1); vec_to[i] = bitmap_data_0.rbs.rb_contains(data1);

View File

@ -1,5 +1,8 @@
DROP TABLE IF EXISTS test; DROP TABLE IF EXISTS test;
CREATE TABLE test (num UInt64, str String) ENGINE = MergeTree ORDER BY num; CREATE TABLE test (num UInt64, str String) ENGINE = MergeTree ORDER BY num;
INSERT INTO test (num) VALUES (1), (2), (10), (15), (23); INSERT INTO test (num) VALUES (1), (2), (10), (15), (23);
SELECT count(*) FROM test WHERE bitmapContains(bitmapBuild([1, 5, 7, 9]), toUInt8(num));
SELECT count(*) FROM test WHERE bitmapContains(bitmapBuild([1, 5, 7, 9]), toUInt16(num));
SELECT count(*) FROM test WHERE bitmapContains(bitmapBuild([1, 5, 7, 9]), toUInt32(num)); SELECT count(*) FROM test WHERE bitmapContains(bitmapBuild([1, 5, 7, 9]), toUInt32(num));
SELECT count(*) FROM test WHERE bitmapContains(bitmapBuild([1, 5, 7, 9]), toUInt64(num));
DROP TABLE test; DROP TABLE test;