test_nn.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def test_ConvTranspose2d_output_size(self):
        m = nn.ConvTranspose2d(3, 4, 3, 3, 0, 2)
        i = Variable(torch.randn(2, 3, 6, 6))
        for h in range(15, 22):
            for w in range(15, 22):
                if 18 <= h <= 20 and 18 <= w <= 20:
                    size = (h, w)
                    if h == 19:
                        size = torch.LongStorage(size)
                    elif h == 2:
                        size = torch.LongStorage((2, 4) + size)
                    m(i, output_size=(h, w))
                else:
                    self.assertRaises(ValueError, lambda: m(i, (h, w)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号