def step(self, x, states):
h_tm1 = states[0]
c_tm1 = states[1]
B_U = states[2]
B_W = states[3]
z = LN(K.dot(x * B_W[0], self.kernel), self.gamma_1, self.beta_1) + \
LN(K.dot(h_tm1 * B_U[0], self.recurrent_kernel), self.gamma_2, self.beta_2)
if self.use_bias:
z = K.bias_add(z, self.bias)
z0 = z[:, :self.units]
z1 = z[:, self.units: 2 * self.units]
z2 = z[:, 2 * self.units: 3 * self.units]
z3 = z[:, 3 * self.units:]
i = self.recurrent_activation(z0)
f = self.recurrent_activation(z1)
c = f * c_tm1 + i * self.activation(z2)
o = self.recurrent_activation(z3)
h = o * self.activation(LN(c, self.gamma_3, self.beta_3))
return h, [h, c]
评论列表
文章目录