def init(self, opt, X, y, y_groundtruth, X_val, y_val, unlabeled_class=-5, ali_model=None):
dataset = Dataset(X, y, y_groundtruth, X_val, y_val, unlabeled_class=unlabeled_class,
al_batch_size=opt.al_batch_size, save_path_db_points=opt.save_path,
dataset=opt.hdf5_dataset_encoded)
print "All samples: ", len(dataset)
print "Labeled samples: ", dataset.len_labeled()
print "Unlabeled samples: ", dataset.len_unlabeled()
print "Initializing model"
model = JointOptimisationSVM(initial_model=self.initial_model,
hyperparameters=opt.hyperparameters,
save_path_boundaries=opt.save_path) # declare model instance
print "Done declaring model"
print "Initializing query strategy", opt.query_strategy
if opt.query_strategy == "uncertainty":
query_strategy = UncertaintySamplingLine(dataset, model=model, generative_model=ali_model,
save_path_queries=opt.save_path,
human_experiment=self.human_experiment,
base_precision=opt.base_precision) # declare a QueryStrategy instance
elif opt.query_strategy == "uncertainty-dense":
query_strategy = UncertaintyDenseSamplingLine(dataset, model=model, generative_model=ali_model,
save_path_queries=opt.save_path,
human_experiment=self.human_experiment,
base_precision=opt.base_precision)
elif opt.query_strategy == "clustercentroids":
query_strategy = ClusterCentroidsLine(dataset, model=model, generative_model=ali_model,
batch_size=opt.al_batch_size, save_path_queries=opt.save_path,
human_experiment=self.human_experiment,
base_precision=opt.base_precision)
elif opt.query_strategy == "random":
query_strategy = RandomSamplingLine(dataset, model=model, generative_model=ali_model,
save_path_queries=opt.save_path,
human_experiment=self.human_experiment,
base_precision=opt.base_precision)
else:
raise Exception("Please specify a query strategy")
print "Done declaring query strategy", opt.query_strategy
if opt.oracle_type == "noisy_line_labeler":
labeler = NoisyLineLabeler(dataset, opt.std_noise, pretrained_groundtruth=self.groundtruth_model,
hyperparameters=opt.hyperparameters)
print "Done declaring NoisyLineLabeler"
elif opt.oracle_type == "human_line_labeler":
labeler = HumanLineLabeler(dataset, ali_model, hyperparameters=opt.hyperparameters)
else:
labeler = LineLabeler(dataset, pretrained_groundtruth=self.groundtruth_model,
hyperparameters=opt.hyperparameters) # declare Labeler instance
print "Done declaring LineLabeler"
print "Done initializing"
return dataset, model, query_strategy, labeler
评论列表
文章目录