def __call__(self, inputs, state, time_mask, scope=None):
"""Gated recurrent unit (GRU) with state_size dimension cells."""
with tf.variable_scope(self._scope or type(self).__name__): # "GRUCell"
input_size = self._input_size
state_size = self._state_size
hidden = tf.concat(1, [state, inputs])
with tf.variable_scope("Gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
self.W_reset = tf.get_variable(name="reset_weight", shape=[state_size+input_size, state_size], \
initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1))
self.W_update = tf.get_variable(name="update_weight", shape=[state_size+input_size, state_size], \
initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1))
self.b_reset = tf.get_variable(name="reset_bias", shape=[state_size], \
initializer=tf.constant_initializer(1.0))
self.b_update = tf.get_variable(name="update_bias", shape=[state_size], \
initializer=tf.constant_initializer(1.0))
reset = sigmoid(tf.matmul(hidden, self.W_reset) + self.b_reset)
update = sigmoid(tf.matmul(hidden, self.W_update) + self.b_update)
with tf.variable_scope("Candidate"):
self.W_candidate = tf.get_variable(name="candidate_weight", shape=[state_size+input_size, state_size], \
initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1))
self.b_candidate = tf.get_variable(name="candidate_bias", shape=[state_size], \
initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1))
reset_input = tf.concat(1, [reset * state, inputs])
candidate = self._activation(tf.matmul(reset_input, self.W_reset) + self.b_candidate)
# Complement of time_mask
anti_time_mask = tf.cast(time_mask<=0, tf.float32)
new_h = update * state + (1 - update) * candidate
new_h = time_mask * new_h + anti_time_mask * state
return new_h, new_h
def zero_state(self, batch_size):
return tf.Variable(tf.zeros([batch_size, state_size]), dtype=tf.float32)
评论列表
文章目录