def eval_model(model_path, mode='dev'):
torch.manual_seed(6)
snli_d, mnli_d, embd = data_loader.load_data_sm(
config.DATA_ROOT, config.EMBD_FILE, reseversed=False, batch_sizes=(32, 32, 32, 32, 32), device=0)
m_train, m_dev_m, m_dev_um, m_test_m, m_test_um = mnli_d
m_dev_um.shuffle = False
m_dev_m.shuffle = False
m_dev_um.sort = False
m_dev_m.sort = False
m_test_um.shuffle = False
m_test_m.shuffle = False
m_test_um.sort = False
m_test_m.sort = False
model = StackBiLSTMMaxout()
model.Embd.weight.data = embd
if torch.cuda.is_available():
embd.cuda()
model.cuda()
criterion = nn.CrossEntropyLoss()
model.load_state_dict(torch.load(model_path))
model.max_l = 150
m_pred = model_eval(model, m_dev_m, criterion)
um_pred = model_eval(model, m_dev_um, criterion)
print("dev_mismatched_score (acc, loss):", um_pred)
print("dev_matched_score (acc, loss):", m_pred)
stack_3bilstm_last_encoder.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录