PR post-review fixes

This commit is contained in:
Mikhail Gorshkov 2024-06-05 14:06:31 +00:00
parent 83901b82c9
commit 7ce67265c7

View File

@ -373,11 +373,11 @@ public:
} }
} }
static void applyOne(const T* __restrict in, size_t scale, T* __restrict out) static void applyOne(T in, size_t scale, T& out)
{ {
using ScalarOp = Op<Vectorize::No>; using ScalarOp = Op<Vectorize::No>;
auto s = ScalarOp::prepare(scale); auto s = ScalarOp::prepare(scale);
ScalarOp::compute(in, s, out); ScalarOp::compute(&in, s, &out);
} }
}; };
@ -435,9 +435,9 @@ public:
} }
} }
static void applyOne(const T* __restrict in, size_t scale, T* __restrict out) static void applyOne(T in, size_t scale, T& out)
{ {
Op::compute(in, scale, out); Op::compute(&in, scale, &out);
} }
}; };
@ -475,17 +475,17 @@ public:
} }
} }
static void applyOne(const NativeType* __restrict in, UInt32 in_scale, NativeType* __restrict out, Scale scale_arg) static void applyOne(NativeType in, UInt32 in_scale, NativeType& out, Scale scale_arg)
{ {
scale_arg = in_scale - scale_arg; scale_arg = in_scale - scale_arg;
if (scale_arg > 0) if (scale_arg > 0)
{ {
auto scale = intExp10OfSize<NativeType>(scale_arg); auto scale = intExp10OfSize<NativeType>(scale_arg);
Op::compute(in, scale, out); Op::compute(&in, scale, &out);
} }
else else
{ {
memcpy(out, in, sizeof(T)); memcpy(&out, &in, sizeof(T));
} }
} }
}; };
@ -553,35 +553,27 @@ struct Dispatcher
const auto & scale_data = scale_col_typed->getData(); const auto & scale_data = scale_col_typed->getData();
const size_t rows = value_data.size(); const size_t rows = value_data.size();
const T * end_in = value_data.data() + rows; for (size_t i = 0; i < rows; ++i)
const T * __restrict p_in = value_data.data();
const ScaleType * __restrict p_scale = scale_data.data();
T * __restrict p_out = vec_res.data();
while (p_in < end_in)
{ {
Int64 scale64 = *p_scale; Int64 scale64 = scale_data[i];
validateScale(scale64); validateScale(scale64);
Scale raw_scale = scale64; Scale raw_scale = scale64;
if (raw_scale == 0) if (raw_scale == 0)
{ {
size_t scale = 1; size_t scale = 1;
FunctionRoundingImpl<ScaleMode::Zero>::applyOne(p_in, scale, p_out); FunctionRoundingImpl<ScaleMode::Zero>::applyOne(value_data[i], scale, vec_res[i]);
} }
else if (raw_scale > 0) else if (raw_scale > 0)
{ {
size_t scale = intExp10(raw_scale); size_t scale = intExp10(raw_scale);
FunctionRoundingImpl<ScaleMode::Positive>::applyOne(p_in, scale, p_out); FunctionRoundingImpl<ScaleMode::Positive>::applyOne(value_data[i], scale, vec_res[i]);
} }
else else
{ {
size_t scale = intExp10(-raw_scale); size_t scale = intExp10(-raw_scale);
FunctionRoundingImpl<ScaleMode::Negative>::applyOne(p_in, scale, p_out); FunctionRoundingImpl<ScaleMode::Negative>::applyOne(value_data[i], scale, vec_res[i]);
} }
++p_in;
++p_scale;
++p_out;
} }
} }
} }
@ -611,27 +603,20 @@ public:
auto scale_arg = scale_col == nullptr ? 0 : getScaleArg(checkAndGetColumnConst<ColumnVector<ScaleType>>(scale_col)); auto scale_arg = scale_col == nullptr ? 0 : getScaleArg(checkAndGetColumnConst<ColumnVector<ScaleType>>(scale_col));
DecimalRoundingImpl<T, rounding_mode, tie_breaking_mode>::apply(value_col_typed.getData(), value_col_typed.getScale(), vec_res, scale_arg); DecimalRoundingImpl<T, rounding_mode, tie_breaking_mode>::apply(value_col_typed.getData(), value_col_typed.getScale(), vec_res, scale_arg);
} }
/// Non-cosnt scale argument /// Non-const scale argument
else if (const auto * scale_col_typed = checkAndGetColumn<ColumnVector<ScaleType>>(scale_col)) else if (const auto * scale_col_typed = checkAndGetColumn<ColumnVector<ScaleType>>(scale_col))
{ {
const auto & scale = scale_col_typed->getData(); const auto & scale = scale_col_typed->getData();
const size_t rows = vec_src.size(); const size_t rows = vec_src.size();
using NativeType = typename T::NativeType; for (size_t i = 0; i < rows; ++i)
const NativeType * __restrict p_in = reinterpret_cast<const NativeType *>(vec_src.data());
const ScaleType * __restrict p_scale = scale.data();
const NativeType * end_in = p_in + rows;
NativeType * __restrict p_out = reinterpret_cast<NativeType *>(vec_res.data());
while (p_in < end_in)
{ {
Int64 scale64 = *p_scale; Int64 scale64 = scale[i];
validateScale(scale64); validateScale(scale64);
Scale raw_scale = scale64; Scale raw_scale = scale64;
DecimalRoundingImpl<T, rounding_mode, tie_breaking_mode>::applyOne(p_in, value_col_typed.getScale(), p_out, raw_scale); DecimalRoundingImpl<T, rounding_mode, tie_breaking_mode>::applyOne(value_col_typed.getElement(i), value_col_typed.getScale(),
++p_in; reinterpret_cast<ColumnDecimal<T>::NativeT&>(col_res->getElement(i)), raw_scale);
++p_scale;
++p_out;
} }
} }
} }