def forward(self, x, mask, hc):
n_in, n_out, activation = self.n_in, self.n_out_t, self.activation
if hc.ndim > 1:
c_tm1 = hc[:, :n_out]
h_tm1 = hc[:, n_out:]
else:
c_tm1 = hc[:n_out]
h_tm1 = hc[n_out:]
in_t = self.in_gate.forward(x,h_tm1)
forget_t = self.forget_gate.forward(x,h_tm1)
out_t = self.out_gate.forward(x, h_tm1)
c_t = forget_t * c_tm1 + in_t * self.input_layer.forward(x,h_tm1)
c_t = c_t * mask.dimshuffle(0, 'x')
c_t = T.cast(c_t, 'float32')
h_t = out_t * T.tanh(c_t)
h_t = h_t * mask.dimshuffle(0, 'x')
h_t = T.cast(h_t, 'float32')
if hc.ndim > 1:
return T.concatenate([ c_t, h_t ], axis=1)
else:
return T.concatenate([ c_t, h_t ])
评论列表
文章目录