multiclass.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:TextCategorization 作者: Y-oHr-N 项目源码 文件源码
def plot2d(self, title=None, domain=[-1, 1], codomain=[-1, 1], predict=True):
        f, ax                 = plt.subplots()

        x1                    = np.linspace(*domain, 100)
        x2                    = np.linspace(*codomain, 100)

        n_samples, n_features = self.X_.shape
        G                     = nx.from_scipy_sparse_matrix(self.A_)
        pos                   = {i: self.X_[i] for i in range(n_samples)}
        cm_sc                 = ListedColormap(['#AAAAAA', '#FF0000', '#0000FF'])

        if title is not None:
            ax.set_title(title)

        ax.set_xlabel('$x_1$')
        ax.set_ylabel('$x_2$')
        ax.set_xlim(domain)
        ax.set_ylim(codomain)

        nx.draw_networkx_nodes(G, pos, ax=ax, node_size=25, node_color=self.y_, cmap=cm_sc)

        if predict:
            xx1, xx2          = np.meshgrid(x1, x2)
            xfull             = np.c_[xx1.ravel(), xx2.ravel()]
            z                 = self.predict(xfull).reshape(100, 100)

            levels            = np.array([-1, 0, 1])
            cm_cs             = plt.cm.RdYlBu

            if self.params['gamma_i'] != 0.0:
                nx.draw_networkx_edges(G, pos, ax=ax, edge_color='#AAAAAA')

            ax.contourf(xx1, xx2, z, levels, cmap=cm_cs, alpha=0.25)

        return (f, ax)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号