def step(self, sess, batch, eval_op=None):
"""
One step for a batch
Either sgd training by setting `eval_op` to `self.update_op` or only evaluate the loss by leaving it to be `None`
:param sess: a tensorflow session
:param batch: a Batch object
:param eval_op: an operator in tensorflow
:return: vals: dict containing the values evaluated by `sess.run()`
"""
feed_dict = self.feed(batch)
fetch_dict = self.fetch(eval_op)
# run sess
vals = sess.run(fetch_dict, feed_dict, options=self.config.run_options, run_metadata=self.config.run_metadata)
# trace time consumption
# very slow and requires large memory
if self.config.time_trace:
tl = timeline.Timeline(self.config.run_metadata.step_stats)
ctf = tl.generate_chrome_trace_format()
with open(self.config.trace_filename, 'w') as f:
f.write(ctf)
print("time tracing output to " + self.config.trace_filename)
return vals
评论列表
文章目录