list_image_data_provider.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:tfplus 作者: renmengye 项目源码 文件源码
def get_batch_idx(self, idx):
        hh = self.inp_height
        ww = self.inp_width
        x = np.zeros([len(idx), hh, ww, 3], dtype='float32')
        orig_height = []
        orig_width = []
        ids = []
        for kk, ii in enumerate(idx):
            fname = self.ids[ii]
            ids.append('{:06}'.format(ii))
            x_ = cv2.imread(fname).astype('float32') / 255
            x[kk] = cv2.resize(
                x_, (self.inp_width, self.inp_height),
                interpolation=cv2.INTER_CUBIC)
            orig_height.append(x_.shape[0])
            orig_width.append(x_.shape[1])
            pass
        return {
            'x': x,
            'orig_height': np.array(orig_height),
            'orig_width': np.array(orig_width),
            'id': ids
        }
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号