networks.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:LasagneNLP 作者: XuezheMax 项目源码 文件源码
def build_BiLSTM_HighCNN(incoming1, incoming2, num_units, mask=None, grad_clipping=0, precompute_input=True,
                         peepholes=False, num_filters=20, dropout=True, in_to_out=False):
    # first get some necessary dimensions or parameters
    conv_window = 3
    _, sent_length, _ = incoming2.output_shape

    # dropout before cnn
    if dropout:
        incoming1 = lasagne.layers.DropoutLayer(incoming1, p=0.5)

    # construct convolution layer
    cnn_layer = lasagne.layers.Conv1DLayer(incoming1, num_filters=num_filters, filter_size=conv_window, pad='full',
                                           nonlinearity=lasagne.nonlinearities.tanh, name='cnn')
    # infer the pool size for pooling (pool size should go through all time step of cnn)
    _, _, pool_size = cnn_layer.output_shape
    # construct max pool layer
    pool_layer = lasagne.layers.MaxPool1DLayer(cnn_layer, pool_size=pool_size)
    # reshape the layer to match highway incoming layer [batch * sent_length, num_filters, 1] --> [batch * sent_length, num_filters]
    output_cnn_layer = lasagne.layers.reshape(pool_layer, ([0], -1))

    # dropout after cnn?
    # if dropout:
    # output_cnn_layer = lasagne.layers.DropoutLayer(output_cnn_layer, p=0.5)

    # construct highway layer
    highway_layer = HighwayDenseLayer(output_cnn_layer, nonlinearity=nonlinearities.rectify)

    # reshape the layer to match lstm incoming layer [batch * sent_length, num_filters] --> [batch, sent_length, number_filters]
    output_highway_layer = lasagne.layers.reshape(highway_layer, (-1, sent_length, [1]))

    # finally, concatenate the two incoming layers together.
    incoming = lasagne.layers.concat([output_highway_layer, incoming2], axis=2)

    return build_BiLSTM(incoming, num_units, mask=mask, grad_clipping=grad_clipping, peepholes=peepholes,
                        precompute_input=precompute_input, dropout=dropout, in_to_out=in_to_out)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号