def fix_saver(collection_lists=None):
# Workaround to prevent serialization warning by removing objects
if collection_lists is None:
try:
# Try latest api
l = tf.get_collection_ref("summary_tags")
l4 = tf.get_collection_ref(tf.GraphKeys.GRAPH_CONFIG)
except Exception:
l = tf.get_collection("summary_tags")
l4 = tf.get_collection(tf.GraphKeys.GRAPH_CONFIG)
l_stags = list(l)
l4_stags = list(l4)
del l[:]
del l4[:]
try:
# Try latest api
l1 = tf.get_collection_ref(tf.GraphKeys.DATA_PREP)
l2 = tf.get_collection_ref(tf.GraphKeys.DATA_AUG)
except Exception:
l1 = tf.get_collection(tf.GraphKeys.DATA_PREP)
l2 = tf.get_collection(tf.GraphKeys.DATA_AUG)
l1_dtags = list(l1)
l2_dtags = list(l2)
del l1[:]
del l2[:]
try: # Do not save exclude variables
l3 = tf.get_collection_ref(tf.GraphKeys.EXCL_RESTORE_VARS)
except Exception:
l3 = tf.get_collection(tf.GraphKeys.EXCL_RESTORE_VARS)
l3_tags = list(l3)
del l3[:]
return [l_stags, l1_dtags, l2_dtags, l3_tags, l4_stags]
else:
# 0.7+ workaround, restore values
for t in collection_lists[0]:
tf.add_to_collection("summary_tags", t)
for t in collection_lists[4]:
tf.add_to_collection(tf.GraphKeys.GRAPH_CONFIG, t)
for t in collection_lists[1]:
tf.add_to_collection(tf.GraphKeys.DATA_PREP, t)
for t in collection_lists[2]:
tf.add_to_collection(tf.GraphKeys.DATA_AUG, t)
for t in collection_lists[3]:
tf.add_to_collection(tf.GraphKeys.EXCL_RESTORE_VARS, t)
评论列表
文章目录