han3_max.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:3HAN 作者: ni9elf 项目源码 文件源码
def fhan3_max(MAX_NB_WORDS, MAX_WORDS, MAX_SENTS, EMBEDDING_DIM, WORDGRU, embedding_matrix, DROPOUTPER):

    wordInputs = Input(shape=(MAX_WORDS,), name="wordInputs", dtype='float32')

    wordEmbedding = Embedding(MAX_NB_WORDS, EMBEDDING_DIM, weights=[embedding_matrix], mask_zero=False, trainable=True, name='wordEmbedding')(wordInputs) 

    hij = Bidirectional(GRU(WORDGRU, return_sequences=True), name='gru1')(wordEmbedding)

    #alpha_its, Si = AttentionLayer(name='att1')(hij)
    wordDrop = Dropout(DROPOUTPER, name='wordDrop')(hij)

    word_max = GlobalMaxPooling1D()(wordDrop)      

    wordEncoder = Model(wordInputs, word_max)

    # -----------------------------------------------------------------------------------------------

    docInputs = Input(shape=(None, MAX_WORDS), name='docInputs' ,dtype='float32')

    #sentenceMasking = Masking(mask_value=0.0, name='sentenceMasking')(docInputs)

    sentEncoding = TimeDistributed(wordEncoder, name='sentEncoding')(docInputs) 

    hi = Bidirectional(GRU(WORDGRU, return_sequences=True), merge_mode='concat', name='gru2')(sentEncoding)

    #alpha_s, Vb = AttentionLayer(name='att2')(hi)
    sentDrop = Dropout(DROPOUTPER, name='sentDrop')(hi)

    sent_max = GlobalMaxPooling1D()(sentDrop)         

    Vb = Reshape((1, sent_max._keras_shape[1]))(sent_max)

    #-----------------------------------------------------------------------------------------------

    headlineInput = Input(shape=(MAX_WORDS,), name='headlineInput',dtype='float32')

    headlineEmb = Embedding(MAX_NB_WORDS, EMBEDDING_DIM, mask_zero=False, name='headlineEmb')(headlineInput)

    #Vb = Masking(mask_value=0.0, name='Vb')(Vb)        
    headlineBodyEmb = concatenate([headlineEmb, Vb], axis=1, name='headlineBodyEmb')

    h3 = Bidirectional(GRU(WORDGRU, return_sequences=True), merge_mode='concat', name='gru3')(headlineBodyEmb)

    #a3, Vn = AttentionLayer(name='att3')(h3)

    headDrop = Dropout(DROPOUTPER, name='3Drop')(h3)

    head_max = GlobalMaxPooling1D()(headDrop) 

    v6 = Dense(1, activation="sigmoid", kernel_initializer = 'he_normal', name="dense")(head_max)
    model = Model(inputs=[docInputs, headlineInput] , outputs=[v6])

    sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
    model.compile(loss='binary_crossentropy', optimizer=sgd, metrics=['accuracy'])
    return model, wordEncoder
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号