(
self, module, accumulate_allreduce_grads_in_fp32, use_contiguous_buffers
)
| 94 | """ |
| 95 | |
| 96 | def __init__( |
| 97 | self, module, accumulate_allreduce_grads_in_fp32, use_contiguous_buffers |
| 98 | ): |
| 99 | |
| 100 | super(DistributedDataParallel, self).__init__(module) |
| 101 | |
| 102 | self.accumulate_allreduce_grads_in_fp32 = accumulate_allreduce_grads_in_fp32 |
| 103 | self.use_contiguous_buffers = use_contiguous_buffers |
| 104 | # If we are using fp32-accumulate-allreduce explicitly |
| 105 | # this means we need main grads in a continuous buffer. |
| 106 | if self.accumulate_allreduce_grads_in_fp32: |
| 107 | assert self.use_contiguous_buffers |
| 108 | |
| 109 | # =================================== |
| 110 | # Rest of this part applies only to |
| 111 | # the case we use continuous buffers. |
| 112 | # =================================== |
| 113 | self._grad_buffers = None |
| 114 | if self.use_contiguous_buffers: |
| 115 | self._grad_buffers = {} |
| 116 | |
| 117 | # Simple function to define buffer type. |
| 118 | def _get_buffer_type(param): |
| 119 | return ( |
| 120 | torch.float |
| 121 | if self.accumulate_allreduce_grads_in_fp32 |
| 122 | else param.dtype |
| 123 | ) |
| 124 | |
| 125 | # First calculate total number of elements per type. |
| 126 | type_num_elements = {} |
| 127 | for param in self.module.parameters(): |
| 128 | if param.requires_grad: |
| 129 | dtype = _get_buffer_type(param) |
| 130 | type_num_elements[dtype] = ( |
| 131 | type_num_elements.get(dtype, 0) + param.data.nelement() |
| 132 | ) |
| 133 | |
| 134 | # Allocate the buffer. |
| 135 | for dtype, num_elements in type_num_elements.items(): |
| 136 | self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype) |
| 137 | |
| 138 | # Assume the back prop order is reverse the params order, |
| 139 | # store the start index for the gradients. |
| 140 | for param in self.module.parameters(): |
| 141 | if param.requires_grad: |
| 142 | dtype = _get_buffer_type(param) |
| 143 | type_num_elements[dtype] -= param.data.nelement() |
| 144 | param.main_grad = self._grad_buffers[dtype].get( |
| 145 | param.data.shape, type_num_elements[dtype] |
| 146 | ) |
| 147 | |
| 148 | # Backward hook. |
| 149 | # Accumalation function for the gradients. We need |
| 150 | # to store them so they don't go out of scope. |
| 151 | self.grad_accs = [] |
| 152 | # Loop over all the parameters in the model. |
| 153 | for param in self.module.parameters(): |
nothing calls this directly
no test coverage detected