Sets distributed processes to be bound to devices. Args: device_ordinal (int, optional): the device id to be bound to
(self, device_ordinal: int = None)
| 488 | self._groups.clear() |
| 489 | |
| 490 | def set_device(self, device_ordinal: int = None): |
| 491 | """Sets distributed processes to be bound to devices. |
| 492 | |
| 493 | Args: |
| 494 | device_ordinal (int, optional): the device id to be bound to |
| 495 | """ |
| 496 | global_rank = self.get_global_rank() |
| 497 | if device_ordinal is None: |
| 498 | devices_per_node = torch.cuda.device_count() |
| 499 | device_ordinal = global_rank % devices_per_node |
| 500 | |
| 501 | torch.cuda.set_device(device_ordinal) |
| 502 | logger.info(f"process rank {global_rank} is bound to host:{socket.gethostname()} device: {device_ordinal}") |
| 503 | |
| 504 | def set_seed(self, seed: int, dpseed_with_tpoffset: bool = False): |
| 505 | """Sets seeds for all random libraries. |
no test coverage detected