This commit is contained in:
Roman Glinskikh 2024-05-28 11:45:11 +03:00
parent 09606d3760
commit 20daaf11da
3 changed files with 36 additions and 18 deletions

View File

@ -348,11 +348,11 @@ private:
}
}
static ColumnPtr createFixedStringResult(
ColumnPtr createFixedStringResult(
const IColumn * nested_offsets_data,
const ColumnFloat32::MutablePtr & col_res,
size_t vector_count,
size_t normal_count)
size_t normal_count) const
{
auto res = ColumnFixedString::create((normal_count / 8) + (normal_count % 8 != 0));
auto& res_chars = res->getChars();
@ -410,32 +410,45 @@ public:
return std::make_shared<DataTypeFixedString>(1);
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t /*input_rows_count*/) const override
const ColumnFloat32 & getVectorData(const ColumnArray * vectors, const ColumnArray::Offsets ** offsets) const
{
if (vectors->getData().getDataType() == TypeIndex::Array)
{
const auto & vectors_data = typeid_cast<const ColumnArray &>(vectors->getData());
*offsets = &vectors_data.getOffsets();
if (vectors_data.getData().getDataType() != TypeIndex::Float32)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Arguments of function {} must be Array of Float32", getName());
return typeid_cast<const ColumnFloat32 &>(vectors_data.getData());
}
return typeid_cast<const ColumnFloat32 &>(vectors->getData());
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & /*result_type*/, size_t /*input_rows_count*/) const override
{
const ColumnArray * vectors = typeid_cast<const ColumnArray *>(arguments[0].column.get());
if (!vectors)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument of function {} must be Array", getName());
const auto & nested_vectors_data = typeid_cast<const ColumnFloat32 &>(vectors->getData());
const ColumnArray * normals = typeid_cast<const ColumnArray *>(arguments[1].column.get());
if (!normals)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Second argument of function {} must be Array", getName());
if (vectors->getData().getDataType() == TypeIndex::Nothing || normals->getData().getDataType() == TypeIndex::Nothing)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Arguments of function {} must be Array", getName());
const auto * vector_offsets = &vectors->getOffsets();
const auto & nested_vectors_data = getVectorData(vectors, &vector_offsets);
const auto * normal_offsets = &normals->getOffsets();
const auto & nested_normals_data = [&] -> const DB::ColumnVector<Float32> &
{
if (normals->getData().getDataType() == TypeIndex::Array)
{
const auto & normals_data = typeid_cast<const ColumnArray &>(normals->getData());
normal_offsets = &normals_data.getOffsets();
return typeid_cast<const ColumnFloat32 &>(normals_data.getData());
}
return typeid_cast<const ColumnFloat32 &>(normals->getData());
}();
const auto & nested_normals_data = getVectorData(normals, &normal_offsets);
const size_t dimension = vectors->getOffsets().front();
const size_t vector_count = vectors->getOffsets().size();
const size_t normal_count = normal_offsets->size();
if (vector_count == 0 || normal_count == 0)
return ColumnFixedString::create(0);
auto offsets = ColumnConst::create(ColumnFloat32::create(1, 0), normal_count);
const IColumn * nested_offsets_data = offsets.get();
if (arguments.size() >= 3)
@ -446,10 +459,7 @@ public:
nested_offsets_data = offsets_const;
}
if (vector_count == 0 || normal_count == 0)
return result_type->createColumn();
checkDimension(vectors->getOffsets(), dimension);
checkDimension(*vector_offsets, dimension);
checkDimension(*normal_offsets, dimension);
auto col_res = ColumnFloat32::create(vector_count * normal_count);

View File

@ -1,3 +1,6 @@
01
00
03
01
00
1F

View File

@ -1,5 +1,10 @@
select hex((SELECT partitionByHyperplanes([toFloat32(2.0), toFloat32(3.0)], [toFloat32(1.0), toFloat32(1.0)])));
select hex((SELECT partitionByHyperplanes([toFloat32(2.0), toFloat32(3.0)], [toFloat32(1.0), toFloat32(-1.0)])));
SELECT hex((select partitionByHyperplanes([toFloat32(2.0), toFloat32(3.0)], [[toFloat32(1.0), toFloat32(1.0)], [toFloat32(1.0), toFloat32(2.0)]])));
select hex((SELECT partitionByHyperplanes((select range(1024) :: Array(Float32)), (select range(1024) :: Array(Float32)))));
select hex((SELECT partitionByHyperplanes((select range(1024) :: Array(Float32)), (select range(-1024, 0) :: Array(Float32)))));
select hex((SELECT partitionByHyperplanes((select range(1024) :: Array(Float32)), (select arrayWithConstant(5, (select range(1024) :: Array(Float32)))))));
SELECT hex((select partitionByHyperplanes([toFloat32(2.0), toFloat32(3.0)], [[toFloat32(1.0), toFloat32(1.0)], [toFloat32(1.0)]]))); -- { serverError 190 }
SELECT hex((select partitionByHyperplanes([toFloat32(2.0), toFloat32(3.0)], []))); -- { serverError 44 }
SELECT hex((select partitionByHyperplanes([toFloat32(2.0), toFloat32(3.0)], [[[toFloat32(2.0), toFloat32(3.0)]]]))); -- { serverError 44 }