MCPcopy
hub / github.com/QData/TextAttack / train

Method train

textattack/trainer.py:598–881  ·  view source on GitHub ↗

Train the model on given training dataset.

(self)

Source from the content-addressed store, hash-verified

596 return preds.cpu(), _targets
597
598 def train(self):
599 """Train the model on given training dataset."""
600 if not self.train_dataset:
601 raise ValueError("No `train_dataset` available for training.")
602
603 textattack.shared.utils.set_seed(self.training_args.random_seed)
604 if not os.path.exists(self.training_args.output_dir):
605 os.makedirs(self.training_args.output_dir)
606
607 # Save logger writes to file
608 log_txt_path = os.path.join(self.training_args.output_dir, "train_log.txt")
609 fh = logging.FileHandler(log_txt_path)
610 fh.setLevel(logging.DEBUG)
611 logger.addHandler(fh)
612 logger.info(f"Writing logs to {log_txt_path}.")
613
614 # Save original self.training_args to file
615 args_save_path = os.path.join(
616 self.training_args.output_dir, "training_args.json"
617 )
618 with open(args_save_path, "w", encoding="utf-8") as f:
619 json.dump(self.training_args.__dict__, f)
620 logger.info(f"Wrote original training args to {args_save_path}.")
621
622 num_gpus = torch.cuda.device_count()
623 tokenizer = self.model_wrapper.tokenizer
624 model = self.model_wrapper.model
625
626 if self.training_args.parallel and num_gpus > 1:
627 # TODO: torch.nn.parallel.DistributedDataParallel
628 # Supposedly faster than DataParallel, but requires more work to setup properly.
629 model = torch.nn.DataParallel(model)
630 logger.info(f"Training on {num_gpus} GPUs via `torch.nn.DataParallel`.")
631 train_batch_size = self.training_args.per_device_train_batch_size * num_gpus
632 else:
633 train_batch_size = self.training_args.per_device_train_batch_size
634
635 if self.attack is None:
636 num_clean_epochs = self.training_args.num_epochs
637 else:
638 num_clean_epochs = self.training_args.num_clean_epochs
639
640 total_clean_training_steps = (
641 math.ceil(
642 len(self.train_dataset)
643 / (train_batch_size * self.training_args.gradient_accumulation_steps)
644 )
645 * num_clean_epochs
646 )
647
648 # calculate total_adv_training_data_length based on type of
649 # num_train_adv_examples.
650 # if num_train_adv_examples is float , num_train_adv_examples is a portion of train_dataset.
651 if isinstance(self.training_args.num_train_adv_examples, float):
652 total_adv_training_data_length = (
653 len(self.train_dataset) * self.training_args.num_train_adv_examples
654 )
655

Callers 3

runMethod · 0.95
get_gradMethod · 0.80
get_gradMethod · 0.80

Calls 14

_print_training_argsMethod · 0.95
get_train_dataloaderMethod · 0.95
training_stepMethod · 0.95
_tb_logMethod · 0.95
_wandb_logMethod · 0.95
evaluateMethod · 0.95
_write_readmeMethod · 0.95
toMethod · 0.80
loadMethod · 0.80

Tested by

no test coverage detected