def __call__(self, inputs, state, scope=None):
with tf.variable_scope(scope, default_name="gru_cell",
values=[inputs, state]):
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
all_inputs = list(inputs) + [state]
r = tf.nn.sigmoid(linear(all_inputs, self._num_units, False, False,
scope="reset_gate"))
u = tf.nn.sigmoid(linear(all_inputs, self._num_units, False, False,
scope="update_gate"))
all_inputs = list(inputs) + [r * state]
c = linear(all_inputs, self._num_units, True, False,
scope="candidate")
new_state = (1.0 - u) * state + u * tf.tanh(c)
return new_state, new_state
评论列表
文章目录