def main(_):
config = flags.FLAGS.__flags.copy()
config.update(json.loads(config['config']))
del config['config']
if config['results_dir'] == '':
del config['results_dir']
if config['task'] == 'search':
# Hyperparameter search cannot be continued, so a new results dir is created.
config['results_dir'] = os.path.join(results_dir, 'hs', config['model_name'] \
+ time.strftime('_%Y-%m-%d_%H-%M-%S', time.gmtime()))
hb = Hyperband(config)
results = hb.run()
else:
model = make_model(config)
if config['task'] == 'train':
model.train()
elif config['task'] == 'test':
model.test()
else:
print('Invalid argument: --task=%s. ' \
+ 'It should be either of {train, test, search}.' % config['task'])
评论列表
文章目录