()
| 81 | |
| 82 | |
| 83 | def _get_gpu_ids(): |
| 84 | if 'CUDA_VISIBLE_DEVICES' in os.environ: |
| 85 | return [int(x) for x in os.environ.get('CUDA_VISIBLE_DEVICES', '').split(',')] |
| 86 | if platform.system() in ['Darwin', 'Linux']: |
| 87 | return [int(d.replace('nvidia', '')) for d in os.listdir('/dev') if re.match('^nvidia\d+$', d)] |
| 88 | else: |
| 89 | print('Please set CUDA_VISIBLE_DEVICES (see http://acceleware.com/blog/cudavisibledevices-masking-gpus)') |
| 90 | return [] |
| 91 | |
| 92 | |
| 93 | GPU_IDS = _get_gpu_ids() |