MCPcopy
hub / github.com/pytorch/vision / __init__

Method __init__

torchvision/models/feature_extraction.py:277–318  ·  view source on GitHub ↗

Args: root (nn.Module): module from which the copied module hierarchy is built train_graph (fx.Graph): the graph that should be used in train mode eval_graph (fx.Graph): the graph that should be used in eval mode

(
        self, root: torch.nn.Module, train_graph: fx.Graph, eval_graph: fx.Graph, class_name: str = "GraphModule"
    )

Source from the content-addressed store, hash-verified

275 """
276
277 def __init__(
278 self, root: torch.nn.Module, train_graph: fx.Graph, eval_graph: fx.Graph, class_name: str = "GraphModule"
279 ):
280 """
281 Args:
282 root (nn.Module): module from which the copied module hierarchy is
283 built
284 train_graph (fx.Graph): the graph that should be used in train mode
285 eval_graph (fx.Graph): the graph that should be used in eval mode
286 """
287 super(fx.GraphModule, self).__init__()
288
289 self.__class__.__name__ = class_name
290
291 self.train_graph = train_graph
292 self.eval_graph = eval_graph
293
294 # Copy all get_attr and call_module ops (indicated by BOTH train and
295 # eval graphs)
296 for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)):
297 if node.op in ["get_attr", "call_module"]:
298 if not isinstance(node.target, str):
299 raise TypeError(f"node.target should be of type str instead of {type(node.target)}")
300 _copy_attr(root, self, node.target)
301
302 # train mode by default
303 self.train()
304 self.graph = train_graph
305
306 # (borrowed from fx.GraphModule):
307 # Store the Tracer class responsible for creating a Graph separately as part of the
308 # GraphModule state, except when the Tracer is defined in a local namespace.
309 # Locally defined Tracers are not pickleable. This is needed because torch.package will
310 # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
311 # to re-create the Graph during deserialization.
312 if self.eval_graph._tracer_cls != self.train_graph._tracer_cls:
313 raise TypeError(
314 f"Train mode and eval mode should use the same tracer class. Instead got {self.eval_graph._tracer_cls} for eval vs {self.train_graph._tracer_cls} for train"
315 )
316 self._tracer_cls = None
317 if self.graph._tracer_cls and "<locals>" not in self.graph._tracer_cls.__qualname__:
318 self._tracer_cls = self.graph._tracer_cls
319
320 def train(self, mode=True):
321 """

Callers

nothing calls this directly

Calls 2

trainMethod · 0.95
__init__Method · 0.45

Tested by

no test coverage detected