def eps_greedy(inputs_t, q_preds_t, nb_actions, N0, min_eps, nb_state=None):
reusing_scope = tf.get_variable_scope().reuse
N0_t = tf.constant(N0, tf.float32, name='N0')
min_eps_t = tf.constant(min_eps, tf.float32, name='min_eps')
if nb_state == None:
N = tf.Variable(1., trainable=False, dtype=tf.float32, name='N')
eps = tf.maximum(N0_t / (N0_t + N), min_eps_t, name="eps")
update_N = tf.assign(N, N + 1)
if reusing_scope is False:
tf.summary.scalar('N', N)
else:
N = tf.Variable(tf.ones(shape=[nb_state]), name='N', trainable=False)
eps = tf.maximum(N0_t / (N0_t + N[inputs_t]), min_eps_t, name="eps")
update_N = tf.scatter_add(N, inputs_t, 1)
if reusing_scope is False:
tf.summary.histogram('N', N)
cond = tf.greater(tf.random_uniform([], 0, 1), eps)
pred_action = tf.cast(tf.argmax(q_preds_t, 0), tf.int32)
random_action = tf.random_uniform([], 0, nb_actions, dtype=tf.int32)
with tf.control_dependencies([update_N]): # Force the update call
action_t = tf.where(cond, pred_action, random_action)
return action_t
评论列表
文章目录