train_mycnn.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:tensorflow_homographynet 作者: linjian93 项目源码 文件源码
def next_batch(self):
        start = self.index_in_epoch
        self.index_in_epoch += batch_size / pairs_per_img
        if self.index_in_epoch > self.number:
            self.index_in_epoch = 0
            start = self.index_in_epoch
            self.index_in_epoch += batch_size / pairs_per_img
        end = self.index_in_epoch

        data_batch, label_batch = generate_data(self.img_path_list[start])
        for i in range(start+1, end):
            data, label = generate_data(self.img_path_list[i])
            data_batch = np.concatenate((data_batch, data))
            label_batch = np.concatenate((label_batch, label))

        data_batch = np.array(data_batch).transpose([0, 2, 3, 1])
        # cv2.imshow('window2', data_batch[1,:,:,1].squeeze())
        # cv2.waitKey()
        label_batch = np.array(label_batch).squeeze()

        return data_batch, label_batch
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号