stack_3bilstm_last_encoder.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:multiNLI_encoder 作者: easonnie 项目源码 文件源码
def eval_model(model_path, mode='dev'):
    torch.manual_seed(6)

    snli_d, mnli_d, embd = data_loader.load_data_sm(
        config.DATA_ROOT, config.EMBD_FILE, reseversed=False, batch_sizes=(32, 32, 32, 32, 32), device=0)

    m_train, m_dev_m, m_dev_um, m_test_m, m_test_um = mnli_d

    m_dev_um.shuffle = False
    m_dev_m.shuffle = False
    m_dev_um.sort = False
    m_dev_m.sort = False

    m_test_um.shuffle = False
    m_test_m.shuffle = False
    m_test_um.sort = False
    m_test_m.sort = False

    model = StackBiLSTMMaxout()
    model.Embd.weight.data = embd

    if torch.cuda.is_available():
        embd.cuda()
        model.cuda()

    criterion = nn.CrossEntropyLoss()

    model.load_state_dict(torch.load(model_path))

    model.max_l = 150
    m_pred = model_eval(model, m_dev_m, criterion)
    um_pred = model_eval(model, m_dev_um, criterion)

    print("dev_mismatched_score (acc, loss):", um_pred)
    print("dev_matched_score (acc, loss):", m_pred)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号