Tensorflow变量范围:如果变量存在则重用
我想要一段代码,如果它不存在,则在范围内创建一个变量,如果它已经存在,则访问该变量。我需要它是 相同的 代码,因为它将被多次调用。
但是,Tensorflow需要我指定是要创建还是重用该变量,如下所示:
with tf.variable_scope("foo"): #create the first time
v = tf.get_variable("v", [1])
with tf.variable_scope("foo", reuse=True): #reuse the second time
v = tf.get_variable("v", [1])
我怎样才能弄清楚是自动创建还是重用它?即,我希望以上两个代码块 相同, 并运行程序。
-
创建新变量且未声明形状时,或在变量创建过程中违反重用时,
ValueError
将引发A。get_variable()
因此,您可以尝试以下操作:def get_scope_variable(scope_name, var, shape=None): with tf.variable_scope(scope_name) as scope: try: v = tf.get_variable(var, shape) except ValueError: scope.reuse_variables() v = tf.get_variable(var) return v v1 = get_scope_variable('foo', 'v', [1]) v2 = get_scope_variable('foo', 'v') assert v1 == v2
请注意,以下内容也适用:
v1 = get_scope_variable('foo', 'v', [1]) v2 = get_scope_variable('foo', 'v', [1]) assert v1 == v2
更新。 新的API现在支持自动重用:
def get_scope_variable(scope, var, shape=None): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): v = tf.get_variable(var, shape) return v