train.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:ActiveBoundary 作者: MiriamHu 项目源码 文件源码
def init(self, opt, X, y, y_groundtruth, X_val, y_val, unlabeled_class=-5, ali_model=None):
        dataset = Dataset(X, y, y_groundtruth, X_val, y_val, unlabeled_class=unlabeled_class,
                          al_batch_size=opt.al_batch_size, save_path_db_points=opt.save_path,
                          dataset=opt.hdf5_dataset_encoded)
        print "All samples: ", len(dataset)
        print "Labeled samples: ", dataset.len_labeled()
        print "Unlabeled samples: ", dataset.len_unlabeled()

        print "Initializing model"
        model = JointOptimisationSVM(initial_model=self.initial_model,
                                     hyperparameters=opt.hyperparameters,
                                     save_path_boundaries=opt.save_path)  # declare model instance
        print "Done declaring model"
        print "Initializing query strategy", opt.query_strategy
        if opt.query_strategy == "uncertainty":
            query_strategy = UncertaintySamplingLine(dataset, model=model, generative_model=ali_model,
                                                     save_path_queries=opt.save_path,
                                                     human_experiment=self.human_experiment,
                                                     base_precision=opt.base_precision)  # declare a QueryStrategy instance
        elif opt.query_strategy == "uncertainty-dense":
            query_strategy = UncertaintyDenseSamplingLine(dataset, model=model, generative_model=ali_model,
                                                          save_path_queries=opt.save_path,
                                                          human_experiment=self.human_experiment,
                                                          base_precision=opt.base_precision)
        elif opt.query_strategy == "clustercentroids":
            query_strategy = ClusterCentroidsLine(dataset, model=model, generative_model=ali_model,
                                                  batch_size=opt.al_batch_size, save_path_queries=opt.save_path,
                                                  human_experiment=self.human_experiment,
                                                  base_precision=opt.base_precision)
        elif opt.query_strategy == "random":
            query_strategy = RandomSamplingLine(dataset, model=model, generative_model=ali_model,
                                                save_path_queries=opt.save_path,
                                                human_experiment=self.human_experiment,
                                                base_precision=opt.base_precision)
        else:
            raise Exception("Please specify a query strategy")
        print "Done declaring query strategy", opt.query_strategy

        if opt.oracle_type == "noisy_line_labeler":
            labeler = NoisyLineLabeler(dataset, opt.std_noise, pretrained_groundtruth=self.groundtruth_model,
                                       hyperparameters=opt.hyperparameters)
            print "Done declaring NoisyLineLabeler"
        elif opt.oracle_type == "human_line_labeler":
            labeler = HumanLineLabeler(dataset, ali_model, hyperparameters=opt.hyperparameters)
        else:
            labeler = LineLabeler(dataset, pretrained_groundtruth=self.groundtruth_model,
                                  hyperparameters=opt.hyperparameters)  # declare Labeler instance
            print "Done declaring LineLabeler"
        print "Done initializing"
        return dataset, model, query_strategy, labeler
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号