ops.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:WaveNet 作者: ritheshkumar95 项目源码 文件源码
def init_weights(fan_in,fan_out,init='he'):

    def uniform(stdev, size):
        """uniform distribution with the given stdev and size"""
        return numpy.random.uniform(
            low=-stdev * numpy.sqrt(3),
            high=stdev * numpy.sqrt(3),
            size=size
        ).astype(theano.config.floatX)

    if init == 'lecun' or (init == None and fan_in != fan_out):
        weight_values = uniform(numpy.sqrt(1. / fan_in), (fan_in, fan_out))

    elif init == 'he':
        weight_values = uniform(numpy.sqrt(2. / fan_in), (fan_in, fan_out))

    elif init == 'orthogonal' or (init == None and fan_in == fan_out):
        # From lasagne
        def sample(shape):
            if len(shape) < 2:
                raise RuntimeError("Only shapes of length 2 or more are "
                                   "supported.")
            flat_shape = (shape[0], numpy.prod(shape[1:]))
            # TODO: why normal and not uniform?
            a = numpy.random.normal(0.0, 1.0, flat_shape)
            u, _, v = numpy.linalg.svd(a, full_matrices=False)
            # pick the one with the correct shape
            q = u if u.shape == flat_shape else v
            q = q.reshape(shape)
            return q.astype(theano.config.floatX)
        weight_values = sample((fan_in, fan_out))
    return weight_values
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号