Args: module: GradSampleModule used for training, optimizer: DPOptimizer used for training, criterion: Loss function used for training, loss_reduction: "mean" or "sum", indicates if the loss reduction (for aggregating the gradients) P
(
self,
*,
module: GradSampleModule,
optimizer: DPOptimizer,
criterion,
loss_reduction: str = "mean",
**kwargs,
)
| 224 | ) |
| 225 | |
| 226 | def _prepare_criterion( |
| 227 | self, |
| 228 | *, |
| 229 | module: GradSampleModule, |
| 230 | optimizer: DPOptimizer, |
| 231 | criterion, |
| 232 | loss_reduction: str = "mean", |
| 233 | **kwargs, |
| 234 | ) -> DPLossFastGradientClipping: |
| 235 | """ |
| 236 | Args: |
| 237 | module: GradSampleModule used for training, |
| 238 | optimizer: DPOptimizer used for training, |
| 239 | criterion: Loss function used for training, |
| 240 | loss_reduction: "mean" or "sum", indicates if the loss reduction (for aggregating the gradients) |
| 241 | |
| 242 | Prepare the DP loss class, which packages the two backward passes for fast gradient clipping. |
| 243 | """ |
| 244 | return DPLossFastGradientClipping(module, optimizer, criterion, loss_reduction) |
| 245 | |
| 246 | def is_compatible( |
| 247 | self, |
no test coverage detected