| 40 | |
| 41 | |
| 42 | def select_device(device='', batch_size=None): |
| 43 | # device = 'cpu' or '0' or '0,1,2,3' |
| 44 | cpu_request = device.lower() == 'cpu' |
| 45 | if device and not cpu_request: # if device requested other than 'cpu' |
| 46 | os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable |
| 47 | assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device # check availablity |
| 48 | |
| 49 | cuda = False if cpu_request else torch.cuda.is_available() |
| 50 | if cuda: |
| 51 | c = 1024 ** 2 # bytes to MB |
| 52 | ng = torch.cuda.device_count() |
| 53 | if ng > 1 and batch_size: # check that batch_size is compatible with device_count |
| 54 | assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng) |
| 55 | x = [torch.cuda.get_device_properties(i) for i in range(ng)] |
| 56 | s = f'Using torch {torch.__version__} ' |
| 57 | for i in range(0, ng): |
| 58 | if i == 1: |
| 59 | s = ' ' * len(s) |
| 60 | logger.info("%sCUDA:%g (%s, %dMB)" % (s, i, x[i].name, x[i].total_memory / c)) |
| 61 | else: |
| 62 | logger.info(f'Using torch {torch.__version__} CPU') |
| 63 | |
| 64 | logger.info('') # skip a line |
| 65 | return torch.device('cuda:0' if cuda else 'cpu') |
| 66 | |
| 67 | |
| 68 | def time_synchronized(): |