def squared_distance_matrix(X):
n = X.shape[0]
XX = F.sum(X ** 2.0, axis=1)
distances = -2.0 * F.linear(X, X)
distances = distances + F.broadcast_to(XX, (n, n))
distances = distances + F.broadcast_to(F.expand_dims(XX, 1), (n, n))
return distances
lifted_struct_loss.py 文件源码
python
阅读 16
收藏 0
点赞 0
评论 0
评论列表
文章目录