def build_graph(self, kl_first_fixed, weights):
weight_list = list(utils.Utils.flatten(weights.node))
gradients1 = tf.gradients(kl_first_fixed.node, weight_list)
ph_tangent = graph.Placeholder(np.float32, shape=(None,))
gvp = []
start = 0
for g in gradients1:
size = np.prod(g.shape.as_list())
gvp.append(tf.reduce_sum(tf.reshape(g, [-1]) * ph_tangent.node[start:start + size]))
start += size
gradients2 = tf.gradients(gvp, weight_list)
fvp = tf.concat([tf.reshape(g, [-1]) for g in gradients2], axis=0)
self.ph_tangent = ph_tangent
return fvp
评论列表
文章目录