MCPcopy
hub / github.com/deepspeedai/DeepSpeed / DeepSpeedHybridEngine

Class DeepSpeedHybridEngine

deepspeed/runtime/hybrid_engine.py:30–445  ·  view source on GitHub ↗

r"""DeepSpeed engine for training and inference.

Source from the content-addressed store, hash-verified

28
29
30class DeepSpeedHybridEngine(DeepSpeedEngine):
31 r"""DeepSpeed engine for training and inference."""
32 inference_mp_group = None
33
34 def __init__(self, args, model, **kwargs):
35
36 super().__init__(args, model, **kwargs)
37
38 # synch seed between all GPUs
39 _rng_state = get_accelerator().get_rng_state().to(get_accelerator().current_device_name())
40 dist.broadcast(_rng_state, 0)
41 get_accelerator().set_rng_state(_rng_state.cpu())
42
43 self.Z3_enabled = (self._config.zero_config.stage == 3)
44 self.gather_all_layers = self._config.hybrid_engine.pin_parameters
45
46 # inference containers / fwds
47 self._inference_containers = []
48 self._orig_modules = []
49 self._orig_fwds = []
50 self.create_inference_module()
51
52 # Performance stats
53 self._t_start = None
54 self._total_latency = 0
55 self._iters = 0
56 self._training_start_time = None
57 self._generate_latency = 0
58 self._training_latency = 0
59 self._total_batch_size = None
60 self._gather_latency = 0
61
62 self.is_lora_fused = False
63 self.workspace = WorkspaceOp()
64
65 def convert_to_linear_transposed(self, model):
66
67 def _replace_linear_layer(r_module, parent_type=None, prev_type=None):
68 for name, child in r_module.named_children():
69 if child.__class__ in [torch.nn.Linear] and \
70 (parent_type is torch.nn.ModuleList or prev_type is torch.nn.ModuleList):
71 setattr(r_module, name, TLinear(child, name))
72 else:
73 _replace_linear_layer(child, type(r_module), prev_type=parent_type)
74 return r_module
75
76 _replace_linear_layer(model)
77
78 def new_inference_container(self, orig_layer, policy_cls, layer_id):
79 policy = policy_cls(orig_layer, inference=True)
80
81 if self._config.float16_config.enabled:
82 inference_dtype = torch.float16
83 elif self._config.bfloat16_config.enabled:
84 inference_dtype = torch.bfloat16
85 else:
86 inference_dtype = torch.float32
87

Callers 1

initializeFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…