handled edge cases and addresses review comments

This commit is contained in:
Bhavna Jindal 2023-11-27 09:51:05 -08:00
parent 6d58c99408
commit c0cff7b4f3
7 changed files with 51 additions and 40 deletions

View File

@ -1,10 +1,9 @@
option (ENABLE_SEASONAL "Enable stl-cpp" ${ENABLE_LIBRARIES})
option (ENABLE_SEASONAL "Enable seasonal decompose (stl-cpp) library" ${ENABLE_LIBRARIES})
if (NOT ENABLE_SEASONAL)
message(STATUS "Not using stl-cpp")
message(STATUS "Not using seasonal decompose (stl-cpp)")
return()
endif()
#set (LIBRARY_DIR "${ClickHouse_SOURCE_DIR}/contrib/stl-cpp")
add_library(_stl-cpp INTERFACE)
target_include_directories(_stl-cpp INTERFACE ${ClickHouse_SOURCE_DIR}/contrib/stl-cpp/include)
add_library(ch_contrib::stl-cpp ALIAS _stl-cpp)
add_library(ch_contrib::seasonal-decompose ALIAS _stl-cpp)

View File

@ -61,8 +61,9 @@ Result:
## seriesDecomposeSTL
Decompose time series data based on STL(Seasonal-Trend Decomposition Procedure Based on Loess)
Returns an array of arrays where the first array include seasonal components, the second array - trend,
Returns an array of three arrays where the first array include seasonal components, the second array - trend,
and the third array - residue component.
https://www.wessa.net/download/stl.pdf
**Syntax**

View File

@ -396,8 +396,8 @@ if (TARGET ch_contrib::nuraft)
target_link_libraries (clickhouse_compression PUBLIC ch_contrib::nuraft)
endif()
if (TARGET ch_contrib::stl-cpp)
target_link_libraries(clickhouse_common_io PUBLIC ch_contrib::stl-cpp)
if (TARGET ch_contrib::seasonal-decompose)
target_link_libraries(clickhouse_common_io PUBLIC ch_contrib::seasonal-decompose)
endif ()
dbms_target_link_libraries (

View File

@ -75,8 +75,8 @@ if (TARGET ch_contrib::base64)
list (APPEND PRIVATE_LIBS ch_contrib::base64)
endif()
if (TARGET ch_contrib::stl-cpp)
list (APPEND PRIVATE_LIBS ch_contrib::stl-cpp)
if (TARGET ch_contrib::seasonal-decompose)
list (APPEND PRIVATE_LIBS ch_contrib::seasonal-decompose)
endif()
if (ENABLE_NLP)

View File

@ -1,28 +1,28 @@
#include "config.h"
#if USE_SEASONAL
# ifdef __clang__
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wold-style-cast"
# pragma clang diagnostic ignored "-Wshadow"
# pragma clang diagnostic ignored "-Wimplicit-float-conversion"
# endif
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wshadow"
#pragma clang diagnostic ignored "-Wimplicit-float-conversion"
#endif
# include <stl.hpp>
#include <stl.hpp>
# ifdef __clang__
# pragma clang diagnostic pop
# endif
#ifdef __clang__
#pragma clang diagnostic pop
#endif
# include <cmath>
# include <Columns/ColumnArray.h>
# include <Columns/ColumnConst.h>
# include <Columns/ColumnsNumber.h>
# include <DataTypes/DataTypeArray.h>
# include <DataTypes/DataTypesNumber.h>
# include <Functions/FunctionFactory.h>
# include <Functions/FunctionHelpers.h>
# include <Functions/IFunction.h>
#include <cmath>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
namespace DB
{
@ -46,8 +46,6 @@ public:
size_t getNumberOfArguments() const override { return 2; }
bool isVariadic() const override { return true; }
bool useDefaultImplementationForConstants() const override { return true; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
@ -56,7 +54,7 @@ public:
{
FunctionArgumentDescriptors args{
{"time-series", &isArray<IDataType>, nullptr, "Array"},
{"period", &isNumber<IDataType>, nullptr, "Number"},
{"period", &isNativeNumber<IDataType>, nullptr, "Number"},
};
validateFunctionArgumentTypes(*this, arguments, args);
@ -158,18 +156,28 @@ public:
size_t len = src_vec.size();
if (len < 4)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "At least four data points are needed for function {}", getName());
else if (period > (len / 2))
throw Exception(
ErrorCodes::BAD_ARGUMENTS, "The series should have data of at least two period lengths for function {}", getName());
std::vector<float> src(src_vec.begin(), src_vec.end());
auto res = stl::params().fit(src, static_cast<size_t>(period));
try
{
auto res = stl::params().fit(src, static_cast<size_t>(period));
if (res.seasonal.empty())
return false;
if (res.seasonal.empty())
return false;
seasonal = res.seasonal;
trend = res.trend;
residue = res.remainder;
return true;
seasonal = res.seasonal;
trend = res.trend;
residue = res.remainder;
return true;
}
catch (const std::exception & e)
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, e.what());
}
}
};
REGISTER_FUNCTION(seriesDecomposeSTL)

View File

@ -161,7 +161,7 @@ endif ()
if (ENABLE_OPENSSL)
set(USE_OPENSSL_INTREE 1)
endif ()
if (TARGET ch_contrib::stl-cpp)
if (TARGET ch_contrib::seasonal-decompose)
set(USE_SEASONAL 1)
endif()
if (TARGET ch_contrib::fiu)

View File

@ -1,4 +1,5 @@
-- Tags: no-fasttest
-- Tags: no-fasttest, no-cpu-aarch64
-- Tag no-cpu-aarch64: values generated are slighly different on aarch64
SELECT seriesDecomposeSTL([10.1, 20.45, 40.34, 10.1, 20.45, 40.34, 10.1, 20.45, 40.34, 10.1, 20.45, 40.34, 10.1, 20.45, 40.34, 10.1, 20.45, 40.34, 10.1, 20.45, 40.34, 10.1, 20.45, 40.34], 3);
SELECT seriesDecomposeSTL([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], 0);
@ -8,3 +9,5 @@ SELECT seriesDecomposeSTL(); --{ serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH}
SELECT seriesDecomposeSTL([]); --{ serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH}
SELECT seriesDecomposeSTL([1,2,3], 2); --{ serverError BAD_ARGUMENTS}
SELECT seriesDecomposeSTL([2,2,2,3,3,3]); --{ serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH}
SELECT seriesDecomposeSTL([2,2,2,3,3,3], 9272653446478); --{ serverError BAD_ARGUMENTS}
SELECT seriesDecomposeSTL([2,2,2,3,3,3], 7); --{ serverError BAD_ARGUMENTS}