mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-12-17 20:02:05 +00:00
quantize functions
This commit is contained in:
parent
1794d8ed27
commit
a646a3242c
120
src/Functions/FunctionApproximateL2Distance.cpp
Normal file
120
src/Functions/FunctionApproximateL2Distance.cpp
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
#include "FunctionApproximateL2Distance.h"
|
||||||
|
#include <Columns/ColumnConst.h>
|
||||||
|
#include <Columns/ColumnFixedString.h>
|
||||||
|
#include <Columns/ColumnVector.h>
|
||||||
|
#include <DataTypes/DataTypesNumber.h>
|
||||||
|
#include <Functions/FunctionFactory.h>
|
||||||
|
#include <Common/Exception.h>
|
||||||
|
#include <Common/assert_cast.h>
|
||||||
|
#include "Columns/ColumnsNumber.h"
|
||||||
|
#include "Functions/FunctionDequantize.h"
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
namespace ErrorCodes
|
||||||
|
{
|
||||||
|
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||||
|
extern const int BAD_ARGUMENTS;
|
||||||
|
}
|
||||||
|
|
||||||
|
DataTypePtr FunctionApproximateL2Distance::getReturnTypeImpl(const DataTypes & arguments) const
|
||||||
|
{
|
||||||
|
if (!checkDataTypes<DataTypeFixedString>(arguments[0].get()) || !checkDataTypes<DataTypeFixedString>(arguments[1].get()))
|
||||||
|
{
|
||||||
|
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Arguments of function {} must be FixedString", getName());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!WhichDataType(arguments[2]).isUInt8())
|
||||||
|
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Third argument of function {} must be UInt8", getName());
|
||||||
|
|
||||||
|
return std::make_shared<DataTypeFloat64>();
|
||||||
|
}
|
||||||
|
|
||||||
|
ColumnPtr
|
||||||
|
FunctionApproximateL2Distance::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const
|
||||||
|
{
|
||||||
|
const auto * col_left = checkAndGetColumn<ColumnFixedString>(arguments[0].column.get());
|
||||||
|
const auto * col_right = checkAndGetColumn<ColumnFixedString>(arguments[1].column.get());
|
||||||
|
|
||||||
|
if (!col_left || !col_right)
|
||||||
|
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Arguments of function {} must be FixedString", getName());
|
||||||
|
|
||||||
|
if (col_left->getN() != col_right->getN())
|
||||||
|
throw Exception(ErrorCodes::BAD_ARGUMENTS, "FixedStrings in function {} must have the same length", getName());
|
||||||
|
|
||||||
|
const auto & data_left = col_left->getChars();
|
||||||
|
const auto & data_right = col_right->getChars();
|
||||||
|
size_t fixed_string_length = col_left->getN();
|
||||||
|
|
||||||
|
UInt8 bit_width = 0;
|
||||||
|
if (const auto * const_col = checkAndGetColumnConst<ColumnUInt8>(arguments[2].column.get()))
|
||||||
|
bit_width = const_col->getValue<UInt8>();
|
||||||
|
else
|
||||||
|
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Third argument of function {} must be a constant UInt8", getName());
|
||||||
|
|
||||||
|
if (!(bit_width == 16 || bit_width == 8 || bit_width == 4 || bit_width == 2 || bit_width == 1))
|
||||||
|
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Bit width must be one of 16, 8, 4, 2, or 1");
|
||||||
|
|
||||||
|
auto result_column = ColumnFloat64::create();
|
||||||
|
auto & result_data = result_column->getData();
|
||||||
|
result_data.resize(input_rows_count);
|
||||||
|
|
||||||
|
size_t total_bits_per_row = fixed_string_length * 8;
|
||||||
|
|
||||||
|
for (size_t row = 0; row < input_rows_count; ++row)
|
||||||
|
{
|
||||||
|
size_t row_offset = row * fixed_string_length;
|
||||||
|
|
||||||
|
size_t num_elements = total_bits_per_row / bit_width;
|
||||||
|
|
||||||
|
size_t bits_read = 0;
|
||||||
|
|
||||||
|
double sum = 0.0;
|
||||||
|
|
||||||
|
for (size_t idx = 0; idx < num_elements; ++idx)
|
||||||
|
{
|
||||||
|
UInt32 value_left = 0;
|
||||||
|
UInt32 value_right = 0;
|
||||||
|
for (size_t bit = 0; bit < bit_width; ++bit)
|
||||||
|
{
|
||||||
|
size_t bit_pos = bits_read++;
|
||||||
|
size_t byte_pos = bit_pos / 8;
|
||||||
|
size_t bit_in_byte = bit_pos % 8;
|
||||||
|
|
||||||
|
UInt8 byte_left = data_left[row_offset + byte_pos];
|
||||||
|
UInt8 byte_right = data_right[row_offset + byte_pos];
|
||||||
|
|
||||||
|
if (byte_left & (1 << (7 - bit_in_byte)))
|
||||||
|
value_left |= (1 << (bit_width - 1 - bit));
|
||||||
|
|
||||||
|
if (byte_right & (1 << (7 - bit_in_byte)))
|
||||||
|
value_right |= (1 << (bit_width - 1 - bit));
|
||||||
|
}
|
||||||
|
|
||||||
|
Float32 left_value = 0.0f;
|
||||||
|
Float32 right_value = 0.0f;
|
||||||
|
|
||||||
|
if (bit_width == 16)
|
||||||
|
{
|
||||||
|
left_value = uint16ToFloat32(static_cast<UInt16>(value_left));
|
||||||
|
right_value = uint16ToFloat32(static_cast<UInt16>(value_right));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
left_value = dequantizeFromBits(value_left, bit_width);
|
||||||
|
right_value = dequantizeFromBits(value_right, bit_width);
|
||||||
|
}
|
||||||
|
|
||||||
|
double diff = static_cast<double>(left_value) - static_cast<double>(right_value);
|
||||||
|
sum += diff * diff;
|
||||||
|
}
|
||||||
|
|
||||||
|
double distance = std::sqrt(sum);
|
||||||
|
result_data[row] = distance;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result_column;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
30
src/Functions/FunctionApproximateL2Distance.h
Normal file
30
src/Functions/FunctionApproximateL2Distance.h
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <Columns/ColumnFixedString.h>
|
||||||
|
#include <DataTypes/DataTypesNumber.h>
|
||||||
|
#include <Functions/IFunction.h>
|
||||||
|
#include <Interpreters/Context_fwd.h>
|
||||||
|
#include <DataTypes/DataTypeFixedString.h>
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
class FunctionApproximateL2Distance : public IFunction
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
static constexpr auto name = "approximateL2Distance";
|
||||||
|
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionApproximateL2Distance>(); }
|
||||||
|
|
||||||
|
String getName() const override { return name; }
|
||||||
|
|
||||||
|
size_t getNumberOfArguments() const override { return 3; }
|
||||||
|
|
||||||
|
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; }
|
||||||
|
|
||||||
|
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override;
|
||||||
|
|
||||||
|
ColumnPtr
|
||||||
|
executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
171
src/Functions/FunctionDequantize.cpp
Normal file
171
src/Functions/FunctionDequantize.cpp
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
#include "FunctionDequantize.h"
|
||||||
|
#include <Columns/ColumnArray.h>
|
||||||
|
#include <Columns/ColumnFixedString.h>
|
||||||
|
#include <DataTypes/DataTypeArray.h>
|
||||||
|
#include <DataTypes/DataTypesNumber.h>
|
||||||
|
#include <Functions/FunctionFactory.h>
|
||||||
|
#include <Common/Exception.h>
|
||||||
|
#include <Common/assert_cast.h>
|
||||||
|
#include "Columns/ColumnsNumber.h"
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
namespace ErrorCodes
|
||||||
|
{
|
||||||
|
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||||
|
extern const int BAD_ARGUMENTS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Float32 dequantizeFromBits(UInt32 quantized_value, UInt8 bit_width)
|
||||||
|
{
|
||||||
|
UInt32 max_int = (1 << bit_width) - 1;
|
||||||
|
Float32 normalized = static_cast<Float32>(quantized_value) / static_cast<Float32>(max_int);
|
||||||
|
Float32 value = normalized * 2.0f - 1.0f;
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
Float32 uint16ToFloat32(UInt16 h)
|
||||||
|
{
|
||||||
|
UInt32 sign = (h >> 15) & 0x1;
|
||||||
|
UInt32 exponent = (h >> 10) & 0x1F;
|
||||||
|
UInt32 mantissa = h & 0x3FF;
|
||||||
|
|
||||||
|
UInt32 f_sign = sign << 31;
|
||||||
|
UInt32 f_exponent;
|
||||||
|
UInt32 f_mantissa;
|
||||||
|
|
||||||
|
if (exponent == 0)
|
||||||
|
{
|
||||||
|
if (mantissa == 0)
|
||||||
|
{
|
||||||
|
f_exponent = 0;
|
||||||
|
f_mantissa = 0;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
exponent = 1;
|
||||||
|
while ((mantissa & 0x400) == 0)
|
||||||
|
{
|
||||||
|
mantissa <<= 1;
|
||||||
|
exponent -= 1;
|
||||||
|
}
|
||||||
|
mantissa &= 0x3FF;
|
||||||
|
exponent += 127 - 15;
|
||||||
|
f_exponent = exponent << 23;
|
||||||
|
f_mantissa = mantissa << 13;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (exponent == 0x1F)
|
||||||
|
{
|
||||||
|
f_exponent = 0xFF << 23;
|
||||||
|
f_mantissa = mantissa << 13;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
exponent += 127 - 15;
|
||||||
|
f_exponent = exponent << 23;
|
||||||
|
f_mantissa = mantissa << 13;
|
||||||
|
}
|
||||||
|
|
||||||
|
UInt32 f = f_sign | f_exponent | f_mantissa;
|
||||||
|
|
||||||
|
union
|
||||||
|
{
|
||||||
|
UInt32 u;
|
||||||
|
Float32 f;
|
||||||
|
} val = {f};
|
||||||
|
|
||||||
|
return val.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
DataTypePtr FunctionDequantize::getReturnTypeImpl(const DataTypes & arguments) const
|
||||||
|
{
|
||||||
|
if (!checkDataTypes<DataTypeFixedString>(arguments[0].get()))
|
||||||
|
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument of function {} must be FixedString", getName());
|
||||||
|
|
||||||
|
if (!WhichDataType(arguments[1]).isUInt8())
|
||||||
|
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument of function {} must be UInt8", getName());
|
||||||
|
|
||||||
|
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat32>());
|
||||||
|
}
|
||||||
|
|
||||||
|
ColumnPtr FunctionDequantize::executeImpl(
|
||||||
|
const ColumnsWithTypeAndName & arguments, const DataTypePtr & /*result_type*/, size_t input_rows_count) const
|
||||||
|
{
|
||||||
|
const auto & fixed_string_col = arguments[0].column;
|
||||||
|
const auto * col_fixed_string = checkAndGetColumn<ColumnFixedString>(fixed_string_col.get());
|
||||||
|
if (!col_fixed_string)
|
||||||
|
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument of function {} must be FixedString", getName());
|
||||||
|
|
||||||
|
const auto & data = col_fixed_string->getChars();
|
||||||
|
size_t fixed_string_length = col_fixed_string->getN();
|
||||||
|
|
||||||
|
UInt8 bit_width = 0;
|
||||||
|
if (const auto * const_col = checkAndGetColumnConst<ColumnUInt8>(arguments[1].column.get()))
|
||||||
|
bit_width = const_col->getValue<UInt8>();
|
||||||
|
else
|
||||||
|
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument of function {} must be a constant UInt8", getName());
|
||||||
|
|
||||||
|
if (!(bit_width == 16 || bit_width == 8 || bit_width == 4 || bit_width == 2 || bit_width == 1))
|
||||||
|
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Bit width must be one of 16, 8, 4, 2, or 1");
|
||||||
|
|
||||||
|
size_t total_bits_per_row = fixed_string_length * 8;
|
||||||
|
if (total_bits_per_row % bit_width != 0)
|
||||||
|
{
|
||||||
|
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Total bits per row is not divisible by bit width in function {}", getName());
|
||||||
|
}
|
||||||
|
size_t num_elements = total_bits_per_row / bit_width;
|
||||||
|
|
||||||
|
auto col_float = ColumnFloat32::create();
|
||||||
|
auto & float_data = col_float->getData();
|
||||||
|
float_data.reserve(input_rows_count * num_elements);
|
||||||
|
|
||||||
|
auto col_array = ColumnArray::create(std::move(col_float));
|
||||||
|
auto & offsets = col_array->getOffsets();
|
||||||
|
offsets.resize(input_rows_count);
|
||||||
|
|
||||||
|
size_t offset = 0;
|
||||||
|
|
||||||
|
for (size_t row = 0; row < input_rows_count; ++row)
|
||||||
|
{
|
||||||
|
size_t row_offset = row * fixed_string_length;
|
||||||
|
|
||||||
|
size_t bits_read = 0;
|
||||||
|
|
||||||
|
for (size_t idx = 0; idx < num_elements; ++idx)
|
||||||
|
{
|
||||||
|
UInt32 quantized_value = 0;
|
||||||
|
for (size_t bit = 0; bit < bit_width; ++bit)
|
||||||
|
{
|
||||||
|
size_t bit_pos = bits_read++;
|
||||||
|
size_t byte_pos = bit_pos / 8;
|
||||||
|
size_t bit_in_byte = bit_pos % 8;
|
||||||
|
|
||||||
|
UInt8 current_byte = data[row_offset + byte_pos];
|
||||||
|
|
||||||
|
if (current_byte & (1 << (7 - bit_in_byte)))
|
||||||
|
quantized_value |= (1 << (bit_width - 1 - bit));
|
||||||
|
}
|
||||||
|
|
||||||
|
Float32 value = 0.0f;
|
||||||
|
if (bit_width == 16)
|
||||||
|
{
|
||||||
|
value = uint16ToFloat32(static_cast<UInt16>(quantized_value));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
value = dequantizeFromBits(quantized_value, bit_width);
|
||||||
|
}
|
||||||
|
|
||||||
|
float_data.emplace_back(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
offset += num_elements;
|
||||||
|
offsets[row] = offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
return col_array;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
35
src/Functions/FunctionDequantize.h
Normal file
35
src/Functions/FunctionDequantize.h
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <Columns/ColumnArray.h>
|
||||||
|
#include <Columns/ColumnFixedString.h>
|
||||||
|
#include <DataTypes/DataTypeArray.h>
|
||||||
|
#include <DataTypes/DataTypesNumber.h>
|
||||||
|
#include <Functions/IFunction.h>
|
||||||
|
#include <Interpreters/Context_fwd.h>
|
||||||
|
#include <DataTypes/DataTypeFixedString.h>
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
class FunctionDequantize : public IFunction
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
static constexpr auto name = "dequantize";
|
||||||
|
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionDequantize>(); }
|
||||||
|
|
||||||
|
String getName() const override { return name; }
|
||||||
|
|
||||||
|
size_t getNumberOfArguments() const override { return 2; }
|
||||||
|
|
||||||
|
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; }
|
||||||
|
|
||||||
|
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override;
|
||||||
|
|
||||||
|
ColumnPtr
|
||||||
|
executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
Float32 dequantizeFromBits(UInt32 quantized_value, UInt8 bit_width);
|
||||||
|
Float32 uint16ToFloat32(UInt16 h);
|
||||||
|
|
||||||
|
}
|
227
src/Functions/FunctionQuantize.cpp
Normal file
227
src/Functions/FunctionQuantize.cpp
Normal file
@ -0,0 +1,227 @@
|
|||||||
|
#include "FunctionQuantize.h"
|
||||||
|
#include <cmath>
|
||||||
|
#include <Columns/ColumnArray.h>
|
||||||
|
#include <Columns/ColumnFixedString.h>
|
||||||
|
#include <Columns/ColumnVector.h>
|
||||||
|
#include <DataTypes/DataTypeArray.h>
|
||||||
|
#include <Functions/FunctionFactory.h>
|
||||||
|
#include <Common/Arena.h>
|
||||||
|
#include <Common/Exception.h>
|
||||||
|
#include <Common/assert_cast.h>
|
||||||
|
#include "Columns/ColumnsNumber.h"
|
||||||
|
#include "Functions/FunctionApproximateL2Distance.h"
|
||||||
|
#include "Functions/FunctionDequantize.h"
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
REGISTER_FUNCTION(Quantize)
|
||||||
|
{
|
||||||
|
factory.registerFunction<FunctionQuantize>({}, FunctionFactory::Case::Insensitive);
|
||||||
|
factory.registerFunction<FunctionDequantize>({}, FunctionFactory::Case::Sensitive);
|
||||||
|
factory.registerFunction<FunctionApproximateL2Distance>({}, FunctionFactory::Case::Insensitive);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace ErrorCodes
|
||||||
|
{
|
||||||
|
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||||
|
extern const int BAD_ARGUMENTS;
|
||||||
|
}
|
||||||
|
|
||||||
|
UInt32 quantizeToBits(Float64 value, UInt8 bit_width)
|
||||||
|
{
|
||||||
|
constexpr Float64 min_value = -1.0;
|
||||||
|
constexpr Float64 max_value = 1.0;
|
||||||
|
|
||||||
|
value = std::max(value, min_value);
|
||||||
|
value = std::min(value, max_value);
|
||||||
|
|
||||||
|
Float64 normalized = (value - min_value) / (max_value - min_value);
|
||||||
|
|
||||||
|
UInt32 max_int = (1U << bit_width) - 1;
|
||||||
|
UInt32 quantized = static_cast<UInt32>(std::round(normalized * max_int));
|
||||||
|
|
||||||
|
return quantized;
|
||||||
|
}
|
||||||
|
|
||||||
|
UInt16 float32ToUInt16(Float32 value)
|
||||||
|
{
|
||||||
|
union
|
||||||
|
{
|
||||||
|
Float32 f;
|
||||||
|
UInt32 u;
|
||||||
|
} val;
|
||||||
|
val.f = value;
|
||||||
|
|
||||||
|
UInt32 f = val.u;
|
||||||
|
UInt16 h = 0;
|
||||||
|
|
||||||
|
UInt32 sign = (f >> 31) & 0x1;
|
||||||
|
UInt32 exponent = (f >> 23) & 0xFF;
|
||||||
|
UInt32 mantissa = f & 0x7FFFFF;
|
||||||
|
|
||||||
|
if (exponent == 0xFF)
|
||||||
|
{
|
||||||
|
h = (sign << 15) | (0x1F << 10);
|
||||||
|
if (mantissa)
|
||||||
|
h |= (mantissa & 0x3FF);
|
||||||
|
}
|
||||||
|
else if (exponent > 0x70)
|
||||||
|
{
|
||||||
|
exponent -= 0x70;
|
||||||
|
h = (sign << 15) | (exponent << 10) | (mantissa >> 13);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
h = (sign << 15);
|
||||||
|
}
|
||||||
|
|
||||||
|
return h;
|
||||||
|
}
|
||||||
|
|
||||||
|
DataTypePtr FunctionQuantize::getReturnTypeImpl(const DataTypes & arguments) const
|
||||||
|
{
|
||||||
|
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(arguments[0].get());
|
||||||
|
if (!array_type)
|
||||||
|
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument must be an array in function {}", getName());
|
||||||
|
|
||||||
|
const IDataType * nested_type = array_type->getNestedType().get();
|
||||||
|
if (!WhichDataType(nested_type).isFloat())
|
||||||
|
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Array elements must be Float32 or Float64 in function {}", getName());
|
||||||
|
|
||||||
|
if (!WhichDataType(arguments[1]).isUInt8())
|
||||||
|
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument of function {} must be UInt8", getName());
|
||||||
|
|
||||||
|
return std::make_shared<DataTypeFixedString>(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
ColumnPtr FunctionQuantize::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const
|
||||||
|
{
|
||||||
|
const auto & array_col_with_type = arguments[0];
|
||||||
|
const auto & array_col = array_col_with_type.column;
|
||||||
|
|
||||||
|
const auto * col_array = checkAndGetColumn<ColumnArray>(array_col.get());
|
||||||
|
if (!col_array)
|
||||||
|
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument of function {} must be an array", getName());
|
||||||
|
|
||||||
|
const auto & data_col = col_array->getData();
|
||||||
|
const auto & offsets = col_array->getOffsets();
|
||||||
|
|
||||||
|
const PaddedPODArray<Float32> * float32_data = nullptr;
|
||||||
|
const PaddedPODArray<Float64> * float64_data = nullptr;
|
||||||
|
bool is_float32 = false;
|
||||||
|
|
||||||
|
if (const auto * col_float32 = checkAndGetColumn<ColumnFloat32>(&data_col))
|
||||||
|
{
|
||||||
|
float32_data = &col_float32->getData();
|
||||||
|
is_float32 = true;
|
||||||
|
}
|
||||||
|
else if (const auto * col_float64 = checkAndGetColumn<ColumnFloat64>(&data_col))
|
||||||
|
{
|
||||||
|
float64_data = &col_float64->getData();
|
||||||
|
is_float32 = false;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Array elements in function {} must be Float32 or Float64", getName());
|
||||||
|
}
|
||||||
|
|
||||||
|
UInt8 bit_width = 0;
|
||||||
|
if (const auto * const_col = checkAndGetColumnConst<ColumnUInt8>(arguments[1].column.get()))
|
||||||
|
bit_width = const_col->getValue<UInt8>();
|
||||||
|
else
|
||||||
|
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument of function {} must be a constant UInt8", getName());
|
||||||
|
|
||||||
|
if (!(bit_width == 16 || bit_width == 8 || bit_width == 4 || bit_width == 2 || bit_width == 1))
|
||||||
|
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Bit width must be one of 16, 8, 4, 2, or 1 in function {}", getName());
|
||||||
|
|
||||||
|
size_t num_elements_per_row = 0;
|
||||||
|
if (input_rows_count > 0)
|
||||||
|
{
|
||||||
|
num_elements_per_row = offsets[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t row = 1; row < input_rows_count; ++row)
|
||||||
|
{
|
||||||
|
size_t num_elements = offsets[row] - offsets[row - 1];
|
||||||
|
if (num_elements != num_elements_per_row)
|
||||||
|
{
|
||||||
|
throw Exception(ErrorCodes::BAD_ARGUMENTS, "All input arrays must have the same number of elements in function {}", getName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t total_bits = num_elements_per_row * bit_width;
|
||||||
|
size_t fixed_string_length = (total_bits + 7) / 8;
|
||||||
|
|
||||||
|
auto result_column = ColumnFixedString::create(fixed_string_length);
|
||||||
|
auto & result_chars = result_column->getChars();
|
||||||
|
result_chars.resize(input_rows_count * fixed_string_length);
|
||||||
|
|
||||||
|
size_t prev_offset = 0;
|
||||||
|
|
||||||
|
for (size_t row = 0; row < input_rows_count; ++row)
|
||||||
|
{
|
||||||
|
size_t current_offset = offsets[row];
|
||||||
|
size_t offset_in_result = row * fixed_string_length;
|
||||||
|
|
||||||
|
size_t bits_written = 0;
|
||||||
|
UInt8 current_byte = 0;
|
||||||
|
|
||||||
|
for (size_t idx = prev_offset; idx < current_offset; ++idx)
|
||||||
|
{
|
||||||
|
UInt32 quantized_value = 0;
|
||||||
|
if (bit_width == 16)
|
||||||
|
{
|
||||||
|
Float32 value;
|
||||||
|
if (is_float32)
|
||||||
|
{
|
||||||
|
value = (*float32_data)[idx];
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
value = static_cast<Float32>((*float64_data)[idx]);
|
||||||
|
}
|
||||||
|
quantized_value = float32ToUInt16(value);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
Float64 value;
|
||||||
|
if (is_float32)
|
||||||
|
{
|
||||||
|
value = static_cast<Float64>((*float32_data)[idx]);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
value = (*float64_data)[idx];
|
||||||
|
}
|
||||||
|
quantized_value = quantizeToBits(value, bit_width);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t bit = 0; bit < bit_width; ++bit)
|
||||||
|
{
|
||||||
|
if (quantized_value & (1U << (bit_width - 1 - bit)))
|
||||||
|
current_byte |= (1U << (7 - (bits_written % 8)));
|
||||||
|
|
||||||
|
bits_written++;
|
||||||
|
if (bits_written % 8 == 0)
|
||||||
|
{
|
||||||
|
result_chars[offset_in_result++] = current_byte;
|
||||||
|
current_byte = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bits_written % 8 != 0)
|
||||||
|
{
|
||||||
|
result_chars[offset_in_result++] = current_byte;
|
||||||
|
}
|
||||||
|
|
||||||
|
prev_offset = current_offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result_column;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
36
src/Functions/FunctionQuantize.h
Normal file
36
src/Functions/FunctionQuantize.h
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
// FunctionQuantize.h
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <Columns/ColumnArray.h>
|
||||||
|
#include <Columns/ColumnFixedString.h>
|
||||||
|
#include <Columns/ColumnVector.h>
|
||||||
|
#include <DataTypes/DataTypeFixedString.h>
|
||||||
|
#include <DataTypes/DataTypesNumber.h>
|
||||||
|
#include <Functions/IFunction.h>
|
||||||
|
#include <Common/Logger.h>
|
||||||
|
#include "Interpreters/Context_fwd.h"
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
class FunctionQuantize : public IFunction
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
static constexpr auto name = "quantize";
|
||||||
|
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionQuantize>(); }
|
||||||
|
|
||||||
|
String getName() const override { return name; }
|
||||||
|
|
||||||
|
size_t getNumberOfArguments() const override { return 2; }
|
||||||
|
|
||||||
|
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; }
|
||||||
|
|
||||||
|
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override;
|
||||||
|
|
||||||
|
ColumnPtr
|
||||||
|
executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user