def __call__(self, inputs, state, scope=None):
state, fast_weights = state
with vs.variable_scope(scope or type(self).__name__) as scope:
"""Compute Wh(t) + Cx(t)"""
linear = self._fwlinear([state, inputs], self._num_units, False)
"""Compute h_0(t+1) = f(Wh(t) + Cx(t))"""
if not self._reuse_norm:
h = self._activation(self._norm(linear, scope="Norm0"))
else:
h = self._activation(self._norm(linear))
h = self._vector2matrix(h)
linear = self._vector2matrix(linear)
for i in range(self._S):
"""
Compute h_{s+1}(t+1) = f([Wh(t) + Cx(t)] + A(t) h_s(t+1)), S times.
See Eqn (2) in the paper.
"""
if not self._reuse_norm:
h = self._activation(self._norm(linear +
math_ops.batch_matmul(fast_weights, h), scope="Norm%d" % (i + 1)))
else:
h = self._activation(self._norm(linear +
math_ops.batch_matmul(fast_weights, h)))
"""
Compute A(t+1) according to Eqn (4)
"""
state = self._vector2matrix(state)
new_fast_weights = self._lambda * fast_weights + self._eta * math_ops.batch_matmul(state, state, adj_y=True)
h = self._matrix2vector(h)
return h, (h, new_fast_weights)
评论列表
文章目录