def __next__(self):
def img2variable(img_files):
tensors = [self._encode(Image.open(self._path + img_name)).unsqueeze(0)
for img_name in img_files]
v = Variable(torch.cat(tensors, 0))
if self._is_cuda: v = v.cuda()
return v
if self._step == self._stop_step:
self._step = 0
raise StopIteration()
_start = self._step*self._batch_size
self._step += 1
return img2variable(self._img_files[_start:_start+self._batch_size])
评论列表
文章目录