Backward step.
(optimizer, model, lm_loss, nsp_loss, args)
| 228 | |
| 229 | |
| 230 | def backward_step(optimizer, model, lm_loss, nsp_loss, args): |
| 231 | """Backward step.""" |
| 232 | |
| 233 | # Total loss. |
| 234 | loss = lm_loss + nsp_loss |
| 235 | |
| 236 | # Backward pass. |
| 237 | optimizer.zero_grad() |
| 238 | if args.fp16: |
| 239 | optimizer.backward(loss, update_master_grads=False) |
| 240 | else: |
| 241 | loss.backward() |
| 242 | |
| 243 | # Reduce across processes. |
| 244 | lm_loss_reduced = lm_loss |
| 245 | nsp_loss_reduced = nsp_loss |
| 246 | |
| 247 | reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1))) |
| 248 | torch.distributed.all_reduce(reduced_losses.data) |
| 249 | reduced_losses.data = reduced_losses.data / args.world_size |
| 250 | if not USE_TORCH_DDP: |
| 251 | model.allreduce_params(reduce_after=False, |
| 252 | fp32_allreduce=args.fp32_allreduce) |
| 253 | lm_loss_reduced = reduced_losses[0] |
| 254 | nsp_loss_reduced = reduced_losses[1] |
| 255 | |
| 256 | # Update master gradients. |
| 257 | if args.fp16: |
| 258 | optimizer.update_master_grads() |
| 259 | |
| 260 | # Clipping gradients helps prevent the exploding gradient. |
| 261 | if args.clip_grad > 0: |
| 262 | if not args.fp16: |
| 263 | mpu.clip_grad_norm(model.parameters(), args.clip_grad) |
| 264 | else: |
| 265 | optimizer.clip_master_grads(args.clip_grad) |
| 266 | |
| 267 | return lm_loss_reduced, nsp_loss_reduced |
| 268 | |
| 269 | |
| 270 | def train_step(data_iterator, model, optimizer, lr_scheduler, |
no test coverage detected