MCPcopy
hub / github.com/Vchitect/Latte / _sync

Function _sync

tools/torch_utils/training_stats.py:234–266  ·  view source on GitHub ↗

r"""Synchronize the global cumulative counters across devices and processes. Called internally by `Collector.update()`.

(names)

Source from the content-addressed store, hash-verified

232#----------------------------------------------------------------------------
233
234def _sync(names):
235 r"""Synchronize the global cumulative counters across devices and
236 processes. Called internally by `Collector.update()`.
237 """
238 if len(names) == 0:
239 return []
240 global _sync_called
241 _sync_called = True
242
243 # Collect deltas within current rank.
244 deltas = []
245 device = _sync_device if _sync_device is not None else torch.device('cpu')
246 for name in names:
247 delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
248 for counter in _counters[name].values():
249 delta.add_(counter.to(device))
250 counter.copy_(torch.zeros_like(counter))
251 deltas.append(delta)
252 deltas = torch.stack(deltas)
253
254 # Sum deltas across ranks.
255 if _sync_device is not None:
256 torch.distributed.all_reduce(deltas)
257
258 # Update cumulative values.
259 deltas = deltas.cpu()
260 for idx, name in enumerate(names):
261 if name not in _cumulative:
262 _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
263 _cumulative[name].add_(deltas[idx])
264
265 # Return name-value pairs.
266 return [(name, _cumulative[name]) for name in names]
267
268#----------------------------------------------------------------------------

Callers 1

updateMethod · 0.85

Calls 1

appendMethod · 0.80

Tested by

no test coverage detected