def testCorrectlyMakesNoBatchLowerTril(self):
with self.test_session():
x = ops.convert_to_tensor(self._rng.randn(10))
expected = self._fill_lower_triangular(tensor_util.constant_value(x))
actual = distribution_util.fill_lower_triangular(x, validate_args=True)
self.assertAllEqual(expected.shape, actual.get_shape())
self.assertAllEqual(expected, actual.eval())
g = gradients_impl.gradients(
distribution_util.fill_lower_triangular(x), x)
self.assertAllEqual(np.tri(4).reshape(-1), g[0].values.eval())
distribution_util_test.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录