(
self,
model: nn.Module,
output_to_fp32: bool = True,
parallel_mode: ParallelMode = ParallelMode.DATA,
sync_buffer: bool = True,
dtype=torch.float16,
)
| 29 | """ |
| 30 | |
| 31 | def __init__( |
| 32 | self, |
| 33 | model: nn.Module, |
| 34 | output_to_fp32: bool = True, |
| 35 | parallel_mode: ParallelMode = ParallelMode.DATA, |
| 36 | sync_buffer: bool = True, |
| 37 | dtype=torch.float16, |
| 38 | ): |
| 39 | super().__init__() |
| 40 | self.model = model.to(dtype) |
| 41 | self._output_to_fp32 = output_to_fp32 |
| 42 | self._sync_buf = sync_buffer |
| 43 | self.dtype = dtype |
| 44 | |
| 45 | if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: |
| 46 | self._process_group = gpc.get_group(parallel_mode) |
| 47 | self._world_size = gpc.get_world_size(parallel_mode) |
| 48 | else: |
| 49 | self._process_group = None |
| 50 | self._world_size = 1 |
| 51 | self._sync_buf = False |
| 52 | self._first_eval_run = False |
| 53 | |
| 54 | @property |
| 55 | def sync_buffer(self): |
nothing calls this directly
no test coverage detected