ops.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:WaveNet 作者: ritheshkumar95 项目源码 文件源码
def __ConvLSTMStep(
        name,
        seq_len,
        input_dim,
        hidden_dim,
        current_input,
        last_hidden,
        last_cell,
        dilation_depth=10,
        inp_bias_init=0.,
        forget_bias_init=3.,
        out_bias_init=0.,
        g_bias_init=0.):
    # X_t*(U^i, U^f, U^o, U^g)

    dilations = [2**i for i in xrange(dilation_depth)]
    prev_conv = current_input
    last_cell_stack = T.concatenate((last_cell,last_cell),axis=1)
    for i,value in enumerate(dilations):
        #prev_conv = lib.ops.conv1d(name+".WaveNetConv%d"%(i+1),prev_conv,2,1,hidden_dim,input_dim,True,False,pad=(dilation,0),filter_dilation=(dilation,1))[:,:,:current_input.shape[2],:]
        prev_conv,y = lib.ops.WaveNetConv1d("WaveNetBlock-%d"%(i+1),prev_conv,2,hidden_dim,input_dim,bias=True,batchnorm=False,dilation=value)

    prev_conv = T.concatenate((prev_conv,last_hidden),axis=1)
    prev_conv = lib.ops.conv1d(name+".ConvGates",prev_conv,1,1,4*hidden_dim,2*input_dim,True,False)

    W_cell = lib.param(name+'.CellWeights',lasagne.init.HeNormal().sample((3*hidden_dim,seq_len,1)))
    inp_forget = T.nnet.sigmoid(prev_conv[:,:2*hidden_dim] + W_cell[:2*hidden_dim]*last_cell_stack)
    i_t = inp_forget[:,:hidden_dim]
    f_t = inp_forget[:,hidden_dim:]

    C_t = f_t*last_cell + i_t*T.tanh(prev_conv[:,2*hidden_dim:3*hidden_dim])

    o_t = T.nnet.sigmoid(prev_conv[:,3*hidden_dim:]+W_cell[2*hidden_dim:]*C_t)

    H_t = o_t*T.tanh(C_t)

    return H_t,C_t
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号