hierarchical_layers.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:text_classification 作者: senochow 项目源码 文件源码
def HierarchicalRNN(embed_matrix, max_words, ans_cnt, sequence_length, embedding_dim, lstm_dim=100):
    ''' Hierachical RNN model
        Input: (batch_size, answers, answer words)
    Args:
        embed_matrix: word embedding
        max words:    word dict size of embedding layer
        ans_cnt:      answer count
        sequence_length: answer words count
        embedding_dim: embedding dimention
        lstm_dim:
    '''
    hnn = Sequential()
    x = Input(shape=(ans_cnt, sequence_length))
    # 1. time distributed word embedding: (None, steps, words, embed_dim)
    words_embed = TimeDistributed(Embedding(max_words, embedding_dim,input_length=sequence_length,weights=[embed_matrix]))(x)
    # 2. word level lstm embedding: --> (None, steps/sentence_num, hidden/sent_words, hidden_dim)
    word_lstm = TimeDistributed(Bidirectional(MGU(lstm_dim, return_sequences=True)))(words_embed)

    # 3. average pooling : --> (None,steps,dim)
    word_avg = TimeDistributed(GlobalMaxPooling1D())(word_lstm)
    #word_avg = TimeDistributed(AttentionLayer(lstm_dim*2))(word_lstm)

    # 4.  sentence lstm:  --> (None, hidden, hidden_dim)
    sent_lstm = Bidirectional(MGU(lstm_dim, return_sequences=True))(word_avg)

    # 5. pooling:  --> (None, hidden_dim)
    sent_avg = GlobalMaxPooling1D()(sent_lstm)
    #sent_avg = AttentionLayer(lstm_dim*2)(sent_lstm)
    model = Model(input=x, output=sent_avg)
    hnn.add(model)
    return hnn


# vim: set expandtab ts=4 sw=4 sts=4 tw=100:
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号