MCPcopy
hub / github.com/QData/TextAttack / _save_model_checkpoint

Method _save_model_checkpoint

textattack/trainer.py:271–301  ·  view source on GitHub ↗
(
        self, model, tokenizer, step=None, epoch=None, best=False, last=False
    )

Source from the content-addressed store, hash-verified

269 logger.info(f" Total optimization steps = {total_training_steps}")
270
271 def _save_model_checkpoint(
272 self, model, tokenizer, step=None, epoch=None, best=False, last=False
273 ):
274 # Save model checkpoint
275 if step:
276 dir_name = f"checkpoint-step-{step}"
277 if epoch:
278 dir_name = f"checkpoint-epoch-{epoch}"
279 if best:
280 dir_name = "best_model"
281 if last:
282 dir_name = "last_model"
283
284 output_dir = os.path.join(self.training_args.output_dir, dir_name)
285 if not os.path.exists(output_dir):
286 os.makedirs(output_dir)
287
288 if isinstance(model, torch.nn.DataParallel):
289 model = model.module
290
291 if isinstance(model, (WordCNNForClassification, LSTMForClassification)):
292 model.save_pretrained(output_dir)
293 elif isinstance(model, transformers.PreTrainedModel):
294 model.save_pretrained(output_dir)
295 tokenizer.save_pretrained(output_dir)
296 else:
297 state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
298 torch.save(
299 state_dict,
300 os.path.join(output_dir, "pytorch_model.bin"),
301 )
302
303 def _tb_log(self, log, step):
304 if not hasattr(self, "_tb_writer"):

Callers 1

trainMethod · 0.95

Calls 2

saveMethod · 0.80
save_pretrainedMethod · 0.45

Tested by

no test coverage detected