MCPcopy
hub / github.com/hpcaitech/ColossalAI / ColoTracer

Class ColoTracer

colossalai/fx/tracer/tracer.py:40–482  ·  view source on GitHub ↗

ColoTracer is a symbolic tracer designed to support dynamic control flow by using meta tensors for the `colossalai.fx` module. This tracer is initialized in the same way as the original torch.fx.Tracer. Usage:: class Model(nn.Module): def __init__(self):

Source from the content-addressed store, hash-verified

38
39
40class ColoTracer(Tracer):
41 """
42 ColoTracer is a symbolic tracer designed to support dynamic control flow by using meta tensors for the `colossalai.fx` module.
43 This tracer is initialized in the same way as the original torch.fx.Tracer.
44
45 Usage::
46
47 class Model(nn.Module):
48 def __init__(self):
49 super().__init__()
50 self.linear1 = nn.Linear(10, 10)
51 self.linear2 = nn.Linear(10, 10)
52
53 def forward(self, x, y):
54 x1 = self.linear1(x)
55 y1 = self.linear2(y)
56
57 if x1.dim() == 2:
58 return x1 + y1
59 else:
60 return x1 - y1
61
62 model = Model()
63 tracer = ColoTracer()
64 graph = tracer.trace(model, concrete_args={'y': torch.rand(4, 10)}, meta_args={'x': torch.rand(4, 10, device='meta')})
65 """
66
67 def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs):
68 super().__init__(*args, **kwargs)
69 self.tracer_type = TracerType.META
70 self.proxy_cls = ColoProxy
71
72 # whether the tracer will record the usage of torch.utils.checkpoint
73 self.trace_act_ckpt = trace_act_ckpt
74 # whether the current tracing occurs within the activation checkpoint functions
75 self.inside_torch_checkpoint_func = False
76 self.act_ckpt_region_count = 0
77
78 # Feature flag for proxying accesses to buffer values
79 proxy_buffer_attributes: bool = True
80
81 _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor", "finfo"]
82
83 def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None) -> ColoProxy:
84 """
85 Create a proxy for different kinds of operations.
86 """
87
88 if self.tracer_type == TracerType.DEFAULT:
89 # since meta_args is not given
90 # we just fall back to the original torch.fx.Tracer
91 proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
92 return proxy
93
94 # if graph is traced for auto parallelism module, some extra node will be added during
95 # graph construction to deal with the compatibility between bias addition and all reduce.
96
97 # if no extra manipulation is applied, we just pass the origin arguments to create_proxy function

Callers 15

initialize_modelFunction · 0.90
memory_optimizeFunction · 0.90
test_coloproxyFunction · 0.90
test_graph_manipulationFunction · 0.90
test_torchvision_modelsFunction · 0.90
split_model_and_get_DAGFunction · 0.90
test_gpt_meta_info_propFunction · 0.90
test_linear_moduleFunction · 0.90
test_conv_moduleFunction · 0.90

Calls

no outgoing calls

Tested by 15

test_coloproxyFunction · 0.72
test_graph_manipulationFunction · 0.72
test_torchvision_modelsFunction · 0.72
test_gpt_meta_info_propFunction · 0.72
test_linear_moduleFunction · 0.72
test_conv_moduleFunction · 0.72
_run_act_ckpt_codegenFunction · 0.72
_run_act_ckpt_codegenFunction · 0.72
_run_offload_codegenFunction · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…