def getAction(self, action_space, state):
# checkpoint = tf.train.get_checkpoint_state('saved_PiNetworks_' + self.player + '/')
# if checkpoint and checkpoint.model_checkpoint_path:
# self.saver.restore(self.session, checkpoint.model_checkpoint_path)
# # print('model loaded')
self.train_phase = False
# state = np.zeros(33)
state = np.expand_dims(state, -1)
self.QValue = self.session.run(self.out, feed_dict={self.stateInput: [state], self.keep_probability: 1.0})[0]
Q_test = self.QValue * action_space
# print('Qtest ' + self.player)
# print(Q_test)
if max(Q_test) <= 0.0000001:
action_index = random.randrange(self.ACTION_NUM)
while action_space[action_index] != 1:
action_index = random.randrange(self.ACTION_NUM)
else:
action_index = np.argmax(self.QValue * action_space)
# if self.QValue[action_index] <= 0.0:
# action_index = random.randrange(self.ACTION_NUM)
# while action_space[action_index] != 1:
# action_index = random.randrange(self.ACTION_NUM)
return action_index
评论列表
文章目录