def get_copy_var_ops(*, dest_scope_name: str, src_scope_name: str) -> List[tf.Operation]:
"""Creates TF operations that copy weights from `src_scope` to `dest_scope`
Args:
dest_scope_name (str): Destination weights (copy to)
src_scope_name (str): Source weight (copy from)
Returns:
List[tf.Operation]: Update operations are created and returned
"""
# Copy variables src_scope to dest_scope
op_holder = []
src_vars = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope=src_scope_name)
dest_vars = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope=dest_scope_name)
for src_var, dest_var in zip(src_vars, dest_vars):
op_holder.append(dest_var.assign(src_var.value()))
return op_holder
# returns pysc2.env.environment.TimeStep after end of the game
评论列表
文章目录