def _check_capability():
error_str = """
Found GPU%d %s which requires CUDA_VERSION >= %d for
optimal performance and fast startup time, but your PyTorch was compiled
with CUDA_VERSION %d. Please install the correct PyTorch binary
using instructions from http://pytorch.org
"""
CUDA_VERSION = torch._C._cuda_getCompiledVersion()
for d in range(device_count()):
major = get_device_capability(d)[0]
name = get_device_name(d)
if CUDA_VERSION < 8000 and major >= 6:
warnings.warn(error_str % (d, name, 8000, CUDA_VERSION))
elif CUDA_VERSION < 9000 and major >= 7:
warnings.warn(error_str % (d, name, 9000, CUDA_VERSION))
评论列表
文章目录