ops.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:WaveNet 作者: ritheshkumar95 项目源码 文件源码
def Dense(name, input_dim, output_dim, inputs, bias=True, init=None, weightnorm=True,hidden_dim=None):

    weight_values = init_weights(input_dim,output_dim,init)

    weight = lib.param(
        name + '.W',
        weight_values
    )

    batch_size = None
    if inputs.ndim==3:
        batch_size = inputs.shape[0]
        inputs = inputs.reshape((-1,input_dim))

    if weightnorm:
        norm_values = numpy.linalg.norm(weight_values, axis=0)
        norms = lib.param(
            name + '.g',
            norm_values
        )

        normed_weight = weight * (norms / weight.norm(2, axis=0)).dimshuffle('x', 0)
        result = T.dot(inputs, normed_weight)

    else:
        result = T.dot(inputs, weight)

    if bias:
        b = lib.param(
            name + '.b',
            numpy.zeros((output_dim,), dtype=theano.config.floatX)
        )
        result += b

    result.name = name+".output"
    if batch_size!=None:
        return result.reshape((batch_size,hidden_dim,output_dim))
    else:
        return result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号