fix add method

This commit is contained in:
redclusive 2021-04-15 17:30:41 +03:00
parent 8db89e493b
commit be547193ae

View File

@ -20,7 +20,7 @@ struct Bfloat16Histogram
{ {
using bfloat16 = UInt16; using bfloat16 = UInt16;
using Data = HashMap<bfloat16, size_t>; using Data = HashMap<bfloat16, size_t>;
using Array = PODArrayWithStackMemory<Value, 64>; using Array = PODArrayWithStackMemory<Float32, 64>;
Data data; Data data;
Array array; Array array;
@ -33,15 +33,15 @@ struct Bfloat16Histogram
if (!data.find(val)) if (!data.find(val))
{ {
sorted = false; sorted = false;
count += to_add; array.push_back(to_Float32(val));
array.push_back(x);
} }
count += to_add;
data[val] += to_add; data[val] += to_add;
} }
void merge(const Bfloat16Histogram & rhs) void merge(const Bfloat16Histogram & rhs)
{ {
for (const Value & value : rhs.array) for (const Float32 & value : rhs.array)
{ {
add(value, rhs.data.find(to_bfloat16(value))->getMapped()); add(value, rhs.data.find(to_bfloat16(value))->getMapped());
} }
@ -78,9 +78,9 @@ struct Bfloat16Histogram
size_t sum = 0; size_t sum = 0;
size_t need = level * count; size_t need = level * count;
for (const Value & value : array) for (const Float32 & value : array)
{ {
sum += data.find(to_bfloat16(value))->getMapped(); sum += data[to_bfloat16(value)];
if (sum >= need) if (sum >= need)
return value; return value;
} }
@ -104,9 +104,9 @@ struct Bfloat16Histogram
size_t sum = 0; size_t sum = 0;
size_t it = 0; size_t it = 0;
for (const auto & value : array) for (const Float32 & value : array)
{ {
sum += data.find(to_bfloat16(value))->getMapped(); sum += data[to_bfloat16(value)];
while (it < size && sum >= static_cast<size_t>(levels[indices[it]] * count)) while (it < size && sum >= static_cast<size_t>(levels[indices[it]] * count))
{ {
result[indices[it++]] = value; result[indices[it++]] = value;
@ -142,6 +142,11 @@ private:
return ext::bit_cast<UInt32>(static_cast<Float32>(x)) >> 16; return ext::bit_cast<UInt32>(static_cast<Float32>(x)) >> 16;
} }
Float32 to_Float32(const bfloat & x) const
{
return ext::bit_cast<Float32>(x << 16);
}
void sortIfNeeded() void sortIfNeeded()
{ {
if (sorted) if (sorted)