OneShotBuilder.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:MatchingNetworks 作者: gitabcworld 项目源码 文件源码
def build_experiment(self, batch_size, classes_per_set, samples_per_class, channels, fce):

        """
        :param batch_size: The experiment batch size
        :param classes_per_set: An integer indicating the number of classes per support set
        :param samples_per_class: An integer indicating the number of samples per class
        :param channels: The image channels
        :param fce: Whether to use full context embeddings or not
        :return: a matching_network object, along with the losses, the training ops and the init op
        """
        self.classes_per_set = classes_per_set
        self.samples_per_class = samples_per_class
        self.keep_prob = torch.FloatTensor(1)
        self.matchingNet = MatchingNetwork(batch_size=batch_size,
                                         keep_prob=self.keep_prob, num_channels=channels,
                                         fce=fce,
                                         num_classes_per_set=classes_per_set,
                                         num_samples_per_class=samples_per_class,
                                         nClasses = 0, image_size = 28)
        self.optimizer = 'adam'
        self.lr = 1e-03
        self.current_lr = 1e-03
        self.lr_decay = 1e-6
        self.wd = 1e-4
        self.total_train_iter = 0
        self.isCudaAvailable = torch.cuda.is_available()
        if self.isCudaAvailable:
            cudnn.benchmark = True
            torch.cuda.manual_seed_all(0)
            self.matchingNet.cuda()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号