data_visualization.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Oedipus 作者: tum-i22 项目源码 文件源码
def plotReductionGraph(dataSamples, dataLabels, classNames, dimension=2, graphTitle="Test Graph", filename="reduction.pdf"):
    """ Plots data sample visualization graphs """
    try:
        timestamp = int(time.time())
        colors = ['DarkRed', 'DarkGreen', 'DarkBlue', 'DarkOrange', 'DarkMagenta', 'DarkCyan', 'Gray', 'Black']
        randomColor = lambda: random.randint(0,255)
        markers = ['*', 'o', 'v', '^', 's', 'd', 'D', 'p', 'h', 'H', '<', '>', '.', ',', '|', '_']

        fig = P.figure(figsize=(8,5))
        if dimension == 3:
            ax = fig.add_subplot(111, projection='3d')
        P.title(graphTitle, fontname='monospace')
        if dimension == 2:
            P.xlabel('x1', fontsize=12, fontname='monospace')
            P.ylabel('x2', fontsize=12, fontname='monospace')
        else:
            ax.set_xlabel('x1', fontsize=12, fontname='monospace')
            ax.set_ylabel('x2', fontsize=12, fontname='monospace')
            ax.set_zlabel('x3', fontsize=12, fontname='monospace')

        P.grid(color='DarkGray', linestyle='--', linewidth=0.1, axis='both')

        for c in range(len(classNames)):
            X,Y,Z = [], [], []
            for labelIndex in range(len(dataLabels)):
                if c == dataLabels[labelIndex]:
                    X.append(dataSamples[labelIndex,:].tolist()[0])
                    Y.append(dataSamples[labelIndex,:].tolist()[1])
                    if dimension == 3:
                        Z.append(dataSamples[labelIndex,:].tolist()[2])

            # Plot points of that class
            #P.plot(Y, X, color='#%02X%02X%02X' % (randomColor(), randomColor(), randomColor()), marker=markers[c], markeredgecolor='None', markersize=4.0, linestyle='None', label=classNames[c])
            if dimension == 2:
                P.plot(Y, X, color=colors[c % len(colors)], marker=markers[c % len(markers)], markersize=5.0, linestyle='None', label=classNames[c])
            else:
                ax.scatter(X,Y,Z,c=colors[c % len(colors)], marker=markers[c % len(markers)])

        if dimension == 2:
            #P.legend([x.split(",")[-1] for x in classNames], fontsize='xx-small', numpoints=1, fancybox=True)
            P.legend([x for x in classNames], fontsize='xx-small', numpoints=1, fancybox=True)
        else:
            ax.legend([x for x in classNames], fontsize='xx-small', numpoints=1, fancybox=True)

        prettyPrint("Saving results to ./%s" % filename)#(graphTitle, timestamp))
        P.tight_layout()
        fig.savefig("./%s" % filename)#(graphTitle, timestamp))

    except Exception as e:
        prettyPrint("Error encountered in \"plotReductionGraph\": %s" % e, "error")
        return False

    return True
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号