quantize functions

This commit is contained in:
nikita4109 2024-11-08 13:19:45 +01:00
parent 1794d8ed27
commit a646a3242c
6 changed files with 619 additions and 0 deletions

View File

@ -0,0 +1,120 @@
#include "FunctionApproximateL2Distance.h"
#include <Columns/ColumnConst.h>
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnVector.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <Common/Exception.h>
#include <Common/assert_cast.h>
#include "Columns/ColumnsNumber.h"
#include "Functions/FunctionDequantize.h"
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int BAD_ARGUMENTS;
}
DataTypePtr FunctionApproximateL2Distance::getReturnTypeImpl(const DataTypes & arguments) const
{
if (!checkDataTypes<DataTypeFixedString>(arguments[0].get()) || !checkDataTypes<DataTypeFixedString>(arguments[1].get()))
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Arguments of function {} must be FixedString", getName());
}
if (!WhichDataType(arguments[2]).isUInt8())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Third argument of function {} must be UInt8", getName());
return std::make_shared<DataTypeFloat64>();
}
ColumnPtr
FunctionApproximateL2Distance::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const
{
const auto * col_left = checkAndGetColumn<ColumnFixedString>(arguments[0].column.get());
const auto * col_right = checkAndGetColumn<ColumnFixedString>(arguments[1].column.get());
if (!col_left || !col_right)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Arguments of function {} must be FixedString", getName());
if (col_left->getN() != col_right->getN())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "FixedStrings in function {} must have the same length", getName());
const auto & data_left = col_left->getChars();
const auto & data_right = col_right->getChars();
size_t fixed_string_length = col_left->getN();
UInt8 bit_width = 0;
if (const auto * const_col = checkAndGetColumnConst<ColumnUInt8>(arguments[2].column.get()))
bit_width = const_col->getValue<UInt8>();
else
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Third argument of function {} must be a constant UInt8", getName());
if (!(bit_width == 16 || bit_width == 8 || bit_width == 4 || bit_width == 2 || bit_width == 1))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Bit width must be one of 16, 8, 4, 2, or 1");
auto result_column = ColumnFloat64::create();
auto & result_data = result_column->getData();
result_data.resize(input_rows_count);
size_t total_bits_per_row = fixed_string_length * 8;
for (size_t row = 0; row < input_rows_count; ++row)
{
size_t row_offset = row * fixed_string_length;
size_t num_elements = total_bits_per_row / bit_width;
size_t bits_read = 0;
double sum = 0.0;
for (size_t idx = 0; idx < num_elements; ++idx)
{
UInt32 value_left = 0;
UInt32 value_right = 0;
for (size_t bit = 0; bit < bit_width; ++bit)
{
size_t bit_pos = bits_read++;
size_t byte_pos = bit_pos / 8;
size_t bit_in_byte = bit_pos % 8;
UInt8 byte_left = data_left[row_offset + byte_pos];
UInt8 byte_right = data_right[row_offset + byte_pos];
if (byte_left & (1 << (7 - bit_in_byte)))
value_left |= (1 << (bit_width - 1 - bit));
if (byte_right & (1 << (7 - bit_in_byte)))
value_right |= (1 << (bit_width - 1 - bit));
}
Float32 left_value = 0.0f;
Float32 right_value = 0.0f;
if (bit_width == 16)
{
left_value = uint16ToFloat32(static_cast<UInt16>(value_left));
right_value = uint16ToFloat32(static_cast<UInt16>(value_right));
}
else
{
left_value = dequantizeFromBits(value_left, bit_width);
right_value = dequantizeFromBits(value_right, bit_width);
}
double diff = static_cast<double>(left_value) - static_cast<double>(right_value);
sum += diff * diff;
}
double distance = std::sqrt(sum);
result_data[row] = distance;
}
return result_column;
}
}

View File

