def encode(self, X, skip_mask=None):
batchsize = X.shape[0]
seq_length = X.shape[1]
enmbedding = self.encoder_embed(X)
enmbedding = F.swapaxes(enmbedding, 1, 2)
out_data = self._forward_encoder_layer(0, enmbedding, skip_mask=skip_mask)
in_data = [out_data]
for layer_index in range(1, self.num_layers):
out_data = self._forward_encoder_layer(layer_index, F.concat(in_data) if self.densely_connected else in_data[-1], skip_mask=skip_mask)
in_data.append(out_data)
out_data = F.concat(in_data) if self.densely_connected else in_data[-1] # dense conv
if self.using_dropout:
out_data = F.dropout(out_data, ratio=self.dropout)
last_hidden_states = []
for layer_index in range(0, self.num_layers):
encoder = self.get_encoder(layer_index)
last_hidden_states.append(encoder.get_last_hidden_state())
return last_hidden_states
评论列表
文章目录