def train(self, nIter, machine=None, summary_op=None):
# Xh = self._validate(machine=machine, n=10)
run_metadata = tf.RunMetadata()
sv = tf.train.Supervisor(
logdir=self.dirs['logdir'],
# summary_writer=summary_writer,
# summary_op=None,
# is_chief=True,
save_model_secs=300,
global_step=self.opt['global_step'])
# sess_config = configure_gpu_settings(args.gpu_cfg)
sess_config = tf.ConfigProto(
allow_soft_placement=True,
gpu_options=tf.GPUOptions(allow_growth=True))
with sv.managed_session(config=sess_config) as sess:
sv.loop(60, self._refresh_status, (sess,))
for step in range(self.arch['training']['max_iter']):
if sv.should_stop():
break
# main loop
sess.run(self.opt['g'])
# # output img
# if step % 1000 == 0:
# xh = sess.run(Xh)
# with tf.gfile.GFile(
# os.path.join(
# self.dirs['logdir'],
# 'img-anime-{:03d}k.png'.format(step // 1000),
# ),
# mode='wb',
# ) as fp:
# fp.write(xh)
评论列表
文章目录