Train the model on given training dataset.
(self)
| 596 | return preds.cpu(), _targets |
| 597 | |
| 598 | def train(self): |
| 599 | """Train the model on given training dataset.""" |
| 600 | if not self.train_dataset: |
| 601 | raise ValueError("No `train_dataset` available for training.") |
| 602 | |
| 603 | textattack.shared.utils.set_seed(self.training_args.random_seed) |
| 604 | if not os.path.exists(self.training_args.output_dir): |
| 605 | os.makedirs(self.training_args.output_dir) |
| 606 | |
| 607 | # Save logger writes to file |
| 608 | log_txt_path = os.path.join(self.training_args.output_dir, "train_log.txt") |
| 609 | fh = logging.FileHandler(log_txt_path) |
| 610 | fh.setLevel(logging.DEBUG) |
| 611 | logger.addHandler(fh) |
| 612 | logger.info(f"Writing logs to {log_txt_path}.") |
| 613 | |
| 614 | # Save original self.training_args to file |
| 615 | args_save_path = os.path.join( |
| 616 | self.training_args.output_dir, "training_args.json" |
| 617 | ) |
| 618 | with open(args_save_path, "w", encoding="utf-8") as f: |
| 619 | json.dump(self.training_args.__dict__, f) |
| 620 | logger.info(f"Wrote original training args to {args_save_path}.") |
| 621 | |
| 622 | num_gpus = torch.cuda.device_count() |
| 623 | tokenizer = self.model_wrapper.tokenizer |
| 624 | model = self.model_wrapper.model |
| 625 | |
| 626 | if self.training_args.parallel and num_gpus > 1: |
| 627 | # TODO: torch.nn.parallel.DistributedDataParallel |
| 628 | # Supposedly faster than DataParallel, but requires more work to setup properly. |
| 629 | model = torch.nn.DataParallel(model) |
| 630 | logger.info(f"Training on {num_gpus} GPUs via `torch.nn.DataParallel`.") |
| 631 | train_batch_size = self.training_args.per_device_train_batch_size * num_gpus |
| 632 | else: |
| 633 | train_batch_size = self.training_args.per_device_train_batch_size |
| 634 | |
| 635 | if self.attack is None: |
| 636 | num_clean_epochs = self.training_args.num_epochs |
| 637 | else: |
| 638 | num_clean_epochs = self.training_args.num_clean_epochs |
| 639 | |
| 640 | total_clean_training_steps = ( |
| 641 | math.ceil( |
| 642 | len(self.train_dataset) |
| 643 | / (train_batch_size * self.training_args.gradient_accumulation_steps) |
| 644 | ) |
| 645 | * num_clean_epochs |
| 646 | ) |
| 647 | |
| 648 | # calculate total_adv_training_data_length based on type of |
| 649 | # num_train_adv_examples. |
| 650 | # if num_train_adv_examples is float , num_train_adv_examples is a portion of train_dataset. |
| 651 | if isinstance(self.training_args.num_train_adv_examples, float): |
| 652 | total_adv_training_data_length = ( |
| 653 | len(self.train_dataset) * self.training_args.num_train_adv_examples |
| 654 | ) |
| 655 |
no test coverage detected