MCPcopy
hub / github.com/stas00/ml-engineering / timed_allreduce

Function timed_allreduce

network/benchmarks/all_reduce_latency_comp.py:20–48  ·  view source on GitHub ↗
(mat, repeat_times, id, start_event, end_event)

Source from the content-addressed store, hash-verified

18M = 2000
19
20def timed_allreduce(mat, repeat_times, id, start_event, end_event):
21 start_event.record()
22 for i in range(repeat_times):
23 dist.all_reduce(mat)
24 end_event.record()
25
26 torch.cuda.synchronize()
27 duration = start_event.elapsed_time(end_event) / 1000
28
29 size = M * N * 4 # 4 is fp32
30 algbw = (size / duration) * 8 # 8 is bytes to bits
31 n = dist.get_world_size()
32 # the 2*(n-1)/n busbw correction factor specific to all-reduce is explained here:
33 # https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#allreduce
34 # busbw reflects how optimally the hardware is used
35 busbw = algbw * (2*(n - 1) / n)
36
37 # gather all data on global-rank-0 and print the results from there to avoid interleaved prints
38 data = [id, duration, algbw, busbw]
39 output = [None for _ in range(dist.get_world_size())] if dist.get_rank() == 0 else None
40 dist.gather_object(data, output, dst=0)
41 if dist.get_rank() == 0:
42 for data in output:
43 id, duration, algbw, busbw = data
44 print(f"{id}:\n",
45 f"duration: {duration:.3f} sec\n",
46 f"algbw: {algbw/1e9:.3f} Gbps\n",
47 f"busbw: {busbw / 1e9:.3f} Gbps"
48 )
49
50
51

Callers 1

runFunction · 0.70

Calls 4

printFunction · 0.85
recordMethod · 0.80
elapsed_timeMethod · 0.80
synchronizeMethod · 0.45

Tested by

no test coverage detected