def add_link(self, link, rank_in=None, rank_out=None):
"""Register one connected link with its inout rank.
Args:
link (chainer.Link): The link object to be registered.
rank_in (int, list, or None):
Ranks from which it receives data. If None is specified,
the model does not receive from any machines.
rank_out (int, list, or None):
Ranks to which it sends data. If None is specified,
the model will not send to any machine.
"""
super(MultiNodeChainList, self).add_link(link)
if isinstance(rank_in, int):
rank_in = [rank_in]
if isinstance(rank_out, int):
rank_out = [rank_out]
if rank_out is None:
for _, _rank_out in self._rank_inouts:
if _rank_out is None:
raise ValueError(
'MultiNodeChainList cannot have more than two '
'computational graph component whose rank_out is None')
self._rank_inouts.append((rank_in, rank_out))
评论列表
文章目录