simple_gru_model.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:neural-reading-comp 作者: tianwang95 项目源码 文件源码
def get_model(
        data_path, #Path to dataset
        hid_dim, #Dimension of the hidden GRU layers
        optimizer='rmsprop', #Optimization function to be used
        loss='categorical_crossentropy' #Loss function to be used
        ):

    metadata_dict = {}
    f = open(os.path.join(data_path, 'metadata', 'metadata.txt'), 'r')
    for line in f:
        entry = line.split(':')
        metadata_dict[entry[0]] = int(entry[1])
    f.close()
    story_maxlen = metadata_dict['input_length']
    query_maxlen = metadata_dict['query_length']
    vocab_size = metadata_dict['vocab_size']
    entity_dim = metadata_dict['entity_dim']

    embed_weights = np.load(os.path.join(data_path, 'metadata', 'weights.npy'))
    word_dim = embed_weights.shape[1]

########## MODEL ############

    story_input = Input(shape=(story_maxlen,), dtype='int32', name="StoryInput")

    x = Embedding(input_dim=vocab_size+2,
                  output_dim=word_dim,
                  input_length=story_maxlen,
                  mask_zero=True,
                  weights=[embed_weights])(story_input)

    query_input = Input(shape=(query_maxlen,), dtype='int32', name='QueryInput')

    x_q = Embedding(input_dim=vocab_size+2,
            output_dim=word_dim,
            input_length=query_maxlen,
            mask_zero=True,
            weights=[embed_weights])(query_input)

    concat_embeddings = masked_concat([x_q, x], concat_axis=1)

    lstm = GRU(hid_dim, consume_less='gpu')(concat_embeddings)

    reverse_lstm = GRU(hid_dim, consume_less='gpu', go_backwards=True)(concat_embeddings)

    merged = merge([lstm, reverse_lstm], mode='concat')

    result = Dense(entity_dim, activation='softmax')(merged)

    model = Model(input=[story_input, query_input], output=result)
    model.compile(optimizer=optimizer,
                  loss=loss,
                  metrics=['accuracy'])
    print(model.summary())
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号