(
self,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer],
dl: DataLoader,
max_steps: Optional[int] = None,
)
| 156 | return model, optimizer, poisson_dl, privacy_engine |
| 157 | |
| 158 | def _train_steps( |
| 159 | self, |
| 160 | model: nn.Module, |
| 161 | optimizer: Optional[torch.optim.Optimizer], |
| 162 | dl: DataLoader, |
| 163 | max_steps: Optional[int] = None, |
| 164 | ): |
| 165 | steps = 0 |
| 166 | epochs = 1 if max_steps is None else math.ceil(max_steps / len(dl)) |
| 167 | |
| 168 | for _ in range(epochs): |
| 169 | for x, y in dl: |
| 170 | if optimizer: |
| 171 | optimizer.zero_grad() |
| 172 | logits = model(x) |
| 173 | loss = self.criterion(logits, y) |
| 174 | loss.backward() |
| 175 | if optimizer: |
| 176 | optimizer.step() |
| 177 | |
| 178 | steps += 1 |
| 179 | if max_steps and steps >= max_steps: |
| 180 | break |
| 181 | |
| 182 | def _train_steps_with_closure( |
| 183 | self, |
no test coverage detected