MCPcopy
hub / github.com/QData/TextAttack / get_optimizer_and_scheduler

Method get_optimizer_and_scheduler

textattack/trainer.py:327–386  ·  view source on GitHub ↗

Returns optimizer and scheduler to use for training. If you are overriding this method and do not want to use a scheduler, simply return :obj:`None` for scheduler. Args: model (:obj:`torch.nn.Module`): Model to be trained. Pass its parameters to o

(self, model, num_training_steps)

Source from the content-addressed store, hash-verified

325 wandb.log(log, step=step)
326
327 def get_optimizer_and_scheduler(self, model, num_training_steps):
328 """Returns optimizer and scheduler to use for training. If you are
329 overriding this method and do not want to use a scheduler, simply
330 return :obj:`None` for scheduler.
331
332 Args:
333 model (:obj:`torch.nn.Module`):
334 Model to be trained. Pass its parameters to optimizer for training.
335 num_training_steps (:obj:`int`):
336 Number of total training steps.
337 Returns:
338 Tuple of optimizer and scheduler :obj:`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]`
339 """
340 if isinstance(model, torch.nn.DataParallel):
341 model = model.module
342
343 if isinstance(model, transformers.PreTrainedModel):
344 # Reference https://huggingface.co/transformers/training.html
345 param_optimizer = list(model.named_parameters())
346 no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
347 optimizer_grouped_parameters = [
348 {
349 "params": [
350 p
351 for n, p in param_optimizer
352 if not any(nd in n for nd in no_decay)
353 ],
354 "weight_decay": self.training_args.weight_decay,
355 },
356 {
357 "params": [
358 p for n, p in param_optimizer if any(nd in n for nd in no_decay)
359 ],
360 "weight_decay": 0.0,
361 },
362 ]
363
364 optimizer = torch.optim.AdamW(
365 optimizer_grouped_parameters, lr=self.training_args.learning_rate
366 )
367 if isinstance(self.training_args.num_warmup_steps, float):
368 num_warmup_steps = math.ceil(
369 self.training_args.num_warmup_steps * num_training_steps
370 )
371 else:
372 num_warmup_steps = self.training_args.num_warmup_steps
373
374 scheduler = transformers.optimization.get_linear_schedule_with_warmup(
375 optimizer,
376 num_warmup_steps=num_warmup_steps,
377 num_training_steps=num_training_steps,
378 )
379 else:
380 optimizer = torch.optim.Adam(
381 filter(lambda x: x.requires_grad, model.parameters()),
382 lr=self.training_args.learning_rate,
383 )
384 scheduler = None

Callers 1

trainMethod · 0.95

Calls

no outgoing calls

Tested by

no test coverage detected