Remove Nummable support from isIPAddressContainedIn, do not use OverloadResolver

This commit is contained in:
vdimir 2021-03-29 12:48:12 +03:00
parent 24aa25d7dc
commit 81ff4f4791
No known key found for this signature in database
GPG Key ID: F57B3E10A21DBB31
3 changed files with 71 additions and 343 deletions

View File

@ -1,10 +1,8 @@
#include <Columns/ColumnConst.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnsNumber.h>
#include <Common/IPv6ToBinary.h>
#include <Common/formatIPv6.h>
#include <Common/IPv6ToBinary.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/IFunctionImpl.h>
@ -19,7 +17,7 @@ namespace DB::ErrorCodes
{
extern const int CANNOT_PARSE_TEXT;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int CANNOT_PARSE_NUMBER;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
namespace
@ -31,8 +29,8 @@ public:
explicit IPAddressVariant(const StringRef & address_str)
{
// IP address parser functions require that the input is
// NULL-terminated so we need to copy it.
/// IP address parser functions require that the input is
/// NULL-terminated so we need to copy it.
const auto address_str_copy = std::string(address_str);
UInt32 v4;
@ -122,49 +120,57 @@ inline bool isAddressInRange(const IPAddressVariant & address, const IPAddressCI
namespace DB
{
template <typename Name>
class ExecutableFunctionIsIPAddressContainedIn : public IExecutableFunctionImpl
class FunctionIsIPAddressContainedIn : public IFunction
{
public:
String getName() const override
{
return Name::name;
}
static constexpr auto name = "isIPAddressContainedIn";
String getName() const override { return name; }
static FunctionPtr create(const Context &) { return std::make_shared<FunctionIsIPAddressContainedIn>(); }
ColumnPtr execute(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type, size_t input_rows_count) const override
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & /* return_type */, size_t input_rows_count) const override
{
const IColumn * col_addr = arguments[0].column.get();
const IColumn * col_cidr = arguments[1].column.get();
if (const auto * col_addr_const = checkAndGetAnyColumnConst(col_addr))
{
// col_addr_const is constant and is either String or Nullable(String).
// We don't care which one it exactly is.
if (const auto * col_cidr_const = checkAndGetAnyColumnConst(col_cidr))
return executeImpl(*col_addr_const, *col_cidr_const, return_type, input_rows_count);
return executeImpl(*col_addr_const, *col_cidr_const, input_rows_count);
else
return executeImpl(*col_addr_const, *col_cidr, return_type, input_rows_count);
return executeImpl(*col_addr_const, *col_cidr, input_rows_count);
}
else
{
if (const auto * col_cidr_const = checkAndGetAnyColumnConst(col_cidr))
return executeImpl(*col_addr, *col_cidr_const, return_type, input_rows_count);
return executeImpl(*col_addr, *col_cidr_const, input_rows_count);
else
return executeImpl(*col_addr, *col_cidr, input_rows_count);
}
}
bool useDefaultImplementationForNulls() const override
virtual DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
// We can't use the default implementation because that would end up
// parsing invalid addresses or prefixes at NULL fields, which would
// throw exceptions instead of returning NULL.
return false;
if (arguments.size() != 2)
throw Exception(
"Number of arguments for function " + getName() + " doesn't match: passed " + toString(arguments.size()) + ", should be 2",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
const DataTypePtr & addr_type = arguments[0];
const DataTypePtr & prefix_type = arguments[1];
if (!isString(addr_type) || !isString(prefix_type))
throw Exception("The arguments of function " + getName() + " must be String",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<DataTypeUInt8>();
}
virtual size_t getNumberOfArguments() const override { return 2; }
bool useDefaultImplementationForNulls() const override { return false; }
private:
// Like checkAndGetColumnConst() but this function doesn't
// care about the type of data column.
/// Like checkAndGetColumnConst() but this function doesn't
/// care about the type of data column.
static const ColumnConst * checkAndGetAnyColumnConst(const IColumn * column)
{
if (!column || !isColumnConst(*column))
@ -173,277 +179,81 @@ namespace DB
return assert_cast<const ColumnConst *>(column);
}
// Both columns are constant.
ColumnPtr executeImpl(const ColumnConst & col_addr_const, const ColumnConst & col_cidr_const, const DataTypePtr & return_type, size_t input_rows_count) const
/// Both columns are constant.
ColumnPtr executeImpl(
const ColumnConst & col_addr_const,
const ColumnConst & col_cidr_const,
size_t input_rows_count) const
{
const auto & col_addr = col_addr_const.getDataColumn();
const auto & col_cidr = col_cidr_const.getDataColumn();
if (col_addr.isNullAt(0) || col_cidr.isNullAt(0))
{
// If either of the arguments are NULL, the result is also NULL.
assert(return_type->isNullable());
return return_type->createColumnConstWithDefaultValue(input_rows_count);
}
else
{
const auto addr = IPAddressVariant(col_addr.getDataAt(0));
const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(0));
const auto addr = IPAddressVariant(col_addr.getDataAt(0));
const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(0));
ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(1);
ColumnUInt8::Container & vec_res = col_res->getData();
ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(1);
ColumnUInt8::Container & vec_res = col_res->getData();
vec_res[0] = isAddressInRange(addr, cidr) ? 1 : 0;
vec_res[0] = isAddressInRange(addr, cidr) ? 1 : 0;
if (return_type->isNullable())
{
ColumnUInt8::MutablePtr col_null_map_res = ColumnUInt8::create(1);
ColumnUInt8::Container & vec_null_map_res = col_null_map_res->getData();
vec_null_map_res[0] = false;
return ColumnConst::create(ColumnNullable::create(std::move(col_res), std::move(col_null_map_res)), input_rows_count);
}
else
{
return ColumnConst::create(std::move(col_res), input_rows_count);
}
}
return ColumnConst::create(std::move(col_res), input_rows_count);
}
// Address is constant.
ColumnPtr executeImpl(const ColumnConst & col_addr_const, const IColumn & col_cidr, const DataTypePtr & return_type, size_t input_rows_count) const
/// Address is constant.
ColumnPtr executeImpl(const ColumnConst & col_addr_const, const IColumn & col_cidr, size_t input_rows_count) const
{
const auto & col_addr = col_addr_const.getDataColumn();
if (col_addr.isNullAt(0))
const auto addr = IPAddressVariant(col_addr.getDataAt (0));
ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(input_rows_count);
ColumnUInt8::Container & vec_res = col_res->getData();
for (size_t i = 0; i < input_rows_count; i++)
{
// It's constant NULL so the result is also constant NULL.
assert(return_type->isNullable());
return return_type->createColumnConstWithDefaultValue(input_rows_count);
}
else
{
const auto addr = IPAddressVariant(col_addr.getDataAt (0));
ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(input_rows_count);
ColumnUInt8::Container & vec_res = col_res->getData();
if (col_addr.isNullable() || col_cidr.isNullable())
{
ColumnUInt8::MutablePtr col_null_map_res = ColumnUInt8::create(input_rows_count);
ColumnUInt8::Container & vec_null_map_res = col_null_map_res->getData();
for (size_t i = 0; i < input_rows_count; i++)
{
if (col_cidr.isNullAt(i))
{
vec_null_map_res[i] = true;
}
else
{
const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(i));
vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0;
vec_null_map_res[i] = false;
}
}
return ColumnNullable::create(std::move(col_res), std::move(col_null_map_res));
}
else
{
for (size_t i = 0; i < input_rows_count; i++)
{
const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(i));
vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0;
}
return col_res;
}
const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(i));
vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0;
}
return col_res;
}
// CIDR is constant.
ColumnPtr executeImpl(const IColumn & col_addr, const ColumnConst & col_cidr_const, const DataTypePtr & return_type, size_t input_rows_count) const
/// CIDR is constant.
ColumnPtr executeImpl(const IColumn & col_addr, const ColumnConst & col_cidr_const, size_t input_rows_count) const
{
const auto & col_cidr = col_cidr_const.getDataColumn();
if (col_cidr.isNullAt(0))
const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(0));
ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(input_rows_count);
ColumnUInt8::Container & vec_res = col_res->getData();
for (size_t i = 0; i < input_rows_count; i++)
{
// It's constant NULL so the result is also constant NULL.
assert(return_type->isNullable());
return return_type->createColumnConstWithDefaultValue(input_rows_count);
}
else
{
const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(0));
ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(input_rows_count);
ColumnUInt8::Container & vec_res = col_res->getData();
if (col_addr.isNullable() || col_cidr.isNullable())
{
ColumnUInt8::MutablePtr col_null_map_res = ColumnUInt8::create(input_rows_count);
ColumnUInt8::Container & vec_null_map_res = col_null_map_res->getData();
for (size_t i = 0; i < input_rows_count; i++)
{
if (col_addr.isNullAt(i))
{
vec_null_map_res[i] = true;
}
else
{
const auto addr = IPAddressVariant(col_addr.getDataAt(i));
vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0;
vec_null_map_res[i] = false;
}
}
return ColumnNullable::create(std::move(col_res), std::move(col_null_map_res));
}
else
{
for (size_t i = 0; i < input_rows_count; i++)
{
const auto addr = IPAddressVariant(col_addr.getDataAt(i));
vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0;
}
return col_res;
}
const auto addr = IPAddressVariant(col_addr.getDataAt(i));
vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0;
}
return col_res;
}
// Neither are constant.
/// Neither are constant.
ColumnPtr executeImpl(const IColumn & col_addr, const IColumn & col_cidr, size_t input_rows_count) const
{
ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(input_rows_count);
ColumnUInt8::Container & vec_res = col_res->getData();
if (col_addr.isNullable() || col_cidr.isNullable())
for (size_t i = 0; i < input_rows_count; i++)
{
ColumnUInt8::MutablePtr col_null_map_res = ColumnUInt8::create(input_rows_count);
ColumnUInt8::Container & vec_null_map_res = col_null_map_res->getData();
const auto addr = IPAddressVariant(col_addr.getDataAt(i));
const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(i));
for (size_t i = 0; i < input_rows_count; i++)
{
if (col_addr.isNullAt(i) || col_cidr.isNullAt(i))
{
vec_null_map_res[i] = true;
}
else
{
const auto addr = IPAddressVariant(col_addr.getDataAt(i));
const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(i));
vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0;
vec_null_map_res[i] = false;
}
}
return ColumnNullable::create(std::move(col_res), std::move(col_null_map_res));
vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0;
}
else
{
for (size_t i = 0; i < input_rows_count; i++)
{
const auto addr = IPAddressVariant(col_addr.getDataAt(i));
const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(i));
vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0;
}
return col_res;
}
return col_res;
}
};
template <typename Name>
class FunctionBaseIsIPAddressContainedIn : public IFunctionBaseImpl
{
public:
explicit FunctionBaseIsIPAddressContainedIn(DataTypes argument_types_, DataTypePtr return_type_)
: argument_types(std::move(argument_types_))
, return_type(std::move(return_type_)) {}
String getName() const override
{
return Name::name;
}
const DataTypes & getArgumentTypes() const override
{
return argument_types;
}
const DataTypePtr & getResultType() const override
{
return return_type;
}
ExecutableFunctionImplPtr prepare(const ColumnsWithTypeAndName &) const override
{
return std::make_unique<ExecutableFunctionIsIPAddressContainedIn<Name>>();
}
private:
DataTypes argument_types;
DataTypePtr return_type;
};
template <typename Name>
class IsIPAddressContainedInOverloadResolver : public IFunctionOverloadResolverImpl
{
public:
static constexpr auto name = Name::name;
static FunctionOverloadResolverImplPtr create(const Context &)
{
return std::make_unique<IsIPAddressContainedInOverloadResolver<Name>>();
}
String getName() const override
{
return Name::name;
}
FunctionBaseImplPtr build(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type) const override
{
const DataTypePtr & addr_type = removeNullable(arguments[0].type);
const DataTypePtr & prefix_type = removeNullable(arguments[1].type);
DataTypes argument_types = { addr_type, prefix_type };
/* The arguments can be any of Nullable(NULL), Nullable(String), and
* String. We can't do this check in getReturnType() because it
* won't be called when there are any constant NULLs in the
* arguments. */
if (!(WhichDataType(addr_type).isNothing() || isString(addr_type)) ||
!(WhichDataType(prefix_type).isNothing() || isString(prefix_type)))
throw Exception("The arguments of function " + getName() + " must be String",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_unique<FunctionBaseIsIPAddressContainedIn<Name>>(argument_types, return_type);
}
DataTypePtr getReturnType(const DataTypes &) const override
{
return std::make_shared<DataTypeUInt8>();
}
size_t getNumberOfArguments() const override
{
return 2;
}
};
struct NameIsIPAddressContainedIn
{
static constexpr auto name = "isIPAddressContainedIn";
};
void registerFunctionIsIPAddressContainedIn(FunctionFactory & factory)
{
factory.registerFunction<IsIPAddressContainedInOverloadResolver<NameIsIPAddressContainedIn>>();
factory.registerFunction<FunctionIsIPAddressContainedIn>();
}
}

