MCPcopy
hub / github.com/hpcaitech/ColossalAI / bench

Function bench

examples/tutorial/auto_parallel/bench_utils.py:14–55  ·  view source on GitHub ↗

Benchmarking a given graph module Args: gm (torch.fx.GraphModule): The graph module to benchmark. criterion (torch.nn.Module): Loss function. data_gen (Callable): Data generator. num_steps (int, optional): Number of test steps. Defaults to 5. Returns:

(
    gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callable, num_steps: int = 5
)

Source from the content-addressed store, hash-verified

12
13
14def bench(
15 gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callable, num_steps: int = 5
16) -> Tuple[int, int]:
17 """Benchmarking a given graph module
18 Args:
19 gm (torch.fx.GraphModule): The graph module to benchmark.
20 criterion (torch.nn.Module): Loss function.
21 data_gen (Callable): Data generator.
22 num_steps (int, optional): Number of test steps. Defaults to 5.
23 Returns:
24 Tuple[int, int]: peak memory in MB and step time in MS.
25 """
26 gm.train()
27 gm.cuda()
28 step_time = float("inf")
29 torch.cuda.synchronize()
30 torch.cuda.empty_cache()
31 torch.cuda.reset_peak_memory_stats()
32 cached = torch.cuda.max_memory_allocated(device="cuda")
33 try:
34 for _ in range(num_steps):
35 args, label = data_gen()
36 output, loss = None, None
37
38 torch.cuda.synchronize(device="cuda")
39 start = time.time()
40 output = gm(*args)
41 loss = criterion(output, label)
42 loss.backward()
43 torch.cuda.synchronize(device="cuda")
44 step_time = min(step_time, time.time() - start)
45
46 for child in gm.children():
47 for param in child.parameters():
48 param.grad = None
49 del args, label, output, loss
50 except:
51 del args, label, output, loss
52 gm.to("cpu")
53 torch.cuda.empty_cache()
54 peak_mem = (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2
55 return peak_mem, step_time * 1.0e3
56
57
58def bench_rotor(

Callers 2

_benchmarkFunction · 0.90
bench_rotorFunction · 0.85

Calls 11

data_genFunction · 0.50
criterionFunction · 0.50
trainMethod · 0.45
cudaMethod · 0.45
synchronizeMethod · 0.45
empty_cacheMethod · 0.45
max_memory_allocatedMethod · 0.45
backwardMethod · 0.45
parametersMethod · 0.45
toMethod · 0.45

Tested by 1

_benchmarkFunction · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…