def build_experiment(self, batch_size, classes_per_set, samples_per_class, channels, fce):
"""
:param batch_size: The experiment batch size
:param classes_per_set: An integer indicating the number of classes per support set
:param samples_per_class: An integer indicating the number of samples per class
:param channels: The image channels
:param fce: Whether to use full context embeddings or not
:return: a matching_network object, along with the losses, the training ops and the init op
"""
self.classes_per_set = classes_per_set
self.samples_per_class = samples_per_class
self.keep_prob = torch.FloatTensor(1)
self.matchingNet = MatchingNetwork(batch_size=batch_size,
keep_prob=self.keep_prob, num_channels=channels,
fce=fce,
num_classes_per_set=classes_per_set,
num_samples_per_class=samples_per_class,
nClasses = 0, image_size = 28)
self.optimizer = 'adam'
self.lr = 1e-03
self.current_lr = 1e-03
self.lr_decay = 1e-6
self.wd = 1e-4
self.total_train_iter = 0
self.isCudaAvailable = torch.cuda.is_available()
if self.isCudaAvailable:
cudnn.benchmark = True
torch.cuda.manual_seed_all(0)
self.matchingNet.cuda()
评论列表
文章目录