def _flat_reconstruction_loss(self, flat_x_target, flat_rnn_output):
split_x_target = tf.split(flat_x_target, self._output_depths, axis=-1)
split_rnn_output = tf.split(
flat_rnn_output, self._output_depths, axis=-1)
losses = []
truths = []
predictions = []
metric_map = {}
for i in range(len(self._output_depths)):
l, m, t, p = (
super(MultiOutCategoricalLstmDecoder, self)._flat_reconstruction_loss(
split_x_target[i], split_rnn_output[i]))
losses.append(l)
truths.append(t)
predictions.append(p)
for k, v in m.items():
metric_map['%s_%d' % (k, i)] = v
return (tf.reduce_sum(losses, axis=0),
metric_map,
tf.stack(truths),
tf.stack(predictions))
评论列表
文章目录