def __init__(
self, comm, n_layers, n_source_vocab, n_target_vocab, n_units):
super(Encoder, self).__init__(
embed_x=L.EmbedID(n_source_vocab, n_units),
# Corresponding decoder LSTM will be invoked on process 1.
mn_encoder=chainermn.links.create_multi_node_n_step_rnn(
L.NStepLSTM(n_layers, n_units, n_units, 0.1),
comm, rank_in=None, rank_out=1
),
)
self.comm = comm
self.n_layers = n_layers
self.n_units = n_units
评论列表
文章目录