def test_multi_gpu(self):
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.nn.parallel.data_parallel import data_parallel
from inferno.extensions.containers.graph import Graph
input_shape = [8, 1, 3, 128, 128]
model = Graph() \
.add_input_node('input') \
.add_node('conv0', nn.Conv3d(1, 10, 3, padding=1), previous='input') \
.add_node('conv1', nn.Conv3d(10, 1, 3, padding=1), previous='conv0') \
.add_output_node('output', previous='conv1')
model.cuda()
input = Variable(torch.rand(*input_shape).cuda())
output = data_parallel(model, input, device_ids=[0, 1, 2, 3])
评论列表
文章目录