def SkipConnectGRUCell(input, hidden, hidden_skip, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None):
input = input.expand(3, *input.size()) if noise_in is None else input.unsqueeze(0) * noise_in
hx = torch.cat([hidden, hidden_skip], dim=1)
hx = hx.expand(3, *hx.size()) if noise_hidden is None else hx.unsqueeze(0) * noise_hidden
gi = torch.baddbmm(b_ih.unsqueeze(1), input, w_ih)
gh = torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh)
i_r, i_i, i_n = gi
h_r, h_i, h_n = gh
resetgate = F.sigmoid(i_r + h_r)
inputgate = F.sigmoid(i_i + h_i)
newgate = F.tanh(i_n + resetgate * h_n)
hy = newgate + inputgate * (hidden - newgate)
return hy
评论列表
文章目录