miniImagenet.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:FewShotLearning 作者: gitabcworld 项目源码 文件源码
def __getitem__(self, index):

        support_set_x = torch.FloatTensor(self.n_samples, 3, 84, 84)
        support_set_y = np.zeros((self.n_samples), dtype=np.int)
        target_x = torch.FloatTensor(self.n_samples_eval, 3, 84, 84)
        target_y = np.zeros((self.n_samples_eval), dtype=np.int)

        flatten_support_set_x_batch = [os.path.join(self.miniImagenetImagesDir,item)
                                       for sublist in self.support_set_x_batch[index] for item in sublist]
        support_set_y = np.array([self.classes_dict[item[:9]]
                                      for sublist in self.support_set_x_batch[index] for item in sublist])
        flatten_target_x = [os.path.join(self.miniImagenetImagesDir,item)
                            for sublist in self.target_x_batch[index] for item in sublist]
        target_y = np.array([self.classes_dict[item[:9]]
                            for sublist in self.target_x_batch[index] for item in sublist])

        for i,path in enumerate(flatten_support_set_x_batch):
            if self.transform is not None:
                support_set_x[i] = self.transform(path)

        for i,path in enumerate(flatten_target_x):
            if self.transform is not None:
                target_x[i] = self.transform(path)
        return support_set_x, torch.IntTensor(support_set_y), target_x, torch.IntTensor(target_y)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号