def MultiplicativeLSTMCell(input, hidden, w_xm, w_hm, w_ih, w_mh, b_xm=None, b_hm=None, b_ih=None, b_mh=None):
# w_ih holds W_hx, W_ix, W_ox, W_fx
# w_mh holds W_hm, W_im, W_om, W_fm
hx, cx = hidden
# Key difference:
m = F.linear(input, w_xm, b_xm) * F.linear(hx, w_hm, b_hm)
gates = F.linear(input, w_ih, b_ih) + F.linear(m, w_mh, b_mh)
ingate, forgetgate, hiddengate, outgate = gates.chunk(4, 1)
ingate = F.sigmoid(ingate)
outgate = F.sigmoid(outgate)
forgetgate = F.sigmoid(forgetgate)
cy = (forgetgate * cx) + (ingate * hiddengate)
hy = F.tanh(cy * outgate)
return hy, cy
评论列表
文章目录