def use_cuda(enabled, device_id=0):
"""Verifies if CUDA is available and sets default device to be device_id."""
if not enabled:
return None
assert torch.cuda.is_available(), 'CUDA is not available'
torch.set_default_tensor_type('torch.cuda.FloatTensor')
torch.cuda.set_device(device_id)
return device_id
评论列表
文章目录