Tensorflow变量范围:如果变量存在则重用

发布于 2021-01-29 18:29:45

我想要一段代码,如果它不存在,则在范围内创建一个变量,如果它已经存在,则访问该变量。我需要它是 相同的 代码,因为它将被多次调用。

但是,Tensorflow需要我指定是要创建还是重用该变量,如下所示:

with tf.variable_scope("foo"): #create the first time
    v = tf.get_variable("v", [1])

with tf.variable_scope("foo", reuse=True): #reuse the second time
    v = tf.get_variable("v", [1])

我怎样才能弄清楚是自动创建还是重用它?即,我希望以上两个代码块 相同, 并运行程序。

关注者
0
被浏览
44
1 个回答
  • 面试哥
    面试哥 2021-01-29
    为面试而生,有面试问题,就找面试哥。

    创建新变量且未声明形状时,或在变量创建过程中违反重用时,ValueError将引发A。get_variable()因此,您可以尝试以下操作:

    def get_scope_variable(scope_name, var, shape=None):
        with tf.variable_scope(scope_name) as scope:
            try:
                v = tf.get_variable(var, shape)
            except ValueError:
                scope.reuse_variables()
                v = tf.get_variable(var)
        return v
    
    v1 = get_scope_variable('foo', 'v', [1])
    v2 = get_scope_variable('foo', 'v')
    assert v1 == v2
    

    请注意,以下内容也适用:

    v1 = get_scope_variable('foo', 'v', [1])
    v2 = get_scope_variable('foo', 'v', [1])
    assert v1 == v2
    

    更新。 新的API现在支持自动重用:

    def get_scope_variable(scope, var, shape=None):
        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            v = tf.get_variable(var, shape)
        return v
    


知识点
面圈网VIP题库

面圈网VIP题库全新上线,海量真题题库资源。 90大类考试,超10万份考试真题开放下载啦

去下载看看