MCPcopy
hub / github.com/OpenPPL/ppq / __init__

Method __init__

ppq/samples/QAT/trainer.py:29–45  ·  view source on GitHub ↗
(
        self, graph: BaseGraph, device: str = 'cuda')

Source from the content-addressed store, hash-verified

27 """
28
29 def __init__(
30 self, graph: BaseGraph, device: str = 'cuda') -> None:
31
32 self._epoch = 0
33 self._step = 0
34 self._best_metric = 0
35 self._loss_fn = torch.nn.CrossEntropyLoss().to(device)
36 self._executor = TorchExecutor(graph, device=device)
37 self._training_graph = TrainableGraph(graph)
38 self.graph = graph
39
40 for tensor in self._training_graph.parameters():
41 tensor.requires_grad = True
42
43 self._optimizer = torch.optim.RAdam(
44 params=self._training_graph.parameters(), lr=3e-5)
45 self._lr_scheduler = None
46
47 def epoch(self, dataloader: Iterable) -> float:
48 """Do one epoch Training with given dataloader.

Callers

nothing calls this directly

Calls 4

TorchExecutorClass · 0.90
TrainableGraphClass · 0.90
toMethod · 0.80
parametersMethod · 0.45

Tested by

no test coverage detected