mlstm.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:benchmark 作者: pytorch 项目源码 文件源码
def KrauseLSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
    # Terminology matchup:
    #   - This implementation uses the trick of having all gates concatenated
    #     together into a single tensor, so you can do one matrix multiply to
    #     compute all the gates.
    #   - Thus, w_ih holds W_hx, W_ix, W_ox, W_fx
    #       and w_hh holds W_hh, W_ih, W_oh, W_fh
    #   - Notice that the indices are swapped, because F.linear has swapped
    #     arguments.  "Cancelling" indices are always next to each other.
    hx, cx = hidden
    gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
    ingate, forgetgate, hiddengate, outgate = gates.chunk(4, 1)

    ingate = F.sigmoid(ingate)
    outgate = F.sigmoid(outgate)
    forgetgate = F.sigmoid(forgetgate)

    cy = (forgetgate * cx) + (ingate * hiddengate)
    hy = F.tanh(cy * outgate)

    return hy, cy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号