mlstm.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:benchmark 作者: pytorch 项目源码 文件源码
def MultiplicativeLSTMCell(input, hidden, w_xm, w_hm, w_ih, w_mh, b_xm=None, b_hm=None, b_ih=None, b_mh=None):
    # w_ih holds W_hx, W_ix, W_ox, W_fx
    # w_mh holds W_hm, W_im, W_om, W_fm

    hx, cx = hidden

    # Key difference:
    m = F.linear(input, w_xm, b_xm) * F.linear(hx, w_hm, b_hm)
    gates = F.linear(input, w_ih, b_ih) + F.linear(m, w_mh, b_mh)

    ingate, forgetgate, hiddengate, outgate = gates.chunk(4, 1)

    ingate = F.sigmoid(ingate)
    outgate = F.sigmoid(outgate)
    forgetgate = F.sigmoid(forgetgate)

    cy = (forgetgate * cx) + (ingate * hiddengate)
    hy = F.tanh(cy * outgate)

    return hy, cy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号