MCPcopy
hub / github.com/InternLM/InternLM / PipelineSharedModuleGradientHandler

Class PipelineSharedModuleGradientHandler

internlm/core/gradient_handler.py:35–76  ·  view source on GitHub ↗

A helper class to handle all-reduce operations in sub parallel groups. A all-reduce collective communication will be operated in :func:`handle_gradient` among all sub pipeline parallel groups. For better performance, it bucketizes the gradients of all parameters that are the same typ

Source from the content-addressed store, hash-verified

33
34
35class PipelineSharedModuleGradientHandler(BaseGradientHandler):
36 """A helper class to handle all-reduce operations in sub parallel groups.
37 A all-reduce collective communication will be operated in
38 :func:`handle_gradient` among all sub pipeline parallel groups.
39 For better performance, it bucketizes the gradients of all parameters that are
40 the same type to improve the efficiency of communication.
41
42 Args:
43 model (Module): Model where the gradients accumulate.
44 optimizer (Optimizer): Optimizer for updating the parameters.
45 """
46
47 def handle_gradient(self):
48 """A method running a all-reduce operation in sub pipeline parallel groups."""
49 if gpc.pipeline_parallel_size > 1:
50 # bucketize and all-reduce
51 buckets = defaultdict(lambda: defaultdict(list))
52 # Pack the buckets.
53 for param in self._model.parameters():
54 group = getattr(param, "pipeline_shared_module_pg", None)
55 if (
56 param.requires_grad
57 and group is not None
58 and (
59 (hasattr(param, "colo_attr") and not param.colo_attr.saved_grad.is_null())
60 or param.grad is not None
61 )
62 ):
63 tp = param.data.type()
64 buckets[group][tp].append(param)
65
66 # For each bucket, all-reduce and copy all-reduced grads.
67 for group, group_buckets in buckets.items():
68 for tp, bucket in group_buckets.items():
69 grads = [
70 param.colo_attr.grad_payload if hasattr(param, "colo_attr") else param.grad.data
71 for param in bucket
72 ]
73 coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device())
74 dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
75 for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
76 buf.copy_(synced)

Callers 1

initialize_trainerFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected