Merge pull request #39541 from ClickHouse/obfuscator-save-load

Add save/load capabilities to Obfuscator
This commit is contained in:
Alexey Milovidov 2022-07-25 21:18:28 +03:00 committed by GitHub
commit 75d0232265
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 261 additions and 6 deletions

View File

@ -34,6 +34,10 @@
#include <base/bit_cast.h>
#include <IO/ReadBufferFromFileDescriptor.h>
#include <IO/WriteBufferFromFileDescriptor.h>
#include <IO/ReadBufferFromFile.h>
#include <IO/WriteBufferFromFile.h>
#include <Compression/CompressedReadBuffer.h>
#include <Compression/CompressedWriteBuffer.h>
#include <memory>
#include <cmath>
#include <unistd.h>
@ -95,6 +99,9 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
extern const int NOT_IMPLEMENTED;
extern const int CANNOT_SEEK_THROUGH_FILE;
extern const int UNKNOWN_FORMAT_VERSION;
extern const int INCORRECT_NUMBER_OF_COLUMNS;
extern const int TYPE_MISMATCH;
}
@ -115,6 +122,12 @@ public:
/// Deterministically change seed to some other value. This can be used to generate more values than were in source.
virtual void updateSeed() = 0;
/// Save into file. Binary, platform-dependent, version-dependent serialization.
virtual void serialize(WriteBuffer & out) const = 0;
/// Read from file
virtual void deserialize(ReadBuffer & in) = 0;
virtual ~IModel() = default;
};
@ -189,6 +202,8 @@ public:
void train(const IColumn &) override {}
void finalize() override {}
void serialize(WriteBuffer &) const override {}
void deserialize(ReadBuffer &) override {}
ColumnPtr generate(const IColumn & column) override
{
@ -230,6 +245,8 @@ public:
void train(const IColumn &) override {}
void finalize() override {}
void serialize(WriteBuffer &) const override {}
void deserialize(ReadBuffer &) override {}
ColumnPtr generate(const IColumn & column) override
{
@ -279,6 +296,8 @@ public:
void train(const IColumn &) override {}
void finalize() override {}
void serialize(WriteBuffer &) const override {}
void deserialize(ReadBuffer &) override {}
ColumnPtr generate(const IColumn & column) override
{
@ -311,6 +330,8 @@ class IdentityModel : public IModel
public:
void train(const IColumn &) override {}
void finalize() override {}
void serialize(WriteBuffer &) const override {}
void deserialize(ReadBuffer &) override {}
ColumnPtr generate(const IColumn & column) override
{
@ -395,6 +416,8 @@ public:
void train(const IColumn &) override {}
void finalize() override {}
void serialize(WriteBuffer &) const override {}
void deserialize(ReadBuffer &) override {}
ColumnPtr generate(const IColumn & column) override
{
@ -431,6 +454,8 @@ public:
void train(const IColumn &) override {}
void finalize() override {}
void serialize(WriteBuffer &) const override {}
void deserialize(ReadBuffer &) override {}
ColumnPtr generate(const IColumn & column) override
{
@ -469,6 +494,8 @@ public:
void train(const IColumn &) override {}
void finalize() override {}
void serialize(WriteBuffer &) const override {}
void deserialize(ReadBuffer &) override {}
ColumnPtr generate(const IColumn & column) override
{
@ -512,6 +539,26 @@ struct MarkovModelParameters
size_t frequency_add;
double frequency_desaturate;
size_t determinator_sliding_window_size;
void serialize(WriteBuffer & out) const
{
writeBinary(order, out);
writeBinary(frequency_cutoff, out);
writeBinary(num_buckets_cutoff, out);
writeBinary(frequency_add, out);
writeBinary(frequency_desaturate, out);
writeBinary(determinator_sliding_window_size, out);
}
void deserialize(ReadBuffer & in)
{
readBinary(order, in);
readBinary(frequency_cutoff, in);
readBinary(num_buckets_cutoff, in);
readBinary(frequency_add, in);
readBinary(frequency_desaturate, in);
readBinary(determinator_sliding_window_size, in);
}
};
@ -565,6 +612,39 @@ private:
return END;
}
void serialize(WriteBuffer & out) const
{
writeBinary(total, out);
writeBinary(count_end, out);
size_t size = buckets.size();
writeBinary(size, out);
for (const auto & elem : buckets)
{
writeBinary(elem.first, out);
writeBinary(elem.second, out);
}
}
void deserialize(ReadBuffer & in)
{
readBinary(total, in);
readBinary(count_end, in);
size_t size = 0;
readBinary(size, in);
buckets.reserve(size);
for (size_t i = 0; i < size; ++i)
{
Buckets::value_type elem;
readBinary(elem.first, in);
readBinary(elem.second, in);
buckets.emplace(std::move(elem));
}
}
};
using Table = HashMap<NGramHash, Histogram, TrivialHash>;
@ -621,6 +701,37 @@ public:
explicit MarkovModel(MarkovModelParameters params_)
: params(std::move(params_)), code_points(params.order, BEGIN) {}
void serialize(WriteBuffer & out) const
{
params.serialize(out);
size_t size = table.size();
writeBinary(size, out);
for (const auto & elem : table)
{
writeBinary(elem.getKey(), out);
elem.getMapped().serialize(out);
}
}
void deserialize(ReadBuffer & in)
{
params.deserialize(in);
size_t size = 0;
readBinary(size, in);
table.reserve(size);
for (size_t i = 0; i < size; ++i)
{
NGramHash key{};
readBinary(key, in);
Histogram & histogram = table[key];
histogram.deserialize(in);
}
}
void consume(const char * data, size_t size)
{
/// First 'order' number of code points are pre-filled with BEGIN.
@ -655,7 +766,6 @@ public:
}
}
void finalize()
{
if (params.num_buckets_cutoff)
@ -878,6 +988,16 @@ public:
{
seed = hash(seed);
}
void serialize(WriteBuffer & out) const override
{
markov_model.serialize(out);
}
void deserialize(ReadBuffer & in) override
{
markov_model.deserialize(in);
}
};
@ -916,6 +1036,16 @@ public:
{
nested_model->updateSeed();
}
void serialize(WriteBuffer & out) const override
{
nested_model->serialize(out);
}
void deserialize(ReadBuffer & in) override
{
nested_model->deserialize(in);
}
};
@ -954,6 +1084,16 @@ public:
{
nested_model->updateSeed();
}
void serialize(WriteBuffer & out) const override
{
nested_model->serialize(out);
}
void deserialize(ReadBuffer & in) override
{
nested_model->deserialize(in);
}
};
@ -1046,6 +1186,18 @@ public:
for (auto & model : models)
model->updateSeed();
}
void serialize(WriteBuffer & out) const
{
for (const auto & model : models)
model->serialize(out);
}
void deserialize(ReadBuffer & in)
{
for (auto & model : models)
model->deserialize(in);
}
};
}
@ -1068,8 +1220,10 @@ try
("input-format", po::value<std::string>(), "input format of the initial table data")
("output-format", po::value<std::string>(), "default output format")
("seed", po::value<std::string>(), "seed (arbitrary string), must be random string with at least 10 bytes length; note that a seed for each column is derived from this seed and a column name: you can obfuscate data for different tables and as long as you use identical seed and identical column names, the data for corresponding non-text columns for different tables will be transformed in the same way, so the data for different tables can be JOINed after obfuscation")
("limit", po::value<UInt64>(), "if specified - stop after generating that number of rows")
("limit", po::value<UInt64>(), "if specified - stop after generating that number of rows; the limit can be also greater than the number of source dataset - in this case it will process the dataset in a loop more than one time, using different seeds on every iteration, generating result as large as needed")
("silent", po::value<bool>()->default_value(false), "don't print information messages to stderr")
("save", po::value<std::string>(), "save the models after training to the specified file. You can use --limit 0 to skip the generation step. The file is using binary, platform-dependent, opaque serialization format. The model parameters are saved, while the seed is not.")
("load", po::value<std::string>(), "load the models instead of training from the specified file. The table structure must match the saved file. The seed should be specified separately, while other model parameters are loaded.")
("order", po::value<UInt64>()->default_value(5), "order of markov model to generate strings")
("frequency-cutoff", po::value<UInt64>()->default_value(5), "frequency cutoff for markov model: remove all buckets with count less than specified")
("num-buckets-cutoff", po::value<UInt64>()->default_value(0), "cutoff for number of different possible continuations for a context: remove all histograms with less than specified number of buckets")
@ -1096,12 +1250,26 @@ try
return 0;
}
if (options.count("save") && options.count("load"))
{
std::cerr << "The options --save and --load cannot be used together.\n";
return 1;
}
UInt64 seed = sipHash64(options["seed"].as<std::string>());
std::string structure = options["structure"].as<std::string>();
std::string input_format = options["input-format"].as<std::string>();
std::string output_format = options["output-format"].as<std::string>();
std::string load_from_file;
std::string save_into_file;
if (options.count("load"))
load_from_file = options["load"].as<std::string>();
else if (options.count("save"))
save_into_file = options["save"].as<std::string>();
UInt64 limit = 0;
if (options.count("limit"))
limit = options["limit"].as<UInt64>();
@ -1117,7 +1285,7 @@ try
markov_model_params.frequency_desaturate = options["frequency-desaturate"].as<double>();
markov_model_params.determinator_sliding_window_size = options["determinator-sliding-window-size"].as<UInt64>();
// Create header block
/// Create the header block
std::vector<std::string> structure_vals;
boost::split(structure_vals, structure, boost::algorithm::is_any_of(" ,"), boost::algorithm::token_compress_on);
@ -1143,6 +1311,7 @@ try
ReadBufferFromFileDescriptor file_in(STDIN_FILENO);
WriteBufferFromFileDescriptor file_out(STDOUT_FILENO);
if (load_from_file.empty())
{
/// stdin must be seekable
auto res = lseek(file_in.getFD(), 0, SEEK_SET);
@ -1156,6 +1325,9 @@ try
/// Train step
UInt64 source_rows = 0;
bool rewind_needed = false;
if (load_from_file.empty())
{
if (!silent)
std::cerr << "Training models\n";
@ -1173,11 +1345,71 @@ try
if (!silent)
std::cerr << "Processed " << source_rows << " rows\n";
}
}
obfuscator.finalize();
rewind_needed = true;
}
else
{
if (!silent)
std::cerr << "Loading models\n";
if (!limit)
ReadBufferFromFile model_file_in(load_from_file);
CompressedReadBuffer model_in(model_file_in);
UInt8 version = 0;
readBinary(version, model_in);
if (version != 0)
throw Exception("Unknown version of the model file", ErrorCodes::UNKNOWN_FORMAT_VERSION);
readBinary(source_rows, model_in);
Names data_types = header.getDataTypeNames();
size_t header_size = 0;
readBinary(header_size, model_in);
if (header_size != data_types.size())
throw Exception("The saved model was created for different number of columns", ErrorCodes::INCORRECT_NUMBER_OF_COLUMNS);
for (size_t i = 0; i < header_size; ++i)
{
String type;
readBinary(type, model_in);
if (type != data_types[i])
throw Exception("The saved model was created for different types of columns", ErrorCodes::TYPE_MISMATCH);
}
obfuscator.deserialize(model_in);
}
if (!save_into_file.empty())
{
if (!silent)
std::cerr << "Saving models\n";
WriteBufferFromFile model_file_out(save_into_file);
CompressedWriteBuffer model_out(model_file_out, CompressionCodecFactory::instance().get("ZSTD", 1));
/// You can change version on format change, it is currently set to zero.
UInt8 version = 0;
writeBinary(version, model_out);
writeBinary(source_rows, model_out);
/// We are writing the data types for validation, because the models serialization depends on the data types.
Names data_types = header.getDataTypeNames();
size_t header_size = data_types.size();
writeBinary(header_size, model_out);
for (const auto & type : data_types)
writeBinary(type, model_out);
/// Write the models.
obfuscator.serialize(model_out);
model_out.finalize();
model_file_out.finalize();
}
if (!options.count("limit"))
limit = source_rows;
/// Generation step
@ -1187,7 +1419,8 @@ try
if (!silent)
std::cerr << "Generating data\n";
file_in.seek(0, SEEK_SET);
if (rewind_needed)
file_in.rewind();
Pipe pipe(context->getInputFormat(input_format, file_in, header, max_block_size));
@ -1220,6 +1453,7 @@ try
out_executor.finish();
obfuscator.updateSeed();
rewind_needed = true;
}
return 0;

