def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
""" Plot matrix M in 2D with lines using alpha values
Plot lines between source and target 2D samples with a color
proportional to the value of the matrix G between samples.
Parameters
----------
xs : ndarray, shape (ns,2)
Source samples positions
b : ndarray, shape (nt,2)
Target samples positions
G : ndarray, shape (na,nb)
OT matrix
thr : float, optional
threshold above which the line is drawn
**kwargs : dict
paameters given to the plot functions (default color is black if
nothing given)
"""
if ('color' not in kwargs) and ('c' not in kwargs):
kwargs['color'] = 'k'
mx = G.max()
for i in range(xs.shape[0]):
for j in range(xt.shape[0]):
if G[i, j] / mx > thr:
pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]],
alpha=G[i, j] / mx, **kwargs)
评论列表
文章目录