Datagenerator.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tensorflow-AlexNet 作者: qiansi 项目源码 文件源码
def getNext_batch(self):
        paths = self.images[self.pointer:self.pointer+self.batch_size]
        labels = self.labels[self.pointer:self.pointer+self.batch_size]
        self.pointer += self.batch_size

        images = np.ndarray([self.batch_size,self.scale_size[0],self.scale_size[1],3])
        for i in range(len(paths)):
            image = cv2.imread(paths[i])
            #print ('file name is {}'.format(paths[i]))
            #cv2.imshow(paths[i],image)
            #cv2.waitKey(0)
            if self.horizontal and np.random.random()<0.5:
                image = cv2.flip(image,1)
            image = cv2.resize(image,(self.scale_size[0],self.scale_size[1]))
            image = image.astype(np.float32)

            image -= self.mean
            images[i] = image

        one_hot_labels = np.zeros((self.batch_size,self.n_class))
        for i in range(len(labels)):
            one_hot_labels[i][int(labels[i])] = 1
        return images,one_hot_labels
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号