def test_symbolic_mismatch(self):
class MyFun(Function):
@staticmethod
def symbolic(g, x):
# The inside of this function should never be invoked, because
# we will fail due to an argument mismatch first.
assert False
@staticmethod
def forward(ctx, x, y):
return x + y
x = Variable(torch.randn(2, 2).fill_(1.0))
y = Variable(torch.randn(2, 2).fill_(1.0))
with self.assertRaisesRegex(TypeError, "occurred when translating MyFun"):
export_to_string(FuncModule(MyFun().apply), (x, y))
# TODO: Do an nn style test for these
评论列表
文章目录