def SkipConnectFastLSTMCell(input, hidden, hidden_skip, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None):
if noise_in is not None:
input = input * noise_in
hx, cx = hidden
hx = torch.cat([hx, hidden_skip], dim=1)
if noise_hidden is not None:
hx = hx * noise_hidden
if input.is_cuda:
igates = F.linear(input, w_ih)
hgates = F.linear(hx, w_hh)
state = fusedBackend.LSTMFused()
return state(igates, hgates, cx) if b_ih is None else state(igates, hgates, cx, b_ih, b_hh)
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = F.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * F.tanh(cy)
return hy, cy
评论列表
文章目录