convolution_rbm.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:SeRanet 作者: corochann 项目源码 文件源码
def __init__(self, in_channels, out_channels, ksize, stride=1, real=0, wscale=1.0):
        super(ConvolutionRBM, self).__init__(
            conv=L.Convolution2D(in_channels, out_channels, ksize, stride=stride, wscale=wscale),
        )

#        if gpu >= 0:
#            cuda.check_cuda_available()
#            xp = cuda.cupy # if gpu >= 0 else np
        self.conv.add_param("a", in_channels)  # dtype=xp.float32
        self.conv.a.data.fill(0.)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.ksize = ksize
        self.real = real

        self.rbm_train = False  # default value is false
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号