graphutils.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:rnnlab 作者: phueb 项目源码 文件源码
def __call__(self, input, state, scope=None):  # TODO test
        with tf.variable_scope(scope or type(self).__name__):
            # computation
            c_prev, h_prev = state
            with tf.variable_scope('mul'):
                concat = _linear([input, h_prev], 2 * self._num_units, True)
            proj_input, rec_input = tf.split(value=concat, num_or_size_splits=2, axis=1)
            mul_input = proj_input * rec_input  # equation (18)
            with tf.variable_scope('rec_input'):
                rec_mul_input = _linear(mul_input, 4 * self._num_units, True)
                b = tf.get_variable('b', [self._num_units * 4])
            lstm_mat = input + rec_mul_input + b
            i, j, f, o = tf.split(value=lstm_mat, num_or_size_splits=4, axis=1)
        # new_c, new_h
        new_c = (c_prev * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) * tf.nn.tanh(j))
        new_h = tf.nn.tanh(new_c) * tf.nn.sigmoid(o)
        new_state = (LSTMStateTuple(new_c, new_h))
        return new_h, new_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号