predict.py 文件源码

python
阅读 53 收藏 0 点赞 0 评论 0

项目:pytorch-bilstmcrf 作者: kaniblu 项目源码 文件源码
def prepare_batch(xs, lens, gpu=True):
    lens, idx = torch.sort(lens, 0, True)
    _, ridx = torch.sort(idx, 0)
    idx_exp = idx.unsqueeze(0).unsqueeze(-1).expand_as(xs)
    xs = torch.gather(xs, 1, idx_exp)

    xs = Variable(xs, volatile=True)
    lens = Variable(lens, volatile=True)
    ridx = Variable(ridx, volatile=True)

    if gpu:
        xs = xs.cuda()
        lens = lens.cuda()
        ridx = ridx.cuda()

    return xs, lens, ridx
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号