skipconnect_rnn.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:NeuroNLP2 作者: XuezheMax 项目源码 文件源码
def SkipConnectFastGRUCell(input, hidden, hidden_skip, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None):
    if noise_in is not None:
        input = input * noise_in

    hx = torch.cat([hidden, hidden_skip], dim=1)
    if noise_hidden is not None:
        hx = hx * noise_hidden

    if input.is_cuda:
        gi = F.linear(input, w_ih)
        gh = F.linear(hx, w_hh)
        state = fusedBackend.GRUFused()
        return state(gi, gh, hidden) if b_ih is None else state(gi, gh, hidden, b_ih, b_hh)

    gi = F.linear(input, w_ih, b_ih)
    gh = F.linear(hx, w_hh, b_hh)
    i_r, i_i, i_n = gi.chunk(3, 1)
    h_r, h_i, h_n = gh.chunk(3, 1)

    resetgate = F.sigmoid(i_r + h_r)
    inputgate = F.sigmoid(i_i + h_i)
    newgate = F.tanh(i_n + resetgate * h_n)
    hy = newgate + inputgate * (hidden - newgate)

    return hy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号