MCPcopy
hub / github.com/hpcaitech/ColossalAI / ColoProxy

Class ColoProxy

colossalai/fx/proxy.py:11–73  ·  view source on GitHub ↗

ColoProxy is a proxy class which uses meta tensor to handle data-dependent control flow. The original torch.fx proxy cannot be used to infer the condition statement, with this proxy, torch.fx can still run even with if statements. Example:: proxy = tracer.create_proxy(...)

Source from the content-addressed store, hash-verified

9
10
11class ColoProxy(Proxy):
12 """
13 ColoProxy is a proxy class which uses meta tensor to handle data-dependent control flow. The original torch.fx proxy
14 cannot be used to infer the condition statement, with this proxy, torch.fx can still run even with if statements.
15
16 Example::
17
18 proxy = tracer.create_proxy(...)
19 proxy.meta_data = torch.empty(4, 2, device='meta')
20 print(len(proxy)) # expect output 4
21
22 """
23
24 def __init__(self, *args, **kwargs):
25 super().__init__(*args, **kwargs)
26 self.node._meta_data = None
27
28 @property
29 def meta_data(self):
30 return self.node._meta_data
31
32 @meta_data.setter
33 def meta_data(self, data: Any):
34 self.node._meta_data = data
35
36 @property
37 def has_meta_data(self):
38 return self._meta_data is not None
39
40 def _assert_meta_data_is_tensor(self):
41 assert (
42 torch.is_tensor(self._meta_data) and self._meta_data.is_meta
43 ), f"Meta data is not a meta tensor for {self.node.name}"
44
45 def _assert_has_meta_data(self):
46 assert self._meta_data is not None, f"Meta data is not set for {self.node.name}"
47
48 def __len__(self):
49 self._assert_has_meta_data()
50 return len(self.meta_data)
51
52 def __int__(self):
53 self._assert_has_meta_data()
54 return int(self.meta_data)
55
56 def __float__(self):
57 self._assert_has_meta_data()
58 return float(self.meta_data)
59
60 def __bool__(self):
61 self._assert_has_meta_data()
62 return self.meta_data
63
64 def __getattr__(self, k):
65 return ColoAttribute(self, k)
66
67 def __contains__(self, key):
68 if self.node.op == "placeholder":

Callers 3

test_coloproxyFunction · 0.90
nodeMethod · 0.70
__call__Method · 0.70

Calls

no outgoing calls

Tested by 1

test_coloproxyFunction · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…