conv.py 文件源码

python
阅读 41 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                padding=0, dilation=None, groups=1, bias=True):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kh, self.kw = _pair(kernel_size)
        self.dh, self.dw = _pair(stride)
        self.padh, self.padw = _pair(padding)
        self.is_dilated = dilation is not None
        if self.is_dilated:
            self.dilh, self.dilw = _pair(dilation)
        self.groups = groups

        weight = torch.Tensor(self.out_channels, self.in_channels, self.kh,
                self.kw)
        bias = torch.Tensor(self.out_channels) if bias else None
        super(Conv2d, self).__init__(
            weight=weight,
            bias=bias,
        )

        self.reset_parameters()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号