ClickHouse/src/Functions/bitHammingDistance.cpp

163 lines
5.7 KiB
C++
Raw Normal View History

2019-11-06 10:35:23 +00:00
#include <Columns/ColumnVector.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <Functions/castTypeToEither.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
template <typename A, typename B>
struct BitHammingDistanceImpl
{
using ResultType = UInt8;
2020-08-16 07:42:35 +00:00
static void NO_INLINE vectorVector(const PaddedPODArray<A> & a, const PaddedPODArray<B> & b, PaddedPODArray<ResultType> & c)
2019-11-06 10:35:23 +00:00
{
size_t size = a.size();
for (size_t i = 0; i < size; ++i)
c[i] = apply(a[i], b[i]);
}
2020-08-16 07:42:35 +00:00
static void NO_INLINE vectorConstant(const PaddedPODArray<A> & a, B b, PaddedPODArray<ResultType> & c)
2019-11-06 10:35:23 +00:00
{
size_t size = a.size();
for (size_t i = 0; i < size; ++i)
c[i] = apply(a[i], b);
}
2020-08-16 07:42:35 +00:00
static void NO_INLINE constantVector(A a, const PaddedPODArray<B> & b, PaddedPODArray<ResultType> & c)
2019-11-06 10:35:23 +00:00
{
size_t size = b.size();
for (size_t i = 0; i < size; ++i)
c[i] = apply(a, b[i]);
}
private:
static inline UInt8 apply(UInt64 a, UInt64 b)
{
UInt64 res = a ^ b;
2020-05-22 13:23:49 +00:00
return __builtin_popcountll(res);
2019-11-06 10:35:23 +00:00
}
};
template <typename F>
bool castType(const IDataType * type, F && f)
{
return castTypeToEither<
DataTypeInt8,
DataTypeInt16,
DataTypeInt32,
DataTypeInt64,
DataTypeUInt8,
DataTypeUInt16,
DataTypeUInt32,
DataTypeUInt64>(type, std::forward<F>(f));
}
template <typename F>
static bool castBothTypes(const IDataType * left, const IDataType * right, F && f)
{
2020-05-22 13:23:49 +00:00
return castType(left, [&](const auto & left_) { return castType(right, [&](const auto & right_) { return f(left_, right_); }); });
2019-11-06 10:35:23 +00:00
}
// bitHammingDistance function: (Integer, Integer) -> UInt8
2019-11-06 10:35:23 +00:00
class FunctionBitHammingDistance : public IFunction
{
public:
static constexpr auto name = "bitHammingDistance";
using ResultType = UInt8;
2021-06-01 12:20:52 +00:00
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionBitHammingDistance>(); }
2019-11-06 10:35:23 +00:00
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 2; }
2021-06-22 16:21:23 +00:00
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }
2019-11-06 10:35:23 +00:00
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isInteger(arguments[0]))
throw Exception(
"Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!isInteger(arguments[1]))
throw Exception(
"Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<DataTypeUInt8>();
}
2020-05-22 13:23:49 +00:00
bool useDefaultImplementationForConstants() const override { return true; }
2020-12-17 19:14:01 +00:00
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
2019-11-06 10:35:23 +00:00
{
2020-12-17 19:14:01 +00:00
const auto * left_generic = arguments[0].type.get();
const auto * right_generic = arguments[1].type.get();
ColumnPtr result_column;
2020-12-22 15:17:23 +00:00
bool valid = castBothTypes(left_generic, right_generic, [&](const auto & left, const auto & right)
{
2019-11-06 10:35:23 +00:00
using LeftDataType = std::decay_t<decltype(left)>;
using RightDataType = std::decay_t<decltype(right)>;
using T0 = typename LeftDataType::FieldType;
using T1 = typename RightDataType::FieldType;
using ColVecT0 = ColumnVector<T0>;
using ColVecT1 = ColumnVector<T1>;
using ColVecResult = ColumnVector<ResultType>;
using OpImpl = BitHammingDistanceImpl<T0, T1>;
2020-12-17 19:14:01 +00:00
const auto * const col_left_raw = arguments[0].column.get();
const auto * const col_right_raw = arguments[1].column.get();
2019-11-06 10:35:23 +00:00
typename ColVecResult::MutablePtr col_res = nullptr;
col_res = ColVecResult::create();
auto & vec_res = col_res->getData();
2020-12-17 19:14:01 +00:00
vec_res.resize(input_rows_count);
2019-11-06 10:35:23 +00:00
if (auto col_left_const = checkAndGetColumnConst<ColVecT0>(col_left_raw))
{
if (auto col_right = checkAndGetColumn<ColVecT1>(col_right_raw))
{
// constant integer - non-constant integer
2020-08-16 07:42:35 +00:00
OpImpl::constantVector(col_left_const->template getValue<T0>(), col_right->getData(), vec_res);
2019-11-06 10:35:23 +00:00
}
else
return false;
}
else if (auto col_left = checkAndGetColumn<ColVecT0>(col_left_raw))
{
if (auto col_right = checkAndGetColumn<ColVecT1>(col_right_raw))
// non-constant integer - non-constant integer
2020-08-16 07:42:35 +00:00
OpImpl::vectorVector(col_left->getData(), col_right->getData(), vec_res);
2019-11-06 10:35:23 +00:00
else if (auto col_right_const = checkAndGetColumnConst<ColVecT1>(col_right_raw))
// non-constant integer - constant integer
2020-08-16 07:42:35 +00:00
OpImpl::vectorConstant(col_left->getData(), col_right_const->template getValue<T1>(), vec_res);
2019-11-06 10:35:23 +00:00
else
return false;
}
else
return false;
2020-12-17 19:14:01 +00:00
result_column = std::move(col_res);
2019-11-06 10:35:23 +00:00
return true;
});
if (!valid)
throw Exception(getName() + "'s arguments do not match the expected data types", ErrorCodes::ILLEGAL_COLUMN);
2020-12-17 19:14:01 +00:00
return result_column;
2019-11-06 10:35:23 +00:00
}
};
void registerFunctionBitHammingDistance(FunctionFactory & factory)
{
factory.registerFunction<FunctionBitHammingDistance>();
}
}