@ -0,0 +1,30 @@
#pragma once
#include <Columns/ColumnFixedString.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/IFunction.h>
#include <Interpreters/Context_fwd.h>
#include <DataTypes/DataTypeFixedString.h>
namespace DB
{
class FunctionApproximateL2Distance : public IFunction
{
public:
static constexpr auto name = "approximateL2Distance";
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionApproximateL2Distance>(); }
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 3; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override;
ColumnPtr
executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override;
};
}

View File

@ -0,0 +1,171 @@
#include "FunctionDequantize.h"
#include <Columns/ColumnArray.h>
#include <Columns/ColumnFixedString.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <Common/Exception.h>
#include <Common/assert_cast.h>
#include "Columns/ColumnsNumber.h"
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int BAD_ARGUMENTS;
}
Float32 dequantizeFromBits(UInt32 quantized_value, UInt8 bit_width)
{
UInt32 max_int = (1 << bit_width) - 1;
Float32 normalized = static_cast<Float32>(quantized_value) / static_cast<Float32>(max_int);
Float32 value = normalized * 2.0f - 1.0f;
return value;
}
Float32 uint16ToFloat32(UInt16 h)
{
UInt32 sign = (h >> 15) & 0x1;
UInt32 exponent = (h >> 10) & 0x1F;
UInt32 mantissa = h & 0x3FF;
UInt32 f_sign = sign << 31;
UInt32 f_exponent;
UInt32 f_mantissa;
if (exponent == 0)
{
if (mantissa == 0)
{
f_exponent = 0;
f_mantissa = 0;
}
else
{
exponent = 1;
while ((mantissa & 0x400) == 0)
{
mantissa <<= 1;
exponent -= 1;
}
mantissa &= 0x3FF;
exponent += 127 - 15;
f_exponent = exponent << 23;
f_mantissa = mantissa << 13;
}
}
else if (exponent == 0x1F)
{
f_exponent = 0xFF << 23;
f_mantissa = mantissa << 13;
}
else
{
exponent += 127 - 15;
f_exponent = exponent << 23;
f_mantissa = mantissa << 13;
}
UInt32 f = f_sign | f_exponent | f_mantissa;
union
{
UInt32 u;
Float32 f;
} val = {f};
return val.f;
}
DataTypePtr FunctionDequantize::getReturnTypeImpl(const DataTypes & arguments) const
{
if (!checkDataTypes<DataTypeFixedString>(arguments[0].get()))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument of function {} must be FixedString", getName());
if (!WhichDataType(arguments[1]).isUInt8())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument of function {} must be UInt8", getName());
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat32>());
}
ColumnPtr FunctionDequantize::executeImpl(
const ColumnsWithTypeAndName & arguments, const DataTypePtr & /*result_type*/, size_t input_rows_count) const
{
const auto & fixed_string_col = arguments[0].column;
const auto * col_fixed_string = checkAndGetColumn<ColumnFixedString>(fixed_string_col.get());
if (!col_fixed_string)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument of function {} must be FixedString", getName());
const auto & data = col_fixed_string->getChars();
size_t fixed_string_length = col_fixed_string->getN();
UInt8 bit_width = 0;
if (const auto * const_col = checkAndGetColumnConst<ColumnUInt8>(arguments[1].column.get()))
bit_width = const_col->getValue<UInt8>();
else
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument of function {} must be a constant UInt8", getName());
if (!(bit_width == 16 || bit_width == 8 || bit_width == 4 || bit_width == 2 || bit_width == 1))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Bit width must be one of 16, 8, 4, 2, or 1");
size_t total_bits_per_row = fixed_string_length * 8;
if (total_bits_per_row % bit_width != 0)
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Total bits per row is not divisible by bit width in function {}", getName());
}
size_t num_elements = total_bits_per_row / bit_width;
auto col_float = ColumnFloat32::create();
auto & float_data = col_float->getData();
float_data.reserve(input_rows_count * num_elements);
auto col_array = ColumnArray::create(std::move(col_float));
auto & offsets = col_array->getOffsets();
offsets.resize(input_rows_count);
size_t offset = 0;
for (size_t row = 0; row < input_rows_count; ++row)
{
size_t row_offset = row * fixed_string_length;
size_t bits_read = 0;
for (size_t idx = 0; idx < num_elements; ++idx)
{
UInt32 quantized_value = 0;
for (size_t bit = 0; bit < bit_width; ++bit)
{
size_t bit_pos = bits_read++;
size_t byte_pos = bit_pos / 8;
size_t bit_in_byte = bit_pos % 8;
UInt8 current_byte = data[row_offset + byte_pos];
if (current_byte & (1 << (7 - bit_in_byte)))
quantized_value |= (1 << (bit_width - 1 - bit));
}
Float32 value = 0.0f;
if (bit_width == 16)
{
value = uint16ToFloat32(static_cast<UInt16>(quantized_value));
}
else
{
value = dequantizeFromBits(quantized_value, bit_width);
}
float_data.emplace_back(value);
}
offset += num_elements;
offsets[row] = offset;
}
return col_array;
}
}

