def plot_all_epochs(gen_data, I_XT_array, I_TY_array, axes, epochsInds, f, index_i, index_j, size_ind,
font_size, y_ticks, x_ticks, colorbar_axis, title_str, axis_font, bar_font, save_name, plot_error = True,index_to_emphasis=1000):
"""Plot the infomration plane with the epochs in diffrnet colors """
#If we want to plot the train and test error
if plot_error:
fig_strs = ['train_error','test_error','loss_train','loss_test' ]
fig_data = [np.squeeze(gen_data[fig_str]) for fig_str in fig_strs]
f1 = plt.figure(figsize=(12, 8))
ax1 = f1.add_subplot(111)
mean_sample = False if len(fig_data[0].shape)==1 else True
if mean_sample:
fig_data = [ np.mean(fig_data_s, axis=0) for fig_data_s in fig_data]
for i in range(len(fig_data)):
ax1.plot(epochsInds, fig_data[i],':', linewidth = 3 , label = fig_strs[i])
ax1.legend(loc='best')
f = plt.figure(figsize=(12, 8))
axes = f.add_subplot(111)
axes = np.array([[axes]])
I_XT_array = np.squeeze(I_XT_array)
I_TY_array = np.squeeze(I_TY_array)
if len(I_TY_array[0].shape) >1:
I_XT_array = np.mean(I_XT_array, axis=0)
I_TY_array = np.mean(I_TY_array, axis=0)
max_index = size_ind if size_ind != -1 else I_XT_array.shape[0]
cmap = plt.get_cmap('gnuplot')
#For each epoch we have diffrenet color
colors = [cmap(i) for i in np.linspace(0, 1, epochsInds[max_index-1]+1)]
#Change this if we have more then one network arch
nums_arc= -1
#Go over all the epochs and plot then with the right color
for index_in_range in range(0, max_index):
XT = I_XT_array[index_in_range, :]
TY = I_TY_array[index_in_range, :]
#If this is the index that we want to emphsis
if epochsInds[index_in_range] ==index_to_emphasis:
axes[index_i, index_j].plot(XT, TY, marker='o', linestyle=None, markersize=19, markeredgewidth=0.04,
linewidth=2.1,
color='g',zorder=10)
else:
axes[index_i, index_j].plot(XT[:], TY[:], marker='o', linestyle='-', markersize=12, markeredgewidth=0.01, linewidth=0.2,
color=colors[int(epochsInds[index_in_range])])
utils.adjustAxes(axes[index_i, index_j], axis_font=axis_font, title_str=title_str, x_ticks=x_ticks,
y_ticks=y_ticks, x_lim=[0, 25.1], y_lim=None,
set_xlabel=index_i == axes.shape[0] - 1, set_ylabel=index_j == 0, x_label='$I(X;T)$',
y_label='$I(T;Y)$', set_xlim=False,
set_ylim=False, set_ticks=True, label_size=font_size)
#Save the figure and add color bar
if index_i ==axes.shape[0]-1 and index_j ==axes.shape[1]-1:
utils.create_color_bar(f, cmap, colorbar_axis, bar_font, epochsInds, title='Epochs')
f.savefig(save_name+'.jpg', dpi=500, format='jpg')
评论列表
文章目录