def stats(criterion, a, y, mask):
if mask is not None:
_, preds = t.max(a.data, 2)
batch, sLen, c = a.size()
loss = criterion(a.view(-1, c), y.view(-1))
m = t.sum(mask)
mask = _sequence_mask(mask, sLen)
acc = t.sum(mask.data.float() * (y.data == preds).float()) / float(m.data[0])
#loss = criterion(a.view(-1, c), y.view(-1))
else:
_, preds = t.max(a.data, 1)
loss = criterion(a, y)
acc = t.mean((y.data == preds).float())
return loss, acc
评论列表
文章目录