Drawing.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:options 作者: mcmachado 项目源码 文件源码
def plotPolicy(self, policy, prefix):
        plt.clf()
        for idx in xrange(len(policy)):
            i, j = self.env.getStateXY(idx)

            dx = 0
            dy = 0
            if policy[idx] == 0: # up
                dy = 0.35
            elif policy[idx] == 1: #right
                dx = 0.35
            elif policy[idx] == 2: #down
                dy = -0.35
            elif policy[idx] == 3: #left
                dx = -0.35
            elif self.matrixMDP[i][j] != -1 and policy[idx] == 4: # termination
                circle = plt.Circle(
                    (j + 0.5, self.numRows - i + 0.5 - 1), 0.025, color='k')
                plt.gca().add_artist(circle)

            if self.matrixMDP[i][j] != -1:
                plt.arrow(j + 0.5, self.numRows - i + 0.5 - 1, dx, dy,
                    head_width=0.05, head_length=0.05, fc='k', ec='k')
            else:
                plt.gca().add_patch(
                    patches.Rectangle(
                    (j, self.numRows - i - 1), # (x,y)
                    1.0,                   # width
                    1.0,                   # height
                    facecolor = "gray"
                    )
                )

        plt.xlim([0, self.numCols])
        plt.ylim([0, self.numRows])


        for i in xrange(self.numCols):
            plt.axvline(i, color='k', linestyle=':')
        plt.axvline(self.numCols, color='k', linestyle=':')

        for j in xrange(self.numRows):
            plt.axhline(j, color='k', linestyle=':')
        plt.axhline(self.numRows, color='k', linestyle=':')

        plt.savefig(self.outputPath + prefix + 'policy.png')
        plt.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号