models.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:diagnose-heart 作者: woshialex 项目源码 文件源码
def get_segmenter_function(params_loc, img_size, NCV=1, version=1,
        param_file_key = None):
    shape = (None, 1, img_size, img_size)
    input_var = T.tensor4('input')
    if NCV> 1:
        expr = 0
        params_files = filter(lambda s: 'fcn_v{}'.format(version) in s, os.listdir(params_loc))
        if param_file_key is not None:
            params_files = filter(lambda s: param_file_key in s, params_files)
        for pfn in params_files:
            net, _, output_det = build_fcn_segmenter(input_var, shape, version)
            u.load_params(net['output'], os.path.join(params_loc, pfn))
            cv = int(pfn.split('_')[-1][1]);
            if cv == NCV:
                expr = expr + output_det * NCV;
            else:
                expr = expr + output_det
            print 'loaded {}'.format(pfn)
        assert(len(params_files)==NCV+1);
        expr = expr / NCV /2;
        print 'loaded {} in ensemble'.format(len(params_files))
    else:
        net, _, output_det = build_fcn_segmenter(input_var, shape, version)
        u.load_params(net['output'], params_loc)
        expr = output_det
        print 'loaded indiv function {}'.format(params_loc)
    return theano.function([input_var], expr)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号