From 20daaf11dae3ed69fab13a5995799a6edfb7c4b9 Mon Sep 17 00:00:00 2001 From: Roman Glinskikh Date: Tue, 28 May 2024 11:45:11 +0300 Subject: [PATCH] add test --- src/Functions/PartitionByHyperplanes.cpp | 46 +++++++++++-------- .../03152_partition_by_hyperplanes.reference | 3 ++ .../03152_partition_by_hyperplanes.sql | 5 ++ 3 files changed, 36 insertions(+), 18 deletions(-) diff --git a/src/Functions/PartitionByHyperplanes.cpp b/src/Functions/PartitionByHyperplanes.cpp index 0dfd292249a..03146cef47c 100644 --- a/src/Functions/PartitionByHyperplanes.cpp +++ b/src/Functions/PartitionByHyperplanes.cpp @@ -348,11 +348,11 @@ private: } } - static ColumnPtr createFixedStringResult( + ColumnPtr createFixedStringResult( const IColumn * nested_offsets_data, const ColumnFloat32::MutablePtr & col_res, size_t vector_count, - size_t normal_count) + size_t normal_count) const { auto res = ColumnFixedString::create((normal_count / 8) + (normal_count % 8 != 0)); auto& res_chars = res->getChars(); @@ -410,32 +410,45 @@ public: return std::make_shared(1); } - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t /*input_rows_count*/) const override + const ColumnFloat32 & getVectorData(const ColumnArray * vectors, const ColumnArray::Offsets ** offsets) const + { + if (vectors->getData().getDataType() == TypeIndex::Array) + { + const auto & vectors_data = typeid_cast(vectors->getData()); + *offsets = &vectors_data.getOffsets(); + if (vectors_data.getData().getDataType() != TypeIndex::Float32) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Arguments of function {} must be Array of Float32", getName()); + return typeid_cast(vectors_data.getData()); + } + return typeid_cast(vectors->getData()); + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & /*result_type*/, size_t /*input_rows_count*/) const override { const ColumnArray * vectors = typeid_cast(arguments[0].column.get()); if (!vectors) throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument of function {} must be Array", getName()); - const auto & nested_vectors_data = typeid_cast(vectors->getData()); const ColumnArray * normals = typeid_cast(arguments[1].column.get()); if (!normals) throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Second argument of function {} must be Array", getName()); + + if (vectors->getData().getDataType() == TypeIndex::Nothing || normals->getData().getDataType() == TypeIndex::Nothing) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Arguments of function {} must be Array", getName()); + + const auto * vector_offsets = &vectors->getOffsets(); + const auto & nested_vectors_data = getVectorData(vectors, &vector_offsets); + const auto * normal_offsets = &normals->getOffsets(); - const auto & nested_normals_data = [&] -> const DB::ColumnVector & - { - if (normals->getData().getDataType() == TypeIndex::Array) - { - const auto & normals_data = typeid_cast(normals->getData()); - normal_offsets = &normals_data.getOffsets(); - return typeid_cast(normals_data.getData()); - } - return typeid_cast(normals->getData()); - }(); + const auto & nested_normals_data = getVectorData(normals, &normal_offsets); const size_t dimension = vectors->getOffsets().front(); const size_t vector_count = vectors->getOffsets().size(); const size_t normal_count = normal_offsets->size(); + if (vector_count == 0 || normal_count == 0) + return ColumnFixedString::create(0); + auto offsets = ColumnConst::create(ColumnFloat32::create(1, 0), normal_count); const IColumn * nested_offsets_data = offsets.get(); if (arguments.size() >= 3) @@ -446,10 +459,7 @@ public: nested_offsets_data = offsets_const; } - if (vector_count == 0 || normal_count == 0) - return result_type->createColumn(); - - checkDimension(vectors->getOffsets(), dimension); + checkDimension(*vector_offsets, dimension); checkDimension(*normal_offsets, dimension); auto col_res = ColumnFloat32::create(vector_count * normal_count); diff --git a/tests/queries/0_stateless/03152_partition_by_hyperplanes.reference b/tests/queries/0_stateless/03152_partition_by_hyperplanes.reference index 9557e4d7b4a..ac02c4a5b06 100644 --- a/tests/queries/0_stateless/03152_partition_by_hyperplanes.reference +++ b/tests/queries/0_stateless/03152_partition_by_hyperplanes.reference @@ -1,3 +1,6 @@ 01 00 03 +01 +00 +1F diff --git a/tests/queries/0_stateless/03152_partition_by_hyperplanes.sql b/tests/queries/0_stateless/03152_partition_by_hyperplanes.sql index 6d586745ccd..37071a45570 100644 --- a/tests/queries/0_stateless/03152_partition_by_hyperplanes.sql +++ b/tests/queries/0_stateless/03152_partition_by_hyperplanes.sql @@ -1,5 +1,10 @@ select hex((SELECT partitionByHyperplanes([toFloat32(2.0), toFloat32(3.0)], [toFloat32(1.0), toFloat32(1.0)]))); select hex((SELECT partitionByHyperplanes([toFloat32(2.0), toFloat32(3.0)], [toFloat32(1.0), toFloat32(-1.0)]))); SELECT hex((select partitionByHyperplanes([toFloat32(2.0), toFloat32(3.0)], [[toFloat32(1.0), toFloat32(1.0)], [toFloat32(1.0), toFloat32(2.0)]]))); +select hex((SELECT partitionByHyperplanes((select range(1024) :: Array(Float32)), (select range(1024) :: Array(Float32))))); +select hex((SELECT partitionByHyperplanes((select range(1024) :: Array(Float32)), (select range(-1024, 0) :: Array(Float32))))); +select hex((SELECT partitionByHyperplanes((select range(1024) :: Array(Float32)), (select arrayWithConstant(5, (select range(1024) :: Array(Float32))))))); SELECT hex((select partitionByHyperplanes([toFloat32(2.0), toFloat32(3.0)], [[toFloat32(1.0), toFloat32(1.0)], [toFloat32(1.0)]]))); -- { serverError 190 } +SELECT hex((select partitionByHyperplanes([toFloat32(2.0), toFloat32(3.0)], []))); -- { serverError 44 } +SELECT hex((select partitionByHyperplanes([toFloat32(2.0), toFloat32(3.0)], [[[toFloat32(2.0), toFloat32(3.0)]]]))); -- { serverError 44 }