trajmodel.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:RNN-TrajModel 作者: wuhao5688 项目源码 文件源码
def step(self, sess, batch, eval_op=None):
    """
    One step for a batch
    Either sgd training by setting `eval_op` to `self.update_op` or only evaluate the loss by leaving it to be `None`
    :param sess: a tensorflow session
    :param batch: a Batch object
    :param eval_op: an operator in tensorflow
    :return: vals: dict containing the values evaluated by `sess.run()`
    """
    feed_dict = self.feed(batch)
    fetch_dict = self.fetch(eval_op)
    # run sess
    vals = sess.run(fetch_dict, feed_dict, options=self.config.run_options, run_metadata=self.config.run_metadata)
    # trace time consumption
    # very slow and requires large memory
    if self.config.time_trace:
      tl = timeline.Timeline(self.config.run_metadata.step_stats)
      ctf = tl.generate_chrome_trace_format()
      with open(self.config.trace_filename, 'w') as f:
        f.write(ctf)
        print("time tracing output to " + self.config.trace_filename)
    return vals
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号