def prepare_batch(xs, lens, gpu=True):
lens, idx = torch.sort(lens, 0, True)
_, ridx = torch.sort(idx, 0)
idx_exp = idx.unsqueeze(0).unsqueeze(-1).expand_as(xs)
xs = torch.gather(xs, 1, idx_exp)
xs = Variable(xs, volatile=True)
lens = Variable(lens, volatile=True)
ridx = Variable(ridx, volatile=True)
if gpu:
xs = xs.cuda()
lens = lens.cuda()
ridx = ridx.cuda()
return xs, lens, ridx
评论列表
文章目录