def plotValueFunction(self, valueFunction, prefix):
'''3d plot of a value function.'''
fig, ax = plt.subplots(subplot_kw = dict(projection = '3d'))
X, Y = np.meshgrid(np.arange(self.numCols), np.arange(self.numRows))
Z = valueFunction.reshape(self.numRows, self.numCols)
for i in xrange(len(X)):
for j in xrange(len(X[i])/2):
tmp = X[i][j]
X[i][j] = X[i][len(X[i]) - j - 1]
X[i][len(X[i]) - j - 1] = tmp
my_col = cm.jet(np.random.rand(Z.shape[0],Z.shape[1]))
ax.plot_surface(X, Y, Z, rstride = 1, cstride = 1,
cmap = plt.get_cmap('jet'))
plt.gca().view_init(elev=30, azim=30)
plt.savefig(self.outputPath + prefix + 'value_function.png')
plt.close()
评论列表
文章目录