Data obfuscator: development [#CLICKHOUSE-2]

This commit is contained in:
Alexey Milovidov 2018-06-15 11:53:06 +03:00
parent 43a98634a9
commit bd5247864b

View File

@ -27,6 +27,7 @@
#include <IO/WriteBufferFromFileDescriptor.h>
#include <ext/bit_cast.h>
#include <memory>
#include <cmath>
#include <boost/program_options/options_description.hpp>
#include <boost/program_options.hpp>
#include <boost/algorithm/string.hpp>
@ -357,16 +358,19 @@ private:
using CodePoint = UInt32;
using NGramHash = UInt32;
struct HistogramElement
struct Bucket
{
CodePoint code;
UInt64 count;
Bucket(CodePoint code) : code(code), count(1) {}
};
struct Histogram
{
UInt32 total = 0;
std::vector<HistogramElement> data;
UInt64 total = 0; /// Not including count_end.
UInt64 count_end = 0;
std::vector<Bucket> data;
void add(CodePoint code)
{
@ -381,12 +385,21 @@ private:
}
}
data.emplace_back(HistogramElement{.code = code, .count = 1});
data.emplace_back(code);
}
UInt8 sample(UInt64 random) const
void addEnd()
{
random %= total;
++count_end;
}
CodePoint sample(UInt64 random, double end_multiplier) const
{
UInt64 range = total + UInt64(count_end * end_multiplier);
if (range == 0)
return END;
random %= range;
UInt64 sum = 0;
for (const auto & elem : data)
@ -396,7 +409,7 @@ private:
return elem.code;
}
__builtin_unreachable();
return END;
}
};
@ -404,9 +417,13 @@ private:
Table table;
size_t order;
size_t frequency_cutoff;
std::vector<CodePoint> code_points;
static constexpr CodePoint BEGIN = -1;
static constexpr CodePoint END = -2;
NGramHash hashContext(const CodePoint * begin, const CodePoint * end) const
{
@ -443,7 +460,8 @@ private:
}
public:
explicit MarkovModel(size_t order) : order(order), code_points(order, -1) {}
explicit MarkovModel(size_t order, size_t frequency_cutoff)
: order(order), frequency_cutoff(frequency_cutoff), code_points(order, BEGIN) {}
void consume(const char * data, size_t size)
{
@ -452,29 +470,80 @@ public:
const char * pos = data;
const char * end = data + size;
while (pos < end)
while (true)
{
code_points.push_back(readCodePoint(pos, end));
bool inside = pos < end;
CodePoint next_code_point;
if (inside)
next_code_point = readCodePoint(pos, end);
for (size_t context_size = 0; context_size < order; ++context_size)
table[hashContext(&code_points.back() - context_size, &code_points.back())].add(code_points.back());
{
NGramHash context_hash = hashContext(code_points.data() + code_points.size() - context_size, code_points.data() + code_points.size());
if (inside)
table[context_hash].add(next_code_point);
else /// if (context_size != 0 || order == 0) /// Don't allow to break string without context (except order-0 model).
table[context_hash].addEnd();
}
if (inside)
code_points.push_back(next_code_point);
else
break;
}
}
void finalize()
{
/// TODO: Clean low frequencies.
if (frequency_cutoff == 0)
return;
// size_t total_buckets = 0;
// size_t erased_buckets = 0;
for (auto & elem : table)
{
Histogram & histogram = elem.second;
// total_buckets += histogram.data.size();
if (histogram.total + histogram.count_end < frequency_cutoff)
{
// erased_buckets += histogram.data.size();
histogram.data.clear();
histogram.total = 0;
}
else
{
auto erased = std::remove_if(histogram.data.begin(), histogram.data.end(),
[frequency_cutoff=frequency_cutoff](const Bucket & bucket) { return bucket.count < frequency_cutoff; });
UInt64 erased_count = 0;
for (auto it = erased; it < histogram.data.end(); ++it)
erased_count += it->count;
// erased_buckets += histogram.data.end() - erased;
histogram.data.erase(erased, histogram.data.end());
histogram.total -= erased_count;
}
}
// std::cerr << "Erased " << erased_buckets << " out of " << total_buckets << " buckets\n";
}
size_t generate(char * data, size_t size,
size_t generate(char * data, size_t desired_size, size_t buffer_size,
UInt64 seed, const char * determinator_data, size_t determinator_size)
{
code_points.resize(order);
char * pos = data;
char * end = data + size;
char * end = data + buffer_size;
while (pos < end)
{
@ -484,7 +553,7 @@ public:
while (true)
{
it = table.find(hashContext(code_points.data() + code_points.size() - context_size, code_points.data() + code_points.size()));
if (table.end() != it)
if (table.end() != it && it->second.total + it->second.count_end != 0)
break;
if (context_size == 0)
@ -509,11 +578,21 @@ public:
hash.update(determinator_sliding_window_overflow);
UInt64 determinator = hash.get64();
CodePoint code = it->second.sample(determinator);
code_points.push_back(code);
/// If string is greater than desired_size, increase probability of end.
double end_probability_multiplier = 0;
Int64 num_bytes_after_desired_size = (pos - data) - desired_size;
if (num_bytes_after_desired_size)
end_probability_multiplier = std::pow(1.25, num_bytes_after_desired_size);
CodePoint code = it->second.sample(determinator, end_probability_multiplier);
if (code == END)
break;
if (!writeCodePoint(code, pos, end))
break;
code_points.push_back(code);
}
return pos - data;
@ -531,10 +610,10 @@ class StringModel : public IModel
{
private:
UInt64 seed;
MarkovModel markov_model{3};
MarkovModel markov_model;
public:
StringModel(UInt64 seed) : seed(seed) {}
StringModel(UInt64 seed, UInt8 order, UInt64 frequency_cutoff) : seed(seed), markov_model(order, frequency_cutoff) {}
void train(const IColumn & column) override
{
@ -543,7 +622,6 @@ public:
for (size_t i = 0; i < size; ++i)
{
std::cerr << i << "\n";
StringRef string = column_string.getDataAt(i);
markov_model.consume(string.data, string.size);
}
@ -551,7 +629,7 @@ public:
void finalize() override
{
/// TODO cut low frequencies
markov_model.finalize();
}
ColumnPtr generate(const IColumn & column) override
@ -567,11 +645,11 @@ public:
{
StringRef src_string = column_string.getDataAt(i);
size_t desired_string_size = transform(src_string.size, seed);
new_string.resize(desired_string_size);
new_string.resize(desired_string_size * 2);
size_t actual_size = 0;
if (desired_string_size != 0)
actual_size = markov_model.generate(new_string.data(), desired_string_size, seed, src_string.data, src_string.size);
actual_size = markov_model.generate(new_string.data(), desired_string_size, new_string.size(), seed, src_string.data, src_string.size);
res_column->insertData(new_string.data(), actual_size);
}
@ -650,7 +728,7 @@ public:
class ModelFactory
{
public:
ModelPtr get(const IDataType & data_type, UInt64 seed) const
ModelPtr get(const IDataType & data_type, UInt64 seed, UInt8 markov_model_order, UInt64 frequency_cutoff) const
{
if (data_type.isInteger())
{
@ -673,16 +751,16 @@ public:
return std::make_unique<DateTimeModel>(seed);
if (typeid_cast<const DataTypeString *>(&data_type))
return std::make_unique<StringModel>(seed);
return std::make_unique<StringModel>(seed, markov_model_order, frequency_cutoff);
if (typeid_cast<const DataTypeFixedString *>(&data_type))
return std::make_unique<FixedStringModel>(seed);
if (auto type = typeid_cast<const DataTypeArray *>(&data_type))
return std::make_unique<ArrayModel>(get(*type->getNestedType(), seed));
return std::make_unique<ArrayModel>(get(*type->getNestedType(), seed, markov_model_order, frequency_cutoff));
if (auto type = typeid_cast<const DataTypeNullable *>(&data_type))
return std::make_unique<NullableModel>(get(*type->getNestedType(), seed));
return std::make_unique<NullableModel>(get(*type->getNestedType(), seed, markov_model_order, frequency_cutoff));
throw Exception("Unsupported data type");
}
@ -695,7 +773,7 @@ private:
std::vector<ModelPtr> models;
public:
Anonymizer(const Block & header, UInt64 seed)
Anonymizer(const Block & header, UInt64 seed, UInt8 markov_model_order, UInt64 frequency_cutoff)
{
ModelFactory factory;
@ -703,7 +781,7 @@ public:
models.reserve(columns);
for (size_t i = 0; i < columns; ++i)
models.emplace_back(factory.get(*header.getByPosition(i).type, hash(seed, i)));
models.emplace_back(factory.get(*header.getByPosition(i).type, hash(seed, i), markov_model_order, frequency_cutoff));
}
void train(const Columns & columns)
@ -745,6 +823,8 @@ try
("input-format", po::value<std::string>(), "input format of the initial table data")
("output-format", po::value<std::string>(), "default output format")
("seed", po::value<std::string>(), "seed (arbitary string), must be random string with at least 10 bytes length")
("order", po::value<UInt64>()->default_value(5), "order of markov model to generate strings")
("cutoff", po::value<UInt64>()->default_value(5), "frequency cutoff for markov model")
;
po::parsed_options parsed = po::command_line_parser(argc, argv).options(description).run();
@ -763,6 +843,9 @@ try
std::string input_format = options["input-format"].as<std::string>();
std::string output_format = options["output-format"].as<std::string>();
UInt64 markov_model_order = options["order"].as<UInt64>();
UInt64 frequency_cutoff = options["cutoff"].as<UInt64>();
// Create header block
std::vector<std::string> structure_vals;
boost::split(structure_vals, structure, boost::algorithm::is_any_of(" ,"), boost::algorithm::token_compress_on);
@ -788,11 +871,12 @@ try
ReadBufferFromFileDescriptor file_in(STDIN_FILENO);
WriteBufferFromFileDescriptor file_out(STDOUT_FILENO);
Anonymizer anonymizer(header, seed);
Anonymizer anonymizer(header, seed, markov_model_order, frequency_cutoff);
size_t max_block_size = 8192;
/// Train step
std::cerr << "Training models\n";
{
BlockInputStreamPtr input = context.getInputFormat(input_format, file_in, header, max_block_size);
@ -805,6 +889,7 @@ try
anonymizer.finalize();
/// Generation step
std::cerr << "Generating data\n";
{
file_in.seek(0);