rnn_cell.py 文件源码

python
阅读 192 收藏 0 点赞 0 评论 0

项目:reading-comprehension 作者: kellywzhang 项目源码 文件源码
def __call__(self, inputs, state, time_mask, scope=None):
    """Gated recurrent unit (GRU) with state_size dimension cells."""
    with tf.variable_scope(self._scope or type(self).__name__):  # "GRUCell"
        input_size = self._input_size
        state_size = self._state_size

        hidden = tf.concat(1, [state, inputs])

        with tf.variable_scope("Gates"):  # Reset gate and update gate.
            # We start with bias of 1.0 to not reset and not update.
            self.W_reset = tf.get_variable(name="reset_weight", shape=[state_size+input_size, state_size], \
                initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1))
            self.W_update = tf.get_variable(name="update_weight", shape=[state_size+input_size, state_size], \
                initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1))
            self.b_reset = tf.get_variable(name="reset_bias", shape=[state_size], \
                initializer=tf.constant_initializer(1.0))
            self.b_update = tf.get_variable(name="update_bias", shape=[state_size], \
                initializer=tf.constant_initializer(1.0))

            reset = sigmoid(tf.matmul(hidden, self.W_reset) + self.b_reset)
            update = sigmoid(tf.matmul(hidden, self.W_update) + self.b_update)

        with tf.variable_scope("Candidate"):
            self.W_candidate = tf.get_variable(name="candidate_weight", shape=[state_size+input_size, state_size], \
                initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1))
            self.b_candidate = tf.get_variable(name="candidate_bias", shape=[state_size], \
                initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1))

            reset_input = tf.concat(1, [reset * state, inputs])
            candidate = self._activation(tf.matmul(reset_input, self.W_reset) + self.b_candidate)

        # Complement of time_mask
        anti_time_mask = tf.cast(time_mask<=0, tf.float32)
        new_h = update * state + (1 - update) * candidate
        new_h = time_mask * new_h + anti_time_mask * state

    return new_h, new_h

    def zero_state(self, batch_size):
        return tf.Variable(tf.zeros([batch_size, state_size]), dtype=tf.float32)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号