def padding(seqs, pad, batch_first=False):
"""
:param seqs: tuple of seq_length x dim
:return: seq_length x Batch x dim
"""
lengths = [len(s) for s in seqs]
seqs = [torch.Tensor(s) for s in seqs]
batch_length = max(lengths)
seq_tensor = torch.LongTensor(batch_length, len(seqs)).fill_(pad)
for i, s in enumerate(seqs):
end_seq = lengths[i]
seq_tensor[:end_seq, i].copy_(s[:end_seq])
if batch_first:
seq_tensor = seq_tensor.t()
return (seq_tensor, lengths)
评论列表
文章目录