def __init__(self, session,
optimizer,
q_network,
state_dim,
num_actions,
batch_size=32,
init_exp=0.5, # initial exploration prob
final_exp=0.1, # final exploration prob
anneal_steps=10000, # N steps for annealing exploration
replay_buffer_size=10000,
store_replay_every=5, # how frequent to store experience
discount_factor=0.9, # discount future rewards
target_update_rate=0.01,
reg_param=0.01, # regularization constants
max_gradient=5, # max gradient norms
double_q_learning=False,
summary_writer=None,
summary_every=100):
# tensorflow machinery
self.session = session
self.optimizer = optimizer
self.summary_writer = summary_writer
# model components
self.q_network = q_network
self.replay_buffer = ReplayBuffer(buffer_size=replay_buffer_size)
# Q learning parameters
self.batch_size = batch_size
self.state_dim = state_dim
self.num_actions = num_actions
self.exploration = init_exp
self.init_exp = init_exp
self.final_exp = final_exp
self.anneal_steps = anneal_steps
self.discount_factor = discount_factor
self.target_update_rate = target_update_rate
self.double_q_learning = double_q_learning
# training parameters
self.max_gradient = max_gradient
self.reg_param = reg_param
# counters
self.store_replay_every = store_replay_every
self.store_experience_cnt = 0
self.train_iteration = 0
# create and initialize variables
self.create_variables()
var_lists = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
self.session.run(tf.variables_initializer(var_lists))
# make sure all variables are initialized
self.session.run(tf.assert_variables_initialized())
if self.summary_writer is not None:
# graph was not available when journalist was created
self.summary_writer.add_graph(self.session.graph)
self.summary_every = summary_every
评论列表
文章目录