han_max.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:3HAN 作者: ni9elf 项目源码 文件源码
def fhan2_max(MAX_NB_WORDS, MAX_WORDS, MAX_SENTS, EMBEDDING_DIM, WORDGRU, embedding_matrix, DROPOUTPER):
    wordInputs = Input(shape=(MAX_WORDS,), name="wordInputs", dtype='float32')

    wordEmbedding = Embedding(MAX_NB_WORDS, EMBEDDING_DIM, weights=[embedding_matrix], mask_zero=False, trainable=True, name='wordEmbedding')(wordInputs) 

    hij = Bidirectional(GRU(WORDGRU, return_sequences=True), name='gru1')(wordEmbedding)


    Si = GlobalMaxPooling1D()(hij)

    wordEncoder = Model(wordInputs, Si)

    # -----------------------------------------------------------------------------------------------

    docInputs = Input(shape=(None, MAX_WORDS), name='docInputs' ,dtype='float32')

    #sentenceMasking = Masking(mask_value=0.0, name='sentenceMasking')(docInputs)

    sentEncoding = TimeDistributed(wordEncoder, name='sentEncoding')(docInputs) 

    hi = Bidirectional(GRU(WORDGRU, return_sequences=True), merge_mode='concat', name='gru2')(sentEncoding)

    Vb = GlobalMaxPooling1D()(hi)

    v6 = Dense(1, activation="sigmoid", kernel_initializer = 'glorot_uniform', name="dense")(Vb)
    model = Model(inputs=[docInputs] , outputs=[v6])

    sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
    model.compile(loss='binary_crossentropy', optimizer=sgd, metrics=['accuracy'])
    return model, wordEncoder
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号