plot_figures.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:IDNNs 作者: ravidziv 项目源码 文件源码
def plot_animation(name_s, save_name):
    """Plot the movie for all the networks in the information plane"""
    # If we want to print the loss function also
    print_loss  = False
    #The bins that we extened the x axis of the accuracy each time
    epochs_bins = [0, 500, 1500, 3000, 6000, 10000, 20000]

    data_array = utils.get_data(name_s[0][0])
    data = data_array['infomration']
    epochsInds = data_array['epochsInds']
    loss_train_data = data_array['loss_train']
    loss_test_data = data_array['loss_test_data']
    f, (axes) = plt.subplots(2, 1)
    f.subplots_adjust(left=0.14, bottom=0.1, right=.928, top=0.94, wspace=0.13, hspace=0.55)
    colors = LAYERS_COLORS
    #new/old version
    if False:
        Ix = np.squeeze(data[0,:,-1,-1, :, :])
        Iy = np.squeeze(data[1,:,-1,-1, :, :])
    else:
        Ix = np.squeeze(data[0, :, -1, -1, :, :])[np.newaxis,:,:]
        Iy = np.squeeze(data[1, :, -1, -1, :, :])[np.newaxis,:,:]
    #Interploation of the samplings (because we don't cauclaute the infomration in each epoch)
    interp_data_x = interp1d(epochsInds,  Ix, axis=1)
    interp_data_y = interp1d(epochsInds,  Iy, axis=1)
    new_x = np.arange(0,epochsInds[-1])
    new_data  = np.array([interp_data_x(new_x), interp_data_y(new_x)])
    """"
    train_data = interp1d(epochsInds,  np.squeeze(train_data), axis=1)(new_x)
    test_data = interp1d(epochsInds,  np.squeeze(test_data), axis=1)(new_x)
    """
    if print_loss:
        loss_train_data =  interp1d(epochsInds,  np.squeeze(loss_train_data), axis=1)(new_x)
        loss_test_data=interp1d(epochsInds,  np.squeeze(loss_test_data), axis=1)(new_x)
    line_ani = animation.FuncAnimation(f, update_line, len(new_x), repeat=False,
                                       interval=1, blit=False, fargs=(print_loss, new_data, axes,new_x,train_data,test_data,epochs_bins, loss_train_data,loss_test_data, colors))
    Writer = animation.writers['ffmpeg']
    writer = Writer(fps=100)
    #Save the movie
    line_ani.save(save_name+'_movie2.mp4',writer=writer,dpi=250)
    plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号