nn1_stress_test.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:YellowFin_Pytorch 作者: JianGoForIt 项目源码 文件源码
def test_accuracy_full_batch(tokens, features, mini_batch_size, word_attn, sent_attn, th=0.5):
    p = []
    l = []
    cnt = 0
    g = gen_minibatch1(tokens, features, mini_batch_size, False)
    for token, feature in g:
        if cnt % 100 == 0:
            print(cnt)
        cnt +=1
#         print token.size()
#         y_pred = get_predictions(token, word_attn, sent_attn)
#         print y_pred
        y_pred = get_predictions(token, feature, word_attn, sent_attn)
#         print y_pred
#         _, y_pred = torch.max(y_pred, 1)
#         y_pred = y_pred[:, 1]
#         print y_pred
        p.append(np.ndarray.flatten(y_pred.data.cpu().numpy()))
    p = [item for sublist in p for item in sublist]
    p = np.array(p)
    return p
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号