graph.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:multi-task-learning 作者: jg8610 项目源码 文件源码
def _pos_private(self, encoder_units, config, is_training):
        """Decode model for pos

        Args:
            encoder_units - these are the encoder units
            num_pos - the number of pos tags there are (output units)

        returns:
            logits
        """
        with tf.variable_scope("pos_decoder"):
            pos_decoder_cell = rnn.BasicLSTMCell(config.pos_decoder_size,
                                     forget_bias=1.0, reuse=tf.get_variable_scope().reuse)

            if is_training and config.keep_prob < 1:
                pos_decoder_cell = rnn.DropoutWrapper(
                    pos_decoder_cell, output_keep_prob=config.keep_prob)

            encoder_units = tf.transpose(encoder_units, [1, 0, 2])

            decoder_outputs, decoder_states = tf.nn.dynamic_rnn(pos_decoder_cell,
                                                                encoder_units,
                                                                dtype=tf.float32,
                                                                scope="pos_rnn")

            output = tf.reshape(tf.concat(decoder_outputs, 1),
                                [-1, config.pos_decoder_size])

            softmax_w = tf.get_variable("softmax_w",
                                        [config.pos_decoder_size,
                                         config.num_pos_tags])
            softmax_b = tf.get_variable("softmax_b", [config.num_pos_tags])
            logits = tf.matmul(output, softmax_w) + softmax_b

        return logits, decoder_states
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号