architectures.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:unblackboxing_webinar 作者: deepsense-ai 项目源码 文件源码
def arch_attention(embedding_layer, sequence_length, classes):    
    tweet_input = Input(shape=(sequence_length,), dtype='int32')        
    embedded_tweet = embedding_layer(tweet_input)

    activations = LSTM(128, return_sequences=True, name='recurrent_layer')(embedded_tweet)

    attention = TimeDistributed(Dense(1, activation='tanh'))(activations) 
    attention = Flatten()(attention)
    attention = Activation('softmax')(attention)
    attention = RepeatVector(128)(attention)
    attention = Permute([2, 1], name='attention_layer')(attention)

    sent_representation = merge([activations, attention], mode='mul')
    sent_representation = Lambda(lambda xin: K.sum(xin, axis=1), name='merged_layer')(sent_representation)

    tweet_output = Dense(classes, activation='softmax', name='predictions')(sent_representation)      

    tweetnet = Model(tweet_input, tweet_output)
    tweetnet.compile(optimizer='adam',
                     loss='categorical_crossentropy',
                     metrics=['accuracy'])    
    return tweetnet
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号