Return non-Nullable results from COUNT(DISTINCT)

This commit is contained in:
Alexey Milovidov 2020-06-11 06:45:12 +03:00
parent cbc08bc393
commit e40ee1a173
7 changed files with 83 additions and 35 deletions

View File

@ -67,6 +67,12 @@ public:
{
data(place).count = new_count;
}
/// The function returns non-Nullable type even when wrapped with Null combinator.
bool returnDefaultWhenOnlyNull() const override
{
return true;
}
};

View File

@ -63,14 +63,15 @@ AggregateFunctionPtr AggregateFunctionFactory::get(
{
auto type_without_low_cardinality = convertLowCardinalityTypesToNested(argument_types);
/// If one of types is Nullable, we apply aggregate function combinator "Null".
/// If one of the types is Nullable, we apply aggregate function combinator "Null".
if (std::any_of(type_without_low_cardinality.begin(), type_without_low_cardinality.end(),
[](const auto & type) { return type->isNullable(); }))
{
AggregateFunctionCombinatorPtr combinator = AggregateFunctionCombinatorFactory::instance().tryFindSuffix("Null");
if (!combinator)
throw Exception("Logical error: cannot find aggregate function combinator to apply a function to Nullable arguments.", ErrorCodes::LOGICAL_ERROR);
throw Exception("Logical error: cannot find aggregate function combinator to apply a function to Nullable arguments.",
ErrorCodes::LOGICAL_ERROR);
DataTypes nested_types = combinator->transformArguments(type_without_low_cardinality);
Array nested_parameters = combinator->transformParameters(parameters);
@ -132,9 +133,10 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl(
auto hints = this->getHints(name);
if (!hints.empty())
throw Exception("Unknown aggregate function " + name + ". Maybe you meant: " + toString(hints), ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
throw Exception(fmt::format("Unknown aggregate function {}. Maybe you meant: {}", name, toString(hints)),
ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
else
throw Exception("Unknown aggregate function " + name, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
throw Exception(fmt::format("Unknown aggregate function {}", name), ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
}

View File

@ -49,35 +49,52 @@ public:
}
if (!has_nullable_types)
throw Exception("Aggregate function combinator 'Null' requires at least one argument to be Nullable", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (nested_function)
if (auto adapter = nested_function->getOwnNullAdapter(nested_function, arguments, params))
return adapter;
/// Special case for 'count' function. It could be called with Nullable arguments
/// - that means - count number of calls, when all arguments are not NULL.
if (nested_function && nested_function->getName() == "count")
return std::make_shared<AggregateFunctionCountNotNullUnary>(arguments[0], params);
throw Exception("Aggregate function combinator 'Null' requires at least one argument to be Nullable",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (has_null_types)
return std::make_shared<AggregateFunctionNothing>(arguments, params);
bool return_type_is_nullable = nested_function->getReturnType()->canBeInsideNullable();
assert(nested_function);
if (auto adapter = nested_function->getOwnNullAdapter(nested_function, arguments, params))
return adapter;
/// Special case for 'count' function. It could be called with Nullable arguments
/// - that means - count number of calls, when all arguments are not NULL.
if (nested_function->getName() == "count")
return std::make_shared<AggregateFunctionCountNotNullUnary>(arguments[0], params);
bool return_type_is_nullable = !nested_function->returnDefaultWhenOnlyNull() && nested_function->getReturnType()->canBeInsideNullable();
bool serialize_flag = return_type_is_nullable || nested_function->returnDefaultWhenOnlyNull();
if (arguments.size() == 1)
{
if (return_type_is_nullable)
return std::make_shared<AggregateFunctionNullUnary<true>>(nested_function, arguments, params);
{
return std::make_shared<AggregateFunctionNullUnary<true, true>>(nested_function, arguments, params);
}
else
return std::make_shared<AggregateFunctionNullUnary<false>>(nested_function, arguments, params);
{
if (serialize_flag)
return std::make_shared<AggregateFunctionNullUnary<false, true>>(nested_function, arguments, params);
else
return std::make_shared<AggregateFunctionNullUnary<false, false>>(nested_function, arguments, params);
}
}
else
{
if (return_type_is_nullable)
return std::make_shared<AggregateFunctionNullVariadic<true, true>>(nested_function, arguments, params);
{
return std::make_shared<AggregateFunctionNullVariadic<true, true, true>>(nested_function, arguments, params);
}
else
return std::make_shared<AggregateFunctionNullVariadic<false, true>>(nested_function, arguments, params);
{
if (serialize_flag)
return std::make_shared<AggregateFunctionNullVariadic<false, true, true>>(nested_function, arguments, params);
else
return std::make_shared<AggregateFunctionNullVariadic<false, true, false>>(nested_function, arguments, params);
}
}
}
};

View File

@ -28,7 +28,10 @@ namespace ErrorCodes
/// If all rows had NULL, the behaviour is determined by "result_is_nullable" template parameter.
/// true - return NULL; false - return value from empty aggregation state of nested function.
template <bool result_is_nullable, typename Derived>
/// When serialize_flag is set to true, the flag about presense of values is serialized
/// regardless to the "result_is_nullable" even if it's unneeded - for protocol compatibility.
template <bool result_is_nullable, bool serialize_flag, typename Derived>
class AggregateFunctionNullBase : public IAggregateFunctionHelper<Derived>
{
protected:
@ -129,7 +132,7 @@ public:
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
bool flag = getFlag(place);
if constexpr (result_is_nullable)
if constexpr (serialize_flag)
writeBinary(flag, buf);
if (flag)
nested_function->serialize(nestedPlace(place), buf);
@ -138,7 +141,7 @@ public:
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
{
bool flag = 1;
if constexpr (result_is_nullable)
if constexpr (serialize_flag)
readBinary(flag, buf);
if (flag)
{
@ -183,12 +186,15 @@ public:
/** There are two cases: for single argument and variadic.
* Code for single argument is much more efficient.
*/
template <bool result_is_nullable>
class AggregateFunctionNullUnary final : public AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullUnary<result_is_nullable>>
template <bool result_is_nullable, bool serialize_flag>
class AggregateFunctionNullUnary final
: public AggregateFunctionNullBase<result_is_nullable, serialize_flag,
AggregateFunctionNullUnary<result_is_nullable, serialize_flag>>
{
public:
AggregateFunctionNullUnary(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullUnary<result_is_nullable>>(std::move(nested_function_), arguments, params)
: AggregateFunctionNullBase<result_is_nullable, serialize_flag,
AggregateFunctionNullUnary<result_is_nullable, serialize_flag>>(std::move(nested_function_), arguments, params)
{
}
@ -218,12 +224,15 @@ public:
};
template <bool result_is_nullable, bool null_is_skipped>
class AggregateFunctionNullVariadic final : public AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullVariadic<result_is_nullable, null_is_skipped>>
template <bool result_is_nullable, bool serialize_flag, bool null_is_skipped>
class AggregateFunctionNullVariadic final
: public AggregateFunctionNullBase<result_is_nullable, serialize_flag,
AggregateFunctionNullVariadic<result_is_nullable, serialize_flag, null_is_skipped>>
{
public:
AggregateFunctionNullVariadic(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullVariadic<result_is_nullable, null_is_skipped>>(std::move(nested_function_), arguments, params),
: AggregateFunctionNullBase<result_is_nullable, serialize_flag,
AggregateFunctionNullVariadic<result_is_nullable, serialize_flag, null_is_skipped>>(std::move(nested_function_), arguments, params),
number_of_arguments(arguments.size())
{
if (number_of_arguments == 1)
@ -263,11 +272,6 @@ public:
this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena);
}
bool allocatesMemoryInArena() const override
{
return this->nested_function->allocatesMemoryInArena();
}
private:
enum { MAX_ARGS = 8 };
size_t number_of_arguments = 0;

View File

@ -244,6 +244,12 @@ public:
{
assert_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size());
}
/// The function returns non-Nullable type even when wrapped with Null combinator.
bool returnDefaultWhenOnlyNull() const override
{
return true;
}
};
@ -298,6 +304,12 @@ public:
{
assert_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size());
}
/// The function returns non-Nullable type even when wrapped with Null combinator.
bool returnDefaultWhenOnlyNull() const override
{
return true;
}
};
}

View File

@ -240,9 +240,10 @@ public:
return std::make_shared<DataTypeUInt8>();
}
AggregateFunctionPtr getOwnNullAdapter(const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override
AggregateFunctionPtr getOwnNullAdapter(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override
{
return std::make_shared<AggregateFunctionNullVariadic<false, false>>(nested_function, arguments, params);
return std::make_shared<AggregateFunctionNullVariadic<false, false, false>>(nested_function, arguments, params);
}
void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override

View File

@ -171,6 +171,12 @@ public:
return nullptr;
}
/** When the function is wrapped with Null combinator,
* should we return Nullable type with NULL when no values were aggregated
* or we should return non-Nullable type with default value (example: count, countDistinct).
*/
virtual bool returnDefaultWhenOnlyNull() const { return false; }
const DataTypes & getArgumentTypes() const { return argument_types; }
const Array & getParameters() const { return parameters; }