def step(self, data, update_model=True, align=False, use_sgd=False, **kwargs):
if update_model:
self.dropout_on.run()
else:
self.dropout_off.run()
encoder_inputs, targets, input_length = self.get_batch(data)
input_feed = {self.targets: targets}
for i in range(len(self.encoders)):
input_feed[self.encoder_inputs[i]] = encoder_inputs[i]
input_feed[self.encoder_input_length[i]] = input_length[i]
output_feed = {'loss': self.xent_loss}
if update_model:
output_feed['update'] = self.update_ops.xent[1] if use_sgd else self.update_ops.xent[0]
if align:
output_feed['weights'] = self.attention_weights
res = tf.get_default_session().run(output_feed, input_feed)
return namedtuple('output', 'loss weights')(res['loss'], res.get('weights'))
评论列表
文章目录