experiments.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:oasis 作者: ngmarchant 项目源码 文件源码
def repeat_expt(smplr, n_expts, n_labels, output_file = None):
    """
    Parameters
    ----------
    smplr : sub-class of PassiveSampler
        sampler must have a sample_distinct method, reset method and ...

    n_expts : int
        number of expts to run

    n_labels : int
        number of labels to query from the oracle in each expt
    """

    FILTERS = tables.Filters(complib='zlib', complevel=5)

    max_iter = smplr._max_iter
    n_class = smplr._n_class
    if max_iter < n_labels:
        raise ValueError("Cannot query {} labels. Sampler ".format(n_labels) +
                         "instance supports only {} iterations".format(max_iter))

    if output_file is None:
        # Use current date/time as filename
        output_file = 'expt_' + time.strftime("%d-%m-%Y_%H:%M:%S") + '.h5'
    logging.info("Writing output to {}".format(output_file))

    f = tables.open_file(output_file, mode='w', filters=FILTERS)
    float_atom = tables.Float64Atom()
    bool_atom = tables.BoolAtom()
    int_atom = tables.Int64Atom()

    array_F = f.create_carray(f.root, 'F_measure', float_atom, (n_expts, n_labels, n_class))
    array_s = f.create_carray(f.root, 'n_iterations', int_atom, (n_expts, 1))
    array_t = f.create_carray(f.root, 'CPU_time', float_atom, (n_expts, 1))

    logging.info("Starting {} experiments".format(n_expts))
    for i in range(n_expts):
        if i%np.ceil(n_expts/10).astype(int) == 0:
            logging.info("Completed {} of {} experiments".format(i, n_expts))
        ti = time.process_time()
        smplr.reset()
        smplr.sample_distinct(n_labels)
        tf = time.process_time()
        if hasattr(smplr, 'queried_oracle_'):
            array_F[i,:,:] = smplr.estimate_[smplr.queried_oracle_]
        else:
            array_F[i,:,:] = smplr.estimate_
        array_s[i] = smplr.t_
        array_t[i] = tf - ti
    f.close()

    logging.info("Completed all experiments")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号