Train the model
(args, train_dataset, model, tokenizer)
| 46 | |
| 47 | |
| 48 | def train(args, train_dataset, model, tokenizer): |
| 49 | """ Train the model """ |
| 50 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) |
| 51 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) |
| 52 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, |
| 53 | collate_fn=xlnet_collate_fn if args.model_type in ['xlnet'] else collate_fn) |
| 54 | |
| 55 | if args.max_steps > 0: |
| 56 | t_total = args.max_steps |
| 57 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 |
| 58 | else: |
| 59 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs |
| 60 | args.warmup_steps = int(t_total * args.warmup_proportion) |
| 61 | # Prepare optimizer and schedule (linear warmup and decay) |
| 62 | no_decay = ['bias', 'LayerNorm.weight'] |
| 63 | optimizer_grouped_parameters = [ |
| 64 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], |
| 65 | 'weight_decay': args.weight_decay}, |
| 66 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} |
| 67 | ] |
| 68 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) |
| 69 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) |
| 70 | if args.fp16: |
| 71 | try: |
| 72 | from apex import amp |
| 73 | except ImportError: |
| 74 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") |
| 75 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) |
| 76 | |
| 77 | # multi-gpu training (should be after apex fp16 initialization) |
| 78 | if args.n_gpu > 1: |
| 79 | model = torch.nn.DataParallel(model) |
| 80 | |
| 81 | # Distributed training (should be after apex fp16 initialization) |
| 82 | if args.local_rank != -1: |
| 83 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], |
| 84 | output_device=args.local_rank, |
| 85 | find_unused_parameters=True) |
| 86 | |
| 87 | # Train! |
| 88 | logger.info("***** Running training *****") |
| 89 | logger.info(" Num examples = %d", len(train_dataset)) |
| 90 | logger.info(" Num Epochs = %d", args.num_train_epochs) |
| 91 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) |
| 92 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", |
| 93 | args.train_batch_size * args.gradient_accumulation_steps * ( |
| 94 | torch.distributed.get_world_size() if args.local_rank != -1 else 1)) |
| 95 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) |
| 96 | logger.info(" Total optimization steps = %d", t_total) |
| 97 | |
| 98 | global_step = 0 |
| 99 | tr_loss, logging_loss = 0.0, 0.0 |
| 100 | model.zero_grad() |
| 101 | seed_everything(args.seed) # Added here for reproductibility (even between python 2 and 3) |
| 102 | for _ in range(int(args.num_train_epochs)): |
| 103 | pbar = ProgressBar(n_total=len(train_dataloader), desc='Training') |
| 104 | for step, batch in enumerate(train_dataloader): |
| 105 | model.train() |
no test coverage detected