Wrapper on the loss function to be used with Fast Gradient and Ghost Clipping. It computes the per-sample loss, and wraps it in DPTensorFastGradientClipping.
| 202 | |
| 203 | |
| 204 | class DPLossFastGradientClipping: |
| 205 | """ |
| 206 | Wrapper on the loss function to be used with Fast Gradient and Ghost Clipping. It computes the per-sample loss, and wraps it in DPTensorFastGradientClipping. |
| 207 | """ |
| 208 | |
| 209 | def __init__( |
| 210 | self, |
| 211 | module: Union[ |
| 212 | GradSampleModuleFastGradientClipping, GradSampleHooksFastGradientClipping |
| 213 | ], |
| 214 | optimizer: DPOptimizerFastGradientClipping, |
| 215 | criterion, |
| 216 | loss_reduction: str = "mean", |
| 217 | ): |
| 218 | assert loss_reduction in [ |
| 219 | "mean", |
| 220 | "sum", |
| 221 | ], "loss_reduction should be either 'mean' or 'sum'" |
| 222 | |
| 223 | # if the criterion is missing reduction attribute, use module's reduction attribute' |
| 224 | if not hasattr(criterion, "reduction"): |
| 225 | setattr(criterion, "reduction", module.loss_reduction) |
| 226 | |
| 227 | assert ( |
| 228 | loss_reduction |
| 229 | == criterion.reduction |
| 230 | == module.loss_reduction |
| 231 | == optimizer.loss_reduction |
| 232 | ), "loss_reduction should be the same across GradSampleModule, Optimizer, Criterion, and loss_reduction" |
| 233 | |
| 234 | self.optimizer = optimizer |
| 235 | self.module = module |
| 236 | self.criterion = criterion |
| 237 | self.loss_reduction = loss_reduction |
| 238 | |
| 239 | def __call__(self, *args, shape=None, **kwargs) -> DPTensorFastGradientClipping: |
| 240 | """ |
| 241 | Redefining the forward function to compute per-sample loss and wrap it in DPTensorFastGradientClipping |
| 242 | """ |
| 243 | old_reduction = self.criterion.reduction |
| 244 | self.criterion.reduction = "none" |
| 245 | loss_per_sample = self.criterion(*args, **kwargs) |
| 246 | self.criterion.reduction = old_reduction |
| 247 | |
| 248 | if shape is not None and loss_per_sample.shape[0] == shape[0] * shape[1]: |
| 249 | # Note that the privacy unit for generative NLP tasks is per sequence. |
| 250 | # The shape variable is the shape of the logits before flattening i.e., [batch_size, sequence_lenght, vocab_size]. |
| 251 | # This variable is necessary for ghost clipping to work with generative NLP tasks. |
| 252 | loss_per_sample = loss_per_sample.view(shape[0], shape[1]) # BxT |
| 253 | if self.loss_reduction == "mean": |
| 254 | # When the criterion has ignore_index, positions matching it |
| 255 | # produce zero loss but should also be excluded from the |
| 256 | # denominator (matching PyTorch's CrossEntropyLoss behavior). |
| 257 | ignore_index = getattr(self.criterion, "ignore_index", None) |
| 258 | if ignore_index is not None and len(args) >= 2: |
| 259 | targets = args[1] |
| 260 | if "target" in kwargs: |
| 261 | targets = kwargs["target"] |
no outgoing calls