def test_mlpg_gradcheck():
# MLPG is performed dimention by dimention, so static_dim 1 is enough,
# 2 just for in case.
static_dim = 2
T = 10
for windows in _get_windows_set():
torch.manual_seed(1234)
means = Variable(torch.rand(T, static_dim * len(windows)),
requires_grad=True)
inputs = (means,)
# Unit variances case
variances = torch.ones(static_dim * len(windows)
).expand(T, static_dim * len(windows))
assert gradcheck(MLPG(variances, windows),
inputs, eps=1e-3, atol=1e-3)
# Rand variances case
variances = torch.rand(static_dim * len(windows)
).expand(T, static_dim * len(windows))
assert gradcheck(MLPG(variances, windows),
inputs, eps=1e-3, atol=1e-3)
评论列表
文章目录