crf.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:chordrec 作者: fdlm 项目源码 文件源码
def build_net(in_shape, out_size, model):
    # input variables
    input_var = (tt.tensor4('input', dtype='float32')
                 if len(in_shape) > 1 else
                 tt.tensor3('input', dtype='float32'))
    target_var = tt.tensor3('target_output', dtype='float32')
    mask_var = tt.matrix('mask_input', dtype='float32')

    # stack more layers
    network = lnn.layers.InputLayer(
        name='input', shape=(None, None) + in_shape,
        input_var=input_var
    )

    mask_in = lnn.layers.InputLayer(name='mask',
                                    input_var=mask_var,
                                    shape=(None, None))

    network = spg.layers.CrfLayer(
        network, mask_input=mask_in, num_states=out_size, name='CRF')

    return network, input_var, target_var, mask_var
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号