def __call__(self, inputs, state, scope=None):
with _checked_scope(self, scope or "rwa_cell", reuse=self._reuse):
h, n, d = state
with vs.variable_scope("u"):
u = linear(inputs, self._num_units, True, normalize=self._normalize)
with vs.variable_scope("g"):
g = linear([inputs, h], self._num_units, True, normalize=self._normalize)
with vs.variable_scope("a"): # The bias term when factored out of the numerator and denominator cancels and is unnecessary
a = tf.exp(linear([inputs, h], self._num_units, True, normalize=self._normalize))
with vs.variable_scope("discount_factor"):
discount_factor = tf.nn.sigmoid(linear([inputs, h], self._num_units, True, normalize=self._normalize))
z = tf.multiply(u, tanh(g))
n = tf.multiply(n, discount_factor) + tf.multiply(z, a) # Numerically stable update of numerator
d = tf.multiply(d, discount_factor) + a # Numerically stable update of denominator
h_new = self._activation(tf.div(n, d))
new_state = RDACellTuple(h_new, n, d)
return h_new, new_state
评论列表
文章目录