basic_rnn_cells.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:skiprnn-2017-telecombcn 作者: imatge-upc 项目源码 文件源码
def __call__(self, inputs, state, scope=None):
        """Long short-term memory cell (LSTM)."""
        with tf.variable_scope(scope or type(self).__name__):
            c, h = state

            # Parameters of gates are concatenated into one multiply for efficiency.
            concat = rnn_ops.linear([inputs, h], 4 * self._num_units, True)

            # i = input_gate, j = new_input, f = forget_gate, o = output_gate
            i, j, f, o = tf.split(value=concat, num_or_size_splits=4, axis=1)

            if self._layer_norm:
                i = rnn_ops.layer_norm(i, name="i")
                j = rnn_ops.layer_norm(j, name="j")
                f = rnn_ops.layer_norm(f, name="f")
                o = rnn_ops.layer_norm(o, name="o")

            new_c = (c * tf.sigmoid(f + self._forget_bias) + tf.sigmoid(i) *
                     self._activation(j))
            new_h = self._activation(new_c) * tf.sigmoid(o)

            new_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h)
            return new_h, new_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号