def test_accuracy_full_batch(tokens, features, mini_batch_size, word_attn, sent_attn, th=0.5):
p = []
l = []
cnt = 0
g = gen_minibatch1(tokens, features, mini_batch_size, False)
for token, feature in g:
if cnt % 100 == 0:
print cnt
cnt +=1
# print token.size()
# y_pred = get_predictions(token, word_attn, sent_attn)
# print y_pred
y_pred = get_predictions(token, feature, word_attn, sent_attn)
# print y_pred
# _, y_pred = torch.max(y_pred, 1)
# y_pred = y_pred[:, 1]
# print y_pred
p.append(np.ndarray.flatten(y_pred.data.cpu().numpy()))
p = [item for sublist in p for item in sublist]
p = np.array(p)
return p
评论列表
文章目录