View File

@ -33,44 +33,5 @@
0
0
0
# Arguments can be nullable.
## Nullable address
\N
\N
1
## Nullable prefix
\N
\N
1
## Both nullable
\N
\N
1
# Non-constant nullable arguments
## Non-constant address
127.0.0.1 127.0.0.0/8 1
\N 127.0.0.0/8 \N
127.0.0.1 \N \N
\N \N \N
127.0.0.1 \N \N
\N \N \N
127.0.0.1 127.0.0.0/8 1
\N 127.0.0.0/8 \N
127.0.0.1 127.0.0.0/8 1
## Non-constant prefix
127.0.0.1 127.0.0.0/8 1
127.0.0.1 \N \N
\N 127.0.0.0/8 \N
\N \N \N
\N 127.0.0.0/8 \N
\N \N \N
127.0.0.1 127.0.0.0/8 1
127.0.0.1 \N \N
127.0.0.1 127.0.0.0/8 1
## Both non-constant
127.0.0.1 127.0.0.0/8 1
127.0.0.1 \N \N
\N 127.0.0.0/8 \N
\N \N \N
# Unparsable arguments
# Wrong argument types

View File

@ -1,4 +1,3 @@
--
SELECT '# Invocation with constants';
SELECT isIPAddressContainedIn('127.0.0.1', '127.0.0.0/8');
@ -7,25 +6,21 @@ SELECT isIPAddressContainedIn('128.0.0.1', '127.0.0.0/8');
SELECT isIPAddressContainedIn('ffff::1', 'ffff::/16');
SELECT isIPAddressContainedIn('fffe::1', 'ffff::/16');
--
SELECT '# Invocation with non-constant addresses';
WITH arrayJoin(['192.168.99.255', '192.168.100.1', '192.168.103.255', '192.168.104.0']) as addr, '192.168.100.0/22' as prefix SELECT addr, prefix, isIPAddressContainedIn(addr, prefix);
WITH arrayJoin(['::192.168.99.255', '::192.168.100.1', '::192.168.103.255', '::192.168.104.0']) as addr, '::192.168.100.0/118' as prefix SELECT addr, prefix, isIPAddressContainedIn(addr, prefix);
--
SELECT '# Invocation with non-constant prefixes';
WITH '192.168.100.1' as addr, arrayJoin(['192.168.100.0/22', '192.168.100.0/24', '192.168.100.0/32']) as prefix SELECT addr, prefix, isIPAddressContainedIn(addr, prefix);
WITH '::192.168.100.1' as addr, arrayJoin(['::192.168.100.0/118', '::192.168.100.0/120', '::192.168.100.0/128']) as prefix SELECT addr, prefix, isIPAddressContainedIn(addr, prefix);
--
SELECT '# Invocation with non-constants';
WITH arrayJoin(['192.168.100.1', '192.168.103.255']) as addr, arrayJoin(['192.168.100.0/22', '192.168.100.0/24']) as prefix SELECT addr, prefix, isIPAddressContainedIn(addr, prefix);
WITH arrayJoin(['::192.168.100.1', '::192.168.103.255']) as addr, arrayJoin(['::192.168.100.0/118', '::192.168.100.0/120']) as prefix SELECT addr, prefix, isIPAddressContainedIn(addr, prefix);
--
SELECT '# Mismatching IP versions is not an error.';
SELECT isIPAddressContainedIn('127.0.0.1', 'ffff::/16');
@ -33,54 +28,16 @@ SELECT isIPAddressContainedIn('127.0.0.1', '::127.0.0.1/128');
SELECT isIPAddressContainedIn('::1', '127.0.0.0/8');
SELECT isIPAddressContainedIn('::127.0.0.1', '127.0.0.1/32');
--
SELECT '# Arguments can be nullable.';
SELECT '## Nullable address';
SELECT isIPAddressContainedIn(NULL , '127.0.0.0/8');
SELECT isIPAddressContainedIn(CAST(NULL, 'Nullable(String)') , '127.0.0.0/8');
SELECT isIPAddressContainedIn(CAST('127.0.0.1', 'Nullable(String)'), '127.0.0.0/8');
SELECT '## Nullable prefix';
SELECT isIPAddressContainedIn('127.0.0.1', NULL);
SELECT isIPAddressContainedIn('127.0.0.1', CAST(NULL, 'Nullable(String)'));
SELECT isIPAddressContainedIn('127.0.0.1', CAST('127.0.0.0/8', 'Nullable(String)'));
SELECT '## Both nullable';
SELECT isIPAddressContainedIn(NULL , NULL);
SELECT isIPAddressContainedIn(CAST(NULL, 'Nullable(String)') , CAST(NULL, 'Nullable(String)'));
SELECT isIPAddressContainedIn(CAST('127.0.0.1', 'Nullable(String)'), CAST('127.0.0.0/8', 'Nullable(String)'));
--
SELECT '# Non-constant nullable arguments';
SELECT '## Non-constant address';
WITH arrayJoin(['127.0.0.1', NULL]) as addr, '127.0.0.0/8' as prefix SELECT addr, prefix, isIPAddressContainedIn(addr, prefix);
WITH arrayJoin(['127.0.0.1', NULL]) as addr, NULL as prefix SELECT addr, prefix, isIPAddressContainedIn(addr, prefix);
WITH arrayJoin(['127.0.0.1', NULL]) as addr, CAST(NULL, 'Nullable(String)') as prefix SELECT addr, prefix, isIPAddressContainedIn(addr, prefix);
WITH arrayJoin(['127.0.0.1', NULL]) as addr, CAST('127.0.0.0/8', 'Nullable(String)') as prefix SELECT addr, prefix, isIPAddressContainedIn(addr, prefix);
WITH arrayJoin(['127.0.0.1']) as addr, CAST('127.0.0.0/8', 'Nullable(String)') as prefix SELECT addr, prefix, isIPAddressContainedIn(addr, prefix);
SELECT '## Non-constant prefix';
WITH '127.0.0.1' as addr, arrayJoin(['127.0.0.0/8', NULL]) as prefix SELECT addr, prefix, isIPAddressContainedIn(addr, prefix);
WITH NULL as addr, arrayJoin(['127.0.0.0/8', NULL]) as prefix SELECT addr, prefix, isIPAddressContainedIn(addr, prefix);
WITH CAST(NULL, 'Nullable(String)') as addr, arrayJoin(['127.0.0.0/8', NULL]) as prefix SELECT addr, prefix, isIPAddressContainedIn(addr, prefix);
WITH CAST('127.0.0.1', 'Nullable(String)') as addr, arrayJoin(['127.0.0.0/8', NULL]) as prefix SELECT addr, prefix, isIPAddressContainedIn(addr, prefix);
WITH CAST('127.0.0.1', 'Nullable(String)') as addr, arrayJoin(['127.0.0.0/8']) as prefix SELECT addr, prefix, isIPAddressContainedIn(addr, prefix);
SELECT '## Both non-constant';
WITH arrayJoin(['127.0.0.1', NULL]) as addr, arrayJoin(['127.0.0.0/8', NULL]) as prefix SELECT addr, prefix, isIPAddressContainedIn(addr, prefix);
--
SELECT '# Unparsable arguments';
SELECT isIPAddressContainedIn('unparsable', '127.0.0.0/8'); -- { serverError 6 }
SELECT isIPAddressContainedIn('127.0.0.1', 'unparsable'); -- { serverError 6 }
--
SELECT '# Wrong argument types';
SELECT isIPAddressContainedIn(100, '127.0.0.0/8'); -- { serverError 43 }
SELECT isIPAddressContainedIn(NULL, '127.0.0.0/8'); -- { serverError 43 }
SELECT isIPAddressContainedIn(CAST(NULL, 'Nullable(String)'), '127.0.0.0/8'); -- { serverError 43 }
SELECT isIPAddressContainedIn('127.0.0.1', 100); -- { serverError 43 }
SELECT isIPAddressContainedIn(100, NULL); -- { serverError 43 }
WITH arrayJoin([NULL, NULL, NULL, NULL]) AS prefix SELECT isIPAddressContainedIn([NULL, NULL, 0, 255, 0], prefix); -- { serverError 43 }