def fetch_batch(self, part, batch_size: int = None):
if batch_size is None:
batch_size = self.batch_size
X, Y = self._fetch_batch(part, batch_size)
X = Variable(torch.from_numpy(X)).view(2*batch_size, self.image_size, self.image_size)
X1 = X[:batch_size] # (B, h, w)
X2 = X[batch_size:] # (B, h, w)
X = torch.stack([X1, X2], dim=1) # (B, 2, h, w)
Y = Variable(torch.from_numpy(Y))
if use_cuda:
X, Y = X.cuda(), Y.cuda()
return X, Y
评论列表
文章目录