def plot_2d_animation(input, mask, predictions):
rgb_image = np.concatenate((input, input, input), axis=0)
mask = np.concatenate((np.zeros_like(input), mask, predictions), axis=0)
# green = targets
# blue = predictions
# red = overlap
idxs = np.where(mask > 0.3)
rgb_image[idxs] = mask[idxs]
rgb_image = np.rollaxis(rgb_image, axis=0, start=4)
print rgb_image.shape
def get_data_step(step):
return rgb_image[step, :, :, :]
fig = plt.figure()
im = fig.gca().imshow(get_data_step(0))
def init():
im.set_data(get_data_step(0))
return im,
def animate(i):
im.set_data(get_data_step(i))
return im,
anim = animation.FuncAnimation(fig, animate, init_func=init,
frames=rgb_image.shape[1],
interval=20000 / rgb_image.shape[0],
blit=True)
def on_click(event):
global anim_running
if anim_running:
anim.event_source.stop()
anim_running = False
else:
anim.event_source.start()
anim_running = True
fig.canvas.mpl_connect('button_press_event', on_click)
try:
plt.show()
except AttributeError:
pass
评论列表
文章目录