Trainable Graph offers a bunch of functions that provide training interfaces.
| 7 | |
| 8 | |
| 9 | class TrainableGraph(GraphCommandProcessor): |
| 10 | """ Trainable Graph offers a bunch of functions that provide training interfaces. """ |
| 11 | |
| 12 | def __init__(self, graph_or_processor: Union[BaseGraph, Callable]) -> None: |
| 13 | super().__init__(graph_or_processor) |
| 14 | |
| 15 | def parameters(self) -> List[torch.Tensor]: |
| 16 | parameters = [] |
| 17 | for var in self.graph.variables.values(): |
| 18 | if var.is_parameter and var.dtype == torch.float: |
| 19 | parameters.append(var.value) |
| 20 | return parameters |
| 21 | |
| 22 | def zero_grad(self): |
| 23 | for var in self.graph.variables.values(): |
| 24 | if var.is_parameter and var.dtype == torch.float: |
| 25 | if var.value._grad is not None: |
| 26 | var.value._grad.zero_() |
| 27 | |
| 28 | def state_dict(self) -> dict: |
| 29 | parameters = {} |
| 30 | for var in self.graph.variables.values(): |
| 31 | if var.is_parameter and var.dtype == torch.float: |
| 32 | parameters[var.name] = var.value |
| 33 | return parameters |
| 34 | |
| 35 | def _acceptable_command_types(self): return None |
| 36 | def process(self): return None |