def scatter(inputs, target_gpus, dim=0):
"""
Slices variables into approximately equal chunks and
distributes them accross given GPUs. Duplicates
references to objects that are not variables. Does not
support Tensors.
"""
def scatter_map(obj):
if isinstance(obj, Variable):
return Scatter(target_gpus, dim=dim)(obj)
assert not torch.is_tensor(obj), "Tensors not supported in scatter."
if isinstance(obj, tuple):
return tuple(zip(*map(scatter_map, obj)))
if isinstance(obj, list):
return tuple(map(list, zip(*map(scatter_map, obj))))
if isinstance(obj, dict):
return tuple(map(type(obj), zip(*map(scatter_map, obj.items()))))
return tuple(obj for targets in target_gpus)
return scatter_map(inputs)
评论列表
文章目录