(
self, graph: BaseGraph, device: str = 'cuda')
| 27 | """ |
| 28 | |
| 29 | def __init__( |
| 30 | self, graph: BaseGraph, device: str = 'cuda') -> None: |
| 31 | |
| 32 | self._epoch = 0 |
| 33 | self._step = 0 |
| 34 | self._best_metric = 0 |
| 35 | self._loss_fn = torch.nn.CrossEntropyLoss().to(device) |
| 36 | self._executor = TorchExecutor(graph, device=device) |
| 37 | self._training_graph = TrainableGraph(graph) |
| 38 | self.graph = graph |
| 39 | |
| 40 | for tensor in self._training_graph.parameters(): |
| 41 | tensor.requires_grad = True |
| 42 | |
| 43 | self._optimizer = torch.optim.RAdam( |
| 44 | params=self._training_graph.parameters(), lr=3e-5) |
| 45 | self._lr_scheduler = None |
| 46 | |
| 47 | def epoch(self, dataloader: Iterable) -> float: |
| 48 | """Do one epoch Training with given dataloader. |
nothing calls this directly
no test coverage detected