skipconnect_rnn.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:NeuroNLP2 作者: XuezheMax 项目源码 文件源码
def SkipConnectFastLSTMCell(input, hidden, hidden_skip, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None):
    if noise_in is not None:
        input = input * noise_in

    hx, cx = hidden
    hx = torch.cat([hx, hidden_skip], dim=1)
    if noise_hidden is not None:
        hx = hx * noise_hidden

    if input.is_cuda:
        igates = F.linear(input, w_ih)
        hgates = F.linear(hx, w_hh)
        state = fusedBackend.LSTMFused()
        return state(igates, hgates, cx) if b_ih is None else state(igates, hgates, cx, b_ih, b_hh)

    gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)

    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

    ingate = F.sigmoid(ingate)
    forgetgate = F.sigmoid(forgetgate)
    cellgate = F.tanh(cellgate)
    outgate = F.sigmoid(outgate)

    cy = (forgetgate * cx) + (ingate * cellgate)
    hy = outgate * F.tanh(cy)

    return hy, cy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号