srl.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:tag_srl 作者: danfriedman0 项目源码 文件源码
def run_training_batch(self, session, batch):
        """
        A batch contains input tensors for words, pos, lemmas, preds,
          preds_idx, and labels (in that order)
        Runs the model on the batch (through train_op if train=True)
        Returns the loss
        """
        feed_dict = self.batch_to_feed(batch)
        feed_dict[self.use_dropout_placeholder] = 1.0
        fetches = [self.loss, self.train_op]

        # options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        # run_metadata = tf.RunMetadata()

        loss, _ = session.run(fetches, feed_dict=feed_dict)
        # loss, _ = session.run(fetches,
        #                       feed_dict=feed_dict,
        #                       options=options,
        #                       run_metadata=run_metadata)

        # fetched_timeline = timeline.Timeline(run_metadata.step_stats)
        # chrome_trace = fetched_timeline.generate_chrome_trace_format()
        # with open('timeline.json', 'w') as f:
        #     f.write(chrome_trace)

        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号