diff --git a/dbms/src/Functions/FunctionsFindCluster.h b/dbms/src/Functions/FunctionsFindCluster.h index 706842ef1ff..60d55665132 100644 --- a/dbms/src/Functions/FunctionsFindCluster.h +++ b/dbms/src/Functions/FunctionsFindCluster.h @@ -16,12 +16,14 @@ #include #include + namespace DB { namespace ErrorCodes { extern const int BAD_ARGUMENTS; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; } enum ClusterOperation @@ -47,154 +49,147 @@ template class Centroids { public: - Centroids() {} + Centroids() {} - Centroids(const Centroids & c): centroids(c.centroids) {} + Centroids(const Centroids & c): centroids(c.centroids) {} - virtual ~Centroids() {} + virtual ~Centroids() {} - bool fill (const IColumn* centroids_array_untyped) - { - const ColumnArray * centroids_array = typeid_cast(centroids_array_untyped); + bool fill (const IColumn* centroids_array_untyped) + { + const ColumnArray * centroids_array = typeid_cast(centroids_array_untyped); - if (centroids_array) - { - if (centroids_array->empty()) - throw Exception{"Centroids array must be not empty", ErrorCodes::ILLEGAL_COLUMN}; + if (centroids_array) + { + if (centroids_array->empty()) + throw Exception{"Centroids array must be not empty", ErrorCodes::ILLEGAL_COLUMN}; - for (size_t k = 0; k < centroids_array->size(); k++) - { - const Field& tmp_field = (*centroids_array)[k]; - CentroidsType value; - if (!tmp_field.tryGet(value)) - return false; - centroids.push_back(Float64(value)); - } - } - else - { - const ColumnConst * const_centroids_array = typeid_cast *>(centroids_array_untyped); + for (size_t k = 0; k < centroids_array->size(); k++) + { + const Field& tmp_field = (*centroids_array)[k]; + CentroidsType value; + if (!tmp_field.tryGet(value)) + return false; + centroids.push_back(Float64(value)); + } + } + else + { + const ColumnConst * const_centroids_array = typeid_cast *>(centroids_array_untyped); - if (!const_centroids_array) - return false; + if (!const_centroids_array) + return false; - if (const_centroids_array->getData().empty()) - throw Exception{"Centroids array must be not empty", ErrorCodes::ILLEGAL_COLUMN}; + if (const_centroids_array->getData().empty()) + throw Exception{"Centroids array must be not empty", ErrorCodes::ILLEGAL_COLUMN}; - for (size_t k = 0; k < const_centroids_array->getData().size(); ++k) - { - const Field& tmp_field = (const_centroids_array->getData())[k]; - CentroidsType value; - if (!tmp_field.tryGet(value)) - return false; - centroids.push_back(Float64(value)); - } - } - return true; - } + for (size_t k = 0; k < const_centroids_array->getData().size(); ++k) + { + const Field& tmp_field = (const_centroids_array->getData())[k]; + CentroidsType value; + if (!tmp_field.tryGet(value)) + return false; + centroids.push_back(Float64(value)); + } + } + return true; + } - template - bool findCluster( - const IColumn* in_untyped, - IColumn* out_untyped, - ClusterOperation operation) - { - if (operation == ClusterOperation::FindClusterIndex) - return findClusterTyped(in_untyped, out_untyped, operation); - else if (operation == ClusterOperation::FindCentroidValue) - return findClusterTyped(in_untyped, out_untyped, operation); - - throw Exception{"Unexpected error in findCluster* function", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; - } - + template + bool findCluster( + const IColumn* in_untyped, + IColumn* out_untyped, + ClusterOperation operation) + { + if (operation == ClusterOperation::FindClusterIndex) + return findClusterTyped(in_untyped, out_untyped, operation); + else if (operation == ClusterOperation::FindCentroidValue) + return findClusterTyped(in_untyped, out_untyped, operation); + throw Exception{"Unexpected error in findCluster* function", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; + } private: - std::vector centroids; + std::vector centroids; - // Centroids array has the same size as number of clusters. We expect it - // to be small, maybe 10s or 100s in most real life situation, so we - // choose the naive implementation + // Centroids array has the same size as number of clusters. We expect it + // to be small, maybe 10s or 100s in most real life situation, so we + // choose the naive implementation size_t find_centroid(CentroidsType x) { - Float64 y = Float64(x); + Float64 y = Float64(x); - // Centroids array has to have at least one element, and if it has only one element, - // it is also the result of this Function. - Float64 distance = abs(centroids[0]-y); - size_t index = 0; + // Centroids array has to have at least one element, and if it has only one element, + // it is also the result of this Function. + Float64 distance = abs(centroids[0]-y); + size_t index = 0; - // Check if we have more clusters and if we have, whether some is closer to src[i] - for (size_t j = 1; j < centroids.size(); ++j) - { - Float64 next_distance = abs(centroids[j]-y); + // Check if we have more clusters and if we have, whether some is closer to src[i] + for (size_t j = 1; j < centroids.size(); ++j) + { + Float64 next_distance = abs(centroids[j]-y); - if (next_distance < distance) - { - distance = next_distance; - index = j; - } - } - - // Index of the closest cluster, or 0 in case of just one cluster - return index; + if (next_distance < distance) + { + distance = next_distance; + index = j; + } + } + // Index of the closest cluster, or 0 in case of just one cluster + return index; } template - bool findClusterTyped( - const IColumn* in_untyped, - IColumn* out_untyped, - ClusterOperation operation) - { - ColumnVector * out = typeid_cast *>(out_untyped); + bool findClusterTyped( + const IColumn* in_untyped, + IColumn* out_untyped, + ClusterOperation operation) + { + ColumnVector * out = typeid_cast *>(out_untyped); - if (!out) - return false; + if (!out) + return false; - PaddedPODArray & dst = out->getData(); + PaddedPODArray & dst = out->getData(); - const auto in_vector = typeid_cast *>(in_untyped); - if (in_vector) - { - const PaddedPODArray & src = in_vector->getData(); + const auto in_vector = typeid_cast *>(in_untyped); + if (in_vector) + { + const PaddedPODArray & src = in_vector->getData(); - if (operation == ClusterOperation::FindClusterIndex) - for (size_t i = 0; i < src.size(); ++i) - // Note that array indexes start with 1 in Clickhouse - dst.push_back(UInt64(find_centroid(CentroidsType(src[i]))+1)); - else if (operation == ClusterOperation::FindCentroidValue) - for (size_t i = 0; i < src.size(); ++i) - dst.push_back(centroids[find_centroid(CentroidsType(src[i]))]); - else - throw Exception{"Unexpected error in findCluster* function", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; + if (operation == ClusterOperation::FindClusterIndex) + for (size_t i = 0; i < src.size(); ++i) + // Note that array indexes start with 1 in Clickhouse + dst.push_back(UInt64(find_centroid(CentroidsType(src[i]))+1)); + else if (operation == ClusterOperation::FindCentroidValue) + for (size_t i = 0; i < src.size(); ++i) + dst.push_back(centroids[find_centroid(CentroidsType(src[i]))]); + else + throw Exception{"Unexpected error in findCluster* function", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; - return true; + return true; + } + else + { + const auto in_const = typeid_cast *>(in_untyped); - } - else - { - const auto in_const = typeid_cast *>(in_untyped); - - if (!in_const) - return false; - - if (operation == ClusterOperation::FindClusterIndex) - // Note that array indexes start with 1 in Clickhouse - dst.push_back(UInt64(find_centroid(CentroidsType(in_const->getData()))+1)); - else if (operation == ClusterOperation::FindCentroidValue) - dst.push_back(centroids[find_centroid(CentroidsType(in_const->getData()))]); - else - throw Exception{"Unexpected error in findCluster* function", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; - - } - - return true; - } + if (!in_const) + return false; + if (operation == ClusterOperation::FindClusterIndex) + // Note that array indexes start with 1 in Clickhouse + dst.push_back(UInt64(find_centroid(CentroidsType(in_const->getData()))+1)); + else if (operation == ClusterOperation::FindCentroidValue) + dst.push_back(centroids[find_centroid(CentroidsType(in_const->getData()))]); + else + throw Exception{"Unexpected error in findCluster* function", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; + } + return true; + } }; @@ -211,7 +206,7 @@ private: class FunctionFindClusterIndex : public IFunction { public: - static constexpr auto name = "findClusterIndex"; + static constexpr auto name = "findClusterIndex"; static FunctionPtr create(const Context &) { return std::make_shared(); } String getName() const override @@ -245,7 +240,7 @@ public: throw Exception{"Second argument of function " + getName() + ", must be array.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; - return std::make_shared(); + return std::make_shared(); } @@ -257,16 +252,16 @@ public: auto out_untyped = column_result.get(); if ( !executeByCentroidsType(in_untyped, out_untyped, centroids_array_untyped) - && !executeByCentroidsType(in_untyped, out_untyped, centroids_array_untyped) - && !executeByCentroidsType(in_untyped, out_untyped, centroids_array_untyped) - && !executeByCentroidsType(in_untyped, out_untyped, centroids_array_untyped) - && !executeByCentroidsType(in_untyped, out_untyped, centroids_array_untyped) - && !executeByCentroidsType(in_untyped, out_untyped, centroids_array_untyped) - && !executeByCentroidsType(in_untyped, out_untyped, centroids_array_untyped) - && !executeByCentroidsType(in_untyped, out_untyped, centroids_array_untyped) - && !executeByCentroidsType(in_untyped, out_untyped, centroids_array_untyped) - && !executeByCentroidsType(in_untyped, out_untyped, centroids_array_untyped) - ) + && !executeByCentroidsType(in_untyped, out_untyped, centroids_array_untyped) + && !executeByCentroidsType(in_untyped, out_untyped, centroids_array_untyped) + && !executeByCentroidsType(in_untyped, out_untyped, centroids_array_untyped) + && !executeByCentroidsType(in_untyped, out_untyped, centroids_array_untyped) + && !executeByCentroidsType(in_untyped, out_untyped, centroids_array_untyped) + && !executeByCentroidsType(in_untyped, out_untyped, centroids_array_untyped) + && !executeByCentroidsType(in_untyped, out_untyped, centroids_array_untyped) + && !executeByCentroidsType(in_untyped, out_untyped, centroids_array_untyped) + && !executeByCentroidsType(in_untyped, out_untyped, centroids_array_untyped) + ) { throw Exception{ "Function " + getName() + " expects centroids_array of a numeric type", @@ -278,32 +273,32 @@ public: protected: virtual ClusterOperation getOperation() - { - return ClusterOperation::FindClusterIndex; - } + { + return ClusterOperation::FindClusterIndex; + } template bool executeByCentroidsType( - const IColumn* in_untyped, - IColumn* out_untyped, - const IColumn* centroids_array_untyped) + const IColumn* in_untyped, + IColumn* out_untyped, + const IColumn* centroids_array_untyped) { - Centroids::Type> centroids; + Centroids::Type> centroids; - if (!centroids.fill(centroids_array_untyped)) - return false; + if (!centroids.fill(centroids_array_untyped)) + return false; - if ( !centroids.template findCluster(in_untyped, out_untyped, getOperation()) - && !centroids.template findCluster(in_untyped, out_untyped, getOperation()) - && !centroids.template findCluster(in_untyped, out_untyped, getOperation()) - && !centroids.template findCluster(in_untyped, out_untyped, getOperation()) - && !centroids.template findCluster(in_untyped, out_untyped, getOperation()) - && !centroids.template findCluster(in_untyped, out_untyped, getOperation()) - && !centroids.template findCluster(in_untyped, out_untyped, getOperation()) - && !centroids.template findCluster(in_untyped, out_untyped, getOperation()) - && !centroids.template findCluster(in_untyped, out_untyped, getOperation()) - && !centroids.template findCluster(in_untyped, out_untyped, getOperation()) - ) + if ( !centroids.template findCluster(in_untyped, out_untyped, getOperation()) + && !centroids.template findCluster(in_untyped, out_untyped, getOperation()) + && !centroids.template findCluster(in_untyped, out_untyped, getOperation()) + && !centroids.template findCluster(in_untyped, out_untyped, getOperation()) + && !centroids.template findCluster(in_untyped, out_untyped, getOperation()) + && !centroids.template findCluster(in_untyped, out_untyped, getOperation()) + && !centroids.template findCluster(in_untyped, out_untyped, getOperation()) + && !centroids.template findCluster(in_untyped, out_untyped, getOperation()) + && !centroids.template findCluster(in_untyped, out_untyped, getOperation()) + && !centroids.template findCluster(in_untyped, out_untyped, getOperation()) + ) { throw Exception{ "Illegal column " + in_untyped->getName() + " of first argument of function " + getName(), @@ -315,18 +310,19 @@ protected: }; + class FunctionFindClusterValue : public FunctionFindClusterIndex { public: - static constexpr auto name = "findClusterValue"; + static constexpr auto name = "findClusterValue"; static FunctionPtr create(const Context &) { return std::make_shared(); } DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override - { - FunctionFindClusterIndex::getReturnTypeImpl(arguments); - const DataTypeArray * type_arr_from = typeid_cast(arguments[1].get()); - return type_arr_from->getNestedType(); - } + { + FunctionFindClusterIndex::getReturnTypeImpl(arguments); + const DataTypeArray * type_arr_from = typeid_cast(arguments[1].get()); + return type_arr_from->getNestedType(); + } String getName() const override { @@ -335,9 +331,9 @@ public: protected: ClusterOperation getOperation() override - { - return ClusterOperation::FindCentroidValue; - } + { + return ClusterOperation::FindCentroidValue; + } }; }