processData.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:anime_recs 作者: Cpierse 项目源码 文件源码
def plot_top_sims(cos_sim_mat,aids,N_recs,aid_counts,aid_dict,
                  img_name = 'results\\Sample_similarities.png',threshold=0, 
                  main_name = None, var_name = 'Sim',
                  plot_histograms=True):
    ''' Plot of anime with the highest similarity scores. '''
    plt.figure(figsize=(6, 10))
    f, axarr = plt.subplots(len(aids),N_recs+1)
    f.set_size_inches(6, 10)
    f.tight_layout()
    for (idx_0,aid) in enumerate(aids):
        image = get_image(aid_dict[aid])
        axarr[idx_0,0].imshow(image)
        axarr[idx_0,0].axis("off")
        axarr[idx_0,0].set_title('Query ' + str(idx_0+1),size=10)
        top_aids,top_sims = get_highest_cos(cos_sim_mat,aid,aid_dict,N_recs,aid_counts,threshold)
        for (idx_1,aid_1) in enumerate(top_aids):
            image = get_image(aid_dict[aid_1])
            if image != None:
                axarr[idx_0,idx_1+1].imshow(image)
            axarr[idx_0,idx_1+1].axis("off")
            axarr[idx_0,idx_1+1].set_title(var_name + ' = {:.2f}'.format(top_sims[idx_1]),size=10)
        # Add horizonatal lines:
        if not idx_0==0 or idx_0==len(aids):
            line = lines.Line2D((0,1),(1-1.0/len(aids)*idx_0*0.98,1-1.0/len(aids)*idx_0*0.98), transform=f.transFigure,color=[0,0,0])
            f.lines.append(line)
            # TODO: the 0.98 shouldn't be necessary. Fixit. 
    if main_name:
        plt.suptitle(main_name)
    #plt.savefig(img_name,dpi=300,format='png')
    plt.show()
    # Plot a histrogram of these similarities:
    if plot_histograms:
        for aid in aids:
            plt.hist(cos_sim_mat[aid,:])
            plt.show()
    return None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号