batcher.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:arc-pytorch 作者: sanyam5 项目源码 文件源码
def fetch_batch(self, part, batch_size: int = None):

        if batch_size is None:
            batch_size = self.batch_size

        X, Y = self._fetch_batch(part, batch_size)

        X = Variable(torch.from_numpy(X)).view(2*batch_size, self.image_size, self.image_size)

        X1 = X[:batch_size]  # (B, h, w)
        X2 = X[batch_size:]  # (B, h, w)

        X = torch.stack([X1, X2], dim=1)  # (B, 2, h, w)

        Y = Variable(torch.from_numpy(Y))

        if use_cuda:
            X, Y = X.cuda(), Y.cuda()

        return X, Y
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号