def test_accuracy_mini_batch(tokens, features, labels, word_attn, sent_attn):
y_pred = get_predictions(tokens, features, word_attn, sent_attn)
y_pred = torch.gt(y_pred, 0.5)
correct = np.ndarray.flatten(y_pred.data.cpu().numpy())
labels = torch.gt(labels, 0.5)
labels = np.ndarray.flatten(labels.data.cpu().numpy())
num_correct = sum(correct == labels)
return float(num_correct) / len(correct)
评论列表
文章目录