def __get_batch(self, inputs, targets):
data_size = targets.size(0)
inds = torch.floor(torch.rand(self.batch_size) * data_size).long().cuda()
# bug: floor(rand()) sometimes gives 1
inds[inds >= data_size] = data_size - 1
if type(inputs) == tuple:
inp = tuple([Variable( i.index_select(0, inds).cuda() ) for i in inputs])
else:
inp = Variable( inputs.index_select(0, inds).cuda() )
targ = Variable( targets.index_select(0, inds).cuda() )
return inp, targ
评论列表
文章目录