r"""DeepSpeed engine for training and inference.
| 28 | |
| 29 | |
| 30 | class DeepSpeedHybridEngine(DeepSpeedEngine): |
| 31 | r"""DeepSpeed engine for training and inference.""" |
| 32 | inference_mp_group = None |
| 33 | |
| 34 | def __init__(self, args, model, **kwargs): |
| 35 | |
| 36 | super().__init__(args, model, **kwargs) |
| 37 | |
| 38 | # synch seed between all GPUs |
| 39 | _rng_state = get_accelerator().get_rng_state().to(get_accelerator().current_device_name()) |
| 40 | dist.broadcast(_rng_state, 0) |
| 41 | get_accelerator().set_rng_state(_rng_state.cpu()) |
| 42 | |
| 43 | self.Z3_enabled = (self._config.zero_config.stage == 3) |
| 44 | self.gather_all_layers = self._config.hybrid_engine.pin_parameters |
| 45 | |
| 46 | # inference containers / fwds |
| 47 | self._inference_containers = [] |
| 48 | self._orig_modules = [] |
| 49 | self._orig_fwds = [] |
| 50 | self.create_inference_module() |
| 51 | |
| 52 | # Performance stats |
| 53 | self._t_start = None |
| 54 | self._total_latency = 0 |
| 55 | self._iters = 0 |
| 56 | self._training_start_time = None |
| 57 | self._generate_latency = 0 |
| 58 | self._training_latency = 0 |
| 59 | self._total_batch_size = None |
| 60 | self._gather_latency = 0 |
| 61 | |
| 62 | self.is_lora_fused = False |
| 63 | self.workspace = WorkspaceOp() |
| 64 | |
| 65 | def convert_to_linear_transposed(self, model): |
| 66 | |
| 67 | def _replace_linear_layer(r_module, parent_type=None, prev_type=None): |
| 68 | for name, child in r_module.named_children(): |
| 69 | if child.__class__ in [torch.nn.Linear] and \ |
| 70 | (parent_type is torch.nn.ModuleList or prev_type is torch.nn.ModuleList): |
| 71 | setattr(r_module, name, TLinear(child, name)) |
| 72 | else: |
| 73 | _replace_linear_layer(child, type(r_module), prev_type=parent_type) |
| 74 | return r_module |
| 75 | |
| 76 | _replace_linear_layer(model) |
| 77 | |
| 78 | def new_inference_container(self, orig_layer, policy_cls, layer_id): |
| 79 | policy = policy_cls(orig_layer, inference=True) |
| 80 | |
| 81 | if self._config.float16_config.enabled: |
| 82 | inference_dtype = torch.float16 |
| 83 | elif self._config.bfloat16_config.enabled: |
| 84 | inference_dtype = torch.bfloat16 |
| 85 | else: |
| 86 | inference_dtype = torch.float32 |
| 87 |
no outgoing calls
no test coverage detected
searching dependent graphs…