nn.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:WGAN_mnist 作者: rajeswar18 项目源码 文件源码
def get_output_for(self, input, init=False, deterministic=False, **kwargs):
        if input.ndim > 2:
            # if the input has more than two dimensions, flatten it into a
            # batch of feature vectors.
            input = input.flatten(2)

        activation = T.dot(input, self.W)

        if init:
            ma = T.mean(activation, axis=0)
            activation -= ma.dimshuffle('x',0)
            stdv = T.sqrt(T.mean(T.square(activation),axis=0))
            activation /= stdv.dimshuffle('x',0)
            self.init_updates = [(self.weight_scale, self.weight_scale/stdv), (self.b, -ma/stdv)]
        else:
            activation += self.b.dimshuffle('x', 0)

        return self.nonlinearity(activation)


# comes from Ishamel code base
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号