flappybird.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:DQN-FlappyBird-mxnet 作者: foolyc 项目源码 文件源码
def __init__(self):
        self.replayMemory = deque()
        self.timestep = 0
        if FLG_GPU:
            self.ctx = mx.gpu()
        else:
            self.ctx = mx.cpu()
        if args.mode == 'train':
            self.q_net = mx.mod.Module(symbol=self.createNet(1), data_names=['frame', 'act_mul'], label_names=['target', ], context=self.ctx)
            self.q_net.bind(data_shapes=[('frame', (BATCH, FRAME, HEIGHT, WIDTH)), ('act_mul', (BATCH, ACTIONS))], label_shapes=[('target', (BATCH,))], for_training=True)
            self.q_net.init_params(initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
            self.q_net.init_optimizer(optimizer='adam', optimizer_params={'learning_rate': 0.0002, 'wd': 0.0, 'beta1': 0.5})
            if args.pretrain:
                self.q_net.load_params(args.pretrain)
                print "load pretrained file......"

        self.tg_net = mx.mod.Module(symbol=self.createNet(), data_names=['frame',], label_names=[], context=self.ctx)
        self.tg_net.bind(data_shapes=[('frame', (1, FRAME, HEIGHT, WIDTH))], for_training=False)
        self.tg_net.init_params(initializer=mx.init.Xavier(factor_type='in', magnitude=2.34))
        if args.pretrain:
            self.tg_net.load_params(args.pretrain)
            print "load pretrained file......"
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号