ClickHouse/dbms/scripts/gen_benchmark_data/train.py
2018-05-29 19:14:03 +03:00

27 lines
1.0 KiB
Python

import argparse
from model import Model
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--n_iter', type=int, default=10000,
help='number of iterations')
parser.add_argument('--save_dir', type=str, default='save',
help='dir for saving weights')
parser.add_argument('--data_path', type=str,
help='path to train data')
parser.add_argument('--learning_rate', type=int, default=0.0001,
help='learning rate')
parser.add_argument('--batch_size', type=int, default=64,
help='batch size')
parser.add_argument('--restore_from', type=str,
help='path to train saved weights')
args = parser.parse_args()
if __name__ == '__main__':
if not args.data_path:
raise Exception('please specify path to train data with --data_path')
gen = Model(args.learning_rate)
gen.train(args.data_path, args.save_dir, args.n_iter, args.batch_size, args.restore_from)