ops.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:WaveNet 作者: ritheshkumar95 项目源码 文件源码
def recurrent_fn_hred(x_t, h_tm1,hidden_dim,W1,b1,W2,b2):
    global DIM
    #A1 = T.nnet.sigmoid(lib.ops.BatchNorm(T.dot(T.concatenate((x_t,h_tm1),axis=1),W1),name="FrameLevel.GRU"+str(name)+".Input.",length=2*512) + b1)
    A1 = T.nnet.sigmoid(T.dot(T.concatenate((x_t,h_tm1),axis=1),W1) + b1)

    z = A1[:,:hidden_dim]

    r = A1[:,hidden_dim:]

    scaled_hidden = r*h_tm1

    #h = T.tanh(lib.ops.BatchNorm(T.dot(T.concatenate((scaled_hidden,x_t),axis=1),W2),name="FrameLevel.GRU"+str(name)+".Output.",length=512)+b2)
    h = T.tanh(T.dot(T.concatenate((scaled_hidden,x_t),axis=1),W2) + b2)

    one = lib.floatX(1.0)
    return ((z * h) + ((one - z) * h_tm1)).astype('float32')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号