MemNN_classifier.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:ParseLawDocuments 作者: FanhuaandLuomu 项目源码 文件源码
def Mem_Model(story_maxlen,query_maxlen,vocab_size):
    input_encoder_m=Input(shape=(story_maxlen,),dtype='int32',name='input_encoder_m')

    x=Embedding(output_dim=64,input_dim=vocab_size,input_length=story_maxlen)(input_encoder_m)

    x=Dropout(0.5)(x)

    question_encoder=Input(shape=(query_maxlen,),dtype='int32',name='question_encoder')

    y=Embedding(output_dim=64,input_dim=vocab_size,input_length=query_maxlen)(question_encoder)

    y=Dropout(0.5)(y)

    z=merge([x,y],mode='dot',dot_axes=[2,2])
    # z=merge([x,y],mode='sum')

    match=Activation('softmax')(z)

    input_encoder_c=Input(shape=(story_maxlen,),dtype='int32',name='input_encoder_c')

    c=Embedding(output_dim=query_maxlen,input_dim=vocab_size,input_length=story_maxlen)(input_encoder_c)

    c=Dropout(0.5)(c)

    response=merge([match,c],mode='sum')

    w=Permute((2,1))(response)

    answer=merge([w,y],mode='concat',concat_axis=-1)

    lstm=LSTM(32)(answer)

    lstm=Dropout(0.5)(lstm)

    main_loss=Dense(50,activation='sigmoid',name='main_output')(lstm)

    model=Model(input=[input_encoder_m,question_encoder,input_encoder_c],output=main_loss)
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号