graph.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:multi-task-learning 作者: jg8610 项目源码 文件源码
def _shared_layer(self, input_data, config, is_training):
        """Build the model up until decoding.

        Args:
            input_data = size batch_size X num_steps X embedding size

        Returns:
            output units
        """

        with tf.variable_scope('encoder'):
            lstm_cell = rnn.BasicLSTMCell(config.encoder_size, reuse=tf.get_variable_scope().reuse, forget_bias=1.0)
            if is_training and config.keep_prob < 1:
                lstm_cell = rnn.DropoutWrapper(
                    lstm_cell, output_keep_prob=config.keep_prob)
            encoder_outputs, encoder_states = tf.nn.dynamic_rnn(lstm_cell,
                                                                input_data,
                                                                dtype=tf.float32,
                                                                scope="encoder_rnn")

        return encoder_outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号