test.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:PyTorch-Encoding 作者: zhanghang1989 项目源码 文件源码
def test_scaledL2():
    B,N,K,D = 2,3,4,5
    X = Variable(torch.cuda.DoubleTensor(B,N,D).uniform_(-0.5,0.5), 
        requires_grad=True)
    C = Variable(torch.cuda.DoubleTensor(K,D).uniform_(-0.5,0.5), 
        requires_grad=True)
    S = Variable(torch.cuda.DoubleTensor(K).uniform_(-0.5,0.5), 
        requires_grad=True)
    input = (X, C, S)
    test = gradcheck(encoding.functions.scaledL2, input, eps=1e-6, atol=1e-4)
    print('Testing scaledL2(): {}'.format(test))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号