def get_num_gpu():
"""Get number of available GPUs
Returns:
a `int`, available GPUs in CUDA_VISIBLE_DEVICES, or in the system.
"""
env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
if env is not None:
return len(env.split(','))
from tensorflow.python.client import device_lib
device_protos = device_lib.list_local_devices()
gpus = [x.name for x in device_protos if x.device_type == 'GPU']
return len(gpus)
评论列表
文章目录