def _plot_value_function(self, value_function, n_iter):
value_matrix = numpy.zeros((self.max_car+1, self.max_car+1), dtype='float')
for stateid in range(len(self.states)):
state = [int(t) for t in self.states[stateid].split('#')]
value_matrix[state[0], state[1]] = value_function[stateid]
fig = plt.figure()
ax = Axes3D(fig)
X, Y = numpy.meshgrid(range(self.max_car+1), range(self.max_car+1))
ax.plot_surface(Y, X, value_matrix, rstride=1, cstride=1, cmap='coolwarm')
ax.set_title('value function in iteration %i' % n_iter)
ax.set_xlabel('#cars at A')
ax.set_ylabel('#cars at B')
ax.set_zlabel('value function')
# plt.show()
fig.savefig('experiments/value%i' % n_iter)
评论列表
文章目录