MCPcopy
hub / github.com/OpenPPL/ppq / TrainableGraph

Class TrainableGraph

ppq/IR/training.py:9–36  ·  view source on GitHub ↗

Trainable Graph offers a bunch of functions that provide training interfaces.

Source from the content-addressed store, hash-verified

7
8
9class TrainableGraph(GraphCommandProcessor):
10 """ Trainable Graph offers a bunch of functions that provide training interfaces. """
11
12 def __init__(self, graph_or_processor: Union[BaseGraph, Callable]) -> None:
13 super().__init__(graph_or_processor)
14
15 def parameters(self) -> List[torch.Tensor]:
16 parameters = []
17 for var in self.graph.variables.values():
18 if var.is_parameter and var.dtype == torch.float:
19 parameters.append(var.value)
20 return parameters
21
22 def zero_grad(self):
23 for var in self.graph.variables.values():
24 if var.is_parameter and var.dtype == torch.float:
25 if var.value._grad is not None:
26 var.value._grad.zero_()
27
28 def state_dict(self) -> dict:
29 parameters = {}
30 for var in self.graph.variables.values():
31 if var.is_parameter and var.dtype == torch.float:
32 parameters[var.name] = var.value
33 return parameters
34
35 def _acceptable_command_types(self): return None
36 def process(self): return None

Callers 1

__init__Method · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected