readded new funcs + basic tests, big decimals WiP

This commit is contained in:
zvonand 2023-03-21 13:38:39 +01:00
parent e23f624968
commit 1bd6eef9c2
5 changed files with 366 additions and 11 deletions

View File

@ -0,0 +1,15 @@
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionFormatDecimal.h>
namespace DB
{
REGISTER_FUNCTION(FormatDecimal)
{
factory.registerFunction<FunctionFormatDecimal>();
}
}

View File

@ -0,0 +1,303 @@
#pragma once
#include <Core/DecimalFunctions.h>
#include <Functions/IFunction.h>
#include <Functions/FunctionHelpers.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnDecimal.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>
#include <IO/WriteBufferFromVector.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/Context_fwd.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_COLUMN;
}
struct Processor
{
/// For operations with Integer/Float
template <typename FromVectorType>
void vectorConstant(const FromVectorType & vec_from, const UInt8 value_precision,
ColumnString::Chars & vec_to, ColumnString::Offsets & offsets_to) const
{
WriteBufferFromVector<ColumnString::Chars> buf_to(vec_to);
size_t input_rows_count = vec_from.size();
offsets_to.resize(input_rows_count);
for (size_t i = 0; i < input_rows_count; ++i)
{
format(vec_from[i], buf_to, value_precision);
writeChar(0, buf_to);
offsets_to[i] = buf_to.count();
}
buf_to.finalize();
}
template <typename FirstArgVectorType>
void vectorVector(const FirstArgVectorType & vec_from, const ColumnVector<UInt8>::Container & vec_precision,
ColumnString::Chars & vec_to, ColumnString::Offsets & offsets_to) const
{
WriteBufferFromVector<ColumnString::Chars> buf_to(vec_to);
size_t input_rows_count = vec_from.size();
offsets_to.resize(input_rows_count);
for (size_t i = 0; i < input_rows_count; ++i)
{
format(vec_from[i], buf_to, vec_precision[i]);
writeChar(0, buf_to);
offsets_to[i] = buf_to.count();
}
buf_to.finalize();
}
template <typename FirstArgType>
void constantVector(const FirstArgType & value_from, const ColumnVector<UInt8>::Container & vec_precision,
ColumnString::Chars & vec_to, ColumnString::Offsets & offsets_to) const
{
WriteBufferFromVector<ColumnString::Chars> buf_to(vec_to);
size_t input_rows_count = vec_precision.size();
offsets_to.resize(input_rows_count);
for (size_t i = 0; i < input_rows_count; ++i)
{
format(value_from, buf_to, vec_precision[i]);
writeChar(0, buf_to);
offsets_to[i] = buf_to.count();
}
buf_to.finalize();
}
/// For operations with Decimal
template <typename FirstArgVectorType>
void vectorConstant(const FirstArgVectorType & vec_from, const UInt8 value_precision,
ColumnString::Chars & vec_to, ColumnString::Offsets & offsets_to, const UInt8 from_scale) const
{
WriteBufferFromVector<ColumnString::Chars> buf_to(vec_to);
size_t input_rows_count = vec_from.size();
offsets_to.resize(input_rows_count);
for (size_t i = 0; i < input_rows_count; ++i)
{
writeText(vec_from[i], from_scale, buf_to, true, true, static_cast<UInt32>(value_precision));
offsets_to[i] = buf_to.count();
}
buf_to.finalize();
}
template <typename FirstArgVectorType>
void vectorVector(const FirstArgVectorType & vec_from, const ColumnVector<UInt8>::Container & vec_precision,
ColumnString::Chars & vec_to, ColumnString::Offsets & offsets_to, const UInt8 from_scale) const
{
WriteBufferFromVector<ColumnString::Chars> buf_to(vec_to);
size_t input_rows_count = vec_from.size();
offsets_to.resize(input_rows_count);
for (size_t i = 0; i < input_rows_count; ++i)
{
writeText(vec_from[i], from_scale, buf_to, true, true, static_cast<UInt32>(vec_precision[i]));
writeChar(0, buf_to);
offsets_to[i] = buf_to.count();
}
buf_to.finalize();
}
template <typename FirstArgType>
void constantVector(const FirstArgType & value_from, const ColumnVector<UInt8>::Container & vec_precision,
ColumnString::Chars & vec_to, ColumnString::Offsets & offsets_to, const UInt8 from_scale) const
{
WriteBufferFromVector<ColumnString::Chars> buf_to(vec_to);
size_t input_rows_count = vec_precision.size();
offsets_to.resize(input_rows_count);
for (size_t i = 0; i < input_rows_count; ++i)
{
writeText(value_from, from_scale, buf_to, true, true, static_cast<UInt32>(vec_precision[i]));
offsets_to[i] = buf_to.count();
}
buf_to.finalize();
}
private:
static void format(double value, DB::WriteBuffer & out, int precision)
{
DB::DoubleConverter<false>::BufferType buffer;
double_conversion::StringBuilder builder{buffer, sizeof(buffer)};
const auto result = DB::DoubleConverter<false>::instance().ToFixed(value, precision, &builder);
if (!result)
throw DB::Exception(DB::ErrorCodes::CANNOT_PRINT_FLOAT_OR_DOUBLE_NUMBER, "Cannot print float or double number");
out.write(buffer, builder.position());
}
};
class FunctionFormatDecimal : public IFunction
{
public:
static constexpr auto name = "formatDecimal";
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionFormatDecimal>(); }
String getName() const override
{
return name;
}
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override
{
return true;
}
size_t getNumberOfArguments() const override { return 2; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isNumber(*arguments[0]))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal first argument for formatDecimal function: got {}, expected numeric type",
arguments[0]->getName());
if (!isUInt8(*arguments[1]))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal second argument for formatDecimal function: got {}, expected UInt8",
arguments[1]->getName());
return std::make_shared<DataTypeString>();
}
bool useDefaultImplementationForConstants() const override { return true; }
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override
{
WhichDataType which_from(arguments[0].type.get());
if (which_from.isUInt8())
return executeType<UInt8>(arguments);
else if (which_from.isUInt16())
return executeType<UInt16>(arguments);
else if (which_from.isUInt32())
return executeType<UInt32>(arguments);
else if (which_from.isUInt64())
return executeType<UInt64>(arguments);
else if (which_from.isUInt128())
return executeType<UInt128>(arguments);
else if (which_from.isUInt256())
return executeType<UInt256>(arguments);
else if (which_from.isInt8())
return executeType<Int8>(arguments);
else if (which_from.isInt16())
return executeType<Int16>(arguments);
else if (which_from.isInt32())
return executeType<Int32>(arguments);
else if (which_from.isInt64())
return executeType<Int64>(arguments);
else if (which_from.isInt128())
return executeType<Int128>(arguments);
else if (which_from.isInt256())
return executeType<Int256>(arguments);
else if (which_from.isFloat32())
return executeType<Float32>(arguments);
else if (which_from.isFloat64())
return executeType<Float64>(arguments);
else if (which_from.isDecimal32())
return executeType<Decimal32>(arguments);
else if (which_from.isDecimal64())
return executeType<Decimal64>(arguments);
else if (which_from.isDecimal128())
return executeType<Decimal128>(arguments);
else if (which_from.isDecimal256())
return executeType<Decimal256>(arguments);
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of argument of function {}",
arguments[0].column->getName(), getName());
}
private:
template <typename T, std::enable_if_t<is_integer<T> || is_floating_point<T>, bool> = true>
ColumnPtr executeType(const ColumnsWithTypeAndName & arguments) const
{
auto result_column_string = ColumnString::create();
auto col_to = assert_cast<ColumnString *>(result_column_string.get());
ColumnString::Chars & data_to = col_to->getChars();
ColumnString::Offsets & offsets_to = col_to->getOffsets();
const auto * from_col = checkAndGetColumn<ColumnVector<T>>(arguments[0].column.get());
const auto * precision_col = checkAndGetColumn<ColumnVector<UInt8>>(arguments[1].column.get());
const auto * from_col_const = typeid_cast<const ColumnConst *>(arguments[0].column.get());
const auto * precision_col_const = typeid_cast<const ColumnConst *>(arguments[1].column.get());
Processor processor;
if (from_col)
{
if (precision_col_const)
processor.vectorConstant(from_col->getData(), precision_col_const->template getValue<UInt8>(), data_to, offsets_to);
else
processor.vectorVector(from_col->getData(), precision_col->getData(), data_to, offsets_to);
}
else if (from_col_const)
{
processor.constantVector(from_col_const->template getValue<T>(), precision_col->getData(), data_to, offsets_to);
}
else
{
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of first argument of function formatDecimal",
arguments[0].column->getName());
}
return result_column_string;
}
template <typename T, std::enable_if_t<is_decimal<T>, bool> = true>
ColumnPtr executeType(const ColumnsWithTypeAndName & arguments) const
{
auto result_column_string = ColumnString::create();
auto col_to = assert_cast<ColumnString *>(result_column_string.get());
ColumnString::Chars & data_to = col_to->getChars();
ColumnString::Offsets & offsets_to = col_to->getOffsets();
const auto * from_col = checkAndGetColumn<ColumnDecimal<T>>(arguments[0].column.get());
const auto * precision_col = checkAndGetColumn<ColumnVector<UInt8>>(arguments[1].column.get());
const auto * from_col_const = typeid_cast<const ColumnConst *>(arguments[0].column.get());
const auto * precision_col_const = typeid_cast<const ColumnConst *>(arguments[1].column.get());
UInt8 from_scale = from_col->getScale();
Processor processor;
if (from_col)
{
if (precision_col_const)
processor.vectorConstant(from_col->getData(), precision_col_const->template getValue<UInt8>(), data_to, offsets_to, from_scale);
else
processor.vectorVector(from_col->getData(), precision_col->getData(), data_to, offsets_to, from_scale);
}
else if (from_col_const)
{
processor.constantVector(from_col_const->template getValue<T>(), precision_col->getData(), data_to, offsets_to, from_scale);
}
else
{
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of first argument of function formatDecimal",
arguments[0].column->getName());
}
return result_column_string;
}
};
}

View File

@ -891,7 +891,8 @@ inline void writeText(const IPv4 & x, WriteBuffer & buf) { writeIPv4Text(x, buf)
inline void writeText(const IPv6 & x, WriteBuffer & buf) { writeIPv6Text(x, buf); }
template <typename T>
void writeDecimalFractional(const T & x, UInt32 scale, WriteBuffer & ostr, bool trailing_zeros)
void writeDecimalFractional(const T & x, UInt32 scale, WriteBuffer & ostr, bool trailing_zeros,
bool exact_frac_digits_set, UInt32 frac_digits_num)
{
/// If it's big integer, but the number of digits is small,
/// use the implementation for smaller integers for more efficient arithmetic.
@ -900,17 +901,17 @@ void writeDecimalFractional(const T & x, UInt32 scale, WriteBuffer & ostr, bool
{
if (x <= std::numeric_limits<UInt32>::max())
{
writeDecimalFractional(static_cast<UInt32>(x), scale, ostr, trailing_zeros);
writeDecimalFractional(static_cast<UInt32>(x), scale, ostr, trailing_zeros, exact_frac_digits_set, frac_digits_num);
return;
}
else if (x <= std::numeric_limits<UInt64>::max())
{
writeDecimalFractional(static_cast<UInt64>(x), scale, ostr, trailing_zeros);
writeDecimalFractional(static_cast<UInt64>(x), scale, ostr, trailing_zeros, exact_frac_digits_set, frac_digits_num);
return;
}
else if (x <= std::numeric_limits<UInt128>::max())
{
writeDecimalFractional(static_cast<UInt128>(x), scale, ostr, trailing_zeros);
writeDecimalFractional(static_cast<UInt128>(x), scale, ostr, trailing_zeros, exact_frac_digits_set, frac_digits_num);
return;
}
}
@ -918,27 +919,48 @@ void writeDecimalFractional(const T & x, UInt32 scale, WriteBuffer & ostr, bool
{
if (x <= std::numeric_limits<UInt32>::max())
{
writeDecimalFractional(static_cast<UInt32>(x), scale, ostr, trailing_zeros);
writeDecimalFractional(static_cast<UInt32>(x), scale, ostr, trailing_zeros, exact_frac_digits_set, frac_digits_num);
return;
}
else if (x <= std::numeric_limits<UInt64>::max())
{
writeDecimalFractional(static_cast<UInt64>(x), scale, ostr, trailing_zeros);
writeDecimalFractional(static_cast<UInt64>(x), scale, ostr, trailing_zeros, exact_frac_digits_set, frac_digits_num);
return;
}
}
constexpr size_t max_digits = std::numeric_limits<UInt256>::digits10;
assert(scale <= max_digits);
char buf[max_digits];
memset(buf, '0', scale);
T value = x;
Int32 last_nonzero_pos = 0;
for (Int32 pos = scale - 1; pos >= 0; --pos)
if (exact_frac_digits_set && frac_digits_num < scale)
{
T new_value = value / DecimalUtils::scaleMultiplier<Int256>(frac_digits_num - 1);
auto round_carry = new_value % 10;
value = new_value / 10;
if (round_carry >= 5)
value += 1;
}
for (Int32 pos = exact_frac_digits_set ? frac_digits_num - 1 : scale - 1; pos >= 0; --pos)
{
auto remainder = value % 10;
value /= 10;
// if (unlikely(carry))
// {
// if (remainder)
// {
// --remainder;
// carry = 0;
// }
// else
// remainder = 9;
// }
if (remainder != 0 && last_nonzero_pos == 0)
last_nonzero_pos = pos;
@ -946,12 +968,15 @@ void writeDecimalFractional(const T & x, UInt32 scale, WriteBuffer & ostr, bool
buf[pos] += static_cast<char>(remainder);
}
writeChar('.', ostr);
ostr.write(buf, trailing_zeros ? scale : last_nonzero_pos + 1);
if (likely(!exact_frac_digits_set || frac_digits_num != 0))
writeChar('.', ostr);
ostr.write(buf, exact_frac_digits_set ? frac_digits_num : trailing_zeros ? scale : last_nonzero_pos + 1);
}
template <typename T>
void writeText(Decimal<T> x, UInt32 scale, WriteBuffer & ostr, bool trailing_zeros)
void writeText(Decimal<T> x, UInt32 scale, WriteBuffer & ostr, bool trailing_zeros,
bool exact_frac_digits_set = false, UInt32 frac_digits_num = 0)
{
T part = DecimalUtils::getWholePart(x, scale);
@ -970,7 +995,7 @@ void writeText(Decimal<T> x, UInt32 scale, WriteBuffer & ostr, bool trailing_zer
if (part < 0)
part *= T(-1);
writeDecimalFractional(part, scale, ostr, trailing_zeros);
writeDecimalFractional(part, scale, ostr, trailing_zeros, exact_frac_digits_set, frac_digits_num);
}
}
}

View File

@ -0,0 +1,6 @@
2.00
2.12
2.15
64.32
64.32
64.32

View File

@ -0,0 +1,6 @@
SELECT formatDecimal(2,2); -- 2.00
SELECT formatDecimal(2.123456,2); -- 2.12
SELECT formatDecimal(2.1456,2); -- 2.15 -- rounding!
SELECT formatDecimal(64.32::Float64, 2); -- 64.32
SELECT formatDecimal(64.32::Decimal32(2), 2); -- 64.32
SELECT formatDecimal(64.32::Decimal64(2), 2); -- 64.32