graph.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:multi-task-learning 作者: jg8610 项目源码 文件源码
def _chunk_private(self, encoder_units, pos_prediction, config, is_training):
        """Decode model for chunks

        Args:
            encoder_units - these are the encoder units:
            [batch_size X encoder_size] with the one the pos prediction
            pos_prediction:
            must be the same size as the encoder_size

        returns:
            logits
        """
        # concatenate the encoder_units and the pos_prediction

        pos_prediction = tf.reshape(pos_prediction,
                                    [self.batch_size, self.num_steps, self.pos_embedding_size])
        encoder_units = tf.transpose(encoder_units, [1, 0, 2])
        chunk_inputs = tf.concat([pos_prediction, encoder_units], 2)

        with tf.variable_scope("chunk_decoder"):
            cell = rnn.BasicLSTMCell(config.chunk_decoder_size, forget_bias=1.0, reuse=tf.get_variable_scope().reuse)

            if is_training and config.keep_prob < 1:
                cell = rnn.DropoutWrapper(
                    cell, output_keep_prob=config.keep_prob)

            decoder_outputs, decoder_states = tf.nn.dynamic_rnn(cell,
                                                                chunk_inputs,
                                                                dtype=tf.float32,
                                                                scope="chunk_rnn")

            output = tf.reshape(tf.concat(decoder_outputs, 1),
                                [-1, config.chunk_decoder_size])

            softmax_w = tf.get_variable("softmax_w",
                                        [config.chunk_decoder_size,
                                         config.num_chunk_tags])
            softmax_b = tf.get_variable("softmax_b", [config.num_chunk_tags])
            logits = tf.matmul(output, softmax_w) + softmax_b

        return logits, decoder_states
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号