View File

@ -0,0 +1,35 @@
#pragma once
#include <Columns/ColumnArray.h>
#include <Columns/ColumnFixedString.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/IFunction.h>
#include <Interpreters/Context_fwd.h>
#include <DataTypes/DataTypeFixedString.h>
namespace DB
{
class FunctionDequantize : public IFunction
{
public:
static constexpr auto name = "dequantize";
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionDequantize>(); }
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 2; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override;
ColumnPtr
executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override;
};
Float32 dequantizeFromBits(UInt32 quantized_value, UInt8 bit_width);
Float32 uint16ToFloat32(UInt16 h);
}

View File

@ -0,0 +1,227 @@
#include "FunctionQuantize.h"
#include <cmath>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnVector.h>
#include <DataTypes/DataTypeArray.h>
#include <Functions/FunctionFactory.h>
#include <Common/Arena.h>
#include <Common/Exception.h>
#include <Common/assert_cast.h>
#include "Columns/ColumnsNumber.h"
#include "Functions/FunctionApproximateL2Distance.h"
#include "Functions/FunctionDequantize.h"
namespace DB
{
REGISTER_FUNCTION(Quantize)
{
factory.registerFunction<FunctionQuantize>({}, FunctionFactory::Case::Insensitive);
factory.registerFunction<FunctionDequantize>({}, FunctionFactory::Case::Sensitive);
factory.registerFunction<FunctionApproximateL2Distance>({}, FunctionFactory::Case::Insensitive);
}
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int BAD_ARGUMENTS;
}
UInt32 quantizeToBits(Float64 value, UInt8 bit_width)
{
constexpr Float64 min_value = -1.0;
constexpr Float64 max_value = 1.0;
value = std::max(value, min_value);
value = std::min(value, max_value);
Float64 normalized = (value - min_value) / (max_value - min_value);
UInt32 max_int = (1U << bit_width) - 1;
UInt32 quantized = static_cast<UInt32>(std::round(normalized * max_int));
return quantized;
}
UInt16 float32ToUInt16(Float32 value)
{
union
{
Float32 f;
UInt32 u;
} val;
val.f = value;
UInt32 f = val.u;
UInt16 h = 0;
UInt32 sign = (f >> 31) & 0x1;
UInt32 exponent = (f >> 23) & 0xFF;
UInt32 mantissa = f & 0x7FFFFF;
if (exponent == 0xFF)
{
h = (sign << 15) | (0x1F << 10);
if (mantissa)
h |= (mantissa & 0x3FF);
}
else if (exponent > 0x70)
{
exponent -= 0x70;
h = (sign << 15) | (exponent << 10) | (mantissa >> 13);
}
else
{
h = (sign << 15);
}
return h;
}
DataTypePtr FunctionQuantize::getReturnTypeImpl(const DataTypes & arguments) const
{
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(arguments[0].get());
if (!array_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument must be an array in function {}", getName());
const IDataType * nested_type = array_type->getNestedType().get();
if (!WhichDataType(nested_type).isFloat())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Array elements must be Float32 or Float64 in function {}", getName());
if (!WhichDataType(arguments[1]).isUInt8())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument of function {} must be UInt8", getName());
return std::make_shared<DataTypeFixedString>(0);
}
ColumnPtr FunctionQuantize::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const
{
const auto & array_col_with_type = arguments[0];
const auto & array_col = array_col_with_type.column;
const auto * col_array = checkAndGetColumn<ColumnArray>(array_col.get());
if (!col_array)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument of function {} must be an array", getName());
const auto & data_col = col_array->getData();
const auto & offsets = col_array->getOffsets();
const PaddedPODArray<Float32> * float32_data = nullptr;
const PaddedPODArray<Float64> * float64_data = nullptr;
bool is_float32 = false;
if (const auto * col_float32 = checkAndGetColumn<ColumnFloat32>(&data_col))
{
float32_data = &col_float32->getData();
is_float32 = true;
}
else if (const auto * col_float64 = checkAndGetColumn<ColumnFloat64>(&data_col))
{
float64_data = &col_float64->getData();
is_float32 = false;
}
else
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Array elements in function {} must be Float32 or Float64", getName());
}
UInt8 bit_width = 0;
if (const auto * const_col = checkAndGetColumnConst<ColumnUInt8>(arguments[1].column.get()))
bit_width = const_col->getValue<UInt8>();
else
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument of function {} must be a constant UInt8", getName());
if (!(bit_width == 16 || bit_width == 8 || bit_width == 4 || bit_width == 2 || bit_width == 1))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Bit width must be one of 16, 8, 4, 2, or 1 in function {}", getName());
size_t num_elements_per_row = 0;
if (input_rows_count > 0)
{
num_elements_per_row = offsets[0];
}
for (size_t row = 1; row < input_rows_count; ++row)
{
size_t num_elements = offsets[row] - offsets[row - 1];
if (num_elements != num_elements_per_row)
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "All input arrays must have the same number of elements in function {}", getName());
}
}
size_t total_bits = num_elements_per_row * bit_width;
size_t fixed_string_length = (total_bits + 7) / 8;
auto result_column = ColumnFixedString::create(fixed_string_length);
auto & result_chars = result_column->getChars();
result_chars.resize(input_rows_count * fixed_string_length);
size_t prev_offset = 0;
for (size_t row = 0; row < input_rows_count; ++row)
{
size_t current_offset = offsets[row];
size_t offset_in_result = row * fixed_string_length;
size_t bits_written = 0;
UInt8 current_byte = 0;
for (size_t idx = prev_offset; idx < current_offset; ++idx)
{
UInt32 quantized_value = 0;
if (bit_width == 16)
{
Float32 value;
if (is_float32)
{
value = (*float32_data)[idx];
}
else
{
value = static_cast<Float32>((*float64_data)[idx]);
}
quantized_value = float32ToUInt16(value);
}
else
{
Float64 value;
if (is_float32)
{
value = static_cast<Float64>((*float32_data)[idx]);
}
else
{
value = (*float64_data)[idx];
}
quantized_value = quantizeToBits(value, bit_width);
}
for (size_t bit = 0; bit < bit_width; ++bit)
{
if (quantized_value & (1U << (bit_width - 1 - bit)))
current_byte |= (1U << (7 - (bits_written % 8)));
bits_written++;
if (bits_written % 8 == 0)
{
result_chars[offset_in_result++] = current_byte;
current_byte = 0;
}
}
}
if (bits_written % 8 != 0)
{
result_chars[offset_in_result++] = current_byte;
}
prev_offset = current_offset;
}
return result_column;
}
}

View File

@ -0,0 +1,36 @@
// FunctionQuantize.h
#pragma once
#include <memory>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnVector.h>
#include <DataTypes/DataTypeFixedString.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/IFunction.h>
#include <Common/Logger.h>
#include "Interpreters/Context_fwd.h"
namespace DB
{
class FunctionQuantize : public IFunction
{
public:
static constexpr auto name = "quantize";
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionQuantize>(); }
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 2; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override;
ColumnPtr
executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override;
};
}