models.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:diagnose-heart 作者: woshialex 项目源码 文件源码
def get_segmenter_function(params_loc, img_size, ensemble=False, version=2,
        param_file_key = '.npz', weight_full_params=0.33):
    shape = (None, 1, img_size, img_size)
    input_var = T.tensor4('input')
    if ensemble:
        expr = 0
        params_files = filter(lambda s: 'v{}'.format(version) in s, os.listdir(params_loc))
        if param_file_key is not None:
            params_files = filter(lambda s: param_file_key in s, params_files)
        full_params_indices = [i for i,a in enumerate(params_files) if 'f-1' in a]
        if len(full_params_indices) > 0:
            wt_norm = (1. - weight_full_params)/(len(params_files) - len(full_params_indices))
            wt_full = weight_full_params / len(full_params_indices)
            params_weights = [(wt_norm if i not in full_params_indices else wt_full) \
                    for i in xrange(len(params_files))]
        else:
            params_weights = [1./len(params_files)] * len(params_files)
        for pfn,w in zip(params_files, params_weights):
            net, _, output_det = build_fcn_segmenter(input_var, shape, version)
            u.load_params(net['output'], os.path.join(params_loc, pfn))
            expr = expr + w*output_det
            print 'loaded {} wt {}'.format(pfn, w)
        print 'loaded {} in ensemble'.format(len(params_files))
    else:
        net, _, output_det = build_fcn_segmenter(input_var, shape, version)
        u.load_params(net['output'], params_loc)
        expr = output_det
        print 'loaded indiv function {}'.format(params_loc)
    return theano.function([input_var], expr)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号