MCPcopy Index your code
hub / github.com/zai-org/CodeGeeX / __init__

Method __init__

codegeex/megatron/model/distributed.py:96–160  ·  view source on GitHub ↗
(
            self, module, accumulate_allreduce_grads_in_fp32, use_contiguous_buffers
    )

Source from the content-addressed store, hash-verified

94 """
95
96 def __init__(
97 self, module, accumulate_allreduce_grads_in_fp32, use_contiguous_buffers
98 ):
99
100 super(DistributedDataParallel, self).__init__(module)
101
102 self.accumulate_allreduce_grads_in_fp32 = accumulate_allreduce_grads_in_fp32
103 self.use_contiguous_buffers = use_contiguous_buffers
104 # If we are using fp32-accumulate-allreduce explicitly
105 # this means we need main grads in a continuous buffer.
106 if self.accumulate_allreduce_grads_in_fp32:
107 assert self.use_contiguous_buffers
108
109 # ===================================
110 # Rest of this part applies only to
111 # the case we use continuous buffers.
112 # ===================================
113 self._grad_buffers = None
114 if self.use_contiguous_buffers:
115 self._grad_buffers = {}
116
117 # Simple function to define buffer type.
118 def _get_buffer_type(param):
119 return (
120 torch.float
121 if self.accumulate_allreduce_grads_in_fp32
122 else param.dtype
123 )
124
125 # First calculate total number of elements per type.
126 type_num_elements = {}
127 for param in self.module.parameters():
128 if param.requires_grad:
129 dtype = _get_buffer_type(param)
130 type_num_elements[dtype] = (
131 type_num_elements.get(dtype, 0) + param.data.nelement()
132 )
133
134 # Allocate the buffer.
135 for dtype, num_elements in type_num_elements.items():
136 self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype)
137
138 # Assume the back prop order is reverse the params order,
139 # store the start index for the gradients.
140 for param in self.module.parameters():
141 if param.requires_grad:
142 dtype = _get_buffer_type(param)
143 type_num_elements[dtype] -= param.data.nelement()
144 param.main_grad = self._grad_buffers[dtype].get(
145 param.data.shape, type_num_elements[dtype]
146 )
147
148 # Backward hook.
149 # Accumalation function for the gradients. We need
150 # to store them so they don't go out of scope.
151 self.grad_accs = []
152 # Loop over all the parameters in the model.
153 for param in self.module.parameters():

Callers

nothing calls this directly

Calls 4

_make_param_hookMethod · 0.95
MemoryBufferClass · 0.70
__init__Method · 0.45
getMethod · 0.45

Tested by

no test coverage detected