common_nn.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def __call__(self, test_case):
        module = self.constructor(*self.constructor_args)
        input = self._get_input()

        if self.reference_fn is not None:
            out = test_case._forward(module, input)
            if isinstance(out, Variable):
                out = out.data
            ref_input = self._unpack_input(deepcopy(input))
            expected_out = self.reference_fn(ref_input, test_case._get_parameters(module)[0])
            test_case.assertEqual(out, expected_out)

        # TODO: do this with in-memory files as soon as torch.save will support it
        with TemporaryFile() as f:
            test_case._forward(module, input)
            torch.save(module, f)
            f.seek(0)
            module_copy = torch.load(f)
            test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input))

        self._do_test(test_case, module, input)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号