MCPcopy
hub / github.com/ultralytics/yolov5 / select_device

Function select_device

utils/torch_utils.py:108–141  ·  view source on GitHub ↗
(device='', batch_size=0, newline=True)

Source from the content-addressed store, hash-verified

106
107
108def select_device(device='', batch_size=0, newline=True):
109 # device = None or 'cpu' or 0 or '0' or '0,1,2,3'
110 s = f'YOLOv5 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '
111 device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
112 cpu = device == 'cpu'
113 mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
114 if cpu or mps:
115 os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
116 elif device: # non-cpu device requested
117 os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
118 assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
119 f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
120
121 if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
122 devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
123 n = len(devices) # device count
124 if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
125 assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
126 space = ' ' * (len(s) + 1)
127 for i, d in enumerate(devices):
128 p = torch.cuda.get_device_properties(i)
129 s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
130 arg = 'cuda:0'
131 elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available(): # prefer MPS if available
132 s += 'MPS\n'
133 arg = 'mps'
134 else: # revert to CPU
135 s += 'CPU\n'
136 arg = 'cpu'
137
138 if not newline:
139 s = s.rstrip()
140 LOGGER.info(s)
141 return torch.device(arg)
142
143
144def time_sync():

Callers 15

runFunction · 0.90
mainFunction · 0.90
runFunction · 0.90
testFunction · 0.90
runFunction · 0.90
runFunction · 0.90
_createFunction · 0.90
notebook_initFunction · 0.90
sweepFunction · 0.90
runFunction · 0.90
runFunction · 0.90
runFunction · 0.90

Calls 3

git_describeFunction · 0.90
file_dateFunction · 0.90
infoMethod · 0.80

Tested by

no test coverage detected