def create_multi_node_n_step_rnn(
actual_link, communicator, rank_in=None, rank_out=None):
"""Create a multi node stacked RNN link from a Chainer stacked RNN link.
Multi node stacked RNN link is used for model-parallel.
The created link will receive initial hidden states from the process
specified by ``rank_in`` (or do not receive if ``None``), execute
the original RNN compuation, and then send resulting hidden states
to the process specified by ``rank_out``.
Compared with Chainer stacked RNN link, multi node stacked RNN link
returns an extra object called ``delegate_variable``.
If ``rank_out`` is not ``None``, backward computation is expected
to be begun from ``delegate_variable``.
For detail, please refer ``chainermn.functions.pseudo_connect``.
The following RNN links can be passed to this function:
- ``chainer.links.NStepBiGRU``
- ``chainer.links.NStepBiLSTM``
- ``chainer.links.NStepBiRNNReLU``
- ``chainer.links.NStepBiRNNTanh``
- ``chainer.links.NStepGRU``
- ``chainer.links.NStepLSTM``
- ``chainer.links.NStepRNNReLU``
- ``chainer.links.NStepRNNTanh``
Args:
link (chainer.Link): Chainer stacked RNN link
communicator: ChainerMN communicator
rank_in (int, or None):
Rank of the process which sends hidden RNN states to this process.
rank_out (int, or None):
Rank of the process to which this process sends hiddne RNN states.
Returns:
The multi node stacked RNN link based on ``actual_link``.
"""
chainer.utils.experimental('chainermn.links.create_multi_node_n_step_rnn')
return _MultiNodeNStepRNN(actual_link, communicator, rank_in, rank_out)
评论列表
文章目录