def dump_vars(sess):
all_vars = set(tf.all_variables())
trainable_vars = set(tf.trainable_variables())
non_trainable_vars = all_vars.difference(trainable_vars)
def _dump_set(var_set):
names_vars = map(lambda v: (v.name, v), var_set)
for n, v in sorted(names_vars, key=lambda nv: nv[0]):
print("%s=%s" % (n, sess.run(v)))
print("Variable values:")
print("-----------")
print("\n---Trainable vars:")
_dump_set(trainable_vars)
print("\n---Non Trainable vars:")
_dump_set(non_trainable_vars)
print("-----------")
评论列表
文章目录