model_utils.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:lm 作者: rafaljozefowicz 项目源码 文件源码
def __call__(self, inputs, state, scope=None):
        num_proj = self._num_units if self._num_proj is None else self._num_proj

        c_prev = tf.slice(state, [0, 0], [-1, self._num_units])
        m_prev = tf.slice(state, [0, self._num_units], [-1, num_proj])

        input_size = inputs.get_shape().with_rank(2)[1]
        if input_size.value is None:
            raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
        with tf.variable_scope(type(self).__name__,
                               initializer=self._initializer):  # "LSTMCell"
            # i = input_gate, j = new_input, f = forget_gate, o = output_gate
            cell_inputs = tf.concat(1, [inputs, m_prev])
            lstm_matrix = tf.nn.bias_add(tf.matmul(cell_inputs, self._concat_w), self._b)
            i, j, f, o = tf.split(1, 4, lstm_matrix)

            c = tf.sigmoid(f + 1.0) * c_prev + tf.sigmoid(i) * tf.tanh(j)
            m = tf.sigmoid(o) * tf.tanh(c)

            if self._num_proj is not None:
                m = tf.matmul(m, self._concat_w_proj)

        new_state = tf.concat(1, [c, m])
        return m, new_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号