utils.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tflearn 作者: tflearn 项目源码 文件源码
def fix_saver(collection_lists=None):
    # Workaround to prevent serialization warning by removing objects
    if collection_lists is None:
        try:
            # Try latest api
            l = tf.get_collection_ref("summary_tags")
            l4 = tf.get_collection_ref(tf.GraphKeys.GRAPH_CONFIG)
        except Exception:
            l = tf.get_collection("summary_tags")
            l4 = tf.get_collection(tf.GraphKeys.GRAPH_CONFIG)
        l_stags = list(l)
        l4_stags = list(l4)
        del l[:]
        del l4[:]

        try:
            # Try latest api
            l1 = tf.get_collection_ref(tf.GraphKeys.DATA_PREP)
            l2 = tf.get_collection_ref(tf.GraphKeys.DATA_AUG)
        except Exception:
            l1 = tf.get_collection(tf.GraphKeys.DATA_PREP)
            l2 = tf.get_collection(tf.GraphKeys.DATA_AUG)
        l1_dtags = list(l1)
        l2_dtags = list(l2)
        del l1[:]
        del l2[:]

        try: # Do not save exclude variables
            l3 = tf.get_collection_ref(tf.GraphKeys.EXCL_RESTORE_VARS)
        except Exception:
            l3 = tf.get_collection(tf.GraphKeys.EXCL_RESTORE_VARS)
        l3_tags = list(l3)
        del l3[:]
        return [l_stags, l1_dtags, l2_dtags, l3_tags, l4_stags]
    else:
        # 0.7+ workaround, restore values
        for t in collection_lists[0]:
            tf.add_to_collection("summary_tags", t)
        for t in collection_lists[4]:
            tf.add_to_collection(tf.GraphKeys.GRAPH_CONFIG, t)
        for t in collection_lists[1]:
            tf.add_to_collection(tf.GraphKeys.DATA_PREP, t)
        for t in collection_lists[2]:
            tf.add_to_collection(tf.GraphKeys.DATA_AUG, t)
        for t in collection_lists[3]:
            tf.add_to_collection(tf.GraphKeys.EXCL_RESTORE_VARS, t)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号