diff --git a/src/AggregateFunctions/Bfloat16Histogram.h b/src/AggregateFunctions/Bfloat16Histogram.h index 3a32a8aee46..8b056a1fbac 100644 --- a/src/AggregateFunctions/Bfloat16Histogram.h +++ b/src/AggregateFunctions/Bfloat16Histogram.h @@ -20,7 +20,7 @@ struct Bfloat16Histogram { using bfloat16 = UInt16; using Data = HashMap; - using Array = PODArrayWithStackMemory; + using Array = PODArrayWithStackMemory; Data data; Array array; @@ -33,15 +33,15 @@ struct Bfloat16Histogram if (!data.find(val)) { sorted = false; - count += to_add; - array.push_back(x); + array.push_back(to_Float32(val)); } + count += to_add; data[val] += to_add; } void merge(const Bfloat16Histogram & rhs) { - for (const Value & value : rhs.array) + for (const Float32 & value : rhs.array) { add(value, rhs.data.find(to_bfloat16(value))->getMapped()); } @@ -78,9 +78,9 @@ struct Bfloat16Histogram size_t sum = 0; size_t need = level * count; - for (const Value & value : array) + for (const Float32 & value : array) { - sum += data.find(to_bfloat16(value))->getMapped(); + sum += data[to_bfloat16(value)]; if (sum >= need) return value; } @@ -104,9 +104,9 @@ struct Bfloat16Histogram size_t sum = 0; size_t it = 0; - for (const auto & value : array) + for (const Float32 & value : array) { - sum += data.find(to_bfloat16(value))->getMapped(); + sum += data[to_bfloat16(value)]; while (it < size && sum >= static_cast(levels[indices[it]] * count)) { result[indices[it++]] = value; @@ -142,6 +142,11 @@ private: return ext::bit_cast(static_cast(x)) >> 16; } + Float32 to_Float32(const bfloat & x) const + { + return ext::bit_cast(x << 16); + } + void sortIfNeeded() { if (sorted)