Add function for get an ipv4 range using an ipv4 and a cidr mask

This commit is contained in:
Guillaume Tassery 2019-04-23 14:02:23 +07:00
parent 153c9d6455
commit 8dee4fe5d1

View File

@ -1474,7 +1474,7 @@ public:
else
{
const size_t shifts_bits = new_byte_offset - bits_to_keep > 8 ? 8 : new_byte_offset - bits_to_keep;
UInt8 byte_reference = lower_range ? 0 : 0b11111111;
UInt8 byte_reference = lower_range ? 0 : std::numeric_limits<UInt8>::max();
dst[offset] = ((src[offset] >> shifts_bits << shifts_bits))
| (byte_reference >> (8 - shifts_bits));
@ -1485,7 +1485,7 @@ public:
static constexpr auto name = "IPv6CIDRtoIPv6Range";
static FunctionPtr create(const Context &) { return std::make_shared<FunctionIPv4CIDRtoIPv4Range>(); }
static FunctionPtr create(const Context &) { return std::make_shared<FunctionIPv6CIDRtoIPv6Range>(); }
String getName() const override { return name; }
@ -1582,5 +1582,109 @@ public:
}
};
class FunctionIPv4CIDRtoIPv4Range : public IFunction
{
public:
template <bool lower_range>
static UInt32 setCIDRMask(UInt32 src, UInt8 bits_to_keep)
{
UInt32 byte_reference = lower_range ? 0 : std::numeric_limits<UInt32>::max();
UInt8 shifts_bits = 32 - bits_to_keep;
return (src >> shifts_bits << shifts_bits)
| (byte_reference >> bits_to_keep);
}
static constexpr auto name = "IPv4CIDRtoIPv4Range";
static FunctionPtr create(const Context &) { return std::make_shared<FunctionIPv4CIDRtoIPv4Range>(); }
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 2; }
bool isInjective(const Block &) override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!WhichDataType(arguments[0]).isUInt32())
throw Exception("Illegal type " + arguments[0]->getName() +
" of first argument of function " + getName() +
", expected UInt32",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const DataTypePtr & second_argument = arguments[1];
if (!isNumber(second_argument))
throw Exception{"Illegal type " + second_argument->getName()
+ " of second argument of function " + getName()
+ ", expected numeric type.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
decltype(arguments) return_type
{
std::make_shared<DataTypeUInt32>(),
std::make_shared<DataTypeUInt32>()
};
return std::make_shared<DataTypeTuple>(return_type);
}
bool useDefaultImplementationForConstants() const override { return true; }
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
{
const auto & col_type_name_ip = block.getByPosition(arguments[0]);
const ColumnPtr & column_ip = col_type_name_ip.column;
const auto col_ip_in = checkAndGetColumn<ColumnUInt32>(column_ip.get());
const auto & col_type_name_cidr = block.getByPosition(arguments[1]);
const ColumnPtr & column_cidr = col_type_name_cidr.column;
const auto col_const_cidr_in = checkAndGetColumnConst<ColumnUInt8>(column_cidr.get());
const auto col_cidr_in = checkAndGetColumn<ColumnUInt8>(column_cidr.get());
if (col_ip_in && (col_const_cidr_in || col_cidr_in))
{
const auto size = col_ip_in->size();
const auto & vec_in = col_ip_in->getData();
Columns tuple_columns(IP_RANGE_TUPLE_SIZE);
auto col_res_lower_range = ColumnUInt32::create();
auto col_res_upper_range = ColumnUInt32::create();
auto & vec_res_lower_range = col_res_lower_range->getData();
vec_res_lower_range.resize(size);
auto & vec_res_upper_range = col_res_upper_range->getData();
vec_res_upper_range.resize(size);
for (size_t i = 0; i < vec_in.size(); ++i)
{
UInt8 cidr = col_const_cidr_in
? col_const_cidr_in->getValue<UInt8>()
: col_cidr_in->getData()[i];
vec_res_lower_range[i] = setCIDRMask<true>(vec_in[i], cidr);
vec_res_upper_range[i] = setCIDRMask<false>(vec_in[i], cidr);
}
tuple_columns[0] = std::move(col_res_lower_range);
tuple_columns[1] = std::move(col_res_upper_range);
block.getByPosition(result).column = ColumnTuple::create(tuple_columns);
}
else if (!col_ip_in)
throw Exception("Illegal column " + block.getByPosition(arguments[0]).column->getName()
+ " of argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN);
else
throw Exception("Illegal column " + block.getByPosition(arguments[1]).column->getName()
+ " of argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN);
}
};
}