MemNN_classifier.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:ParseLawDocuments 作者: FanhuaandLuomu 项目源码 文件源码
def Mem_Model2(story_maxlen,query_maxlen,vocab_size):
    input_encoder_m = Sequential()
    input_encoder_m.add(Embedding(input_dim=vocab_size,
                                  output_dim=128,
                                  input_length=story_maxlen))
    input_encoder_m.add(Dropout(0.5))
    # output: (samples, story_maxlen, embedding_dim)
    # embed the question into a sequence of vectors
    question_encoder = Sequential()
    question_encoder.add(Embedding(input_dim=vocab_size,
                                   output_dim=128,
                                   input_length=query_maxlen))
    question_encoder.add(Dropout(0.5))
    # output: (samples, query_maxlen, embedding_dim)
    # compute a 'match' between input sequence elements (which are vectors)
    # and the question vector sequence
    match = Sequential()
    match.add(Merge([input_encoder_m, question_encoder],
                    mode='dot',
                    dot_axes=[2, 2]))
    match.add(Activation('softmax'))

    plot(match,to_file='model_1.png')

    # output: (samples, story_maxlen, query_maxlen)
    # embed the input into a single vector with size = story_maxlen:
    input_encoder_c = Sequential()
    # input_encoder_c.add(Embedding(input_dim=vocab_size,
    #                               output_dim=query_maxlen,
    #                               input_length=story_maxlen))
    input_encoder_c.add(Embedding(input_dim=vocab_size,
                                  output_dim=query_maxlen,
                                  input_length=story_maxlen))
    input_encoder_c.add(Dropout(0.5))
    # output: (samples, story_maxlen, query_maxlen)
    # sum the match vector with the input vector:
    response = Sequential()
    response.add(Merge([match, input_encoder_c], mode='sum'))
    # output: (samples, story_maxlen, query_maxlen)
    response.add(Permute((2, 1)))  # output: (samples, query_maxlen, story_maxlen)

    plot(response,to_file='model_2.png')

    # concatenate the match vector with the question vector,
    # and do logistic regression on top
    answer = Sequential()
    answer.add(Merge([response, question_encoder], mode='concat', concat_axis=-1))
    # the original paper uses a matrix multiplication for this reduction step.
    # we choose to use a RNN instead.
    answer.add(LSTM(64))
    # one regularization layer -- more would probably be needed.
    answer.add(Dropout(0.5))
    answer.add(Dense(50))
    # we output a probability distribution over the vocabulary
    answer.add(Activation('sigmoid'))

    return answer

# ??????? ?????k???1
评论列表


问题


面经


文章

微信
公众号

扫码关注公众号