def animate(self,X,y,useTqdm=0,filename=None,return_anim=True):
pos = self.getSteps(X,y)
y_mapping = {i:n for n,i in enumerate(set(y))}
last_iter = pos[len(pos)-1].reshape(-1, 2)
lims = np.max(last_iter,axis=0),np.min(last_iter,axis=0)
NCOLORS = len(y_mapping)
fig = plt.figure()
fig.set_tight_layout(True)
ax = fig.add_subplot(111)
jet = plt.get_cmap('jet')
cNorm = colors.Normalize(vmin=0, vmax=NCOLORS)
scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=jet)
A,B = np.array(list(zip(*pos[0].reshape(-1, 2))))
dots_list = []
for i in range(NCOLORS):
colorVal = scalarMap.to_rgba(i)
a,b = A[y == i],B[y == i]
dots, = ax.plot(b,a,'o',color=colorVal)
dots_list.append(dots)
def init():
ax.set_xlim([lims[0][0],lims[1][0]])
ax.set_ylim([lims[0][1],lims[1][1]])
return [i for i in dots_list]
def update(i):
for j in range(len(dots_list)):
a,b = np.array(list(zip(*pos[i].reshape(-1, 2))))
a,b = a[y == j],b[y == j]
dots_list[j].set_xdata(a)
dots_list[j].set_ydata(b)
return [i for i in dots_list]+[ax]
if useTqdm==0:
frames = np.arange(0, len(pos)-1)
elif useTqdm==1:
from tqdm import tqdm
frames = tqdm(np.arange(0, len(pos)-1))
elif useTqdm==2:
from tqdm import tqdm_notebook
frames = tqdm_notebook(np.arange(0, len(pos)-1))
anim = FuncAnimation(fig, update, frames=frames, init_func=init, interval=50)
if return_anim:
return anim
if filename==None:
plt.show()
else:
#anim.save(filename, fps=20, codec='libx264')
anim.save(filename, dpi=80, writer='imagemagick')
评论列表
文章目录