multi_node_chain_list.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:chainermn 作者: chainer 项目源码 文件源码
def add_link(self, link, rank_in=None, rank_out=None):
        """Register one connected link with its inout rank.

        Args:
            link (chainer.Link): The link object to be registered.
            rank_in (int, list, or None):
                Ranks from which it receives data. If None is specified,
                the model does not receive from any machines.
            rank_out (int, list, or None):
                Ranks to which it sends data. If None is specified,
                the model will not send to any machine.
        """
        super(MultiNodeChainList, self).add_link(link)
        if isinstance(rank_in, int):
            rank_in = [rank_in]
        if isinstance(rank_out, int):
            rank_out = [rank_out]

        if rank_out is None:
            for _, _rank_out in self._rank_inouts:
                if _rank_out is None:
                    raise ValueError(
                        'MultiNodeChainList cannot have more than two '
                        'computational graph component whose rank_out is None')

        self._rank_inouts.append((rank_in, rank_out))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号