def lstm_func(x, h, c, wx, wh, b):
"""
x: (N, D)
h: (N, H)
c: (N, H)
wx: (D, 4H)
wh: (H, 4H)
b: (4H, )
"""
N, H = tf.shape(h)[0], tf.shape(h)[1]
a = tf.reshape(tf.matmul(x, wx) + tf.matmul(h, wh) + b, (N, -1, H))
i, f, o, g = a[:,0,:], a[:,1,:], a[:,2,:], a[:,3,:]
i = tf.sigmoid(i)
f = tf.sigmoid(f)
o = tf.sigmoid(o)
g = tf.tanh(g)
next_c = f * c + i * g
next_h = o * tf.tanh(next_c)
return next_h, next_c
评论列表
文章目录