skip_rnn_cells.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:skiprnn-2017-telecombcn 作者: imatge-upc 项目源码 文件源码
def __call__(self, inputs, state, scope=None):
        with tf.variable_scope(scope or type(self).__name__):
            h_prev, update_prob_prev, cum_update_prob_prev = state

            # Parameters of gates are concatenated into one multiply for efficiency.
            with tf.variable_scope("gates"):
                concat = rnn_ops.linear([inputs, h_prev], 2 * self._num_units, bias=True, bias_start=1.0)

            # r = reset_gate, u = update_gate
            r, u = tf.split(value=concat, num_or_size_splits=2, axis=1)

            if self._layer_norm:
                r = rnn_ops.layer_norm(r, name="r")
                u = rnn_ops.layer_norm(u, name="u")

            # Apply non-linearity after layer normalization
            r = tf.sigmoid(r)
            u = tf.sigmoid(u)

            with tf.variable_scope("candidate"):
                new_c_tilde = self._activation(rnn_ops.linear([inputs, r * h_prev], self._num_units, True))
            new_h_tilde = u * h_prev + (1 - u) * new_c_tilde

            # Compute value for the update prob
            with tf.variable_scope('state_update_prob'):
                new_update_prob_tilde = rnn_ops.linear(new_h_tilde, 1, True, bias_start=self._update_bias)
                new_update_prob_tilde = tf.sigmoid(new_update_prob_tilde)

            # Compute value for the update gate
            cum_update_prob = cum_update_prob_prev + tf.minimum(update_prob_prev, 1. - cum_update_prob_prev)
            update_gate = _binary_round(cum_update_prob)

            # Apply update gate
            new_h = update_gate * new_h_tilde + (1. - update_gate) * h_prev
            new_update_prob = update_gate * new_update_prob_tilde + (1. - update_gate) * update_prob_prev
            new_cum_update_prob = update_gate * 0. + (1. - update_gate) * cum_update_prob

            new_state = SkipGRUStateTuple(new_h, new_update_prob, new_cum_update_prob)
            new_output = SkipGRUOutputTuple(new_h, update_gate)

            return new_output, new_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号