def test_modspec_gradcheck():
static_dim = 12
T = 16
torch.manual_seed(1234)
inputs = (Variable(torch.rand(T, static_dim), requires_grad=True),)
n = 16
for norm in [None, "ortho"]:
assert gradcheck(ModSpec(n=n, norm=norm), inputs, eps=1e-4, atol=1e-4)
评论列表
文章目录