def one_step(self, x, h_tm1, s_tm1):
"""
Run the forward pass for a single timestep of a LSTM
h_tm1: initial h
s_tm1: initial s (cell state)
"""
g = T.tanh(T.dot(x, self.W_gx) + T.dot(h_tm1, self.W_gh) + self.b_g)
i = T.nnet.sigmoid(T.dot(x, self.W_ix) + T.dot(h_tm1, self.W_ih) + self.b_i)
f = T.nnet.sigmoid(T.dot(x, self.W_fx) + T.dot(h_tm1, self.W_fh) + self.b_f)
o = T.nnet.sigmoid(T.dot(x, self.W_ox) + T.dot(h_tm1, self.W_oh) + self.b_o)
s = i * g + s_tm1 * f
h = T.tanh(s) * o
return h, s
评论列表
文章目录