def add_role(variables, roles):
r"""Add a role to a given variable.
Parameters
----------
var : :class:`~tensor.TensorVariable`
The variable to assign the new role to.
roles : :subclass:`Role`
this roles will be concatenated with current roles scope.
Notes
-----
Some roles are subroles of others (e.g. :class:`Weight` is a subrole
of :class:`Parameter`). This function will not add a role if a more
specific role has already been added. If you need to replace a role
with a parent role (e.g. replace :class:`Weight` with
:class:`Parameter`) you must do so manually.
"""
if roles is None:
return variables
roles = tuple([name_to_roles(r) for r in as_tuple(roles)])
# create tag attribute for variable
for var in as_tuple(variables):
# append roles scope
var_roles = get_roles(var, return_string=False) + \
roles + \
get_current_role_scope()
# ====== handle string roles first ====== #
_ = []
for r in var_roles:
if isinstance(r, string_types):
_add_to_collection_no_duplication(r, var)
elif isinstance(r, type) and issubclass(r, Role):
_.append(r)
var_roles = _
# ====== shrink the roles so there is NO subrole ====== #
new_roles = []
for r in var_roles:
if any(r != r0 and issubclass(r0, r) for r0 in var_roles):
tf.get_collection_ref(r.__name__).remove(var)
else:
new_roles.append(r)
# ====== adding new role ====== #
for r in new_roles:
_add_to_collection_no_duplication(r.__name__, var)
return variables
评论列表
文章目录