diff --git a/src/Functions/h3PointDistM.cpp b/src/Functions/h3PointDistM.cpp new file mode 100644 index 00000000000..8aff1d7adec --- /dev/null +++ b/src/Functions/h3PointDistM.cpp @@ -0,0 +1,129 @@ +#include "config_functions.h" + +#if USE_H3 + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + + +namespace DB +{ +namespace ErrorCodes +{ +extern const int ILLEGAL_TYPE_OF_ARGUMENT; +extern const int ILLEGAL_COLUMN; +} + +namespace +{ + +class FunctionH3PointDistM final : public IFunction +{ +public: + static constexpr auto name = "h3PointDistM"; + + static FunctionPtr create(ContextPtr) { return std::make_shared(); } + + std::string getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 4; } + bool useDefaultImplementationForConstants() const override { return true; } + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + for (size_t i = 0; i < getNumberOfArguments(); ++i) + { + const auto * arg = arguments[i].get(); + if (!WhichDataType(arg).isFloat64()) + throw Exception( + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal type {} of argument {} of function {}. Must be Float64", + arg->getName(), i, getName()); + } + return std::make_shared(); + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override + { + const auto * col_lat1 = checkAndGetColumn(arguments[0].column.get()); + if (!col_lat1) + throw Exception( + ErrorCodes::ILLEGAL_COLUMN, + "Illegal type {} of argument {} of function {}. Must be Float64", + arguments[0].type->getName(), + 1, + getName()); + const auto & data_lat1 = col_lat1->getData(); + + const auto * col_lon1 = checkAndGetColumn(arguments[1].column.get()); + if (!col_lon1) + throw Exception( + ErrorCodes::ILLEGAL_COLUMN, + "Illegal type {} of argument {} of function {}. Must be Float64", + arguments[1].type->getName(), + 2, + getName()); + const auto & data_lon1 = col_lon1->getData(); + + const auto * col_lat2 = checkAndGetColumn(arguments[2].column.get()); + if (!col_lat2) + throw Exception( + ErrorCodes::ILLEGAL_COLUMN, + "Illegal type {} of argument {} of function {}. Must be Float64", + arguments[2].type->getName(), + 3, + getName()); + const auto & data_lat2 = col_lat2->getData(); + + const auto * col_lon2 = checkAndGetColumn(arguments[3].column.get()); + if (!col_lon2) + throw Exception( + ErrorCodes::ILLEGAL_COLUMN, + "Illegal type {} of argument {} of function {}. Must be Float64", + arguments[3].type->getName(), + 4, + getName()); + const auto & data_lon2 = col_lon2->getData(); + + auto dst = ColumnVector::create(); + auto & dst_data = dst->getData(); + dst_data.resize(input_rows_count); + + for (size_t row = 0; row < input_rows_count; ++row) + { + const double lat1 = data_lat1[row]; + const double lon1 = data_lon1[row]; + const auto lat2 = data_lat2[row]; + const auto lon2 = data_lon2[row]; + + LatLng point1 = {degsToRads(lat1), degsToRads(lon1)}; + LatLng point2 = {degsToRads(lat2), degsToRads(lon2)}; + + Float64 res = distanceM(&point1, &point2); + dst_data[row] = res; + } + + return dst; + } +}; + +} + +void registerFunctionH3PointDistM(FunctionFactory & factory) +{ + factory.registerFunction(); +} + +} + +#endif diff --git a/src/Functions/registerFunctionsGeo.cpp b/src/Functions/registerFunctionsGeo.cpp index 0501b603c57..54fb79c8b31 100644 --- a/src/Functions/registerFunctionsGeo.cpp +++ b/src/Functions/registerFunctionsGeo.cpp @@ -52,6 +52,7 @@ void registerFunctionH3HexAreaKm2(FunctionFactory &); void registerFunctionH3CellAreaM2(FunctionFactory &); void registerFunctionH3CellAreaRads2(FunctionFactory &); void registerFunctionH3NumHexagons(FunctionFactory &); +void registerFunctionH3PointDistM(FunctionFactory &); #endif @@ -118,6 +119,7 @@ void registerFunctionsGeo(FunctionFactory & factory) registerFunctionH3CellAreaM2(factory); registerFunctionH3CellAreaRads2(factory); registerFunctionH3NumHexagons(factory); + registerFunctionH3PointDistM(factory); #endif #if USE_S2_GEOMETRY