lstm.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:relaax 作者: deeplearninc 项目源码 文件源码
def __call__(self, inputs, state, scope=None):
        """Long short-term memory cell (LSTM)."""
        with tf.variable_scope(scope or type(self).__name__):  # "DilatedLSTMCell"
            # Parameters of gates are concatenated into one multiply for efficiency.
            c, h = tf.split(state, 2, axis=1)
            concat = self._linear([inputs, h], 4 * self._num_units, True)

            # i = input_gate, j = new_input, f = forget_gate, o = output_gate
            i, j, f, o = tf.split(concat, 4, axis=1)

            new_c = c * tf.sigmoid(f + self._forget_bias) + tf.sigmoid(i) * tf.tanh(j)
            new_h = tf.tanh(new_c) * tf.sigmoid(o)

            # update relevant cores
            timestep = tf.assign_add(self._timestep, 1)
            core_to_update = tf.mod(timestep, self._cores)

            updated_h = self._hold_mask[core_to_update] * h + self._dilated_mask[core_to_update] * new_h

            return updated_h, tf.concat([new_c, updated_h], axis=1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号