def update_line(num, print_loss, data, axes, epochsInds, test_error, test_data, epochs_bins, loss_train_data, loss_test_data, colors,
font_size = 18, axis_font=16, x_lim = [0,12.2], y_lim=[0, 1.08], x_ticks = [], y_ticks = []):
"""Update the figure of the infomration plane for the movie"""
#Print the line between the points
cmap = ListedColormap(LAYERS_COLORS)
segs = []
for i in range(0, data.shape[1]):
x = data[0, i, num, :]
y = data[1, i, num, :]
points = np.array([x, y]).T.reshape(-1, 1, 2)
segs.append(np.concatenate([points[:-1], points[1:]], axis=1))
segs = np.array(segs).reshape(-1, 2, 2)
axes[0].clear()
if len(axes)>1:
axes[1].clear()
lc = LineCollection(segs, cmap=cmap, linestyles='solid',linewidths = 0.3, alpha = 0.6)
lc.set_array(np.arange(0,5))
#Print the points
for layer_num in range(data.shape[3]):
axes[0].scatter(data[0, :, num, layer_num], data[1, :, num, layer_num], color = colors[layer_num], s = 35,edgecolors = 'black',alpha = 0.85)
axes[1].plot(epochsInds[:num], 1 - np.mean(test_error[:, :num], axis=0), color ='r')
title_str = 'Information Plane - Epoch number - ' + str(epochsInds[num])
utils.adjustAxes(axes[0], axis_font, title_str, x_ticks, y_ticks, x_lim, y_lim, set_xlabel=True, set_ylabel=True,
x_label='$I(X;T)$', y_label='$I(T;Y)$')
title_str = 'Precision as function of the epochs'
utils.adjustAxes(axes[1], axis_font, title_str, x_ticks, y_ticks, x_lim, y_lim, set_xlabel=True, set_ylabel=True,
x_label='# Epochs', y_label='Precision')
评论列表
文章目录