test_autograd.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:pytorch 作者: tylergenter 项目源码 文件源码
def gather_variable(shape, index_dim, max_indices):
    assert len(shape) == 2
    assert index_dim < 2
    batch_dim = 1 - index_dim
    index = torch.LongTensor(*shape)
    for i in range(shape[index_dim]):
        index.select(index_dim, i).copy_(
            torch.randperm(max_indices)[:shape[batch_dim]])
    return Variable(index, requires_grad=False)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号