def build_net(in_shape, out_size, model):
# input variables
input_var = (tt.tensor4('input', dtype='float32')
if len(in_shape) > 1 else
tt.tensor3('input', dtype='float32'))
target_var = tt.tensor3('target_output', dtype='float32')
mask_var = tt.matrix('mask_input', dtype='float32')
# stack more layers
network = lnn.layers.InputLayer(
name='input', shape=(None, None) + in_shape,
input_var=input_var
)
mask_in = lnn.layers.InputLayer(name='mask',
input_var=mask_var,
shape=(None, None))
network = spg.layers.CrfLayer(
network, mask_input=mask_in, num_states=out_size, name='CRF')
return network, input_var, target_var, mask_var
评论列表
文章目录