lstm.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:rnnprop 作者: vfleaking 项目源码 文件源码
def lstm_func(x, h, c, wx, wh, b):
    """
        x: (N, D)
        h: (N, H)
        c: (N, H)
        wx: (D, 4H)
        wh: (H, 4H)
        b: (4H, )
    """
    N, H = tf.shape(h)[0], tf.shape(h)[1]
    a = tf.reshape(tf.matmul(x, wx) + tf.matmul(h, wh) + b, (N, -1, H))
    i, f, o, g = a[:,0,:], a[:,1,:], a[:,2,:], a[:,3,:]
    i = tf.sigmoid(i)
    f = tf.sigmoid(f)
    o = tf.sigmoid(o)
    g = tf.tanh(g)
    next_c = f * c + i * g
    next_h = o * tf.tanh(next_c)
    return next_h, next_c
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号