def forward(self, lstm_out, lengths):
"""
Args:
lstm_out: A Variable containing a 3D tensor of dimension
(seq_len, batch_size, hidden_x_dirs)
lengths: A Variable containing 1D LongTensor of dimension
(batch_size)
Return:
A Variable containing a 2D tensor of the same type as lstm_out of
dim (batch_size, hidden_x_dirs) corresponding to the concatenated
last hidden states of the forward and backward parts of the input.
"""
seq_len = lstm_out.size(0)
batch_size = lstm_out.size(1)
hidden_x_dirs = lstm_out.size(2)
single_dir_hidden = hidden_x_dirs / 2
lengths_fw = lengths
lengths_bw = seq_len - lengths_fw
rep_lengths_fw = lengths_fw.view(1, batch_size, 1)
rep_lengths_fw = rep_lengths_fw.repeat(1, 1, single_dir_hidden)
rep_lengths_bw = lengths_bw.view(1, batch_size, 1)
rep_lengths_bw = rep_lengths_bw.repeat(1, 1, single_dir_hidden)
# we want 2 chunks in the last dimension
out_fw, out_bw = torch.chunk(lstm_out, 2, 2)
h_t_fw = torch.gather(out_fw, 0, rep_lengths_fw-1)
h_t_bw = torch.gather(out_bw, 0, rep_lengths_bw)
# -> (batch_size, hidden_x_dirs)
last_hidden_out = torch.cat([h_t_fw, h_t_bw], 2).squeeze()
return last_hidden_out
评论列表
文章目录