def __init__(self):
self.replayMemory = deque()
self.timestep = 0
if FLG_GPU:
self.ctx = mx.gpu()
else:
self.ctx = mx.cpu()
if args.mode == 'train':
self.q_net = mx.mod.Module(symbol=self.createNet(1), data_names=['frame', 'act_mul'], label_names=['target', ], context=self.ctx)
self.q_net.bind(data_shapes=[('frame', (BATCH, FRAME, HEIGHT, WIDTH)), ('act_mul', (BATCH, ACTIONS))], label_shapes=[('target', (BATCH,))], for_training=True)
self.q_net.init_params(initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
self.q_net.init_optimizer(optimizer='adam', optimizer_params={'learning_rate': 0.0002, 'wd': 0.0, 'beta1': 0.5})
if args.pretrain:
self.q_net.load_params(args.pretrain)
print "load pretrained file......"
self.tg_net = mx.mod.Module(symbol=self.createNet(), data_names=['frame',], label_names=[], context=self.ctx)
self.tg_net.bind(data_shapes=[('frame', (1, FRAME, HEIGHT, WIDTH))], for_training=False)
self.tg_net.init_params(initializer=mx.init.Xavier(factor_type='in', magnitude=2.34))
if args.pretrain:
self.tg_net.load_params(args.pretrain)
print "load pretrained file......"
评论列表
文章目录