utils.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:trpo 作者: jjkke88 项目源码 文件源码
def slice_2d(x, inds0, inds1):
    # assume that a path have 1000 vector, then ncols=action dims, inds0=1000,inds1=
    inds0 = tf.cast(inds0, tf.int64)
    inds1 = tf.cast(inds1, tf.int64)
    shape = tf.cast(tf.shape(x), tf.int64)
    ncols = shape[1]
    x_flat = tf.reshape(x, [-1])
    return tf.gather(x_flat, inds0 * ncols + inds1)


# def linesearch(f, x, fullstep, expected_improve_rate):
#     accept_ratio = .1
#     max_backtracks = 10
#     fval, old_kl, entropy = f(x)
#     for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)):
#         xnew = x + stepfrac * fullstep
#         newfval, new_kl, new_ent= f(xnew)
#         # actual_improve = newfval - fval # minimize target object
#         # expected_improve = expected_improve_rate * stepfrac
#         # ratio = actual_improve / expected_improve
#         # if ratio > accept_ratio and actual_improve > 0:
#         #     return xnew
#         if newfval<fval and new_kl<=pms.max_kl:
#             return xnew
#     return x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号