mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 23:21:59 +00:00
rewrite bitHammingDistance with FunctionBinaryArithmetic
fix fix typo
This commit is contained in:
parent
cdcd93330a
commit
53fde56cb7
@ -139,6 +139,9 @@ public:
|
||||
Case<IsOperation<Operation>::allow_decimal && IsDataTypeDecimal<RightDataType> && IsFloatingPoint<LeftDataType>,
|
||||
DataTypeFloat64>,
|
||||
|
||||
Case<IsOperation<Operation>::bit_hamming_distance && IsIntegral<LeftDataType> && IsIntegral<RightDataType>,
|
||||
DataTypeUInt8>,
|
||||
|
||||
/// Decimal <op> Real is not supported (traditional DBs convert Decimal <op> Real to Real)
|
||||
Case<IsDataTypeDecimal<LeftDataType> && !IsIntegralOrExtendedOrDecimal<RightDataType>, InvalidType>,
|
||||
Case<IsDataTypeDecimal<RightDataType> && !IsIntegralOrExtendedOrDecimal<LeftDataType>, InvalidType>,
|
||||
|
@ -19,6 +19,7 @@ template <typename, typename> struct EqualsOp;
|
||||
template <typename, typename> struct NotEqualsOp;
|
||||
template <typename, typename> struct LessOrEqualsOp;
|
||||
template <typename, typename> struct GreaterOrEqualsOp;
|
||||
template <typename, typename> struct BitHammingDistanceImpl;
|
||||
|
||||
template <typename>
|
||||
struct SignImpl;
|
||||
@ -55,6 +56,8 @@ struct IsOperation
|
||||
static constexpr bool least = IsSameOperation<Op, LeastBaseImpl>::value;
|
||||
static constexpr bool greatest = IsSameOperation<Op, GreatestBaseImpl>::value;
|
||||
|
||||
static constexpr bool bit_hamming_distance = IsSameOperation<Op, BitHammingDistanceImpl>::value;
|
||||
|
||||
static constexpr bool division = div_floating || div_int || div_int_or_zero;
|
||||
|
||||
static constexpr bool allow_decimal = plus || minus || multiply || division || least || greatest;
|
||||
|
@ -1,159 +1,32 @@
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Functions/FunctionBinaryArithmetic.h>
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
#include <Functions/IFunction.h>
|
||||
#include <Functions/castTypeToEither.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_COLUMN;
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
}
|
||||
|
||||
|
||||
template <typename A, typename B>
|
||||
struct BitHammingDistanceImpl
|
||||
{
|
||||
using ResultType = UInt8;
|
||||
static const constexpr bool allow_fixed_string = false;
|
||||
static const constexpr bool allow_string_integer = false;
|
||||
|
||||
static void NO_INLINE vectorVector(const PaddedPODArray<A> & a, const PaddedPODArray<B> & b, PaddedPODArray<ResultType> & c)
|
||||
template <typename Result = ResultType>
|
||||
static inline NO_SANITIZE_UNDEFINED Result apply(A a, B b)
|
||||
{
|
||||
size_t size = a.size();
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
c[i] = apply(a[i], b[i]);
|
||||
}
|
||||
|
||||
static void NO_INLINE vectorConstant(const PaddedPODArray<A> & a, B b, PaddedPODArray<ResultType> & c)
|
||||
{
|
||||
size_t size = a.size();
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
c[i] = apply(a[i], b);
|
||||
}
|
||||
|
||||
static void NO_INLINE constantVector(A a, const PaddedPODArray<B> & b, PaddedPODArray<ResultType> & c)
|
||||
{
|
||||
size_t size = b.size();
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
c[i] = apply(a, b[i]);
|
||||
}
|
||||
|
||||
private:
|
||||
static inline UInt8 apply(UInt64 a, UInt64 b)
|
||||
{
|
||||
UInt64 res = a ^ b;
|
||||
UInt64 res = static_cast<UInt64>(a) ^ static_cast<UInt64>(b);
|
||||
return __builtin_popcountll(res);
|
||||
}
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
static constexpr bool compilable = false; /// special type handling, some other time
|
||||
#endif
|
||||
};
|
||||
|
||||
template <typename F>
|
||||
bool castType(const IDataType * type, F && f)
|
||||
struct NameBitHammingDistance
|
||||
{
|
||||
return castTypeToEither<
|
||||
DataTypeInt8,
|
||||
DataTypeInt16,
|
||||
DataTypeInt32,
|
||||
DataTypeInt64,
|
||||
DataTypeUInt8,
|
||||
DataTypeUInt16,
|
||||
DataTypeUInt32,
|
||||
DataTypeUInt64>(type, std::forward<F>(f));
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
static bool castBothTypes(const IDataType * left, const IDataType * right, F && f)
|
||||
{
|
||||
return castType(left, [&](const auto & left_) { return castType(right, [&](const auto & right_) { return f(left_, right_); }); });
|
||||
}
|
||||
|
||||
// bitHammingDistance function: (Integer, Integer) -> UInt8
|
||||
class FunctionBitHammingDistance : public IFunction
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = "bitHammingDistance";
|
||||
using ResultType = UInt8;
|
||||
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionBitHammingDistance>(); }
|
||||
|
||||
String getName() const override { return name; }
|
||||
|
||||
size_t getNumberOfArguments() const override { return 2; }
|
||||
|
||||
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
|
||||
{
|
||||
if (!isInteger(arguments[0]))
|
||||
throw Exception(
|
||||
"Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
if (!isInteger(arguments[1]))
|
||||
throw Exception(
|
||||
"Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
return std::make_shared<DataTypeUInt8>();
|
||||
}
|
||||
|
||||
bool useDefaultImplementationForConstants() const override { return true; }
|
||||
|
||||
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
|
||||
{
|
||||
const auto * left_generic = arguments[0].type.get();
|
||||
const auto * right_generic = arguments[1].type.get();
|
||||
ColumnPtr result_column;
|
||||
bool valid = castBothTypes(left_generic, right_generic, [&](const auto & left, const auto & right)
|
||||
{
|
||||
using LeftDataType = std::decay_t<decltype(left)>;
|
||||
using RightDataType = std::decay_t<decltype(right)>;
|
||||
using T0 = typename LeftDataType::FieldType;
|
||||
using T1 = typename RightDataType::FieldType;
|
||||
using ColVecT0 = ColumnVector<T0>;
|
||||
using ColVecT1 = ColumnVector<T1>;
|
||||
using ColVecResult = ColumnVector<ResultType>;
|
||||
|
||||
using OpImpl = BitHammingDistanceImpl<T0, T1>;
|
||||
|
||||
const auto * const col_left_raw = arguments[0].column.get();
|
||||
const auto * const col_right_raw = arguments[1].column.get();
|
||||
|
||||
typename ColVecResult::MutablePtr col_res = nullptr;
|
||||
col_res = ColVecResult::create();
|
||||
|
||||
auto & vec_res = col_res->getData();
|
||||
vec_res.resize(input_rows_count);
|
||||
|
||||
if (auto col_left_const = checkAndGetColumnConst<ColVecT0>(col_left_raw))
|
||||
{
|
||||
if (auto col_right = checkAndGetColumn<ColVecT1>(col_right_raw))
|
||||
{
|
||||
// constant integer - non-constant integer
|
||||
OpImpl::constantVector(col_left_const->template getValue<T0>(), col_right->getData(), vec_res);
|
||||
}
|
||||
else
|
||||
return false;
|
||||
}
|
||||
else if (auto col_left = checkAndGetColumn<ColVecT0>(col_left_raw))
|
||||
{
|
||||
if (auto col_right = checkAndGetColumn<ColVecT1>(col_right_raw))
|
||||
// non-constant integer - non-constant integer
|
||||
OpImpl::vectorVector(col_left->getData(), col_right->getData(), vec_res);
|
||||
else if (auto col_right_const = checkAndGetColumnConst<ColVecT1>(col_right_raw))
|
||||
// non-constant integer - constant integer
|
||||
OpImpl::vectorConstant(col_left->getData(), col_right_const->template getValue<T1>(), vec_res);
|
||||
else
|
||||
return false;
|
||||
}
|
||||
else
|
||||
return false;
|
||||
|
||||
result_column = std::move(col_res);
|
||||
return true;
|
||||
});
|
||||
if (!valid)
|
||||
throw Exception(getName() + "'s arguments do not match the expected data types", ErrorCodes::ILLEGAL_COLUMN);
|
||||
|
||||
return result_column;
|
||||
}
|
||||
};
|
||||
using FunctionBitHammingDistance = BinaryArithmeticOverloadResolver<BitHammingDistanceImpl, NameBitHammingDistance>;
|
||||
|
||||
void registerFunctionBitHammingDistance(FunctionFactory & factory)
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user