def repeat_expt(smplr, n_expts, n_labels, output_file = None):
"""
Parameters
----------
smplr : sub-class of PassiveSampler
sampler must have a sample_distinct method, reset method and ...
n_expts : int
number of expts to run
n_labels : int
number of labels to query from the oracle in each expt
"""
FILTERS = tables.Filters(complib='zlib', complevel=5)
max_iter = smplr._max_iter
n_class = smplr._n_class
if max_iter < n_labels:
raise ValueError("Cannot query {} labels. Sampler ".format(n_labels) +
"instance supports only {} iterations".format(max_iter))
if output_file is None:
# Use current date/time as filename
output_file = 'expt_' + time.strftime("%d-%m-%Y_%H:%M:%S") + '.h5'
logging.info("Writing output to {}".format(output_file))
f = tables.open_file(output_file, mode='w', filters=FILTERS)
float_atom = tables.Float64Atom()
bool_atom = tables.BoolAtom()
int_atom = tables.Int64Atom()
array_F = f.create_carray(f.root, 'F_measure', float_atom, (n_expts, n_labels, n_class))
array_s = f.create_carray(f.root, 'n_iterations', int_atom, (n_expts, 1))
array_t = f.create_carray(f.root, 'CPU_time', float_atom, (n_expts, 1))
logging.info("Starting {} experiments".format(n_expts))
for i in range(n_expts):
if i%np.ceil(n_expts/10).astype(int) == 0:
logging.info("Completed {} of {} experiments".format(i, n_expts))
ti = time.process_time()
smplr.reset()
smplr.sample_distinct(n_labels)
tf = time.process_time()
if hasattr(smplr, 'queried_oracle_'):
array_F[i,:,:] = smplr.estimate_[smplr.queried_oracle_]
else:
array_F[i,:,:] = smplr.estimate_
array_s[i] = smplr.t_
array_t[i] = tf - ti
f.close()
logging.info("Completed all experiments")
评论列表
文章目录