han3.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:3HAN 作者: ni9elf 项目源码 文件源码
def HAN(MAX_NB_WORDS, MAX_WORDS, MAX_SENTS, EMBEDDING_DIM, WORDGRU, embedding_matrix, DROPOUTPER):

    wordInputs = Input(shape=(MAX_WORDS,), name="wordInputs", dtype='float32')

    wordEmbedding = Embedding(MAX_NB_WORDS, EMBEDDING_DIM, weights=[embedding_matrix], mask_zero=True, trainable=True, name='wordEmbedding')(wordInputs) 

    hij = Bidirectional(GRU(WORDGRU, return_sequences=True), name='gru1')(wordEmbedding)

    wordDrop = Dropout(DROPOUTPER, name='wordDrop')(hij)

    alpha_its, Si = AttentionLayer(name='att1')(wordDrop)


    wordEncoder = Model(wordInputs, Si)
    # -----------------------------------------------------------------------------------------------

    docInputs = Input(shape=(None, MAX_WORDS), name='docInputs' ,dtype='float32')

    sentenceMasking = Masking(mask_value=0.0, name='sentenceMasking')(docInputs)

    sentEncoding = TimeDistributed(wordEncoder, name='sentEncoding')(sentenceMasking) 

    hi = Bidirectional(GRU(WORDGRU, return_sequences=True), merge_mode='concat', name='gru2')(sentEncoding)

    sentDrop = Dropout(DROPOUTPER, name='sentDrop')(hi)

    alpha_s, Vb = AttentionLayer(name='att2')(sentDrop)

    Vb = Reshape((1, Vb._keras_shape[1]))(Vb)

    #-----------------------------------------------------------------------------------------------

    headlineInput = Input(shape=(MAX_WORDS,), name='headlineInput',dtype='float32')

    headlineEmb = Embedding(MAX_NB_WORDS, EMBEDDING_DIM, mask_zero=True, name='headlineEmb')(headlineInput)

    Vb = Masking(mask_value=0.0, name='Vb')(Vb)     

    headlineBodyEmb = concatenate([headlineEmb, Vb], axis=1, name='headlineBodyEmb')

    h3 = Bidirectional(GRU(WORDGRU, return_sequences=True), merge_mode='concat', name='gru3')(headlineBodyEmb)

    headDrop = Dropout(DROPOUTPER, name='3Drop')(h3)

    a3, Vn = AttentionLayer(name='att3')(headDrop)


    v6 = Dense(1, activation="sigmoid", kernel_initializer = 'he_normal', name="dense")(Vn)
    model = Model(inputs=[docInputs, headlineInput] , outputs=[v6])

    sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
    model.compile(loss='binary_crossentropy', optimizer=sgd, metrics=['accuracy'])
    return model, wordEncoder
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号