def cov(self):
"""This should only be called when NormalDistribution represents one sample"""
if self.v is not None and self.r is not None:
assert self.v.dim() == 1
dim = self.v.dim()
v = self.v.unsqueeze(1) # D * 1 vector
rt = self.r.unsqueeze(0) # 1 * D vector
A = torch.eye(dim) + v.mm(rt)
return A.mm(torch.diag(self.sigma.pow(2)).mm(A.t()))
else:
return torch.diag(self.sigma.pow(2))
评论列表
文章目录