Benchmarking a given graph module Args: gm (torch.fx.GraphModule): The graph module to benchmark. criterion (torch.nn.Module): Loss function. data_gen (Callable): Data generator. num_steps (int, optional): Number of test steps. Defaults to 5. Returns:
(
gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callable, num_steps: int = 5
)
| 12 | |
| 13 | |
| 14 | def bench( |
| 15 | gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callable, num_steps: int = 5 |
| 16 | ) -> Tuple[int, int]: |
| 17 | """Benchmarking a given graph module |
| 18 | Args: |
| 19 | gm (torch.fx.GraphModule): The graph module to benchmark. |
| 20 | criterion (torch.nn.Module): Loss function. |
| 21 | data_gen (Callable): Data generator. |
| 22 | num_steps (int, optional): Number of test steps. Defaults to 5. |
| 23 | Returns: |
| 24 | Tuple[int, int]: peak memory in MB and step time in MS. |
| 25 | """ |
| 26 | gm.train() |
| 27 | gm.cuda() |
| 28 | step_time = float("inf") |
| 29 | torch.cuda.synchronize() |
| 30 | torch.cuda.empty_cache() |
| 31 | torch.cuda.reset_peak_memory_stats() |
| 32 | cached = torch.cuda.max_memory_allocated(device="cuda") |
| 33 | try: |
| 34 | for _ in range(num_steps): |
| 35 | args, label = data_gen() |
| 36 | output, loss = None, None |
| 37 | |
| 38 | torch.cuda.synchronize(device="cuda") |
| 39 | start = time.time() |
| 40 | output = gm(*args) |
| 41 | loss = criterion(output, label) |
| 42 | loss.backward() |
| 43 | torch.cuda.synchronize(device="cuda") |
| 44 | step_time = min(step_time, time.time() - start) |
| 45 | |
| 46 | for child in gm.children(): |
| 47 | for param in child.parameters(): |
| 48 | param.grad = None |
| 49 | del args, label, output, loss |
| 50 | except: |
| 51 | del args, label, output, loss |
| 52 | gm.to("cpu") |
| 53 | torch.cuda.empty_cache() |
| 54 | peak_mem = (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2 |
| 55 | return peak_mem, step_time * 1.0e3 |
| 56 | |
| 57 | |
| 58 | def bench_rotor( |
searching dependent graphs…