def Mem_Model(story_maxlen,query_maxlen,vocab_size):
input_encoder_m=Input(shape=(story_maxlen,),dtype='int32',name='input_encoder_m')
x=Embedding(output_dim=64,input_dim=vocab_size,input_length=story_maxlen)(input_encoder_m)
x=Dropout(0.5)(x)
question_encoder=Input(shape=(query_maxlen,),dtype='int32',name='question_encoder')
y=Embedding(output_dim=64,input_dim=vocab_size,input_length=query_maxlen)(question_encoder)
y=Dropout(0.5)(y)
z=merge([x,y],mode='dot',dot_axes=[2,2])
# z=merge([x,y],mode='sum')
match=Activation('softmax')(z)
input_encoder_c=Input(shape=(story_maxlen,),dtype='int32',name='input_encoder_c')
c=Embedding(output_dim=query_maxlen,input_dim=vocab_size,input_length=story_maxlen)(input_encoder_c)
c=Dropout(0.5)(c)
response=merge([match,c],mode='sum')
w=Permute((2,1))(response)
answer=merge([w,y],mode='concat',concat_axis=-1)
lstm=LSTM(32)(answer)
lstm=Dropout(0.5)(lstm)
main_loss=Dense(50,activation='sigmoid',name='main_output')(lstm)
model=Model(input=[input_encoder_m,question_encoder,input_encoder_c],output=main_loss)
return model
MemNN_classifier.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录