r"""DeepSpeed engine for training.
| 206 | |
| 207 | |
| 208 | class DeepSpeedEngine(Module): |
| 209 | r"""DeepSpeed engine for training.""" |
| 210 | |
| 211 | def __init__(self, |
| 212 | args, |
| 213 | model, |
| 214 | optimizer=None, |
| 215 | model_parameters=None, |
| 216 | training_data=None, |
| 217 | lr_scheduler=None, |
| 218 | mpu=None, |
| 219 | dist_init_required=None, |
| 220 | collate_fn=None, |
| 221 | config=None, |
| 222 | config_class=None, |
| 223 | mesh_device=None, |
| 224 | dont_change_device=False): |
| 225 | super(DeepSpeedEngine, self).__init__() |
| 226 | self.dont_change_device = dont_change_device |
| 227 | self.client_optimizer = optimizer |
| 228 | self.client_lr_scheduler = lr_scheduler |
| 229 | self.training_data = training_data |
| 230 | self.collate_fn = collate_fn |
| 231 | self.mpu = mpu |
| 232 | self.all_to_all_group = None |
| 233 | self.data_parallel_group = None |
| 234 | self.global_steps = 0 |
| 235 | self.global_samples = 0 |
| 236 | self.micro_steps = 0 |
| 237 | self.skipped_steps = 0 |
| 238 | self.gradient_average = True |
| 239 | self.warn_unscaled_loss = True |
| 240 | self.config = config |
| 241 | self._config = config_class |
| 242 | self.loaded_checkpoint_mp_world_size = None |
| 243 | self.loaded_checkpoint_dp_world_size = None |
| 244 | self.enable_backward_allreduce = True |
| 245 | self.inside_no_sync_ctxt = False |
| 246 | self.progressive_layer_drop = None |
| 247 | self.eigenvalue = None |
| 248 | self.block_eigenvalue = None |
| 249 | self.gas_boundary_ctr = 0 |
| 250 | self.dist_backend = get_accelerator().communication_backend_name() |
| 251 | self.has_moe_layers = False |
| 252 | self.num_experts = [] |
| 253 | self.gate_modules = [] |
| 254 | self.moe_layers = [] |
| 255 | self._step_applied = False |
| 256 | self._global_grad_norm = None |
| 257 | self.use_ds_comm = False # False --> Use torch.dist, True --> Use ds.comm backend. |
| 258 | self.checkpoint_engine = None |
| 259 | self.optimizer = None |
| 260 | self.basic_optimizer = None |
| 261 | self.lr_scheduler = None |
| 262 | |
| 263 | self._is_gradient_accumulation_boundary = None |
| 264 | self.scale_wrt_gas = None |
| 265 | self.losses = None |
no outgoing calls
no test coverage detected
searching dependent graphs…