def __call__(self, test_case):
module = self.constructor(*self.constructor_args)
input = self._get_input()
if self.reference_fn is not None:
out = test_case._forward(module, input)
if isinstance(out, Variable):
out = out.data
ref_input = self._unpack_input(deepcopy(input))
expected_out = self.reference_fn(ref_input, test_case._get_parameters(module)[0])
test_case.assertEqual(out, expected_out)
# TODO: do this with in-memory files as soon as torch.save will support it
with TemporaryFile() as f:
test_case._forward(module, input)
torch.save(module, f)
f.seek(0)
module_copy = torch.load(f)
test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input))
self._do_test(test_case, module, input)
评论列表
文章目录