data_loader.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:torch_light 作者: ne7ermore 项目源码 文件源码
def __next__(self):
        def img2variable(img_files):
            tensors = [self._encode(Image.open(self._path + img_name)).unsqueeze(0)
                    for img_name in img_files]
            v = Variable(torch.cat(tensors, 0))
            if self._is_cuda: v = v.cuda()
            return v

        if self._step == self._stop_step:
            self._step = 0
            raise StopIteration()       

        _start = self._step*self._batch_size
        self._step += 1

        return img2variable(self._img_files[_start:_start+self._batch_size])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号