model.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:pytorch-skipthoughts 作者: kaniblu 项目源码 文件源码
def reverse_sequence(self, x, x_lens):
        batch_size, seq_len, word_dim = x.size()

        inv_idx = Variable(torch.arange(seq_len - 1, -1, -1).long())
        shift_idx = Variable(torch.arange(0, seq_len).long())

        if x.is_cuda:
            inv_idx = inv_idx.cuda(x.get_device())
            shift_idx = shift_idx.cuda(x.get_device())

        inv_idx = inv_idx.unsqueeze(0).unsqueeze(-1).expand_as(x)
        shift_idx = shift_idx.unsqueeze(0).unsqueeze(-1).expand_as(x)
        shift = (seq_len + (-1 * x_lens)).unsqueeze(-1).unsqueeze(-1).expand_as(x)
        shift_idx = shift_idx + shift
        shift_idx = shift_idx.clamp(0, seq_len - 1)

        x = x.gather(1, inv_idx)
        x = x.gather(1, shift_idx)

        return x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号