ColoProxy is a proxy class which uses meta tensor to handle data-dependent control flow. The original torch.fx proxy cannot be used to infer the condition statement, with this proxy, torch.fx can still run even with if statements. Example:: proxy = tracer.create_proxy(...)
| 9 | |
| 10 | |
| 11 | class ColoProxy(Proxy): |
| 12 | """ |
| 13 | ColoProxy is a proxy class which uses meta tensor to handle data-dependent control flow. The original torch.fx proxy |
| 14 | cannot be used to infer the condition statement, with this proxy, torch.fx can still run even with if statements. |
| 15 | |
| 16 | Example:: |
| 17 | |
| 18 | proxy = tracer.create_proxy(...) |
| 19 | proxy.meta_data = torch.empty(4, 2, device='meta') |
| 20 | print(len(proxy)) # expect output 4 |
| 21 | |
| 22 | """ |
| 23 | |
| 24 | def __init__(self, *args, **kwargs): |
| 25 | super().__init__(*args, **kwargs) |
| 26 | self.node._meta_data = None |
| 27 | |
| 28 | @property |
| 29 | def meta_data(self): |
| 30 | return self.node._meta_data |
| 31 | |
| 32 | @meta_data.setter |
| 33 | def meta_data(self, data: Any): |
| 34 | self.node._meta_data = data |
| 35 | |
| 36 | @property |
| 37 | def has_meta_data(self): |
| 38 | return self._meta_data is not None |
| 39 | |
| 40 | def _assert_meta_data_is_tensor(self): |
| 41 | assert ( |
| 42 | torch.is_tensor(self._meta_data) and self._meta_data.is_meta |
| 43 | ), f"Meta data is not a meta tensor for {self.node.name}" |
| 44 | |
| 45 | def _assert_has_meta_data(self): |
| 46 | assert self._meta_data is not None, f"Meta data is not set for {self.node.name}" |
| 47 | |
| 48 | def __len__(self): |
| 49 | self._assert_has_meta_data() |
| 50 | return len(self.meta_data) |
| 51 | |
| 52 | def __int__(self): |
| 53 | self._assert_has_meta_data() |
| 54 | return int(self.meta_data) |
| 55 | |
| 56 | def __float__(self): |
| 57 | self._assert_has_meta_data() |
| 58 | return float(self.meta_data) |
| 59 | |
| 60 | def __bool__(self): |
| 61 | self._assert_has_meta_data() |
| 62 | return self.meta_data |
| 63 | |
| 64 | def __getattr__(self, k): |
| 65 | return ColoAttribute(self, k) |
| 66 | |
| 67 | def __contains__(self, key): |
| 68 | if self.node.op == "placeholder": |
no outgoing calls
searching dependent graphs…