nn1_stress_test.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:YellowFin_Pytorch 作者: JianGoForIt 项目源码 文件源码
def test_accuracy_mini_batch(tokens, features, labels, word_attn, sent_attn):
    y_pred = get_predictions(tokens, features, word_attn, sent_attn)
    y_pred = torch.gt(y_pred, 0.5)
    correct = np.ndarray.flatten(y_pred.data.cpu().numpy())
    labels = torch.gt(labels, 0.5)
    labels = np.ndarray.flatten(labels.data.cpu().numpy())

    num_correct = sum(correct == labels)

    return float(num_correct) / len(correct)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号