def __init__(self, game_params, arch_params, solver_params, trained_model, sn_dir):
params=[None, None]
if trained_model[0]:
params[0] = common.load_params(trained_model[0])
if trained_model[1]:
params[1] = common.load_params(trained_model[1])
self.lr_func = []
self.lr_func.append(create_learning_rate_func(solver_params['controler_0']))
self.lr_func.append(create_learning_rate_func(solver_params['controler_1']))
self.x_host_0 = tt.fvector('x_host_0')
self.v_host_0 = tt.fvector('v_host_0')
self.x_target_0 = tt.fvector('x_target_0')
self.v_target_0 = tt.fvector('v_target_0')
self.x_mines_0 = tt.fmatrix('x_mines_0')
self.mines_map = tt.fmatrix('mines_map')
self.time_steps = tt.fvector('time_steps')
self.force = tt.fmatrix('force')
self.n_steps_0 = tt.iscalar('n_steps_0')
self.n_steps_1 = tt.iscalar('n_steps_1')
self.lr = tt.fscalar('lr')
self.goal_1 = tt.fvector('goal_1')
self.trnsprnt = tt.fscalar('trnsprnt')
self.rand_goals = tt.fmatrix('rand_goals')
self.game_params = game_params
self.arch_params = arch_params
self.solver_params = solver_params
self.sn_dir = sn_dir
self.model = CONTROLLER(self.x_host_0,
self.v_host_0,
self.x_target_0,
self.v_target_0,
self.x_mines_0,
self.mines_map,
self.time_steps,
self.force,
self.n_steps_0,
self.n_steps_1,
self.lr,
self.goal_1,
self.trnsprnt,
self.rand_goals,
self.game_params,
self.arch_params,
self.solver_params,
params)
评论列表
文章目录