train.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:pytorch-skipthoughts 作者: kaniblu 项目源码 文件源码
def prepare_batch(self, batch_data, volatile=False):
        x, x_lens, ys, ys_lens = batch_data
        batch_dim = 0 if self.batch_first else 1
        context_dim = 1 if self.batch_first else 0

        x_lens, x_idx = torch.sort(x_lens, 0, True)
        _, x_ridx = torch.sort(x_idx)
        ys_lens, ys_idx = torch.sort(ys_lens, batch_dim, True)

        x_ridx_exp = x_ridx.unsqueeze(context_dim).expand_as(ys_idx)
        xys_idx = torch.gather(x_ridx_exp, batch_dim, ys_idx)

        x = x[x_idx]
        ys = torch.gather(ys, batch_dim, ys_idx.unsqueeze(-1).expand_as(ys))

        x = Variable(x, volatile=volatile)
        x_lens = Variable(x_lens, volatile=volatile)
        ys_i = Variable(ys[..., :-1], volatile=volatile).contiguous()
        ys_t = Variable(ys[..., 1:], volatile=volatile).contiguous()
        ys_lens = Variable(ys_lens - 1, volatile=volatile)
        xys_idx = Variable(xys_idx, volatile=volatile)

        if self.is_cuda:
            x = x.cuda(async=True)
            x_lens = x_lens.cuda(async=True)
            ys_i = ys_i.cuda(async=True)
            ys_t = ys_t.cuda(async=True)
            ys_lens = ys_lens.cuda(async=True)
            xys_idx = xys_idx.cuda(async=True)

        return x, x_lens, ys_i, ys_t, ys_lens, xys_idx
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号