| 41 | # TODO: update notes/cuda.rst when this class handles 8+ GPUs well |
| 42 | |
| 43 | def __init__(self, module, device_ids=None, output_device=None, dim=0, chunk_sizes=None): |
| 44 | super(DataParallel, self).__init__() |
| 45 | |
| 46 | if not torch.cuda.is_available(): |
| 47 | self.module = module |
| 48 | self.device_ids = [] |
| 49 | return |
| 50 | |
| 51 | if device_ids is None: |
| 52 | device_ids = list(range(torch.cuda.device_count())) |
| 53 | if output_device is None: |
| 54 | output_device = device_ids[0] |
| 55 | self.dim = dim |
| 56 | self.module = module |
| 57 | self.device_ids = device_ids |
| 58 | self.chunk_sizes = chunk_sizes |
| 59 | self.output_device = output_device |
| 60 | if len(self.device_ids) == 1: |
| 61 | self.module.cuda(device_ids[0]) |
| 62 | |
| 63 | def forward(self, *inputs, **kwargs): |
| 64 | if not self.device_ids: |