modules.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:braindecode 作者: robintibor 项目源码 文件源码
def forward(self, x):
        # Create weights for the convolution on demand:
        # size or type of x changed...
        in_channels = x.size()[1]
        weight_shape = (in_channels, 1,
                        self.kernel_size[0], self.kernel_size[1])
        if self.weights is None or (
                (tuple(self.weights.size()) != tuple(weight_shape)) or (
                  self.weights.is_cuda != x.is_cuda
                ) or (
                    self.weights.data.type() != x.data.type()
                )):
            n_pool = np.prod(self.kernel_size)
            weights = np_to_var(
                np.ones(weight_shape, dtype=np.float32) / float(n_pool))
            weights = weights.type_as(x)
            if x.is_cuda:
                weights = weights.cuda()
            self.weights = weights

        pooled = F.conv2d(x, self.weights, bias=None, stride=self.stride,
                          dilation=self.dilation,
                          groups=in_channels,)
        return pooled
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号