diff --git a/dbms/src/Dictionaries/CatBoostModel.cpp b/dbms/src/Dictionaries/CatBoostModel.cpp index 412cdc1a923..888bb710673 100644 --- a/dbms/src/Dictionaries/CatBoostModel.cpp +++ b/dbms/src/Dictionaries/CatBoostModel.cpp @@ -206,9 +206,9 @@ private: size_t column_size = columns[offset]->size(); auto data_column = std::make_shared>(size * column_size); T* data = data_column->getData().data(); - for (size_t i = offset; i < offset + size; ++i) + for (size_t i = 0; i < size; ++i) { - const auto & column = columns[i]; + const auto & column = columns[offset + i]; if (column->isNumeric()) placeColumnAsNumber(column, data + i, size); } @@ -226,19 +226,19 @@ private: /// Place columns into buffer, returns data which was used for fixed string columns. /// Buffer should contains column->size() values, each value contains size strings. std::vector> placeStringColumns( - const Columns & columns, size_t offset, size_t size, const char *** buffer) const + const Columns & columns, size_t offset, size_t size, const char ** buffer) const { if (size == 0) return {}; std::vector> data; - for (size_t i = offset; i < offset + size; ++i) + for (size_t i = 0; i < size; ++i) { - const auto & column = columns[i]; + const auto & column = columns[offset + i]; if (auto column_string = typeid_cast(column.get())) - placeStringColumn(*column_string, buffer[i], size); + placeStringColumn(*column_string, buffer + i, size); else if (auto column_fixed_string = typeid_cast(column.get())) - data.push_back(placeFixedStringColumn(*column_fixed_string, buffer[i], size)); + data.push_back(placeFixedStringColumn(*column_fixed_string, buffer + i, size)); else throw Exception("Cannot place string column.", ErrorCodes::LOGICAL_ERROR); } @@ -248,27 +248,30 @@ private: /// Calc hash for string cat feature at ps positions. template - void calcStringHashes(const Column * column, size_t features_count, size_t ps, const int ** buffer) const + void calcStringHashes(const Column * column, size_t ps, const int ** buffer) const { size_t column_size = column->size(); for (size_t j = 0; j < column_size; ++j) { auto ref = column->getDataAt(j); const_cast(*buffer)[ps] = api->GetStringCatFeatureHash(ref.data, ref.size); - buffer += features_count; + ++buffer; } } /// Calc hash for int cat feature at ps position. Buffer at positions ps should contains unhashed values. - void calcIntHashes(size_t column_size, size_t features_count, size_t ps, const int ** buffer) const + void calcIntHashes(size_t column_size, size_t ps, const int ** buffer) const { for (size_t j = 0; j < column_size; ++j) { const_cast(*buffer)[ps] = api->GetIntegerCatFeatureHash((*buffer)[ps]); - buffer += features_count; + ++buffer; } } + /// buffer contains column->size() rows and size columns. + /// For int cat features calc hash inplace. + /// For string cat features calc hash from column rows. void calcHashes(const Columns & columns, size_t offset, size_t size, const int ** buffer) const { if (size == 0) @@ -276,18 +279,19 @@ private: size_t column_size = columns[offset]->size(); std::vector> data; - for (size_t i = offset; i < offset + size; ++i) + for (size_t i = 0; i < size; ++i) { - const auto & column = columns[i]; + const auto & column = columns[offset + i]; if (auto column_string = typeid_cast(column.get())) - calcStringHashes(column_string, size, column_size, buffer); + calcStringHashes(column_string, i, buffer); else if (auto column_fixed_string = typeid_cast(column.get())) - calcStringHashes(column_fixed_string, size, column_size, buffer); + calcStringHashes(column_fixed_string, i, buffer); else - calcIntHashes(column_size, size, column_size, buffer); + calcIntHashes(column_size, i, buffer); } } + /// buffer[column_size * cat_features_count] -> char * => cat_features[column_size][cat_features_count] -> char * void fillCatFeaturesBuffer(const char *** cat_features, const char ** buffer, size_t column_size, size_t cat_features_count) const { @@ -335,7 +339,7 @@ private: fillCatFeaturesBuffer(cat_features_buf, cat_features_holder.data(), column_size, cat_features_count); auto fixed_strings_data = placeStringColumns(columns, float_features_count, - cat_features_count, cat_features_buf); + cat_features_count, cat_features_holder.data()); if (!api->CalcModelPrediction(handle->get(), column_size, float_features_buf, float_features_count,