r"""Synchronize the global cumulative counters across devices and processes. Called internally by `Collector.update()`.
(names)
| 232 | #---------------------------------------------------------------------------- |
| 233 | |
| 234 | def _sync(names): |
| 235 | r"""Synchronize the global cumulative counters across devices and |
| 236 | processes. Called internally by `Collector.update()`. |
| 237 | """ |
| 238 | if len(names) == 0: |
| 239 | return [] |
| 240 | global _sync_called |
| 241 | _sync_called = True |
| 242 | |
| 243 | # Collect deltas within current rank. |
| 244 | deltas = [] |
| 245 | device = _sync_device if _sync_device is not None else torch.device('cpu') |
| 246 | for name in names: |
| 247 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) |
| 248 | for counter in _counters[name].values(): |
| 249 | delta.add_(counter.to(device)) |
| 250 | counter.copy_(torch.zeros_like(counter)) |
| 251 | deltas.append(delta) |
| 252 | deltas = torch.stack(deltas) |
| 253 | |
| 254 | # Sum deltas across ranks. |
| 255 | if _sync_device is not None: |
| 256 | torch.distributed.all_reduce(deltas) |
| 257 | |
| 258 | # Update cumulative values. |
| 259 | deltas = deltas.cpu() |
| 260 | for idx, name in enumerate(names): |
| 261 | if name not in _cumulative: |
| 262 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) |
| 263 | _cumulative[name].add_(deltas[idx]) |
| 264 | |
| 265 | # Return name-value pairs. |
| 266 | return [(name, _cumulative[name]) for name in names] |
| 267 | |
| 268 | #---------------------------------------------------------------------------- |