rnn_cell.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:qrn 作者: uwnlp 项目源码 文件源码
def __call__(self, inputs, state, scope=None):
        gate_size = self._gate_size
        with tf.variable_scope(scope or type(self).__name__):  # "RSMCell"
            with tf.name_scope("Split"):  # Reset gate and update gate.
                a = tf.slice(inputs, [0, 0], [-1, gate_size])
                x, u, v_t = tf.split(1, 3, tf.slice(inputs, [0, gate_size], [-1, -1]))
                o = tf.slice(state, [0, 0], [-1, 1])
                h, v = tf.split(1, 2, tf.slice(state, [0, gate_size], [-1, -1]))

            with tf.variable_scope("Main"):
                r_raw = linear([x * u], 1, True, scope='r_raw', var_on_cpu=self._var_on_cpu,
                               initializer=self._initializer)
                r = tf.sigmoid(r_raw, name='a')
                new_o = a * r + (1 - a) * o
                new_v = a * v_t + (1 - a) * v
                g = r * v_t
                new_h = a * g + (1 - a) * h

            with tf.name_scope("Concat"):
                new_state = tf.concat(1, [new_o, new_h, new_v])
                outputs = tf.concat(1, [a, r, x, new_h, new_v, g])

        return outputs, new_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号