def __init__(self, n_options=10, logger=None, plotting=False,
log_tf_graph=False):
if logger is None:
logger = logging.getLogger("logger")
logger.setLevel(logging.INFO)
self.logger = logger
self.n_options = n_options
self.env = gym.make("deterministic-grid-world-v0")
self.n_actions = self.env.action_space.n
self.n_states = 1 + reduce(lambda x, y: x*y,
map(lambda x: x.n, self.env.observation_space.spaces))
if plotting:
self.plot_robots = [PlotRobot('dqn loss', 0, log_scale=True),
PlotRobot('q loss', 1), PlotRobot('rewards', 2)]
else:
self.plot_robots = [None] * 3
self.plotting = self.plot_robots[2]
self.colors = list('bgrcmyk') + ['magenta', 'lime', 'gray']
self.build_graph(log_tf_graph)
评论列表
文章目录