def recurrent_fn_hred(x_t, h_tm1,hidden_dim,W1,b1,W2,b2):
global DIM
#A1 = T.nnet.sigmoid(lib.ops.BatchNorm(T.dot(T.concatenate((x_t,h_tm1),axis=1),W1),name="FrameLevel.GRU"+str(name)+".Input.",length=2*512) + b1)
A1 = T.nnet.sigmoid(T.dot(T.concatenate((x_t,h_tm1),axis=1),W1) + b1)
z = A1[:,:hidden_dim]
r = A1[:,hidden_dim:]
scaled_hidden = r*h_tm1
#h = T.tanh(lib.ops.BatchNorm(T.dot(T.concatenate((scaled_hidden,x_t),axis=1),W2),name="FrameLevel.GRU"+str(name)+".Output.",length=512)+b2)
h = T.tanh(T.dot(T.concatenate((scaled_hidden,x_t),axis=1),W2) + b2)
one = lib.floatX(1.0)
return ((z * h) + ((one - z) * h_tm1)).astype('float32')
评论列表
文章目录