dqn_keras.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:Attention-DQN 作者: chasewind007 项目源码 文件源码
def compile(self, optimizer = None, loss_func = None):
        """Setup all of the TF graph variables/ops.

        This is inspired by the compile method on the
        keras.models.Model class.

        This is the place to create the target network, setup 
        loss function and any placeholders.
        """
        if loss_func is None:
            loss_func = mean_huber_loss
            # loss_func = 'mse'
        if optimizer is None:
            optimizer = Adam(lr = self.learning_rate)
            # optimizer = RMSprop(lr=0.00025)
        with tf.variable_scope("Loss"):
            state = Input(shape = (self.frame_height, self.frame_width, self.num_frames) , name = "states")
            action_mask = Input(shape = (self.num_actions,), name = "actions")
            qa_value = self.q_network(state)
            qa_value = merge([qa_value, action_mask], mode = 'mul', name = "multiply")
            qa_value = Lambda(lambda x: tf.reduce_sum(x, axis=1, keep_dims = True), name = "sum")(qa_value)

        self.final_model = Model(inputs = [state, action_mask], outputs = qa_value)
        self.final_model.compile(loss=loss_func, optimizer=optimizer)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号