def tabular_eps_greedy(inputs_t, q_preds_t, nb_states, nb_actions, N0, min_eps):
reusing_scope = tf.get_variable_scope().reuse
Ns = tf.get_variable('Ns', shape=[nb_states], dtype=tf.float32, trainable=False, initializer=tf.zeros_initializer())
if reusing_scope is False:
tf.summary.histogram('Ns', Ns)
update_Ns = tf.scatter_add(Ns, inputs_t, tf.ones_like(inputs_t, dtype=tf.float32))
eps = tf.maximum(
N0 / (N0 + tf.gather(Ns, inputs_t))
, min_eps
, name="eps"
)
nb_samples = tf.shape(q_preds_t)[0]
max_actions = tf.cast(tf.argmax(q_preds_t, 1), tf.int32)
probs_t = tf.sparse_to_dense(
sparse_indices=tf.stack([tf.range(nb_samples), max_actions], 1)
, output_shape=[nb_samples, nb_actions]
, sparse_values=1 - eps
, default_value=0.
) + tf.expand_dims(eps / nb_actions, 1)
conditions = tf.greater(tf.random_uniform([nb_samples], 0, 1), eps)
random_actions = tf.random_uniform(shape=[nb_samples], minval=0, maxval=nb_actions, dtype=tf.int32)
with tf.control_dependencies([update_Ns]): # Force the update call
actions_t = tf.where(conditions, max_actions, random_actions)
return actions_t, probs_t
评论列表
文章目录