Get the current device. For GPU we return the local process index to enable multiple GPU training.
()
| 31 | |
| 32 | |
| 33 | def get_current_device() -> int: |
| 34 | """Get the current device. For GPU we return the local process index to enable multiple GPU training.""" |
| 35 | return Accelerator().local_process_index if torch.cuda.is_available() else "cpu" |
| 36 | |
| 37 | |
| 38 | def get_kbit_device_map() -> Dict[str, int] | None: |
no outgoing calls
no test coverage detected