View File

@ -0,0 +1,3 @@
403499
1000 320 171 23
2500 569 354 13

View File

@ -0,0 +1,18 @@
#!/usr/bin/env bash
CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
# shellcheck source=../shell_config.sh
. "$CURDIR"/../shell_config.sh
$CLICKHOUSE_CLIENT --max_threads 1 --query="SELECT URL, Title, SearchPhrase FROM test.hits LIMIT 1000" > "${CLICKHOUSE_TMP}"/data.tsv
$CLICKHOUSE_OBFUSCATOR --structure "URL String, Title String, SearchPhrase String" --input-format TSV --output-format TSV --seed hello --limit 0 --save "${CLICKHOUSE_TMP}"/model.bin < "${CLICKHOUSE_TMP}"/data.tsv 2>/dev/null
wc -c < "${CLICKHOUSE_TMP}"/model.bin
$CLICKHOUSE_OBFUSCATOR --structure "URL String, Title String, SearchPhrase String" --input-format TSV --output-format TSV --seed hello --limit 2500 --load "${CLICKHOUSE_TMP}"/model.bin < "${CLICKHOUSE_TMP}"/data.tsv > "${CLICKHOUSE_TMP}"/data2500.tsv 2>/dev/null
rm "${CLICKHOUSE_TMP}"/model.bin
$CLICKHOUSE_LOCAL --structure "URL String, Title String, SearchPhrase String" --input-format TSV --output-format TSV --query "SELECT count(), uniq(URL), uniq(Title), uniq(SearchPhrase) FROM table" < "${CLICKHOUSE_TMP}"/data.tsv
$CLICKHOUSE_LOCAL --structure "URL String, Title String, SearchPhrase String" --input-format TSV --output-format TSV --query "SELECT count(), uniq(URL), uniq(Title), uniq(SearchPhrase) FROM table" < "${CLICKHOUSE_TMP}"/data2500.tsv
rm "${CLICKHOUSE_TMP}"/data.tsv
rm "${CLICKHOUSE_TMP}"/data2500.tsv