def build_graph(self, *layers):
weights = [layer.weight.node for layer in layers]
self.ph_weights = graph.Placeholders(variables=graph.TfNode(weights))
self.assign = graph.TfNode([tf.assign(variable, value) for variable, value in
utils.Utils.izip(weights, self.ph_weights.checked)])
self.check = graph.TfNode(tf.group(*[tf.check_numerics(w, 'weight_%d' % i) for i, w in
enumerate(utils.Utils.flatten(weights))]))
self.global_norm = tf.global_norm(list(utils.Utils.flatten(weights)))
return weights
评论列表
文章目录