feedforward.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:pytorch2c 作者: lantiga 项目源码 文件源码
def feedforward_test():

    import torch.nn as nn
    import torch.nn.functional as F

    fc1 = nn.Linear(10,20)
    fc1.weight.data.normal_(0.0,1.0)
    fc1.bias.data.normal_(0.0,1.0)

    fc2 = nn.Linear(20,2)
    fc2.weight.data.normal_(0.0,1.0)
    fc2.bias.data.normal_(0.0,1.0)

    model = lambda x: F.log_softmax(fc2(F.relu(fc1(x))))

    data = Variable(torch.rand(10,10))

    out_path = 'out'
    if not os.path.isdir(out_path):
        os.mkdir(out_path)
    uid = str(uuid.uuid4())

    torch2c.compile(model(data),'feedforward',os.path.join(out_path,uid),compile_test=True)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号