train_net.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:tensorflow_homographynet 作者: linjian93 项目源码 文件源码
def next_batch(self):
        self.count += 1
        # print self.count
        start = self.index_in_epoch
        self.index_in_epoch += batch_size / pairs_per_img
        if self.index_in_epoch > self.number:
            self.index_in_epoch = 0
            start = self.index_in_epoch
            self.index_in_epoch += batch_size / pairs_per_img
        end = self.index_in_epoch

        data_batch, label_batch = generate_data(self.img_path_list[start])
        for i in range(start+1, end):
            data, label = generate_data(self.img_path_list[i])  # [4, 2, 128, 128], [4, 1, 8]
            data_batch = np.concatenate((data_batch, data))  # [64, 2, 128, 128]
            label_batch = np.concatenate((label_batch, label))  # [64, 1, 8]

        data_batch = np.array(data_batch).transpose([0, 2, 3, 1])  # (64, 128, 128, 2)
        # cv2.imshow('window2', data_batch[1,:,:,1].squeeze())
        # cv2.waitKey()
        label_batch = np.array(label_batch).squeeze()  # (64, 1, 8)

        return data_batch, label_batch
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号