network_continous_rnn.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:trpo 作者: jjkke88 项目源码 文件源码
def __call__(self , inputs , state , scope=None):
        """
            Long short-term memory cell (LSTM).
            implement from BasicLSTMCell.__call__
        """
        with tf.variable_scope(scope or type(self).__name__):  # "BasicLSTMCell"
            # Parameters of gates are concatenated into one multiply for efficiency.
            c , h = tf.split(1 , 2 , state)
            concat = self.linear([inputs , h] , 4 * self._num_units , True)

            # i = input_gate, j = new_input, f = forget_gate, o = output_gate
            i , j , f , o = tf.split(1 , 4 , concat)

            new_c = c * tf.sigmoid(f + self._forget_bias) + tf.sigmoid(i) * tf.tanh(j)
            new_h = tf.tanh(new_c) * tf.sigmoid(o)

            return new_h , tf.concat(1 , [new_c , new_h])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号