utils.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:Efficient-Dynamic-Batching 作者: jsuarez5341 项目源码 文件源码
def stats(criterion, a, y, mask):
   if mask is not None:
      _, preds = t.max(a.data, 2)
      batch, sLen, c = a.size()
      loss = criterion(a.view(-1, c), y.view(-1))

      m = t.sum(mask)
      mask = _sequence_mask(mask, sLen)
      acc = t.sum(mask.data.float() * (y.data == preds).float()) / float(m.data[0])
      #loss = criterion(a.view(-1, c), y.view(-1))
   else:
      _, preds = t.max(a.data, 1)
      loss = criterion(a, y)
      acc = t.mean((y.data == preds).float())

   return loss, acc
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号