def __getitem__(self, index):
support_set_x = torch.FloatTensor(self.n_samples, 3, 84, 84)
support_set_y = np.zeros((self.n_samples), dtype=np.int)
target_x = torch.FloatTensor(self.n_samples_eval, 3, 84, 84)
target_y = np.zeros((self.n_samples_eval), dtype=np.int)
flatten_support_set_x_batch = [os.path.join(self.miniImagenetImagesDir,item)
for sublist in self.support_set_x_batch[index] for item in sublist]
support_set_y = np.array([self.classes_dict[item[:9]]
for sublist in self.support_set_x_batch[index] for item in sublist])
flatten_target_x = [os.path.join(self.miniImagenetImagesDir,item)
for sublist in self.target_x_batch[index] for item in sublist]
target_y = np.array([self.classes_dict[item[:9]]
for sublist in self.target_x_batch[index] for item in sublist])
for i,path in enumerate(flatten_support_set_x_batch):
if self.transform is not None:
support_set_x[i] = self.transform(path)
for i,path in enumerate(flatten_target_x):
if self.transform is not None:
target_x[i] = self.transform(path)
return support_set_x, torch.IntTensor(support_set_y), target_x, torch.IntTensor(target_y)
评论列表
文章目录