model.py 文件源码

python
阅读 45 收藏 0 点赞 0 评论 0

项目:deep-attention-text-classifier-tf 作者: krayush07 项目源码 文件源码
def apply_attention(self):
        with tf.variable_scope('attention'):
            attention_vector = tf.get_variable(name='attention_vector',
                                               shape=[self.params.ATTENTION_DIM],
                                               dtype=tf.float32)

            mlp_layer_projection = tf.layers.dense(inputs=self.rnn_outputs,
                                                   units=self.params.ATTENTION_DIM,
                                                   activation=tf.nn.tanh,
                                                   kernel_initializer=tf.contrib.layers.xavier_initializer(),
                                                   name='fc_attn')

            attended_vector = tf.tensordot(mlp_layer_projection, attention_vector, axes=[[2], [0]])
            attention_weights = tf.expand_dims(tf.nn.softmax(attended_vector), -1)

            weighted_input = tf.matmul(self.rnn_outputs, attention_weights, transpose_a=True)
            self.attention_output = tf.squeeze(weighted_input, axis=2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号