| 106 | |
| 107 | |
| 108 | def select_device(device='', batch_size=0, newline=True): |
| 109 | # device = None or 'cpu' or 0 or '0' or '0,1,2,3' |
| 110 | s = f'YOLOv5 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} ' |
| 111 | device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0' |
| 112 | cpu = device == 'cpu' |
| 113 | mps = device == 'mps' # Apple Metal Performance Shaders (MPS) |
| 114 | if cpu or mps: |
| 115 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False |
| 116 | elif device: # non-cpu device requested |
| 117 | os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available() |
| 118 | assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \ |
| 119 | f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)" |
| 120 | |
| 121 | if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available |
| 122 | devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7 |
| 123 | n = len(devices) # device count |
| 124 | if n > 1 and batch_size > 0: # check batch_size is divisible by device_count |
| 125 | assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}' |
| 126 | space = ' ' * (len(s) + 1) |
| 127 | for i, d in enumerate(devices): |
| 128 | p = torch.cuda.get_device_properties(i) |
| 129 | s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB |
| 130 | arg = 'cuda:0' |
| 131 | elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available(): # prefer MPS if available |
| 132 | s += 'MPS\n' |
| 133 | arg = 'mps' |
| 134 | else: # revert to CPU |
| 135 | s += 'CPU\n' |
| 136 | arg = 'cpu' |
| 137 | |
| 138 | if not newline: |
| 139 | s = s.rstrip() |
| 140 | LOGGER.info(s) |
| 141 | return torch.device(arg) |
| 142 | |
| 143 | |
| 144 | def time_sync(): |