torch_util.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:multiNLI_encoder 作者: easonnie 项目源码 文件源码
def select_last(inputs, lengths, hidden_size):
    """
    :param inputs: [T * B * D] D = 2 * hidden_size
    :param lengths: [B]
    :param hidden_size: dimension 
    :return:  [B * D]
    """
    batch_size = inputs.size(1)
    batch_out_list = []
    for b in range(batch_size):
        batch_out_list.append(torch.cat((inputs[lengths[b] - 1, b, :hidden_size],
                                         inputs[0, b, hidden_size:])
                                        )
                              )

    out = torch.stack(batch_out_list)
    return out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号