lstm_models.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:magenta 作者: tensorflow 项目源码 文件源码
def _flat_reconstruction_loss(self, flat_x_target, flat_rnn_output):
    split_x_target = tf.split(flat_x_target, self._output_depths, axis=-1)
    split_rnn_output = tf.split(
        flat_rnn_output, self._output_depths, axis=-1)

    losses = []
    truths = []
    predictions = []
    metric_map = {}
    for i in range(len(self._output_depths)):
      l, m, t, p = (
          super(MultiOutCategoricalLstmDecoder, self)._flat_reconstruction_loss(
              split_x_target[i], split_rnn_output[i]))
      losses.append(l)
      truths.append(t)
      predictions.append(p)
      for k, v in m.items():
        metric_map['%s_%d' % (k, i)] = v

    return (tf.reduce_sum(losses, axis=0),
            metric_map,
            tf.stack(truths),
            tf.stack(predictions))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号