MCPcopy
hub / github.com/deepspeedai/DeepSpeed / DeepSpeedEngine

Class DeepSpeedEngine

deepspeed/runtime/engine.py:208–5006  ·  view source on GitHub ↗

r"""DeepSpeed engine for training.

Source from the content-addressed store, hash-verified

206
207
208class DeepSpeedEngine(Module):
209 r"""DeepSpeed engine for training."""
210
211 def __init__(self,
212 args,
213 model,
214 optimizer=None,
215 model_parameters=None,
216 training_data=None,
217 lr_scheduler=None,
218 mpu=None,
219 dist_init_required=None,
220 collate_fn=None,
221 config=None,
222 config_class=None,
223 mesh_device=None,
224 dont_change_device=False):
225 super(DeepSpeedEngine, self).__init__()
226 self.dont_change_device = dont_change_device
227 self.client_optimizer = optimizer
228 self.client_lr_scheduler = lr_scheduler
229 self.training_data = training_data
230 self.collate_fn = collate_fn
231 self.mpu = mpu
232 self.all_to_all_group = None
233 self.data_parallel_group = None
234 self.global_steps = 0
235 self.global_samples = 0
236 self.micro_steps = 0
237 self.skipped_steps = 0
238 self.gradient_average = True
239 self.warn_unscaled_loss = True
240 self.config = config
241 self._config = config_class
242 self.loaded_checkpoint_mp_world_size = None
243 self.loaded_checkpoint_dp_world_size = None
244 self.enable_backward_allreduce = True
245 self.inside_no_sync_ctxt = False
246 self.progressive_layer_drop = None
247 self.eigenvalue = None
248 self.block_eigenvalue = None
249 self.gas_boundary_ctr = 0
250 self.dist_backend = get_accelerator().communication_backend_name()
251 self.has_moe_layers = False
252 self.num_experts = []
253 self.gate_modules = []
254 self.moe_layers = []
255 self._step_applied = False
256 self._global_grad_norm = None
257 self.use_ds_comm = False # False --> Use torch.dist, True --> Use ds.comm backend.
258 self.checkpoint_engine = None
259 self.optimizer = None
260 self.basic_optimizer = None
261 self.lr_scheduler = None
262
263 self._is_gradient_accumulation_boundary = None
264 self.scale_wrt_gas = None
265 self.losses = None

Callers 1

initializeFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…