def _fill_lower_triangular(self, x):
"""Numpy implementation of `fill_lower_triangular`."""
x = np.asarray(x)
d = x.shape[-1]
# d = n(n+1)/2 implies n is:
n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
ids = np.tril_indices(n)
y = np.zeros(list(x.shape[:-1]) + [n, n], dtype=x.dtype)
y[..., ids[0], ids[1]] = x
return y
distribution_util_test.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录