def get_eval_fn(model, in3D=False, use_dice=False):
"""Compile the evaluation function of the model."""
if use_dice:
insec = T.sum(model.trg * model.output, axis=1)
tmp = 1 - 2.0 * insec/(T.sum(model.trg, axis=1) + T.sum(model.output,
axis=1))
error = T.mean(tmp)
else:
error = T.mean(T.mean(T.power(model.output - model.trg, 2), axis=1))
if in3D:
x = T.tensor4('x')
else:
x = T.fmatrix("x")
y = T.fmatrix("y")
theano_arg_vl = [x, y]
output_fn_vl = [error, model.output]
eval_fn = theano.function(
theano_arg_vl, output_fn_vl,
givens={model.x: x,
model.trg: y})
return eval_fn
评论列表
文章目录