ops.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:WaveNet 作者: ritheshkumar95 项目源码 文件源码
def myGRU(name, input_dim, hidden_dim, inputs, h0=None):
    #inputs.shape = (batch_size,N_FRAMES,FRAME_SIZE)
    inputs = inputs.transpose(1,0,2)

    weight_values = lasagne.init.GlorotUniform().sample((input_dim+hidden_dim,2*hidden_dim))
    W1 = lib.param(
        name+'.Gates.W',
        weight_values
    )

    b1 = lib.param(
        name+'.Gates.b',
        np.ones(2*hidden_dim).astype(theano.config.floatX)
        )

    weight_values = lasagne.init.GlorotUniform().sample((input_dim+hidden_dim,hidden_dim))
    W2 = lib.param(
        name+'.Candidate.W',
        weight_values
    )

    b2 = lib.param(
        name+'.Candidate.b',
        np.zeros(hidden_dim).astype(theano.config.floatX)
        )

    def step(x_t, h_tm1):
        return recurrent_fn(
            x_t,
            h_tm1,
            name,
            input_dim,
            hidden_dim,
            W1,b1,W2,b2
        )

    outputs, _ = theano.scan(
        step,
        sequences=[inputs],
        outputs_info=[h0],
    )

    out = outputs.dimshuffle(1,0,2)
    out.name = name+'.output'
    return out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号