Args: root (nn.Module): module from which the copied module hierarchy is built train_graph (fx.Graph): the graph that should be used in train mode eval_graph (fx.Graph): the graph that should be used in eval mode
(
self, root: torch.nn.Module, train_graph: fx.Graph, eval_graph: fx.Graph, class_name: str = "GraphModule"
)
| 275 | """ |
| 276 | |
| 277 | def __init__( |
| 278 | self, root: torch.nn.Module, train_graph: fx.Graph, eval_graph: fx.Graph, class_name: str = "GraphModule" |
| 279 | ): |
| 280 | """ |
| 281 | Args: |
| 282 | root (nn.Module): module from which the copied module hierarchy is |
| 283 | built |
| 284 | train_graph (fx.Graph): the graph that should be used in train mode |
| 285 | eval_graph (fx.Graph): the graph that should be used in eval mode |
| 286 | """ |
| 287 | super(fx.GraphModule, self).__init__() |
| 288 | |
| 289 | self.__class__.__name__ = class_name |
| 290 | |
| 291 | self.train_graph = train_graph |
| 292 | self.eval_graph = eval_graph |
| 293 | |
| 294 | # Copy all get_attr and call_module ops (indicated by BOTH train and |
| 295 | # eval graphs) |
| 296 | for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)): |
| 297 | if node.op in ["get_attr", "call_module"]: |
| 298 | if not isinstance(node.target, str): |
| 299 | raise TypeError(f"node.target should be of type str instead of {type(node.target)}") |
| 300 | _copy_attr(root, self, node.target) |
| 301 | |
| 302 | # train mode by default |
| 303 | self.train() |
| 304 | self.graph = train_graph |
| 305 | |
| 306 | # (borrowed from fx.GraphModule): |
| 307 | # Store the Tracer class responsible for creating a Graph separately as part of the |
| 308 | # GraphModule state, except when the Tracer is defined in a local namespace. |
| 309 | # Locally defined Tracers are not pickleable. This is needed because torch.package will |
| 310 | # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer |
| 311 | # to re-create the Graph during deserialization. |
| 312 | if self.eval_graph._tracer_cls != self.train_graph._tracer_cls: |
| 313 | raise TypeError( |
| 314 | f"Train mode and eval mode should use the same tracer class. Instead got {self.eval_graph._tracer_cls} for eval vs {self.train_graph._tracer_cls} for train" |
| 315 | ) |
| 316 | self._tracer_cls = None |
| 317 | if self.graph._tracer_cls and "<locals>" not in self.graph._tracer_cls.__qualname__: |
| 318 | self._tracer_cls = self.graph._tracer_cls |
| 319 | |
| 320 | def train(self, mode=True): |
| 321 | """ |