def Frobenius(mat):
size = mat.size()
if len(size) == 3: # batched matrix
ret = (torch.sum(torch.sum((mat ** 2), 1), 2).squeeze() + 1e-10) ** 0.5
return torch.sum(ret) / size[0]
else:
raise Exception('matrix for computing Frobenius norm should be with 3 dims')
train.py 文件源码
python
阅读 47
收藏 0
点赞 0
评论 0
评论列表
文章目录