def run_training_batch(self, session, batch):
"""
A batch contains input tensors for words, pos, lemmas, preds,
preds_idx, and labels (in that order)
Runs the model on the batch (through train_op if train=True)
Returns the loss
"""
feed_dict = self.batch_to_feed(batch)
feed_dict[self.use_dropout_placeholder] = 1.0
fetches = [self.loss, self.train_op]
# options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
# run_metadata = tf.RunMetadata()
loss, _ = session.run(fetches, feed_dict=feed_dict)
# loss, _ = session.run(fetches,
# feed_dict=feed_dict,
# options=options,
# run_metadata=run_metadata)
# fetched_timeline = timeline.Timeline(run_metadata.step_stats)
# chrome_trace = fetched_timeline.generate_chrome_trace_format()
# with open('timeline.json', 'w') as f:
# f.write(chrome_trace)
return loss
评论列表
文章目录