def update_line_each_neuron(num, print_loss, Ix, axes, Iy, train_data, accuracy_test, epochs_bins, loss_train_data, loss_test_data, colors, epochsInds,
font_size = 18, axis_font = 16, x_lim = [0,12.2], y_lim=[0, 1.08],x_ticks = [], y_ticks = []):
"""Update the figure of the infomration plane for the movie"""
#Print the line between the points
axes[0].clear()
if len(axes)>1:
axes[1].clear()
#Print the points
for layer_num in range(Ix.shape[2]):
for net_ind in range(Ix.shape[0]):
axes[0].scatter(Ix[net_ind,num, layer_num], Iy[net_ind,num, layer_num], color = colors[layer_num], s = 35,edgecolors = 'black',alpha = 0.85)
title_str = 'Information Plane - Epoch number - ' + str(epochsInds[num])
utils.adjustAxes(axes[0], axis_font, title_str, x_ticks, y_ticks, x_lim, y_lim, set_xlabel=True, set_ylabel=True,
x_label='$I(X;T)$', y_label='$I(T;Y)$')
#Print the loss function and the error
if len(axes)>1:
axes[1].plot(epochsInds[:num], 1 - np.mean(accuracy_test[:, :num], axis=0), color='g')
if print_loss:
axes[1].plot(epochsInds[:num], np.mean(loss_test_data[:, :num], axis=0), color='y')
nereast_val = np.searchsorted(epochs_bins, epochsInds[num], side='right')
axes[1].set_xlim([0,epochs_bins[nereast_val]])
axes[1].legend(('Accuracy', 'Loss Function'), loc='best')
评论列表
文章目录