ColoTracer is a symbolic tracer designed to support dynamic control flow by using meta tensors for the `colossalai.fx` module. This tracer is initialized in the same way as the original torch.fx.Tracer. Usage:: class Model(nn.Module): def __init__(self):
| 38 | |
| 39 | |
| 40 | class ColoTracer(Tracer): |
| 41 | """ |
| 42 | ColoTracer is a symbolic tracer designed to support dynamic control flow by using meta tensors for the `colossalai.fx` module. |
| 43 | This tracer is initialized in the same way as the original torch.fx.Tracer. |
| 44 | |
| 45 | Usage:: |
| 46 | |
| 47 | class Model(nn.Module): |
| 48 | def __init__(self): |
| 49 | super().__init__() |
| 50 | self.linear1 = nn.Linear(10, 10) |
| 51 | self.linear2 = nn.Linear(10, 10) |
| 52 | |
| 53 | def forward(self, x, y): |
| 54 | x1 = self.linear1(x) |
| 55 | y1 = self.linear2(y) |
| 56 | |
| 57 | if x1.dim() == 2: |
| 58 | return x1 + y1 |
| 59 | else: |
| 60 | return x1 - y1 |
| 61 | |
| 62 | model = Model() |
| 63 | tracer = ColoTracer() |
| 64 | graph = tracer.trace(model, concrete_args={'y': torch.rand(4, 10)}, meta_args={'x': torch.rand(4, 10, device='meta')}) |
| 65 | """ |
| 66 | |
| 67 | def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs): |
| 68 | super().__init__(*args, **kwargs) |
| 69 | self.tracer_type = TracerType.META |
| 70 | self.proxy_cls = ColoProxy |
| 71 | |
| 72 | # whether the tracer will record the usage of torch.utils.checkpoint |
| 73 | self.trace_act_ckpt = trace_act_ckpt |
| 74 | # whether the current tracing occurs within the activation checkpoint functions |
| 75 | self.inside_torch_checkpoint_func = False |
| 76 | self.act_ckpt_region_count = 0 |
| 77 | |
| 78 | # Feature flag for proxying accesses to buffer values |
| 79 | proxy_buffer_attributes: bool = True |
| 80 | |
| 81 | _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor", "finfo"] |
| 82 | |
| 83 | def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None) -> ColoProxy: |
| 84 | """ |
| 85 | Create a proxy for different kinds of operations. |
| 86 | """ |
| 87 | |
| 88 | if self.tracer_type == TracerType.DEFAULT: |
| 89 | # since meta_args is not given |
| 90 | # we just fall back to the original torch.fx.Tracer |
| 91 | proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) |
| 92 | return proxy |
| 93 | |
| 94 | # if graph is traced for auto parallelism module, some extra node will be added during |
| 95 | # graph construction to deal with the compatibility between bias addition and all reduce. |
| 96 | |
| 97 | # if no extra manipulation is applied, we just pass the origin arguments to create_proxy function |
no outgoing calls
searching dependent graphs…