MCPcopy
hub / github.com/openai/guided-diffusion / MixedPrecisionTrainer

Class MixedPrecisionTrainer

guided_diffusion/fp16_util.py:148–233  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

146
147
148class MixedPrecisionTrainer:
149 def __init__(
150 self,
151 *,
152 model,
153 use_fp16=False,
154 fp16_scale_growth=1e-3,
155 initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
156 ):
157 self.model = model
158 self.use_fp16 = use_fp16
159 self.fp16_scale_growth = fp16_scale_growth
160
161 self.model_params = list(self.model.parameters())
162 self.master_params = self.model_params
163 self.param_groups_and_shapes = None
164 self.lg_loss_scale = initial_lg_loss_scale
165
166 if self.use_fp16:
167 self.param_groups_and_shapes = get_param_groups_and_shapes(
168 self.model.named_parameters()
169 )
170 self.master_params = make_master_params(self.param_groups_and_shapes)
171 self.model.convert_to_fp16()
172
173 def zero_grad(self):
174 zero_grad(self.model_params)
175
176 def backward(self, loss: th.Tensor):
177 if self.use_fp16:
178 loss_scale = 2 ** self.lg_loss_scale
179 (loss * loss_scale).backward()
180 else:
181 loss.backward()
182
183 def optimize(self, opt: th.optim.Optimizer):
184 if self.use_fp16:
185 return self._optimize_fp16(opt)
186 else:
187 return self._optimize_normal(opt)
188
189 def _optimize_fp16(self, opt: th.optim.Optimizer):
190 logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
191 model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
192 grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
193 if check_overflow(grad_norm):
194 self.lg_loss_scale -= 1
195 logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
196 zero_master_grads(self.master_params)
197 return False
198
199 logger.logkv_mean("grad_norm", grad_norm)
200 logger.logkv_mean("param_norm", param_norm)
201
202 for p in self.master_params:
203 p.grad.mul_(1.0 / (2 ** self.lg_loss_scale))
204 opt.step()
205 zero_master_grads(self.master_params)

Callers 2

mainFunction · 0.90
__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected