qlearning.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:malmo-challenge 作者: Kaixhin 项目源码 文件源码
def __init__(self, model, target, device_id=-1,
                 learning_rate=0.00025, momentum=.9,
                 minibatch_size=32, update_interval=10000):

        assert isinstance(model, ChainerModel), \
            'model should inherit from ChainerModel'

        super(QNeuralNetwork, self).__init__(model.input_shape,
                                             model.output_shape)

        self._gpu_device = None
        self._loss_val = 0

        # Target model update method
        self._steps = 0
        self._target_update_interval = update_interval

        # Setup model and target network
        self._minibatch_size = minibatch_size
        self._model = model
        self._target = target
        self._target.copyparams(self._model)

        # If GPU move to GPU memory
        if device_id >= 0:
            with cuda.get_device(device_id) as device:
                self._gpu_device = device
                self._model.to_gpu(device)
                self._target.to_gpu(device)

        # Setup optimizer
        self._optimizer = Adam(learning_rate, momentum, 0.999)
        self._optimizer.setup(self._model)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号