def get_trainable_variables(trainable_scopes):
"""Returns a list of variables to train.
Returns:
A list of variables to train by the optimizer.
"""
if trainable_scopes is None:
return tf.trainable_variables()
trainable_scopes = [scope.strip() for scope in trainable_scopes]
variables_to_train = []
for scope in trainable_scopes:
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
variables_to_train.extend(variables)
return variables_to_train
评论列表
文章目录