def forward(self, x, hidden):
h, c = hidden
h = h.view(h.size(1), -1)
c = c.view(c.size(1), -1)
x = x.view(x.size(1), -1)
# Linear mappings
i_t = th.mm(x, self.w_xi) + th.mm(h, self.w_hi) + self.b_i
f_t = th.mm(x, self.w_xf) + th.mm(h, self.w_hf) + self.b_f
o_t = th.mm(x, self.w_xo) + th.mm(h, self.w_ho) + self.b_o
# activations
i_t.sigmoid_()
f_t.sigmoid_()
o_t.sigmoid_()
# cell computations
c_t = th.mm(x, self.w_xc) + th.mm(h, self.w_hc) + self.b_c
c_t.tanh_()
c_t = th.mul(c, f_t) + th.mul(i_t, c_t)
h_t = th.mul(o_t, th.tanh(c_t))
# Reshape for compatibility
h_t = h_t.view(1, h_t.size(0), -1)
c_t = c_t.view(1, c_t.size(0), -1)
if self.dropout > 0.0:
F.dropout(h_t, p=self.dropout, training=self.training, inplace=True)
return h_t, (h_t, c_t)
评论列表
文章目录