HAN_model.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:hierarchical-attention-networks 作者: ematvey 项目源码 文件源码
def _init_body(self, scope):
    with tf.variable_scope(scope):

      word_level_inputs = tf.reshape(self.inputs_embedded, [
        self.document_size * self.sentence_size,
        self.word_size,
        self.embedding_size
      ])
      word_level_lengths = tf.reshape(
        self.word_lengths, [self.document_size * self.sentence_size])

      with tf.variable_scope('word') as scope:
        word_encoder_output, _ = bidirectional_rnn(
          self.word_cell, self.word_cell,
          word_level_inputs, word_level_lengths,
          scope=scope)

        with tf.variable_scope('attention') as scope:
          word_level_output = task_specific_attention(
            word_encoder_output,
            self.word_output_size,
            scope=scope)

        with tf.variable_scope('dropout'):
          word_level_output = layers.dropout(
            word_level_output, keep_prob=self.dropout_keep_proba,
            is_training=self.is_training,
          )

      # sentence_level

      sentence_inputs = tf.reshape(
        word_level_output, [self.document_size, self.sentence_size, self.word_output_size])

      with tf.variable_scope('sentence') as scope:
        sentence_encoder_output, _ = bidirectional_rnn(
          self.sentence_cell, self.sentence_cell, sentence_inputs, self.sentence_lengths, scope=scope)

        with tf.variable_scope('attention') as scope:
          sentence_level_output = task_specific_attention(
            sentence_encoder_output, self.sentence_output_size, scope=scope)

        with tf.variable_scope('dropout'):
          sentence_level_output = layers.dropout(
            sentence_level_output, keep_prob=self.dropout_keep_proba,
            is_training=self.is_training,
          )

      with tf.variable_scope('classifier'):
        self.logits = layers.fully_connected(
          sentence_level_output, self.classes, activation_fn=None)

        self.prediction = tf.argmax(self.logits, axis=-1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号