mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
Merge pull request #39541 from ClickHouse/obfuscator-save-load
Add save/load capabilities to Obfuscator
This commit is contained in:
commit
75d0232265
@ -34,6 +34,10 @@
|
||||
#include <base/bit_cast.h>
|
||||
#include <IO/ReadBufferFromFileDescriptor.h>
|
||||
#include <IO/WriteBufferFromFileDescriptor.h>
|
||||
#include <IO/ReadBufferFromFile.h>
|
||||
#include <IO/WriteBufferFromFile.h>
|
||||
#include <Compression/CompressedReadBuffer.h>
|
||||
#include <Compression/CompressedWriteBuffer.h>
|
||||
#include <memory>
|
||||
#include <cmath>
|
||||
#include <unistd.h>
|
||||
@ -95,6 +99,9 @@ namespace ErrorCodes
|
||||
extern const int LOGICAL_ERROR;
|
||||
extern const int NOT_IMPLEMENTED;
|
||||
extern const int CANNOT_SEEK_THROUGH_FILE;
|
||||
extern const int UNKNOWN_FORMAT_VERSION;
|
||||
extern const int INCORRECT_NUMBER_OF_COLUMNS;
|
||||
extern const int TYPE_MISMATCH;
|
||||
}
|
||||
|
||||
|
||||
@ -115,6 +122,12 @@ public:
|
||||
/// Deterministically change seed to some other value. This can be used to generate more values than were in source.
|
||||
virtual void updateSeed() = 0;
|
||||
|
||||
/// Save into file. Binary, platform-dependent, version-dependent serialization.
|
||||
virtual void serialize(WriteBuffer & out) const = 0;
|
||||
|
||||
/// Read from file
|
||||
virtual void deserialize(ReadBuffer & in) = 0;
|
||||
|
||||
virtual ~IModel() = default;
|
||||
};
|
||||
|
||||
@ -189,6 +202,8 @@ public:
|
||||
|
||||
void train(const IColumn &) override {}
|
||||
void finalize() override {}
|
||||
void serialize(WriteBuffer &) const override {}
|
||||
void deserialize(ReadBuffer &) override {}
|
||||
|
||||
ColumnPtr generate(const IColumn & column) override
|
||||
{
|
||||
@ -230,6 +245,8 @@ public:
|
||||
|
||||
void train(const IColumn &) override {}
|
||||
void finalize() override {}
|
||||
void serialize(WriteBuffer &) const override {}
|
||||
void deserialize(ReadBuffer &) override {}
|
||||
|
||||
ColumnPtr generate(const IColumn & column) override
|
||||
{
|
||||
@ -279,6 +296,8 @@ public:
|
||||
|
||||
void train(const IColumn &) override {}
|
||||
void finalize() override {}
|
||||
void serialize(WriteBuffer &) const override {}
|
||||
void deserialize(ReadBuffer &) override {}
|
||||
|
||||
ColumnPtr generate(const IColumn & column) override
|
||||
{
|
||||
@ -311,6 +330,8 @@ class IdentityModel : public IModel
|
||||
public:
|
||||
void train(const IColumn &) override {}
|
||||
void finalize() override {}
|
||||
void serialize(WriteBuffer &) const override {}
|
||||
void deserialize(ReadBuffer &) override {}
|
||||
|
||||
ColumnPtr generate(const IColumn & column) override
|
||||
{
|
||||
@ -395,6 +416,8 @@ public:
|
||||
|
||||
void train(const IColumn &) override {}
|
||||
void finalize() override {}
|
||||
void serialize(WriteBuffer &) const override {}
|
||||
void deserialize(ReadBuffer &) override {}
|
||||
|
||||
ColumnPtr generate(const IColumn & column) override
|
||||
{
|
||||
@ -431,6 +454,8 @@ public:
|
||||
|
||||
void train(const IColumn &) override {}
|
||||
void finalize() override {}
|
||||
void serialize(WriteBuffer &) const override {}
|
||||
void deserialize(ReadBuffer &) override {}
|
||||
|
||||
ColumnPtr generate(const IColumn & column) override
|
||||
{
|
||||
@ -469,6 +494,8 @@ public:
|
||||
|
||||
void train(const IColumn &) override {}
|
||||
void finalize() override {}
|
||||
void serialize(WriteBuffer &) const override {}
|
||||
void deserialize(ReadBuffer &) override {}
|
||||
|
||||
ColumnPtr generate(const IColumn & column) override
|
||||
{
|
||||
@ -512,6 +539,26 @@ struct MarkovModelParameters
|
||||
size_t frequency_add;
|
||||
double frequency_desaturate;
|
||||
size_t determinator_sliding_window_size;
|
||||
|
||||
void serialize(WriteBuffer & out) const
|
||||
{
|
||||
writeBinary(order, out);
|
||||
writeBinary(frequency_cutoff, out);
|
||||
writeBinary(num_buckets_cutoff, out);
|
||||
writeBinary(frequency_add, out);
|
||||
writeBinary(frequency_desaturate, out);
|
||||
writeBinary(determinator_sliding_window_size, out);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & in)
|
||||
{
|
||||
readBinary(order, in);
|
||||
readBinary(frequency_cutoff, in);
|
||||
readBinary(num_buckets_cutoff, in);
|
||||
readBinary(frequency_add, in);
|
||||
readBinary(frequency_desaturate, in);
|
||||
readBinary(determinator_sliding_window_size, in);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -565,6 +612,39 @@ private:
|
||||
|
||||
return END;
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & out) const
|
||||
{
|
||||
writeBinary(total, out);
|
||||
writeBinary(count_end, out);
|
||||
|
||||
size_t size = buckets.size();
|
||||
writeBinary(size, out);
|
||||
|
||||
for (const auto & elem : buckets)
|
||||
{
|
||||
writeBinary(elem.first, out);
|
||||
writeBinary(elem.second, out);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & in)
|
||||
{
|
||||
readBinary(total, in);
|
||||
readBinary(count_end, in);
|
||||
|
||||
size_t size = 0;
|
||||
readBinary(size, in);
|
||||
|
||||
buckets.reserve(size);
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
Buckets::value_type elem;
|
||||
readBinary(elem.first, in);
|
||||
readBinary(elem.second, in);
|
||||
buckets.emplace(std::move(elem));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using Table = HashMap<NGramHash, Histogram, TrivialHash>;
|
||||
@ -621,6 +701,37 @@ public:
|
||||
explicit MarkovModel(MarkovModelParameters params_)
|
||||
: params(std::move(params_)), code_points(params.order, BEGIN) {}
|
||||
|
||||
void serialize(WriteBuffer & out) const
|
||||
{
|
||||
params.serialize(out);
|
||||
|
||||
size_t size = table.size();
|
||||
writeBinary(size, out);
|
||||
|
||||
for (const auto & elem : table)
|
||||
{
|
||||
writeBinary(elem.getKey(), out);
|
||||
elem.getMapped().serialize(out);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & in)
|
||||
{
|
||||
params.deserialize(in);
|
||||
|
||||
size_t size = 0;
|
||||
readBinary(size, in);
|
||||
|
||||
table.reserve(size);
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
NGramHash key{};
|
||||
readBinary(key, in);
|
||||
Histogram & histogram = table[key];
|
||||
histogram.deserialize(in);
|
||||
}
|
||||
}
|
||||
|
||||
void consume(const char * data, size_t size)
|
||||
{
|
||||
/// First 'order' number of code points are pre-filled with BEGIN.
|
||||
@ -655,7 +766,6 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void finalize()
|
||||
{
|
||||
if (params.num_buckets_cutoff)
|
||||
@ -878,6 +988,16 @@ public:
|
||||
{
|
||||
seed = hash(seed);
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & out) const override
|
||||
{
|
||||
markov_model.serialize(out);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & in) override
|
||||
{
|
||||
markov_model.deserialize(in);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -916,6 +1036,16 @@ public:
|
||||
{
|
||||
nested_model->updateSeed();
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & out) const override
|
||||
{
|
||||
nested_model->serialize(out);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & in) override
|
||||
{
|
||||
nested_model->deserialize(in);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -954,6 +1084,16 @@ public:
|
||||
{
|
||||
nested_model->updateSeed();
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & out) const override
|
||||
{
|
||||
nested_model->serialize(out);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & in) override
|
||||
{
|
||||
nested_model->deserialize(in);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -1046,6 +1186,18 @@ public:
|
||||
for (auto & model : models)
|
||||
model->updateSeed();
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & out) const
|
||||
{
|
||||
for (const auto & model : models)
|
||||
model->serialize(out);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & in)
|
||||
{
|
||||
for (auto & model : models)
|
||||
model->deserialize(in);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
@ -1068,8 +1220,10 @@ try
|
||||
("input-format", po::value<std::string>(), "input format of the initial table data")
|
||||
("output-format", po::value<std::string>(), "default output format")
|
||||
("seed", po::value<std::string>(), "seed (arbitrary string), must be random string with at least 10 bytes length; note that a seed for each column is derived from this seed and a column name: you can obfuscate data for different tables and as long as you use identical seed and identical column names, the data for corresponding non-text columns for different tables will be transformed in the same way, so the data for different tables can be JOINed after obfuscation")
|
||||
("limit", po::value<UInt64>(), "if specified - stop after generating that number of rows")
|
||||
("limit", po::value<UInt64>(), "if specified - stop after generating that number of rows; the limit can be also greater than the number of source dataset - in this case it will process the dataset in a loop more than one time, using different seeds on every iteration, generating result as large as needed")
|
||||
("silent", po::value<bool>()->default_value(false), "don't print information messages to stderr")
|
||||
("save", po::value<std::string>(), "save the models after training to the specified file. You can use --limit 0 to skip the generation step. The file is using binary, platform-dependent, opaque serialization format. The model parameters are saved, while the seed is not.")
|
||||
("load", po::value<std::string>(), "load the models instead of training from the specified file. The table structure must match the saved file. The seed should be specified separately, while other model parameters are loaded.")
|
||||
("order", po::value<UInt64>()->default_value(5), "order of markov model to generate strings")
|
||||
("frequency-cutoff", po::value<UInt64>()->default_value(5), "frequency cutoff for markov model: remove all buckets with count less than specified")
|
||||
("num-buckets-cutoff", po::value<UInt64>()->default_value(0), "cutoff for number of different possible continuations for a context: remove all histograms with less than specified number of buckets")
|
||||
@ -1096,12 +1250,26 @@ try
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.count("save") && options.count("load"))
|
||||
{
|
||||
std::cerr << "The options --save and --load cannot be used together.\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
UInt64 seed = sipHash64(options["seed"].as<std::string>());
|
||||
|
||||
std::string structure = options["structure"].as<std::string>();
|
||||
std::string input_format = options["input-format"].as<std::string>();
|
||||
std::string output_format = options["output-format"].as<std::string>();
|
||||
|
||||
std::string load_from_file;
|
||||
std::string save_into_file;
|
||||
|
||||
if (options.count("load"))
|
||||
load_from_file = options["load"].as<std::string>();
|
||||
else if (options.count("save"))
|
||||
save_into_file = options["save"].as<std::string>();
|
||||
|
||||
UInt64 limit = 0;
|
||||
if (options.count("limit"))
|
||||
limit = options["limit"].as<UInt64>();
|
||||
@ -1117,7 +1285,7 @@ try
|
||||
markov_model_params.frequency_desaturate = options["frequency-desaturate"].as<double>();
|
||||
markov_model_params.determinator_sliding_window_size = options["determinator-sliding-window-size"].as<UInt64>();
|
||||
|
||||
// Create header block
|
||||
/// Create the header block
|
||||
std::vector<std::string> structure_vals;
|
||||
boost::split(structure_vals, structure, boost::algorithm::is_any_of(" ,"), boost::algorithm::token_compress_on);
|
||||
|
||||
@ -1143,6 +1311,7 @@ try
|
||||
ReadBufferFromFileDescriptor file_in(STDIN_FILENO);
|
||||
WriteBufferFromFileDescriptor file_out(STDOUT_FILENO);
|
||||
|
||||
if (load_from_file.empty())
|
||||
{
|
||||
/// stdin must be seekable
|
||||
auto res = lseek(file_in.getFD(), 0, SEEK_SET);
|
||||
@ -1156,6 +1325,9 @@ try
|
||||
|
||||
/// Train step
|
||||
UInt64 source_rows = 0;
|
||||
|
||||
bool rewind_needed = false;
|
||||
if (load_from_file.empty())
|
||||
{
|
||||
if (!silent)
|
||||
std::cerr << "Training models\n";
|
||||
@ -1173,11 +1345,71 @@ try
|
||||
if (!silent)
|
||||
std::cerr << "Processed " << source_rows << " rows\n";
|
||||
}
|
||||
|
||||
obfuscator.finalize();
|
||||
rewind_needed = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!silent)
|
||||
std::cerr << "Loading models\n";
|
||||
|
||||
ReadBufferFromFile model_file_in(load_from_file);
|
||||
CompressedReadBuffer model_in(model_file_in);
|
||||
|
||||
UInt8 version = 0;
|
||||
readBinary(version, model_in);
|
||||
if (version != 0)
|
||||
throw Exception("Unknown version of the model file", ErrorCodes::UNKNOWN_FORMAT_VERSION);
|
||||
|
||||
readBinary(source_rows, model_in);
|
||||
|
||||
Names data_types = header.getDataTypeNames();
|
||||
size_t header_size = 0;
|
||||
readBinary(header_size, model_in);
|
||||
if (header_size != data_types.size())
|
||||
throw Exception("The saved model was created for different number of columns", ErrorCodes::INCORRECT_NUMBER_OF_COLUMNS);
|
||||
|
||||
for (size_t i = 0; i < header_size; ++i)
|
||||
{
|
||||
String type;
|
||||
readBinary(type, model_in);
|
||||
if (type != data_types[i])
|
||||
throw Exception("The saved model was created for different types of columns", ErrorCodes::TYPE_MISMATCH);
|
||||
}
|
||||
|
||||
obfuscator.deserialize(model_in);
|
||||
}
|
||||
|
||||
obfuscator.finalize();
|
||||
if (!save_into_file.empty())
|
||||
{
|
||||
if (!silent)
|
||||
std::cerr << "Saving models\n";
|
||||
|
||||
if (!limit)
|
||||
WriteBufferFromFile model_file_out(save_into_file);
|
||||
CompressedWriteBuffer model_out(model_file_out, CompressionCodecFactory::instance().get("ZSTD", 1));
|
||||
|
||||
/// You can change version on format change, it is currently set to zero.
|
||||
UInt8 version = 0;
|
||||
writeBinary(version, model_out);
|
||||
|
||||
writeBinary(source_rows, model_out);
|
||||
|
||||
/// We are writing the data types for validation, because the models serialization depends on the data types.
|
||||
Names data_types = header.getDataTypeNames();
|
||||
size_t header_size = data_types.size();
|
||||
writeBinary(header_size, model_out);
|
||||
for (const auto & type : data_types)
|
||||
writeBinary(type, model_out);
|
||||
|
||||
/// Write the models.
|
||||
obfuscator.serialize(model_out);
|
||||
|
||||
model_out.finalize();
|
||||
model_file_out.finalize();
|
||||
}
|
||||
|
||||
if (!options.count("limit"))
|
||||
limit = source_rows;
|
||||
|
||||
/// Generation step
|
||||
@ -1187,7 +1419,8 @@ try
|
||||
if (!silent)
|
||||
std::cerr << "Generating data\n";
|
||||
|
||||
file_in.seek(0, SEEK_SET);
|
||||
if (rewind_needed)
|
||||
file_in.rewind();
|
||||
|
||||
Pipe pipe(context->getInputFormat(input_format, file_in, header, max_block_size));
|
||||
|
||||
@ -1220,6 +1453,7 @@ try
|
||||
out_executor.finish();
|
||||
|
||||
obfuscator.updateSeed();
|
||||
rewind_needed = true;
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
@ -0,0 +1,3 @@
|
||||
403499
|
||||
1000 320 171 23
|
||||
2500 569 354 13
|
18
tests/queries/1_stateful/00096_obfuscator_save_load.sh
Executable file
18
tests/queries/1_stateful/00096_obfuscator_save_load.sh
Executable file
@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
||||
# shellcheck source=../shell_config.sh
|
||||
. "$CURDIR"/../shell_config.sh
|
||||
|
||||
$CLICKHOUSE_CLIENT --max_threads 1 --query="SELECT URL, Title, SearchPhrase FROM test.hits LIMIT 1000" > "${CLICKHOUSE_TMP}"/data.tsv
|
||||
|
||||
$CLICKHOUSE_OBFUSCATOR --structure "URL String, Title String, SearchPhrase String" --input-format TSV --output-format TSV --seed hello --limit 0 --save "${CLICKHOUSE_TMP}"/model.bin < "${CLICKHOUSE_TMP}"/data.tsv 2>/dev/null
|
||||
wc -c < "${CLICKHOUSE_TMP}"/model.bin
|
||||
$CLICKHOUSE_OBFUSCATOR --structure "URL String, Title String, SearchPhrase String" --input-format TSV --output-format TSV --seed hello --limit 2500 --load "${CLICKHOUSE_TMP}"/model.bin < "${CLICKHOUSE_TMP}"/data.tsv > "${CLICKHOUSE_TMP}"/data2500.tsv 2>/dev/null
|
||||
rm "${CLICKHOUSE_TMP}"/model.bin
|
||||
|
||||
$CLICKHOUSE_LOCAL --structure "URL String, Title String, SearchPhrase String" --input-format TSV --output-format TSV --query "SELECT count(), uniq(URL), uniq(Title), uniq(SearchPhrase) FROM table" < "${CLICKHOUSE_TMP}"/data.tsv
|
||||
$CLICKHOUSE_LOCAL --structure "URL String, Title String, SearchPhrase String" --input-format TSV --output-format TSV --query "SELECT count(), uniq(URL), uniq(Title), uniq(SearchPhrase) FROM table" < "${CLICKHOUSE_TMP}"/data2500.tsv
|
||||
|
||||
rm "${CLICKHOUSE_TMP}"/data.tsv
|
||||
rm "${CLICKHOUSE_TMP}"/data2500.tsv
|
Loading…
Reference in New Issue
Block a user