MCPcopy
hub / github.com/meta-pytorch/opacus / DPLossFastGradientClipping

Class DPLossFastGradientClipping

opacus/utils/fast_gradient_clipping_utils.py:204–276  ·  view source on GitHub ↗

Wrapper on the loss function to be used with Fast Gradient and Ghost Clipping. It computes the per-sample loss, and wraps it in DPTensorFastGradientClipping.

Source from the content-addressed store, hash-verified

202
203
204class DPLossFastGradientClipping:
205 """
206 Wrapper on the loss function to be used with Fast Gradient and Ghost Clipping. It computes the per-sample loss, and wraps it in DPTensorFastGradientClipping.
207 """
208
209 def __init__(
210 self,
211 module: Union[
212 GradSampleModuleFastGradientClipping, GradSampleHooksFastGradientClipping
213 ],
214 optimizer: DPOptimizerFastGradientClipping,
215 criterion,
216 loss_reduction: str = "mean",
217 ):
218 assert loss_reduction in [
219 "mean",
220 "sum",
221 ], "loss_reduction should be either 'mean' or 'sum'"
222
223 # if the criterion is missing reduction attribute, use module's reduction attribute'
224 if not hasattr(criterion, "reduction"):
225 setattr(criterion, "reduction", module.loss_reduction)
226
227 assert (
228 loss_reduction
229 == criterion.reduction
230 == module.loss_reduction
231 == optimizer.loss_reduction
232 ), "loss_reduction should be the same across GradSampleModule, Optimizer, Criterion, and loss_reduction"
233
234 self.optimizer = optimizer
235 self.module = module
236 self.criterion = criterion
237 self.loss_reduction = loss_reduction
238
239 def __call__(self, *args, shape=None, **kwargs) -> DPTensorFastGradientClipping:
240 """
241 Redefining the forward function to compute per-sample loss and wrap it in DPTensorFastGradientClipping
242 """
243 old_reduction = self.criterion.reduction
244 self.criterion.reduction = "none"
245 loss_per_sample = self.criterion(*args, **kwargs)
246 self.criterion.reduction = old_reduction
247
248 if shape is not None and loss_per_sample.shape[0] == shape[0] * shape[1]:
249 # Note that the privacy unit for generative NLP tasks is per sequence.
250 # The shape variable is the shape of the logits before flattening i.e., [batch_size, sequence_lenght, vocab_size].
251 # This variable is necessary for ghost clipping to work with generative NLP tasks.
252 loss_per_sample = loss_per_sample.view(shape[0], shape[1]) # BxT
253 if self.loss_reduction == "mean":
254 # When the criterion has ignore_index, positions matching it
255 # produce zero loss but should also be excluded from the
256 # denominator (matching PyTorch's CrossEntropyLoss behavior).
257 ignore_index = getattr(self.criterion, "ignore_index", None)
258 if ignore_index is not None and len(args) >= 2:
259 targets = args[1]
260 if "target" in kwargs:
261 targets = kwargs["target"]

Calls

no outgoing calls