def get_segmenter_function(params_loc, img_size, NCV=1, version=1,
param_file_key = None):
shape = (None, 1, img_size, img_size)
input_var = T.tensor4('input')
if NCV> 1:
expr = 0
params_files = filter(lambda s: 'fcn_v{}'.format(version) in s, os.listdir(params_loc))
if param_file_key is not None:
params_files = filter(lambda s: param_file_key in s, params_files)
for pfn in params_files:
net, _, output_det = build_fcn_segmenter(input_var, shape, version)
u.load_params(net['output'], os.path.join(params_loc, pfn))
cv = int(pfn.split('_')[-1][1]);
if cv == NCV:
expr = expr + output_det * NCV;
else:
expr = expr + output_det
print 'loaded {}'.format(pfn)
assert(len(params_files)==NCV+1);
expr = expr / NCV /2;
print 'loaded {} in ensemble'.format(len(params_files))
else:
net, _, output_det = build_fcn_segmenter(input_var, shape, version)
u.load_params(net['output'], params_loc)
expr = output_det
print 'loaded indiv function {}'.format(params_loc)
return theano.function([input_var], expr)
评论列表
文章目录