han3_pretrain.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:3HAN 作者: ni9elf 项目源码 文件源码
def fhan3_pretrain(MAX_NB_WORDS, MAX_WORDS, MAX_SENTS, EMBEDDING_DIM, WORDGRU, embedding_matrix, DROPOUTPER):
    wordInputs = Input(shape=(MAX_WORDS,), name='word1', dtype='float32')

    wordEmbedding = Embedding(MAX_NB_WORDS, EMBEDDING_DIM,  weights=[embedding_matrix], mask_zero=True, trainable=True, name='emb1')(wordInputs) #Assuming all the sentences have same number of words. Check for input_length again.

    hij = Bidirectional(GRU(WORDGRU, name='gru1', return_sequences=True))(wordEmbedding)

    wordDrop = Dropout(DROPOUTPER, name='drop1')(hij)

    alpha_its, Si  = AttentionLayer(name='att1')(wordDrop)  

    wordEncoder = Model(wordInputs, Si)

    wordEncoder.load_weights('han1_pretrain.h5', by_name=True)

    # -----------------------------------------------------------------------------------------------

    docInputs = Input(shape=(None, MAX_WORDS), name='docInputs' ,dtype='float32')

    sentenceMasking = Masking(mask_value=0.0, name='sentenceMasking')(docInputs)

    sentEncoding = TimeDistributed(wordEncoder, name='sentEncoding')(sentenceMasking) 

    hi = Bidirectional(GRU(WORDGRU, return_sequences=True), merge_mode='concat', name='gru2')(sentEncoding)   

    sentDrop = Dropout(DROPOUTPER, name='sentDrop')(hi)

    alpha_s, Vb = AttentionLayer(name='att2')(sentDrop)

    Vb = Reshape((1, Vb._keras_shape[1]))(Vb)

    #-----------------------------------------------------------------------------------------------

    headlineInput = Input(shape=(MAX_WORDS,), name='headlineInput',dtype='float32')

    headlineEmb = Embedding(MAX_NB_WORDS, EMBEDDING_DIM, mask_zero=True, name='headlineEmb')(headlineInput)

    Vb = Masking(mask_value=0.0, name='Vb')(Vb)     
    headlineBodyEmb = concatenate([headlineEmb, Vb], axis=1, name='headlineBodyEmb')

    h3 = Bidirectional(GRU(WORDGRU, return_sequences=True), merge_mode='concat', name='gru3')(headlineBodyEmb)    

    h3Drop =  Dropout(DROPOUTPER, name='h3drop')(h3)

    a3, Vn = AttentionLayer(name='att3')(h3Drop)

    v6 = Dense(1, activation="sigmoid", kernel_initializer = 'he_normal', name="dense")(Vn)
    model = Model(inputs=[docInputs, headlineInput] , outputs=[v6])

    sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
    model.compile(loss='binary_crossentropy', optimizer=sgd, metrics=['accuracy'])
    return model, wordEncoder
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号