def add_variable(var, restore=True):
"""Adds a variable to the MODEL_VARIABLES collection.
Optionally it will add the variable to the VARIABLES_TO_RESTORE collection.
Args:
var: a variable.
restore: whether the variable should be added to the
VARIABLES_TO_RESTORE collection.
"""
collections = [MODEL_VARIABLES]
if restore:
collections.append(VARIABLES_TO_RESTORE)
for collection in collections:
if var not in tf.get_collection(collection):
tf.add_to_collection(collection, var)
评论列表
文章目录