training.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:spatial-reasoning 作者: JannerM 项目源码 文件源码
def __get_batch(self, inputs, targets):
        data_size = targets.size(0)

        inds = torch.floor(torch.rand(self.batch_size) * data_size).long().cuda()
        # bug: floor(rand()) sometimes gives 1
        inds[inds >= data_size] = data_size - 1

        if type(inputs) == tuple:
            inp = tuple([Variable( i.index_select(0, inds).cuda() ) for i in inputs])
        else:
            inp = Variable( inputs.index_select(0, inds).cuda() )

        targ = Variable( targets.index_select(0, inds).cuda() )
        return inp, targ
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号