miniImagenetOneShot.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:MatchingNetworks 作者: gitabcworld 项目源码 文件源码
def __init__(self, dataroot = '/home/aberenguel/Dataset/miniImagenet', type = 'train',
                 nEpisodes = 1000, classes_per_set=10, samples_per_class=1):

        self.nEpisodes = nEpisodes
        self.classes_per_set = classes_per_set
        self.samples_per_class = samples_per_class
        self.n_samples = self.samples_per_class * self.classes_per_set
        self.n_samplesNShot = 5 # Samples per meta-test. In this case 1 as is OneShot.
        # Transformations to the image
        self.transform = transforms.Compose([filenameToPILImage,
                                             PiLImageResize,
                                             transforms.ToTensor()
                                             ])

        def loadSplit(splitFile):
            dictLabels = {}
            with open(splitFile) as csvfile:
                csvreader = csv.reader(csvfile, delimiter=',')
                next(csvreader, None)
                for i,row in enumerate(csvreader):
                    filename = row[0]
                    label = row[1]
                    if label in dictLabels.keys():
                        dictLabels[label].append(filename)
                    else:
                        dictLabels[label] = [filename]
            return dictLabels

        #requiredFiles = ['train','val','test']
        self.miniImagenetImagesDir = os.path.join(dataroot,'images')
        self.data = loadSplit(splitFile = os.path.join(dataroot,type + '.csv'))
        self.data = collections.OrderedDict(sorted(self.data.items()))
        self.classes_dict = {self.data.keys()[i]:i  for i in range(len(self.data.keys()))}
        self.create_episodes(self.nEpisodes)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号