def createQNetwork(self, player):
# input layer
self.stateInput = tf.placeholder(dtype=tf.float32, shape=[None, self.STATE_NUM])
self.actionInput = tf.placeholder(dtype=tf.float32, shape=[None, self.ACTION_NUM])
self.yInput = tf.placeholder(dtype=tf.float32, shape=[None])
# weights
W1 = self.weight_variable([self.STATE_NUM, 256])
b1 = self.bias_variable([256])
W2 = self.weight_variable([256, 512])
b2 = self.bias_variable([512])
W3 = self.weight_variable([512, self.ACTION_NUM])
b3 = self.bias_variable([self.ACTION_NUM])
# layers
h_layer1 = tf.nn.relu(tf.nn.bias_add(tf.matmul(self.stateInput, W1), b1))
# h_layer1 = self.batch_norm(h_layer1)
h_layer2 = tf.nn.relu(tf.nn.bias_add(tf.matmul(h_layer1, W2), b2))
# h_layer2 = self.batch_norm(h_layer2)
self.QValue = tf.nn.bias_add(tf.matmul(h_layer2, W3), b3)
self.QValue = tf.nn.softmax(self.QValue)
Q_action = tf.reduce_sum(tf.multiply(self.QValue, self.actionInput), reduction_indices=-1)
self.cost = tf.reduce_mean(tf.square(self.yInput - Q_action))
self.trainStep = tf.train.GradientDescentOptimizer(1e-6).minimize(self.cost)
# saving and loading networks
self.saver = tf.train.Saver()
self.session = tf.InteractiveSession()
checkpoint = tf.train.get_checkpoint_state('saved_QNetworks_new_' + player + '/')
if checkpoint and checkpoint.model_checkpoint_path:
self.saver.restore(self.session, checkpoint.model_checkpoint_path)
print("Successfully loaded:", checkpoint.model_checkpoint_path)
else:
print("Could not find old network weights")
self.session.run(tf.initialize_all_variables())
评论列表
文章目录