graph.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:inferno 作者: inferno-pytorch 项目源码 文件源码
def test_multi_gpu(self):
        import torch
        from torch.autograd import Variable
        import torch.nn as nn
        from torch.nn.parallel.data_parallel import data_parallel
        from inferno.extensions.containers.graph import Graph

        input_shape = [8, 1, 3, 128, 128]
        model = Graph() \
            .add_input_node('input') \
            .add_node('conv0', nn.Conv3d(1, 10, 3, padding=1), previous='input') \
            .add_node('conv1', nn.Conv3d(10, 1, 3, padding=1), previous='conv0') \
            .add_output_node('output', previous='conv1')

        model.cuda()
        input = Variable(torch.rand(*input_shape).cuda())
        output = data_parallel(model, input, device_ids=[0, 1, 2, 3])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号