train.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:pytorch-bilstmcrf 作者: kaniblu 项目源码 文件源码
def prepare_val_texts(model, batch_xs, batch_y, batch_lens,
                     logits, preds, n_previews):
    idx = np.random.permutation(np.arange(batch_lens.size(0)))[:n_previews]
    idx_v = Variable(torch.LongTensor(idx), volatile=True)

    if model.is_cuda:
        idx_v = idx_v.cuda()

    logits = torch.index_select(logits, 0, idx_v)
    bilstm_preds = logits.cpu().max(2)[1].squeeze(-1).data.numpy()
    crf_preds = preds.cpu().data.numpy()[idx]
    xs = batch_xs.cpu().data.numpy()[:, idx]
    y = batch_y.cpu().data.numpy()[idx]
    lens = batch_lens.cpu().data.numpy()[idx]

    sents = val_sents(model.word_vocabs, model.label_vocab,
                      xs, y, bilstm_preds, crf_preds, lens)
    texts = val_texts(*sents)

    return texts
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号