torch_util.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:multiNLI_encoder 作者: easonnie 项目源码 文件源码
def max_along_time(inputs, lengths):
    """
    :param inputs: [T * B * D] 
    :param lengths:  [B]
    :return: [B * D] max_along_time
    """
    ls = list(lengths)

    b_seq_max_list = []
    for i, l in enumerate(ls):
        seq_i = inputs[:l, i, :]
        seq_i_max, _ = seq_i.max(dim=0)
        seq_i_max = seq_i_max.squeeze()
        b_seq_max_list.append(seq_i_max)

    return torch.stack(b_seq_max_list)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号