def compile(self, optimizer = None, loss_func = None):
"""Setup all of the TF graph variables/ops.
This is inspired by the compile method on the
keras.models.Model class.
This is the place to create the target network, setup
loss function and any placeholders.
"""
if loss_func is None:
loss_func = mean_huber_loss
# loss_func = 'mse'
if optimizer is None:
optimizer = Adam(lr = self.learning_rate)
# optimizer = RMSprop(lr=0.00025)
with tf.variable_scope("Loss"):
state = Input(shape = (self.frame_height, self.frame_width, self.num_frames) , name = "states")
action_mask = Input(shape = (self.num_actions,), name = "actions")
qa_value = self.q_network(state)
qa_value = merge([qa_value, action_mask], mode = 'mul', name = "multiply")
qa_value = Lambda(lambda x: tf.reduce_sum(x, axis=1, keep_dims = True), name = "sum")(qa_value)
self.final_model = Model(inputs = [state, action_mask], outputs = qa_value)
self.final_model.compile(loss=loss_func, optimizer=optimizer)
评论列表
文章目录