def cost_matrix(x, y, p=2): "Returns the matrix of $|x_i-y_j|^p$." x_col = x.unsqueeze(1) y_lin = y.unsqueeze(0) c = torch.sum((torch.abs(x_col - y_lin)) ** p, 2) return c