losses.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def cross_entropy_sequence_loss(logits, targets, sequence_length):
    """Calculates the per-example cross-entropy loss for a sequence of logits and
        masks out all losses passed the sequence length.

    Args:
        logits: Logits of shape `[T, B, vocab_size]`
        targets: Target classes of shape `[T, B]`
        sequence_length: An int32 tensor of shape `[B]` corresponding
           to the length of each input

    Returns:
        A tensor of shape [T, B] that contains the loss per example, per time step.
    """
    with tf.name_scope("cross_entropy_sequence_loss"):
        losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=targets)
        loss_mask = tf.sequence_mask(tf.to_int32(
            sequence_length), tf.to_int32(tf.shape(targets)[0]))
        losses = losses * tf.transpose(tf.to_float(loss_mask), [1, 0])

    return losses
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号