From a646a3242c8237b0ca4de42b060fcbab843fd1f5 Mon Sep 17 00:00:00 2001 From: nikita4109 Date: Fri, 8 Nov 2024 13:19:45 +0100 Subject: [PATCH] quantize functions --- .../FunctionApproximateL2Distance.cpp | 120 +++++++++ src/Functions/FunctionApproximateL2Distance.h | 30 +++ src/Functions/FunctionDequantize.cpp | 171 +++++++++++++ src/Functions/FunctionDequantize.h | 35 +++ src/Functions/FunctionQuantize.cpp | 227 ++++++++++++++++++ src/Functions/FunctionQuantize.h | 36 +++ 6 files changed, 619 insertions(+) create mode 100644 src/Functions/FunctionApproximateL2Distance.cpp create mode 100644 src/Functions/FunctionApproximateL2Distance.h create mode 100644 src/Functions/FunctionDequantize.cpp create mode 100644 src/Functions/FunctionDequantize.h create mode 100644 src/Functions/FunctionQuantize.cpp create mode 100644 src/Functions/FunctionQuantize.h diff --git a/src/Functions/FunctionApproximateL2Distance.cpp b/src/Functions/FunctionApproximateL2Distance.cpp new file mode 100644 index 00000000000..91d2c8ee36d --- /dev/null +++ b/src/Functions/FunctionApproximateL2Distance.cpp @@ -0,0 +1,120 @@ +#include "FunctionApproximateL2Distance.h" +#include +#include +#include +#include +#include +#include +#include +#include "Columns/ColumnsNumber.h" +#include "Functions/FunctionDequantize.h" + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int ILLEGAL_TYPE_OF_ARGUMENT; +extern const int BAD_ARGUMENTS; +} + +DataTypePtr FunctionApproximateL2Distance::getReturnTypeImpl(const DataTypes & arguments) const +{ + if (!checkDataTypes(arguments[0].get()) || !checkDataTypes(arguments[1].get())) + { + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Arguments of function {} must be FixedString", getName()); + } + + if (!WhichDataType(arguments[2]).isUInt8()) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Third argument of function {} must be UInt8", getName()); + + return std::make_shared(); +} + +ColumnPtr +FunctionApproximateL2Distance::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const +{ + const auto * col_left = checkAndGetColumn(arguments[0].column.get()); + const auto * col_right = checkAndGetColumn(arguments[1].column.get()); + + if (!col_left || !col_right) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Arguments of function {} must be FixedString", getName()); + + if (col_left->getN() != col_right->getN()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "FixedStrings in function {} must have the same length", getName()); + + const auto & data_left = col_left->getChars(); + const auto & data_right = col_right->getChars(); + size_t fixed_string_length = col_left->getN(); + + UInt8 bit_width = 0; + if (const auto * const_col = checkAndGetColumnConst(arguments[2].column.get())) + bit_width = const_col->getValue(); + else + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Third argument of function {} must be a constant UInt8", getName()); + + if (!(bit_width == 16 || bit_width == 8 || bit_width == 4 || bit_width == 2 || bit_width == 1)) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Bit width must be one of 16, 8, 4, 2, or 1"); + + auto result_column = ColumnFloat64::create(); + auto & result_data = result_column->getData(); + result_data.resize(input_rows_count); + + size_t total_bits_per_row = fixed_string_length * 8; + + for (size_t row = 0; row < input_rows_count; ++row) + { + size_t row_offset = row * fixed_string_length; + + size_t num_elements = total_bits_per_row / bit_width; + + size_t bits_read = 0; + + double sum = 0.0; + + for (size_t idx = 0; idx < num_elements; ++idx) + { + UInt32 value_left = 0; + UInt32 value_right = 0; + for (size_t bit = 0; bit < bit_width; ++bit) + { + size_t bit_pos = bits_read++; + size_t byte_pos = bit_pos / 8; + size_t bit_in_byte = bit_pos % 8; + + UInt8 byte_left = data_left[row_offset + byte_pos]; + UInt8 byte_right = data_right[row_offset + byte_pos]; + + if (byte_left & (1 << (7 - bit_in_byte))) + value_left |= (1 << (bit_width - 1 - bit)); + + if (byte_right & (1 << (7 - bit_in_byte))) + value_right |= (1 << (bit_width - 1 - bit)); + } + + Float32 left_value = 0.0f; + Float32 right_value = 0.0f; + + if (bit_width == 16) + { + left_value = uint16ToFloat32(static_cast(value_left)); + right_value = uint16ToFloat32(static_cast(value_right)); + } + else + { + left_value = dequantizeFromBits(value_left, bit_width); + right_value = dequantizeFromBits(value_right, bit_width); + } + + double diff = static_cast(left_value) - static_cast(right_value); + sum += diff * diff; + } + + double distance = std::sqrt(sum); + result_data[row] = distance; + } + + return result_column; +} + +} diff --git a/src/Functions/FunctionApproximateL2Distance.h b/src/Functions/FunctionApproximateL2Distance.h new file mode 100644 index 00000000000..ce3bbb09d32 --- /dev/null +++ b/src/Functions/FunctionApproximateL2Distance.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace DB +{ + +class FunctionApproximateL2Distance : public IFunction +{ +public: + static constexpr auto name = "approximateL2Distance"; + static FunctionPtr create(ContextPtr) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 3; } + + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override; + + ColumnPtr + executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override; +}; + +} diff --git a/src/Functions/FunctionDequantize.cpp b/src/Functions/FunctionDequantize.cpp new file mode 100644 index 00000000000..726d333e77e --- /dev/null +++ b/src/Functions/FunctionDequantize.cpp @@ -0,0 +1,171 @@ +#include "FunctionDequantize.h" +#include +#include +#include +#include +#include +#include +#include +#include "Columns/ColumnsNumber.h" + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int ILLEGAL_TYPE_OF_ARGUMENT; +extern const int BAD_ARGUMENTS; +} + +Float32 dequantizeFromBits(UInt32 quantized_value, UInt8 bit_width) +{ + UInt32 max_int = (1 << bit_width) - 1; + Float32 normalized = static_cast(quantized_value) / static_cast(max_int); + Float32 value = normalized * 2.0f - 1.0f; + return value; +} + +Float32 uint16ToFloat32(UInt16 h) +{ + UInt32 sign = (h >> 15) & 0x1; + UInt32 exponent = (h >> 10) & 0x1F; + UInt32 mantissa = h & 0x3FF; + + UInt32 f_sign = sign << 31; + UInt32 f_exponent; + UInt32 f_mantissa; + + if (exponent == 0) + { + if (mantissa == 0) + { + f_exponent = 0; + f_mantissa = 0; + } + else + { + exponent = 1; + while ((mantissa & 0x400) == 0) + { + mantissa <<= 1; + exponent -= 1; + } + mantissa &= 0x3FF; + exponent += 127 - 15; + f_exponent = exponent << 23; + f_mantissa = mantissa << 13; + } + } + else if (exponent == 0x1F) + { + f_exponent = 0xFF << 23; + f_mantissa = mantissa << 13; + } + else + { + exponent += 127 - 15; + f_exponent = exponent << 23; + f_mantissa = mantissa << 13; + } + + UInt32 f = f_sign | f_exponent | f_mantissa; + + union + { + UInt32 u; + Float32 f; + } val = {f}; + + return val.f; +} + +DataTypePtr FunctionDequantize::getReturnTypeImpl(const DataTypes & arguments) const +{ + if (!checkDataTypes(arguments[0].get())) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument of function {} must be FixedString", getName()); + + if (!WhichDataType(arguments[1]).isUInt8()) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument of function {} must be UInt8", getName()); + + return std::make_shared(std::make_shared()); +} + +ColumnPtr FunctionDequantize::executeImpl( + const ColumnsWithTypeAndName & arguments, const DataTypePtr & /*result_type*/, size_t input_rows_count) const +{ + const auto & fixed_string_col = arguments[0].column; + const auto * col_fixed_string = checkAndGetColumn(fixed_string_col.get()); + if (!col_fixed_string) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument of function {} must be FixedString", getName()); + + const auto & data = col_fixed_string->getChars(); + size_t fixed_string_length = col_fixed_string->getN(); + + UInt8 bit_width = 0; + if (const auto * const_col = checkAndGetColumnConst(arguments[1].column.get())) + bit_width = const_col->getValue(); + else + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument of function {} must be a constant UInt8", getName()); + + if (!(bit_width == 16 || bit_width == 8 || bit_width == 4 || bit_width == 2 || bit_width == 1)) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Bit width must be one of 16, 8, 4, 2, or 1"); + + size_t total_bits_per_row = fixed_string_length * 8; + if (total_bits_per_row % bit_width != 0) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Total bits per row is not divisible by bit width in function {}", getName()); + } + size_t num_elements = total_bits_per_row / bit_width; + + auto col_float = ColumnFloat32::create(); + auto & float_data = col_float->getData(); + float_data.reserve(input_rows_count * num_elements); + + auto col_array = ColumnArray::create(std::move(col_float)); + auto & offsets = col_array->getOffsets(); + offsets.resize(input_rows_count); + + size_t offset = 0; + + for (size_t row = 0; row < input_rows_count; ++row) + { + size_t row_offset = row * fixed_string_length; + + size_t bits_read = 0; + + for (size_t idx = 0; idx < num_elements; ++idx) + { + UInt32 quantized_value = 0; + for (size_t bit = 0; bit < bit_width; ++bit) + { + size_t bit_pos = bits_read++; + size_t byte_pos = bit_pos / 8; + size_t bit_in_byte = bit_pos % 8; + + UInt8 current_byte = data[row_offset + byte_pos]; + + if (current_byte & (1 << (7 - bit_in_byte))) + quantized_value |= (1 << (bit_width - 1 - bit)); + } + + Float32 value = 0.0f; + if (bit_width == 16) + { + value = uint16ToFloat32(static_cast(quantized_value)); + } + else + { + value = dequantizeFromBits(quantized_value, bit_width); + } + + float_data.emplace_back(value); + } + + offset += num_elements; + offsets[row] = offset; + } + + return col_array; +} + +} diff --git a/src/Functions/FunctionDequantize.h b/src/Functions/FunctionDequantize.h new file mode 100644 index 00000000000..cd76eb4eac1 --- /dev/null +++ b/src/Functions/FunctionDequantize.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ + +class FunctionDequantize : public IFunction +{ +public: + static constexpr auto name = "dequantize"; + static FunctionPtr create(ContextPtr) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 2; } + + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override; + + ColumnPtr + executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override; +}; + +Float32 dequantizeFromBits(UInt32 quantized_value, UInt8 bit_width); +Float32 uint16ToFloat32(UInt16 h); + +} diff --git a/src/Functions/FunctionQuantize.cpp b/src/Functions/FunctionQuantize.cpp new file mode 100644 index 00000000000..3f534dfc374 --- /dev/null +++ b/src/Functions/FunctionQuantize.cpp @@ -0,0 +1,227 @@ +#include "FunctionQuantize.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "Columns/ColumnsNumber.h" +#include "Functions/FunctionApproximateL2Distance.h" +#include "Functions/FunctionDequantize.h" + +namespace DB +{ + +REGISTER_FUNCTION(Quantize) +{ + factory.registerFunction({}, FunctionFactory::Case::Insensitive); + factory.registerFunction({}, FunctionFactory::Case::Sensitive); + factory.registerFunction({}, FunctionFactory::Case::Insensitive); +} + +namespace ErrorCodes +{ +extern const int ILLEGAL_TYPE_OF_ARGUMENT; +extern const int BAD_ARGUMENTS; +} + +UInt32 quantizeToBits(Float64 value, UInt8 bit_width) +{ + constexpr Float64 min_value = -1.0; + constexpr Float64 max_value = 1.0; + + value = std::max(value, min_value); + value = std::min(value, max_value); + + Float64 normalized = (value - min_value) / (max_value - min_value); + + UInt32 max_int = (1U << bit_width) - 1; + UInt32 quantized = static_cast(std::round(normalized * max_int)); + + return quantized; +} + +UInt16 float32ToUInt16(Float32 value) +{ + union + { + Float32 f; + UInt32 u; + } val; + val.f = value; + + UInt32 f = val.u; + UInt16 h = 0; + + UInt32 sign = (f >> 31) & 0x1; + UInt32 exponent = (f >> 23) & 0xFF; + UInt32 mantissa = f & 0x7FFFFF; + + if (exponent == 0xFF) + { + h = (sign << 15) | (0x1F << 10); + if (mantissa) + h |= (mantissa & 0x3FF); + } + else if (exponent > 0x70) + { + exponent -= 0x70; + h = (sign << 15) | (exponent << 10) | (mantissa >> 13); + } + else + { + h = (sign << 15); + } + + return h; +} + +DataTypePtr FunctionQuantize::getReturnTypeImpl(const DataTypes & arguments) const +{ + const DataTypeArray * array_type = checkAndGetDataType(arguments[0].get()); + if (!array_type) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument must be an array in function {}", getName()); + + const IDataType * nested_type = array_type->getNestedType().get(); + if (!WhichDataType(nested_type).isFloat()) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Array elements must be Float32 or Float64 in function {}", getName()); + + if (!WhichDataType(arguments[1]).isUInt8()) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument of function {} must be UInt8", getName()); + + return std::make_shared(0); +} + + +ColumnPtr FunctionQuantize::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const +{ + const auto & array_col_with_type = arguments[0]; + const auto & array_col = array_col_with_type.column; + + const auto * col_array = checkAndGetColumn(array_col.get()); + if (!col_array) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument of function {} must be an array", getName()); + + const auto & data_col = col_array->getData(); + const auto & offsets = col_array->getOffsets(); + + const PaddedPODArray * float32_data = nullptr; + const PaddedPODArray * float64_data = nullptr; + bool is_float32 = false; + + if (const auto * col_float32 = checkAndGetColumn(&data_col)) + { + float32_data = &col_float32->getData(); + is_float32 = true; + } + else if (const auto * col_float64 = checkAndGetColumn(&data_col)) + { + float64_data = &col_float64->getData(); + is_float32 = false; + } + else + { + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Array elements in function {} must be Float32 or Float64", getName()); + } + + UInt8 bit_width = 0; + if (const auto * const_col = checkAndGetColumnConst(arguments[1].column.get())) + bit_width = const_col->getValue(); + else + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument of function {} must be a constant UInt8", getName()); + + if (!(bit_width == 16 || bit_width == 8 || bit_width == 4 || bit_width == 2 || bit_width == 1)) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Bit width must be one of 16, 8, 4, 2, or 1 in function {}", getName()); + + size_t num_elements_per_row = 0; + if (input_rows_count > 0) + { + num_elements_per_row = offsets[0]; + } + + for (size_t row = 1; row < input_rows_count; ++row) + { + size_t num_elements = offsets[row] - offsets[row - 1]; + if (num_elements != num_elements_per_row) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "All input arrays must have the same number of elements in function {}", getName()); + } + } + + size_t total_bits = num_elements_per_row * bit_width; + size_t fixed_string_length = (total_bits + 7) / 8; + + auto result_column = ColumnFixedString::create(fixed_string_length); + auto & result_chars = result_column->getChars(); + result_chars.resize(input_rows_count * fixed_string_length); + + size_t prev_offset = 0; + + for (size_t row = 0; row < input_rows_count; ++row) + { + size_t current_offset = offsets[row]; + size_t offset_in_result = row * fixed_string_length; + + size_t bits_written = 0; + UInt8 current_byte = 0; + + for (size_t idx = prev_offset; idx < current_offset; ++idx) + { + UInt32 quantized_value = 0; + if (bit_width == 16) + { + Float32 value; + if (is_float32) + { + value = (*float32_data)[idx]; + } + else + { + value = static_cast((*float64_data)[idx]); + } + quantized_value = float32ToUInt16(value); + } + else + { + Float64 value; + if (is_float32) + { + value = static_cast((*float32_data)[idx]); + } + else + { + value = (*float64_data)[idx]; + } + quantized_value = quantizeToBits(value, bit_width); + } + + for (size_t bit = 0; bit < bit_width; ++bit) + { + if (quantized_value & (1U << (bit_width - 1 - bit))) + current_byte |= (1U << (7 - (bits_written % 8))); + + bits_written++; + if (bits_written % 8 == 0) + { + result_chars[offset_in_result++] = current_byte; + current_byte = 0; + } + } + } + + if (bits_written % 8 != 0) + { + result_chars[offset_in_result++] = current_byte; + } + + prev_offset = current_offset; + } + + return result_column; +} + + +} diff --git a/src/Functions/FunctionQuantize.h b/src/Functions/FunctionQuantize.h new file mode 100644 index 00000000000..08a27af63de --- /dev/null +++ b/src/Functions/FunctionQuantize.h @@ -0,0 +1,36 @@ +// FunctionQuantize.h + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include "Interpreters/Context_fwd.h" + +namespace DB +{ + +class FunctionQuantize : public IFunction +{ +public: + static constexpr auto name = "quantize"; + static FunctionPtr create(ContextPtr) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 2; } + + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override; + + ColumnPtr + executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override; +}; + +}