def forward(self, input):
lengths = process_lengths(input)
x = self.embedding(input) # seq2seq
x = getattr(F, 'tanh')(x)
x_0, hn = self.rnn_0(x)
vec_0 = select_last(x_0, lengths)
# x_1 = F.dropout(x_0, p=0.3, training=self.training)
# print(x_1.size())
x_1, hn = self.rnn_1(x_0)
vec_1 = select_last(x_1, lengths)
vec_0 = F.dropout(vec_0, p=0.3, training=self.training)
vec_1 = F.dropout(vec_1, p=0.3, training=self.training)
output = torch.cat((vec_0, vec_1), 1)
return output
评论列表
文章目录