utils.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:Efficient-Dynamic-Batching 作者: jsuarez5341 项目源码 文件源码
def maskedCE(logits, target, length):
   """
   Args:
       logits: A Variable containing a FloatTensor of size
           (batch, max_len, num_classes) which contains the
           unnormalized probability for each class.
       target: A Variable containing a LongTensor of size
           (batch, max_len) which contains the index of the true
           class for each corresponding step.
       length: A Variable containing a LongTensor of size (batch,)
           which contains the length of each data in a batch.

   Returns:
       loss: An average loss value masked by the length.
   """

   # logits_flat: (batch * max_len, num_classes)
   logits_flat = logits.view(-1, logits.size(-1))
   # log_probs_flat: (batch * max_len, num_classes)
   log_probs_flat = F.log_softmax(logits_flat)
   # target_flat: (batch * max_len, 1)
   target_flat = target.view(-1, 1)
   # losses_flat: (batch * max_len, 1)
   losses_flat = -t.gather(log_probs_flat, dim=1, index=target_flat)
   # losses: (batch, max_len)
   losses = losses_flat.view(*target.size())
   # mask: (batch, max_len)
   mask = _sequence_mask(sequence_length=length, max_len=target.size(1))
   losses = losses * mask.float()
   loss = losses.sum() / length.float().sum()
   return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号