def train_batch(self, bs):
"""
Get a batch of random images with their attributes.
"""
# image IDs
idx = torch.LongTensor(bs).random_(len(self.images))
# select images / attributes
batch_x = normalize_images(self.images.index_select(0, idx).cuda())
batch_y = self.attributes.index_select(0, idx).cuda()
# data augmentation
if self.v_flip and np.random.rand() <= 0.5:
batch_x = batch_x.index_select(2, torch.arange(batch_x.size(2) - 1, -1, -1).long().cuda())
if self.h_flip and np.random.rand() <= 0.5:
batch_x = batch_x.index_select(3, torch.arange(batch_x.size(3) - 1, -1, -1).long().cuda())
return Variable(batch_x, volatile=False), Variable(batch_y, volatile=False)
评论列表
文章目录