Add save/load capabilities to Obfuscator

This commit is contained in:
Alexey Milovidov 2022-07-25 03:27:10 +02:00
parent 6fdcb009ff
commit c09413e3b9

View File

@ -34,6 +34,10 @@
#include <base/bit_cast.h>
#include <IO/ReadBufferFromFileDescriptor.h>
#include <IO/WriteBufferFromFileDescriptor.h>
#include <IO/ReadBufferFromFile.h>
#include <IO/WriteBufferFromFile.h>
#include <Compression/CompressedReadBuffer.h>
#include <Compression/CompressedWriteBuffer.h>
#include <memory>
#include <cmath>
#include <unistd.h>
@ -95,6 +99,9 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
extern const int NOT_IMPLEMENTED;
extern const int CANNOT_SEEK_THROUGH_FILE;
extern const int UNKNOWN_FORMAT_VERSION;
extern const int INCORRECT_NUMBER_OF_COLUMNS;
extern const int TYPE_MISMATCH;
}
@ -115,6 +122,12 @@ public:
/// Deterministically change seed to some other value. This can be used to generate more values than were in source.
virtual void updateSeed() = 0;
/// Save into file. Binary, platform-dependent, version-dependent serialization.
virtual void serialize(WriteBuffer & out) const = 0;
/// Read from file
virtual void deserialize(ReadBuffer & in) = 0;
virtual ~IModel() = default;
};
@ -189,6 +202,8 @@ public:
void train(const IColumn &) override {}
void finalize() override {}
void serialize(WriteBuffer &) const override {}
void deserialize(ReadBuffer &) override {}
ColumnPtr generate(const IColumn & column) override
{
@ -230,6 +245,8 @@ public:
void train(const IColumn &) override {}
void finalize() override {}
void serialize(WriteBuffer &) const override {}
void deserialize(ReadBuffer &) override {}
ColumnPtr generate(const IColumn & column) override
{
@ -279,6 +296,8 @@ public:
void train(const IColumn &) override {}
void finalize() override {}
void serialize(WriteBuffer &) const override {}
void deserialize(ReadBuffer &) override {}
ColumnPtr generate(const IColumn & column) override
{
@ -311,6 +330,8 @@ class IdentityModel : public IModel
public:
void train(const IColumn &) override {}
void finalize() override {}
void serialize(WriteBuffer &) const override {}
void deserialize(ReadBuffer &) override {}
ColumnPtr generate(const IColumn & column) override
{
@ -395,6 +416,8 @@ public:
void train(const IColumn &) override {}
void finalize() override {}
void serialize(WriteBuffer &) const override {}
void deserialize(ReadBuffer &) override {}
ColumnPtr generate(const IColumn & column) override
{
@ -431,6 +454,8 @@ public:
void train(const IColumn &) override {}
void finalize() override {}
void serialize(WriteBuffer &) const override {}
void deserialize(ReadBuffer &) override {}
ColumnPtr generate(const IColumn & column) override
{
@ -469,6 +494,8 @@ public:
void train(const IColumn &) override {}
void finalize() override {}
void serialize(WriteBuffer &) const override {}
void deserialize(ReadBuffer &) override {}
ColumnPtr generate(const IColumn & column) override
{
@ -512,6 +539,26 @@ struct MarkovModelParameters
size_t frequency_add;
double frequency_desaturate;
size_t determinator_sliding_window_size;
void serialize(WriteBuffer & out) const
{
writeBinary(order, out);
writeBinary(frequency_cutoff, out);
writeBinary(num_buckets_cutoff, out);
writeBinary(frequency_add, out);
writeBinary(frequency_desaturate, out);
writeBinary(determinator_sliding_window_size, out);
}
void deserialize(ReadBuffer & in)
{
readBinary(order, in);
readBinary(frequency_cutoff, in);
readBinary(num_buckets_cutoff, in);
readBinary(frequency_add, in);
readBinary(frequency_desaturate, in);
readBinary(determinator_sliding_window_size, in);
}
};
@ -565,6 +612,39 @@ private:
return END;
}
void serialize(WriteBuffer & out) const
{
writeBinary(total, out);
writeBinary(count_end, out);
size_t size = buckets.size();
writeBinary(size, out);
for (const auto & elem : buckets)
{
writeBinary(elem.first, out);
writeBinary(elem.second, out);
}
}
void deserialize(ReadBuffer & in)
{
readBinary(total, in);
readBinary(count_end, in);
size_t size = 0;
readBinary(size, in);
buckets.reserve(size);
for (size_t i = 0; i < size; ++i)
{
Buckets::value_type elem;
readBinary(elem.first, in);
readBinary(elem.second, in);
buckets.emplace(std::move(elem));
}
}
};
using Table = HashMap<NGramHash, Histogram, TrivialHash>;
@ -621,6 +701,37 @@ public:
explicit MarkovModel(MarkovModelParameters params_)
: params(std::move(params_)), code_points(params.order, BEGIN) {}
void serialize(WriteBuffer & out) const
{
params.serialize(out);
size_t size = table.size();
writeBinary(size, out);
for (const auto & elem : table)
{
writeBinary(elem.getKey(), out);
elem.getMapped().serialize(out);
}
}
void deserialize(ReadBuffer & in)
{
params.deserialize(in);
size_t size = 0;
readBinary(size, in);
table.reserve(size);
for (size_t i = 0; i < size; ++i)
{
NGramHash key{};
readBinary(key, in);
Histogram & histogram = table[key];
histogram.deserialize(in);
}
}
void consume(const char * data, size_t size)
{
/// First 'order' number of code points are pre-filled with BEGIN.
@ -878,6 +989,16 @@ public:
{
seed = hash(seed);
}
void serialize(WriteBuffer & out) const override
{
markov_model.serialize(out);
}
void deserialize(ReadBuffer & in) override
{
markov_model.deserialize(in);
}
};
@ -916,6 +1037,16 @@ public:
{
nested_model->updateSeed();
}
void serialize(WriteBuffer & out) const override
{
nested_model->serialize(out);
}
void deserialize(ReadBuffer & in) override
{
nested_model->deserialize(in);
}
};
@ -954,6 +1085,16 @@ public:
{
nested_model->updateSeed();
}
void serialize(WriteBuffer & out) const override
{
nested_model->serialize(out);
}
void deserialize(ReadBuffer & in) override
{
nested_model->deserialize(in);
}
};
@ -1046,6 +1187,18 @@ public:
for (auto & model : models)
model->updateSeed();
}
void serialize(WriteBuffer & out) const
{
for (const auto & model : models)
model->serialize(out);
}
void deserialize(ReadBuffer & in)
{
for (auto & model : models)
model->deserialize(in);
}
};
}
@ -1070,6 +1223,8 @@ try
("seed", po::value<std::string>(), "seed (arbitrary string), must be random string with at least 10 bytes length; note that a seed for each column is derived from this seed and a column name: you can obfuscate data for different tables and as long as you use identical seed and identical column names, the data for corresponding non-text columns for different tables will be transformed in the same way, so the data for different tables can be JOINed after obfuscation")
("limit", po::value<UInt64>(), "if specified - stop after generating that number of rows")
("silent", po::value<bool>()->default_value(false), "don't print information messages to stderr")
("save", po::value<std::string>(), "save the models after training to the specified file. You can use --limit 0 to skip the generation step. The file is using binary, platform-dependent, opaque serialization format. The model parameters are saved, while the seed is not.")
("load", po::value<std::string>(), "load the models instead of training from the specified file. The table structure must match the saved file. The seed should be specified separately, while other model parameters are loaded.")
("order", po::value<UInt64>()->default_value(5), "order of markov model to generate strings")
("frequency-cutoff", po::value<UInt64>()->default_value(5), "frequency cutoff for markov model: remove all buckets with count less than specified")
("num-buckets-cutoff", po::value<UInt64>()->default_value(0), "cutoff for number of different possible continuations for a context: remove all histograms with less than specified number of buckets")
@ -1096,12 +1251,33 @@ try
return 0;
}
if (options.count("save") && options.count("load"))
{
std::cerr << "The options --save and --load cannot be used together.\n";
return 1;
}
if (options.count("load")
&& (options.count("order")
|| options.count("frequency-cutoff")
|| options.count("num-buckets-cutoff")
|| options.count("frequency-add")
|| options.count("frequency-desaturate")
|| options.count("determinator-sliding-window-size")))
{
std::cerr << "Model parameters should not be specified with the --load options, as they will be loaded from the file.\n";
return 1;
}
UInt64 seed = sipHash64(options["seed"].as<std::string>());
std::string structure = options["structure"].as<std::string>();
std::string input_format = options["input-format"].as<std::string>();
std::string output_format = options["output-format"].as<std::string>();
std::string load_from_file = options["load"].as<std::string>();
std::string save_into_file = options["save"].as<std::string>();
UInt64 limit = 0;
if (options.count("limit"))
limit = options["limit"].as<UInt64>();
@ -1117,7 +1293,7 @@ try
markov_model_params.frequency_desaturate = options["frequency-desaturate"].as<double>();
markov_model_params.determinator_sliding_window_size = options["determinator-sliding-window-size"].as<UInt64>();
// Create header block
/// Create the header block
std::vector<std::string> structure_vals;
boost::split(structure_vals, structure, boost::algorithm::is_any_of(" ,"), boost::algorithm::token_compress_on);
@ -1143,6 +1319,7 @@ try
ReadBufferFromFileDescriptor file_in(STDIN_FILENO);
WriteBufferFromFileDescriptor file_out(STDOUT_FILENO);
if (load_from_file.empty())
{
/// stdin must be seekable
auto res = lseek(file_in.getFD(), 0, SEEK_SET);
@ -1156,6 +1333,9 @@ try
/// Train step
UInt64 source_rows = 0;
bool rewind_needed = false;
if (load_from_file.empty())
{
if (!silent)
std::cerr << "Training models\n";
@ -1173,9 +1353,63 @@ try
if (!silent)
std::cerr << "Processed " << source_rows << " rows\n";
}
obfuscator.finalize();
rewind_needed = true;
}
else
{
ReadBufferFromFile model_file_in(load_from_file);
CompressedReadBuffer model_in(model_file_in);
UInt8 version = 0;
readBinary(version, model_in);
if (version != 0)
throw Exception("Unknown version of the model file", ErrorCodes::UNKNOWN_FORMAT_VERSION);
readBinary(source_rows, model_in);
Names data_types = header.getDataTypeNames();
size_t header_size = 0;
readBinary(header_size, model_in);
if (header_size != data_types.size())
throw Exception("The saved model was created for different number of columns", ErrorCodes::INCORRECT_NUMBER_OF_COLUMNS);
for (size_t i = 0; i < header_size; ++i)
{
String type;
readBinary(type, model_in);
if (type != data_types[i])
throw Exception("The saved model was created for different types of columns", ErrorCodes::TYPE_MISMATCH);
}
obfuscator.deserialize(model_in);
}
obfuscator.finalize();
if (!save_into_file.empty())
{
WriteBufferFromFile model_file_out(save_into_file);
CompressedWriteBuffer model_out(model_file_out, CompressionCodecFactory::instance().get("ZSTD", 1));
/// You can change version on format change, it is currently set to zero.
UInt8 version = 0;
writeBinary(version, model_out);
writeBinary(source_rows, model_out);
/// We are writing the data types for validation, because the models serialization depends on the data types.
Names data_types = header.getDataTypeNames();
size_t header_size = data_types.size();
writeBinary(header_size, model_out);
for (const auto & type : data_types)
writeBinary(type, model_out);
/// Write the models.
obfuscator.serialize(model_out);
model_out.finalize();
model_file_out.finalize();
}
if (!limit)
limit = source_rows;
@ -1187,7 +1421,8 @@ try
if (!silent)
std::cerr << "Generating data\n";
file_in.seek(0, SEEK_SET);
if (rewind_needed)
file_in.rewind();
Pipe pipe(context->getInputFormat(input_format, file_in, header, max_block_size));
@ -1220,6 +1455,7 @@ try
out_executor.finish();
obfuscator.updateSeed();
rewind_needed = true;
}
return 0;