pytorch_misc.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:pytorch-seq2seq 作者: rowanz 项目源码 文件源码
def transpose_packed_sequence(ps):
    """
    Goes from a TxB packed sequence to a BxT or vice versa. Assumes that nothing is a variable
    :param ps: PackedSequence
    :return:
    """
    data, batch_sizes = ps
    seq_lens = transpose_batch_sizes(batch_sizes)

    # Put things in the permutation matrix one way, take out another way
    perm_mat = torch.IntTensor(batch_sizes[0], len(batch_sizes)).long().zero_()
    cur = 0
    for i, sl in enumerate(seq_lens):
        for col_ind in range(sl):
            perm_mat[i, col_ind] = cur + col_ind
        cur += sl
    perm = pack_padded_sequence(perm_mat, seq_lens, batch_first=True).data
    return PackedSequence(data[perm], seq_lens)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号