(model_name: str, memory_budget: float, solver_name: str)
| 17 | @parameterize("memory_budget", [4000]) |
| 18 | @parameterize("solver_name", ["syn", "asyn"]) |
| 19 | def solver_test(model_name: str, memory_budget: float, solver_name: str): |
| 20 | get_components_func = non_distributed_component_funcs.get_callable(model_name) |
| 21 | model_builder, data_gen = get_components_func() |
| 22 | data_args = data_gen(device="cpu") |
| 23 | wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x |
| 24 | data_args = tree_map(wrap_fn, data_args) |
| 25 | model = model_builder() |
| 26 | model.train() |
| 27 | model = model.cpu().half() |
| 28 | |
| 29 | tracer = ColoTracer() |
| 30 | assert is_compatible_with_meta() |
| 31 | wrap_fn = lambda x: x.to("meta") if isinstance(x, torch.Tensor) else x |
| 32 | meta_args = tree_map(wrap_fn, data_args) |
| 33 | graph = tracer.trace(model, meta_args=meta_args) |
| 34 | gm = GraphModule(model, graph, model.__class__.__name__) |
| 35 | |
| 36 | interp = MetaInfoProp(gm) |
| 37 | interp.propagate(*meta_args.values()) |
| 38 | |
| 39 | region_manager = RegionManager(graph, solver_name=solver_name) |
| 40 | region_manager._pre_process() |
| 41 | region_list = region_manager.region_list |
| 42 | |
| 43 | solver_cls = SolverFactory.create(solver_name) |
| 44 | memory_budget = memory_budget * 1024 * 1024 |
| 45 | solver = solver_cls(region_list, memory_budget) |
| 46 | solver._call_solver() |
| 47 | |
| 48 | assert solver.best_ts.peak_mem < memory_budget |
| 49 | |
| 50 | print("****************** execution plan *******************") |
| 51 | for region in region_list: |
| 52 | need_offload = region.need_offload |
| 53 | to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None |
| 54 | print( |
| 55 | f"| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}" |
| 56 | ) |
| 57 | for region in region_list.__reversed__(): |
| 58 | need_offload = region.need_offload |
| 59 | to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None |
| 60 | print( |
| 61 | f"| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}" |
| 62 | ) |
| 63 | |
| 64 | |
| 65 | if __name__ == "__main__": |
no test coverage detected
searching dependent graphs…