def Mem_Model2(story_maxlen,query_maxlen,vocab_size):
input_encoder_m = Sequential()
input_encoder_m.add(Embedding(input_dim=vocab_size,
output_dim=128,
input_length=story_maxlen))
input_encoder_m.add(Dropout(0.5))
# output: (samples, story_maxlen, embedding_dim)
# embed the question into a sequence of vectors
question_encoder = Sequential()
question_encoder.add(Embedding(input_dim=vocab_size,
output_dim=128,
input_length=query_maxlen))
question_encoder.add(Dropout(0.5))
# output: (samples, query_maxlen, embedding_dim)
# compute a 'match' between input sequence elements (which are vectors)
# and the question vector sequence
match = Sequential()
match.add(Merge([input_encoder_m, question_encoder],
mode='dot',
dot_axes=[2, 2]))
match.add(Activation('softmax'))
plot(match,to_file='model_1.png')
# output: (samples, story_maxlen, query_maxlen)
# embed the input into a single vector with size = story_maxlen:
input_encoder_c = Sequential()
# input_encoder_c.add(Embedding(input_dim=vocab_size,
# output_dim=query_maxlen,
# input_length=story_maxlen))
input_encoder_c.add(Embedding(input_dim=vocab_size,
output_dim=query_maxlen,
input_length=story_maxlen))
input_encoder_c.add(Dropout(0.5))
# output: (samples, story_maxlen, query_maxlen)
# sum the match vector with the input vector:
response = Sequential()
response.add(Merge([match, input_encoder_c], mode='sum'))
# output: (samples, story_maxlen, query_maxlen)
response.add(Permute((2, 1))) # output: (samples, query_maxlen, story_maxlen)
plot(response,to_file='model_2.png')
# concatenate the match vector with the question vector,
# and do logistic regression on top
answer = Sequential()
answer.add(Merge([response, question_encoder], mode='concat', concat_axis=-1))
# the original paper uses a matrix multiplication for this reduction step.
# we choose to use a RNN instead.
answer.add(LSTM(64))
# one regularization layer -- more would probably be needed.
answer.add(Dropout(0.5))
answer.add(Dense(50))
# we output a probability distribution over the vocabulary
answer.add(Activation('sigmoid'))
return answer
# ??????? ?????k???1
MemNN_classifier.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录