def trainer(model_params):
"""Train a sketch-rnn model."""
np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)
tf.logging.info('sketch-rnn')
tf.logging.info('Hyperparams:')
for key, val in model_params.values().iteritems():
tf.logging.info('%s = %s', key, str(val))
tf.logging.info('Loading data files.')
datasets = load_dataset(FLAGS.data_dir, model_params)
train_set = datasets[0]
valid_set = datasets[1]
test_set = datasets[2]
model_params = datasets[3]
eval_model_params = datasets[4]
reset_graph()
model = sketch_rnn_model.Model(model_params)
eval_model = sketch_rnn_model.Model(eval_model_params, reuse=True)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
if FLAGS.resume_training:
load_checkpoint(sess, FLAGS.log_root)
# Write config file to json file.
tf.gfile.MakeDirs(FLAGS.log_root)
with tf.gfile.Open(
os.path.join(FLAGS.log_root, 'model_config.json'), 'w') as f:
json.dump(model_params.values(), f, indent=True)
train(sess, model, eval_model, train_set, valid_set, test_set)
sketch_rnn_class.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录