def __init__(
self, comm, n_layers, n_source_vocab, n_target_vocab, n_units):
super(Decoder, self).__init__(
embed_y=L.EmbedID(n_target_vocab, n_units),
# Corresponding encoder LSTM will be invoked on process 0.
mn_decoder=chainermn.links.create_multi_node_n_step_rnn(
L.NStepLSTM(n_layers, n_units, n_units, 0.1),
comm, rank_in=0, rank_out=None),
W=L.Linear(n_units, n_target_vocab),
)
self.comm = comm
self.n_layers = n_layers
self.n_units = n_units
评论列表
文章目录