Returns: int: #available GPUs in CUDA_VISIBLE_DEVICES, or in the system.
()
| 27 | |
| 28 | |
| 29 | def get_num_gpu(): |
| 30 | """ |
| 31 | Returns: |
| 32 | int: #available GPUs in CUDA_VISIBLE_DEVICES, or in the system. |
| 33 | """ |
| 34 | |
| 35 | def warn_return(ret, message): |
| 36 | try: |
| 37 | import tensorflow as tf |
| 38 | except ImportError: |
| 39 | return ret |
| 40 | |
| 41 | built_with_cuda = tf.test.is_built_with_cuda() |
| 42 | if not built_with_cuda and ret > 0: |
| 43 | logger.warn(message + "But TensorFlow was not built with CUDA support and could not use GPUs!") |
| 44 | return ret |
| 45 | |
| 46 | try: |
| 47 | # Use NVML to query device properties |
| 48 | with NVMLContext() as ctx: |
| 49 | nvml_num_dev = ctx.num_devices() |
| 50 | except Exception: |
| 51 | nvml_num_dev = None |
| 52 | |
| 53 | env = os.environ.get('CUDA_VISIBLE_DEVICES', None) |
| 54 | if env: |
| 55 | num_dev = len(env.split(',')) |
| 56 | assert num_dev <= nvml_num_dev, \ |
| 57 | "Only {} GPU(s) available, but CUDA_VISIBLE_DEVICES is set to {}".format(nvml_num_dev, env) |
| 58 | return warn_return(num_dev, "Found non-empty CUDA_VISIBLE_DEVICES. ") |
| 59 | |
| 60 | output, code = subproc_call("nvidia-smi -L", timeout=5) |
| 61 | if code == 0: |
| 62 | output = output.decode('utf-8') |
| 63 | return warn_return(len(output.strip().split('\n')), "Found nvidia-smi. ") |
| 64 | |
| 65 | if nvml_num_dev is not None: |
| 66 | return warn_return(nvml_num_dev, "NVML found nvidia devices. ") |
| 67 | |
| 68 | # Fallback to TF |
| 69 | logger.info("Loading local devices by TensorFlow ...") |
| 70 | |
| 71 | try: |
| 72 | import tensorflow as tf |
| 73 | # available since TF 1.14 |
| 74 | gpu_devices = tf.config.experimental.list_physical_devices('GPU') |
| 75 | except AttributeError: |
| 76 | from tensorflow.python.client import device_lib |
| 77 | local_device_protos = device_lib.list_local_devices() |
| 78 | # Note this will initialize all GPUs and therefore has side effect |
| 79 | # https://github.com/tensorflow/tensorflow/issues/8136 |
| 80 | gpu_devices = [x.name for x in local_device_protos if x.device_type == 'GPU'] |
| 81 | return len(gpu_devices) |
| 82 | |
| 83 | |
| 84 | get_nr_gpu = get_num_gpu |
no test coverage detected