mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-17 13:13:36 +00:00
27 lines
1.0 KiB
Python
27 lines
1.0 KiB
Python
|
import argparse
|
||
|
|
||
|
from model import Model
|
||
|
parser = argparse.ArgumentParser(
|
||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||
|
parser.add_argument('--n_iter', type=int, default=10000,
|
||
|
help='number of iterations')
|
||
|
parser.add_argument('--save_dir', type=str, default='save',
|
||
|
help='dir for saving weights')
|
||
|
parser.add_argument('--data_path', type=str,
|
||
|
help='path to train data')
|
||
|
parser.add_argument('--learning_rate', type=int, default=0.0001,
|
||
|
help='learning rate')
|
||
|
parser.add_argument('--batch_size', type=int, default=64,
|
||
|
help='batch size')
|
||
|
parser.add_argument('--restore_from', type=str,
|
||
|
help='path to train saved weights')
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
if not args.data_path:
|
||
|
raise Exception('please specify path to train data with --data_path')
|
||
|
|
||
|
gen = Model(args.learning_rate)
|
||
|
gen.train(args.data_path, args.save_dir, args.n_iter, args.batch_size, args.restore_from)
|