def __call__(self, inputs, state, scope=None):
with _checked_scope(self, scope or "rwa_cell", reuse=self._reuse):
h, n, d, a_max = state
with vs.variable_scope("u"):
u = _linear(inputs, self._num_units, True)
with vs.variable_scope("g"):
g = _linear([inputs, h], self._num_units, True)
with vs.variable_scope("a"):
a = _linear([inputs, h], self._num_units, False) # The bias term when factored out of the numerator and denominator cancels and is unnecessary
z = tf.multiply(u, tanh(g))
a_newmax = tf.maximum(a_max, a)
exp_diff = tf.exp(a_max - a_newmax)
exp_scaled = tf.exp(a - a_newmax)
n = tf.multiply(n, exp_diff) + tf.multiply(z, exp_scaled) # Numerically stable update of numerator
d = tf.multiply(d, exp_diff) + exp_scaled # Numerically stable update of denominator
h_new = self._activation(tf.div(n, d))
new_state = RWACellTuple(h_new, n, d, a_newmax)
return h_new, new_state
评论列表
文章目录