plot_figures.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:IDNNs 作者: ravidziv 项目源码 文件源码
def update_line(num, print_loss, data, axes, epochsInds, test_error, test_data, epochs_bins, loss_train_data, loss_test_data, colors,
                font_size = 18, axis_font=16, x_lim = [0,12.2], y_lim=[0, 1.08], x_ticks = [], y_ticks = []):
    """Update the figure of the infomration plane for the movie"""
    #Print the line between the points
    cmap = ListedColormap(LAYERS_COLORS)
    segs = []
    for i in range(0, data.shape[1]):
        x = data[0, i, num, :]
        y = data[1, i, num, :]
        points = np.array([x, y]).T.reshape(-1, 1, 2)
        segs.append(np.concatenate([points[:-1], points[1:]], axis=1))
    segs = np.array(segs).reshape(-1, 2, 2)
    axes[0].clear()
    if len(axes)>1:
        axes[1].clear()
    lc = LineCollection(segs, cmap=cmap, linestyles='solid',linewidths = 0.3, alpha = 0.6)
    lc.set_array(np.arange(0,5))
    #Print the points
    for layer_num in range(data.shape[3]):
        axes[0].scatter(data[0, :, num, layer_num], data[1, :, num, layer_num], color = colors[layer_num], s = 35,edgecolors = 'black',alpha = 0.85)
    axes[1].plot(epochsInds[:num], 1 - np.mean(test_error[:, :num], axis=0), color ='r')

    title_str = 'Information Plane - Epoch number - ' + str(epochsInds[num])
    utils.adjustAxes(axes[0], axis_font, title_str, x_ticks, y_ticks, x_lim, y_lim, set_xlabel=True, set_ylabel=True,
                     x_label='$I(X;T)$', y_label='$I(T;Y)$')
    title_str = 'Precision as function of the epochs'
    utils.adjustAxes(axes[1], axis_font, title_str, x_ticks, y_ticks, x_lim, y_lim, set_xlabel=True, set_ylabel=True,
                     x_label='# Epochs', y_label='Precision')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号