train.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:tf_classification 作者: visipedia 项目源码 文件源码
def get_trainable_variables(trainable_scopes):
    """Returns a list of variables to train.
    Returns:
        A list of variables to train by the optimizer.
    """

    if trainable_scopes is None:
        return tf.trainable_variables()

    trainable_scopes = [scope.strip() for scope in trainable_scopes]

    variables_to_train = []
    for scope in trainable_scopes:
        variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
        variables_to_train.extend(variables)
    return variables_to_train
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号