seq2seq_ops.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def sequence_classifier(decoding, labels, sampling_decoding=None, name=None):
  """Returns predictions and loss for sequence of predictions.

  Args:
    decoding: List of Tensors with predictions.
    labels: List of Tensors with labels.
    sampling_decoding: Optional, List of Tensor with predictions to be used
      in sampling. E.g. they shouldn't have dependncy on outputs.
      If not provided, decoding is used.
    name: Operation name.

  Returns:
    Predictions and losses tensors.
  """
  with ops.name_scope(name, "sequence_classifier", [decoding, labels]):
    predictions, xent_list = [], []
    for i, pred in enumerate(decoding):
      xent_list.append(nn.softmax_cross_entropy_with_logits(
          pred, labels[i],
          name="sequence_loss/xent_raw{0}".format(i)))
      if sampling_decoding:
        predictions.append(nn.softmax(sampling_decoding[i]))
      else:
        predictions.append(nn.softmax(pred))
    xent = math_ops.add_n(xent_list, name="sequence_loss/xent")
    loss = math_ops.reduce_sum(xent, name="sequence_loss")
    return array_ops_.pack(predictions, axis=1), loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号