trpo_model.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:relaax 作者: deeplearninc 项目源码 文件源码
def build_graph(self, kl_first_fixed, weights):
        weight_list = list(utils.Utils.flatten(weights.node))
        gradients1 = tf.gradients(kl_first_fixed.node, weight_list)
        ph_tangent = graph.Placeholder(np.float32, shape=(None,))

        gvp = []
        start = 0
        for g in gradients1:
            size = np.prod(g.shape.as_list())
            gvp.append(tf.reduce_sum(tf.reshape(g, [-1]) * ph_tangent.node[start:start + size]))
            start += size

        gradients2 = tf.gradients(gvp, weight_list)
        fvp = tf.concat([tf.reshape(g, [-1]) for g in gradients2], axis=0)

        self.ph_tangent = ph_tangent
        return fvp
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号