def select_last(inputs, lengths, hidden_size):
"""
:param inputs: [T * B * D] D = 2 * hidden_size
:param lengths: [B]
:param hidden_size: dimension
:return: [B * D]
"""
batch_size = inputs.size(1)
batch_out_list = []
for b in range(batch_size):
batch_out_list.append(torch.cat((inputs[lengths[b] - 1, b, :hidden_size],
inputs[0, b, hidden_size:])
)
)
out = torch.stack(batch_out_list)
return out
评论列表
文章目录