def prepare_batch(self, batch_data, volatile=False):
x, x_lens, ys, ys_lens = batch_data
batch_dim = 0 if self.batch_first else 1
context_dim = 1 if self.batch_first else 0
x_lens, x_idx = torch.sort(x_lens, 0, True)
_, x_ridx = torch.sort(x_idx)
ys_lens, ys_idx = torch.sort(ys_lens, batch_dim, True)
x_ridx_exp = x_ridx.unsqueeze(context_dim).expand_as(ys_idx)
xys_idx = torch.gather(x_ridx_exp, batch_dim, ys_idx)
x = x[x_idx]
ys = torch.gather(ys, batch_dim, ys_idx.unsqueeze(-1).expand_as(ys))
x = Variable(x, volatile=volatile)
x_lens = Variable(x_lens, volatile=volatile)
ys_i = Variable(ys[..., :-1], volatile=volatile).contiguous()
ys_t = Variable(ys[..., 1:], volatile=volatile).contiguous()
ys_lens = Variable(ys_lens - 1, volatile=volatile)
xys_idx = Variable(xys_idx, volatile=volatile)
if self.is_cuda:
x = x.cuda(async=True)
x_lens = x_lens.cuda(async=True)
ys_i = ys_i.cuda(async=True)
ys_t = ys_t.cuda(async=True)
ys_lens = ys_lens.cuda(async=True)
xys_idx = xys_idx.cuda(async=True)
return x, x_lens, ys_i, ys_t, ys_lens, xys_idx
评论列表
文章目录