AveragePolicyNetwork.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:RL_NFSP 作者: Richard-An 项目源码 文件源码
def getAction(self, action_space, state):
        # checkpoint = tf.train.get_checkpoint_state('saved_PiNetworks_' + self.player + '/')
        # if checkpoint and checkpoint.model_checkpoint_path:
        #     self.saver.restore(self.session, checkpoint.model_checkpoint_path)
        #     # print('model loaded')
        self.train_phase = False
        # state = np.zeros(33)
        state = np.expand_dims(state, -1)
        self.QValue = self.session.run(self.out, feed_dict={self.stateInput: [state], self.keep_probability: 1.0})[0]
        Q_test = self.QValue * action_space
        # print('Qtest ' + self.player)
        # print(Q_test)
        if max(Q_test) <= 0.0000001:
            action_index = random.randrange(self.ACTION_NUM)
            while action_space[action_index] != 1:
                action_index = random.randrange(self.ACTION_NUM)
        else:
            action_index = np.argmax(self.QValue * action_space)
        # if self.QValue[action_index] <= 0.0:
        #     action_index = random.randrange(self.ACTION_NUM)
        #     while action_space[action_index] != 1:
        #         action_index = random.randrange(self.ACTION_NUM)
        return action_index
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号