def reverse_sequence(self, x, x_lens):
batch_size, seq_len, word_dim = x.size()
inv_idx = Variable(torch.arange(seq_len - 1, -1, -1).long())
shift_idx = Variable(torch.arange(0, seq_len).long())
if x.is_cuda:
inv_idx = inv_idx.cuda(x.get_device())
shift_idx = shift_idx.cuda(x.get_device())
inv_idx = inv_idx.unsqueeze(0).unsqueeze(-1).expand_as(x)
shift_idx = shift_idx.unsqueeze(0).unsqueeze(-1).expand_as(x)
shift = (seq_len + (-1 * x_lens)).unsqueeze(-1).unsqueeze(-1).expand_as(x)
shift_idx = shift_idx + shift
shift_idx = shift_idx.clamp(0, seq_len - 1)
x = x.gather(1, inv_idx)
x = x.gather(1, shift_idx)
return x
评论列表
文章目录