capacities.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:openai-rl 作者: morgangiraud 项目源码 文件源码
def tabular_eps_greedy(inputs_t, q_preds_t, nb_states, nb_actions, N0, min_eps):
    reusing_scope = tf.get_variable_scope().reuse

    Ns = tf.get_variable('Ns', shape=[nb_states], dtype=tf.float32, trainable=False, initializer=tf.zeros_initializer())
    if reusing_scope is False:
        tf.summary.histogram('Ns', Ns)
    update_Ns = tf.scatter_add(Ns, inputs_t, tf.ones_like(inputs_t, dtype=tf.float32))

    eps = tf.maximum(
        N0 / (N0 + tf.gather(Ns, inputs_t))
        , min_eps
        , name="eps"
    )

    nb_samples = tf.shape(q_preds_t)[0]
    max_actions = tf.cast(tf.argmax(q_preds_t, 1), tf.int32)
    probs_t = tf.sparse_to_dense(
        sparse_indices=tf.stack([tf.range(nb_samples), max_actions], 1)
        , output_shape=[nb_samples, nb_actions]
        , sparse_values=1 - eps
        , default_value=0.
    ) + tf.expand_dims(eps / nb_actions, 1)

    conditions = tf.greater(tf.random_uniform([nb_samples], 0, 1), eps)
    random_actions = tf.random_uniform(shape=[nb_samples], minval=0, maxval=nb_actions, dtype=tf.int32)
    with tf.control_dependencies([update_Ns]): # Force the update call
        actions_t = tf.where(conditions, max_actions, random_actions)

    return actions_t, probs_t
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号