(tensor)
| 174 | |
| 175 | |
| 176 | def reduce_tensor(tensor): |
| 177 | rt = tensor.clone() |
| 178 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) |
| 179 | rt /= dist.get_world_size() |
| 180 | return rt |
| 181 | |
| 182 | |
| 183 | def ampscaler_get_grad_norm(parameters, norm_type: float = 2.0) -> torch.Tensor: |