model.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:GNMT2 作者: Mingyearn 项目源码 文件源码
def _compute_loss(self, logits):
    """Compute optimization loss."""
    target_output = self.iterator.target_output
    if self.time_major:
      target_output = tf.transpose(target_output)
    max_time = self.get_max_time(target_output)
    crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=target_output, logits=logits)
    target_weights = tf.sequence_mask(
        self.iterator.target_sequence_length, max_time, dtype=logits.dtype)
    if self.time_major:
      target_weights = tf.transpose(target_weights)

    loss = tf.reduce_sum(
        crossent * target_weights) / tf.to_float(self.batch_size)
    return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号