def _step_batch(self, x_t, mask, h_t_1, w, u, b):
"""
step function of forward in batch version
:param x_t: (batch, in)
:param mask: (batch, )
:param h_t_1: (batch, hidden)
:param w: (hidden, in)
:param u: (hidden, hidden)
:param b: (hidden)
:return: (batch, hidden)
"""
# (batch, in) (in, hidden) -> (batch, hidden)
h_t_1 = T.reshape(h_t_1, (h_t_1.shape[0], 8, 8))
x_t = T.reshape(x_t, (x_t.shape[0], 8, 8))
x_t = x_t / x_t.norm(2, axis=1)[:, None, :]
h_t = self.act.activate(T.dot(x_t, w.T) + T.dot(h_t_1, u.T) + b)
h_t = h_t / h_t.norm(2, axis=1)[:, None, :]
h_t_1 = T.reshape(h_t_1, (h_t_1.shape[0], 64))
h_t = T.reshape(h_t, (h_t.shape[0], 64))
# (batch, hidden) * (batch, None) + (batch, hidden) * (batch, None) -> (batch, hidden)
return h_t * mask[:, None] + h_t_1 * (1 - mask[:, None])
评论列表
文章目录