train.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:pytorch-skipthoughts 作者: kaniblu 项目源码 文件源码
def prepare_batches(self, batch_data, chunks, **kwargs):
        x, x_lens, ys, ys_lens = batch_data
        batch_dim = 0 if self.batch_first else 1

        x_list = x.chunk(chunks, 0)
        x_lens_list = x_lens.chunk(chunks, 0)
        ys_list = ys.chunk(chunks, batch_dim)
        ys_lens_list = ys_lens.chunk(chunks, batch_dim)
        inp_list = [x_list, x_lens_list, ys_list, ys_lens_list]

        data_list = []
        for inp in zip(*inp_list):
            data = self.prepare_batch(inp, **kwargs)
            data_list.append(data)

        data_list = list(zip(*data_list))
        ret_list = []

        for data in data_list:
            data = [d.unsqueeze(0) for d in data]
            data = torch.cat(data)
            ret_list.append(data)

        return ret_list
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号