def forward(self, input, *args, **kwargs):
assert np.ndim(input) == 3, 'Only support batch training.'
# record
self.last_input = input
# dim
nb_batch, nb_timesteps, nb_in = input.shape
# outputs
output = _zero((nb_batch, nb_timesteps, self.n_out))
# forward
for i in range(nb_timesteps):
# data
s_pre = _zero((nb_batch, self.n_out)) if i == 0 else output[:, i - 1, :]
x_now = input[:, i, :]
# computation
z_now = self.gate_activation.forward(np.dot(x_now, self.U_z) +
np.dot(s_pre, self.W_z) +
self.b_z)
r_now = self.gate_activation.forward(np.dot(x_now, self.U_r) +
np.dot(s_pre, self.W_r) +
self.b_r)
h_now = self.activation.forward(np.dot(x_now, self.U_h) +
np.dot(s_pre * r_now, self.W_h) +
self.b_h)
output[:, i, :] = (1 - z_now) * h_now + z_now * s_pre
# record
self.last_output = output
# return
if self.return_sequence:
return self.last_output
else:
return self.last_output[:, -1, :]
评论列表
文章目录