def tabular_learning_with_lr(init_lr, decay_steps, Qs_t, states_t, actions_t, targets):
reusing_scope = tf.get_variable_scope().reuse
state_action_pairs = tf.stack([states_t, actions_t], 1)
estimates = tf.gather_nd(Qs_t, state_action_pairs)
err_estimates = targets - estimates
loss = tf.reduce_mean(err_estimates)
global_step = tf.Variable(0, trainable=False, name="global_step", collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.GLOBAL_VARIABLES])
lr = tf.train.exponential_decay(tf.constant(init_lr, dtype=tf.float32), global_step, decay_steps, 0.5, staircase=True)
if reusing_scope is False:
tf.summary.scalar('lr', lr)
inc_global_step = global_step.assign_add(1)
with tf.control_dependencies([inc_global_step]):
updates = lr * err_estimates
train_op = tf.scatter_nd_add(Qs_t, state_action_pairs, updates)
return loss, train_op
评论列表
文章目录