All-reduce a python scalar across ranks (sum or mean). No-op when single-process.
(value: float, ctx: DDPContext, average: bool = True)
| 80 | |
| 81 | |
| 82 | def reduce_scalar(value: float, ctx: DDPContext, average: bool = True) -> float: |
| 83 | """All-reduce a python scalar across ranks (sum or mean). No-op when single-process.""" |
| 84 | if not ctx.enabled: |
| 85 | return value |
| 86 | t = torch.tensor([value], device=ctx.device, dtype=torch.float32) |
| 87 | dist.all_reduce(t, op=dist.ReduceOp.SUM) |
| 88 | if average: |
| 89 | t /= ctx.world_size |
| 90 | return t.item() |
| 91 | |
| 92 | |
| 93 | def barrier(ctx: DDPContext) -> None: |