ddqn_agent.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:doubleDQN 作者: masataka46 项目源码 文件源码
def __init__(self, enable_controller=[0, 1, 3, 4]):
        self.num_of_actions = len(enable_controller)
        self.enable_controller = enable_controller  # Default setting : "Breakout"

        print "Initializing DDQN..."
#   Initialization of Chainer 1.1.0 or older.
#        print "CUDA init"
#        cuda.init()

        print "Model Building"
        self.model = FunctionSet(
            l1=F.Convolution2D(4, 32, ksize=8, stride=4, nobias=False, wscale=np.sqrt(2)),
            l2=F.Convolution2D(32, 64, ksize=4, stride=2, nobias=False, wscale=np.sqrt(2)),
            l3=F.Convolution2D(64, 64, ksize=3, stride=1, nobias=False, wscale=np.sqrt(2)),
            l4=F.Linear(3136, 512, wscale=np.sqrt(2)),
            q_value=F.Linear(512, self.num_of_actions,
                             initialW=np.zeros((self.num_of_actions, 512),
                                               dtype=np.float32))
        ).to_gpu()

        if args.resumemodel:
            # load saved model
            serializers.load_npz(args.resumemodel, self.model)
            print "load model from resume.model"


        self.model_target = copy.deepcopy(self.model)

        print "Initizlizing Optimizer"
        self.optimizer = optimizers.RMSpropGraves(lr=0.00025, alpha=0.95, momentum=0.95, eps=0.0001)
        self.optimizer.setup(self.model.collect_parameters())

        # History Data :  D=[s, a, r, s_dash, end_episode_flag]
        if args.resumeD1 and args.resumeD2:
            # load saved D1 and D2
            npz_tmp1 = np.load(args.resumeD1)
            npz_tmp2 = np.load(args.resumeD2)
            self.D = [npz_tmp1['D0'],
                      npz_tmp1['D1'],
                      npz_tmp1['D2'],
                      npz_tmp2['D3'],
                      npz_tmp2['D4']]
            npz_tmp1.close()
            npz_tmp2.close()
            print "loaded stored D1 and D2"
        else:
            self.D = [np.zeros((self.data_size, 4, 84, 84), dtype=np.uint8),
                      np.zeros(self.data_size, dtype=np.uint8),
                      np.zeros((self.data_size, 1), dtype=np.int8),
                      np.zeros((self.data_size, 4, 84, 84), dtype=np.uint8),
                      np.zeros((self.data_size, 1), dtype=np.bool)]
            print "initialize D data"
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号