def __init__(self, dataroot = '/home/aberenguel/Dataset/miniImagenet', type = 'train',
nEpisodes = 1000, classes_per_set=10, samples_per_class=1):
self.nEpisodes = nEpisodes
self.classes_per_set = classes_per_set
self.samples_per_class = samples_per_class
self.n_samples = self.samples_per_class * self.classes_per_set
self.n_samplesNShot = 5 # Samples per meta-test. In this case 1 as is OneShot.
# Transformations to the image
self.transform = transforms.Compose([filenameToPILImage,
PiLImageResize,
transforms.ToTensor()
])
def loadSplit(splitFile):
dictLabels = {}
with open(splitFile) as csvfile:
csvreader = csv.reader(csvfile, delimiter=',')
next(csvreader, None)
for i,row in enumerate(csvreader):
filename = row[0]
label = row[1]
if label in dictLabels.keys():
dictLabels[label].append(filename)
else:
dictLabels[label] = [filename]
return dictLabels
#requiredFiles = ['train','val','test']
self.miniImagenetImagesDir = os.path.join(dataroot,'images')
self.data = loadSplit(splitFile = os.path.join(dataroot,type + '.csv'))
self.data = collections.OrderedDict(sorted(self.data.items()))
self.classes_dict = {self.data.keys()[i]:i for i in range(len(self.data.keys()))}
self.create_episodes(self.nEpisodes)
miniImagenetOneShot.py 文件源码
python
阅读 20
收藏 0
点赞 0
评论 0
评论列表
文章目录