def copy_all_vars(from_namespace, to_namespace, affine_coefficient=1.0):
assert affine_coefficient >= 0.0 and affine_coefficient <= 1.0
copy_ops = []
with tf.variable_scope("", reuse=True): # for grabbing the targets by full namespace
for src_var in tf.all_variables():
# ignore any variable not in src namespace
if not src_var.name.startswith(from_namespace):
continue
# fetch reference to target variable with the same name as the src variable
assert src_var.name.endswith(":0")
target_var_name = src_var.name.replace(from_namespace, to_namespace).replace(":0", "")
target_var = tf.get_variable(target_var_name, src_var.get_shape())
# create a copy op to clobber target with src
# target = alpha * src + (1.0-alpha) * target
copy_ops.append(target_var.assign_sub(affine_coefficient * (target_var - src_var)))
single_copy_op = tf.group(*copy_ops)
return single_copy_op
评论列表
文章目录