(self, root, concrete_args=None)
| 64 | |
| 65 | class ProfilerTracer(torch.fx.Tracer): |
| 66 | def trace(self, root, concrete_args=None): |
| 67 | orig_record_function_enter = torch.autograd.profiler.record_function.__enter__ |
| 68 | orig_record_function_exit = torch.autograd.profiler.record_function.__exit__ |
| 69 | |
| 70 | def fake_profiler_enter(_self): |
| 71 | nonlocal self |
| 72 | handle_proxy = self.create_proxy( |
| 73 | kind='call_function', |
| 74 | target=torch.ops.profiler._record_function_enter, |
| 75 | args=(_self.name,), |
| 76 | kwargs={}) |
| 77 | |
| 78 | assert getattr(_self, '_fx_profiler_ctx', None) is None |
| 79 | setattr(_self, '_fx_profiler_ctx', handle_proxy) |
| 80 | return handle_proxy |
| 81 | |
| 82 | def fake_profiler_exit(_self, exc_type, exc_value, traceback): |
| 83 | assert hasattr(_self, '_fx_profiler_ctx') |
| 84 | handle_proxy = _self._fx_profiler_ctx |
| 85 | torch.ops.profiler._record_function_exit(handle_proxy) |
| 86 | setattr(_self, '_fx_profiler_ctx', None) |
| 87 | |
| 88 | torch.autograd.profiler.record_function.__enter__ = fake_profiler_enter |
| 89 | torch.autograd.profiler.record_function.__exit__ = fake_profiler_exit |
| 90 | |
| 91 | try: |
| 92 | return super().trace(root, concrete_args) |
| 93 | finally: |
| 94 | torch.autograd.profiler.record_function.__enter__ = orig_record_function_enter |
| 95 | torch.autograd.profiler.record_function.__exit__ = orig_record_function_exit |
| 96 | |
| 97 | pt = ProfilerTracer() |
| 98 |
no outgoing calls
no test coverage detected