def test_mlpg_variance_expand():
static_dim = 2
T = 10
for windows in _get_windows_set():
torch.manual_seed(1234)
means = Variable(torch.rand(T, static_dim * len(windows)),
requires_grad=True)
variances = torch.rand(static_dim * len(windows))
variances_expanded = variances.expand(T, static_dim * len(windows))
y = AF.mlpg(means, variances, windows)
y_hat = AF.mlpg(means, variances_expanded, windows)
assert np.allclose(y.data.numpy(), y_hat.data.numpy())
评论列表
文章目录