def plot2d(self, title=None, domain=[-1, 1], codomain=[-1, 1], predict=True):
f, ax = plt.subplots()
x1 = np.linspace(*domain, 100)
x2 = np.linspace(*codomain, 100)
n_samples, n_features = self.X_.shape
G = nx.from_scipy_sparse_matrix(self.A_)
pos = {i: self.X_[i] for i in range(n_samples)}
cm_sc = ListedColormap(['#AAAAAA', '#FF0000', '#0000FF'])
if title is not None:
ax.set_title(title)
ax.set_xlabel('$x_1$')
ax.set_ylabel('$x_2$')
ax.set_xlim(domain)
ax.set_ylim(codomain)
nx.draw_networkx_nodes(G, pos, ax=ax, node_size=25, node_color=self.y_, cmap=cm_sc)
if predict:
xx1, xx2 = np.meshgrid(x1, x2)
xfull = np.c_[xx1.ravel(), xx2.ravel()]
z = self.predict(xfull).reshape(100, 100)
levels = np.array([-1, 0, 1])
cm_cs = plt.cm.RdYlBu
if self.params['gamma_i'] != 0.0:
nx.draw_networkx_edges(G, pos, ax=ax, edge_color='#AAAAAA')
ax.contourf(xx1, xx2, z, levels, cmap=cm_cs, alpha=0.25)
return (f, ax)
评论列表
文章目录