main.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:mitre 作者: gerberlab 项目源码 文件源码
def sample(config, model=None):
    """ Create sampler and sample per options in configuration file.

    If there is a configuration option 'load_model_from_pickle' in
    section 'sampling' the function tries to load that model, ignoring
    the data argument (the option value shold be the path to a pickle
    file containing a single LogisticRuleModel object.) If the model
    argument is None, and that configuration option is not present, an
    exception results.

    """
    if config.has_option('sampling','load_model_from_pickle'):
        with open(config.get('sampling','load_model_from_pickle')) as f:
            model = pickle.load(f)

    if model is None:
        raise ValueError('Model must be passed as argument if not specified in config file.')

    l = [hamming(model.data.y,t) for t in model.rule_population.flat_truth]
    arbitrary_rl = rules.RuleList(
        [[model.rule_population.flat_rules[np.argmin(l)]]]
     )
    sampler = logit_rules.LogisticRuleSampler(model,
                                              arbitrary_rl)

    if config.has_option('sampling','sampling_time'):
        sampling_time = config.getfloat('sampling','sampling_time')
        logger.info('Starting sampling: will continue for %.1f seconds' %
                    sampling_time)
        sampler.sample_for(sampling_time)
    elif config.has_option('sampling','total_samples'):
        total_samples = config.getint('sampling','total_samples')
        logger.info('Starting to draw %d samples' % total_samples)
        sampler.sample(total_samples)
    else:
        raise ValueError('Either number of samples or sampling time must be specified.')


    if config.has_option('sampling', 'pickle_sampler'):
        prefix = config.get('description','tag')
        if config.getboolean('sampling','pickle_sampler'):
            filename = prefix + '_sampler_object.pickle'
            with open(filename, 'w') as f:
                pickle.dump(sampler,f)
            logger.info('Sampler written to %s' % filename)

    return sampler
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号