evaluate_new.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:motion-classification 作者: matthiasplappert 项目源码 文件源码
def evaluate_fhmms(dataset, iterator, args):
    # Select features
    if args.features is not None and args.features != dataset.feature_names:
        print('selecting features ...')
        features = _explode_features(args.features)
        start = timeit.default_timer()
        dataset = dataset.dataset_from_feature_names(features)
        print('done, took %fs' % (timeit.default_timer() - start))
    print('')

    chains = [1, 2, 3, 4]
    total_steps = len(chains)

    curr_step = 0
    measures = []
    for chain in chains:
        curr_step += 1
        prefix = '%.3d_%d-chains' % (curr_step, chain)
        print('(%.3d/%.3d) evaluating n_chains=%d ...' % (curr_step, total_steps, chain))
        start = timeit.default_timer()
        old_loglikelihood_method = args.loglikelihood_method
        try:
            # Configure args from which the HMMs are created
            args.n_chains = chain
            if chain == 1:
                args.model = 'hmm'
                args.loglikelihood_method = 'exact'  # there's no approx loglikelihood method for HMMs
            else:
                args.model = 'fhmm-seq'

            ll_stats = _compute_averaged_pos_and_neg_lls(dataset, iterator, prefix, args, save_model=True, compute_distances=False)
            measure = _compute_measure(ll_stats, dataset, args)
        except:
            measure = np.nan
        args.loglikelihood_method = old_loglikelihood_method
        if measure is np.isnan(measure):
            print('measure: not computable')
        else:
            print('measure: %f' % measure)
        measures.append(measure)
        print('done, took %fs' % (timeit.default_timer() - start))
        print('')

    best_idx = np.nanargmax(np.array(measures))  # get the argmax ignoring NaNs
    print('best model with score %f: %d chains' % (measures[best_idx], chains[best_idx]))
    print('detailed reports have been saved')

    # Save results
    assert len(chains) == len(measures)
    if args.output_dir is not None:
        filename = '_results.csv'
        with open(os.path.join(args.output_dir, filename), 'wb') as f:
            writer = csv.writer(f, delimiter=';')
            writer.writerow(['', 'idx', 'measure', 'chains'])
            for idx, (measure, chain) in enumerate(zip(measures, chains)):
                selected = '*' if best_idx == idx else ''
                writer.writerow([selected, '%d' % idx, '%f' % measure, '%d' % chain])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号