def reset_parameters(self):
I.normal(self.embeddings.weight.data, mean=0, std=0.01)
I.xavier_normal(self.W_i.weight.data)
I.xavier_normal(self.W_o.weight.data)
init_rnn_cell(self.encoder)
for i in range(self.n_decoders):
decoder = getattr(self, "decoder{}".format(i))
init_rnn_cell(decoder)
评论列表
文章目录