MCPcopy
hub / github.com/CLUEbenchmark/CLUE / train

Function train

baselines/models_pytorch/classifier_pytorch/run_classifier.py:48–157  ·  view source on GitHub ↗

Train the model

(args, train_dataset, model, tokenizer)

Source from the content-addressed store, hash-verified

46
47
48def train(args, train_dataset, model, tokenizer):
49 """ Train the model """
50 args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
51 train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
52 train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size,
53 collate_fn=xlnet_collate_fn if args.model_type in ['xlnet'] else collate_fn)
54
55 if args.max_steps > 0:
56 t_total = args.max_steps
57 args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
58 else:
59 t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
60 args.warmup_steps = int(t_total * args.warmup_proportion)
61 # Prepare optimizer and schedule (linear warmup and decay)
62 no_decay = ['bias', 'LayerNorm.weight']
63 optimizer_grouped_parameters = [
64 {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
65 'weight_decay': args.weight_decay},
66 {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
67 ]
68 optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
69 scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
70 if args.fp16:
71 try:
72 from apex import amp
73 except ImportError:
74 raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
75 model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
76
77 # multi-gpu training (should be after apex fp16 initialization)
78 if args.n_gpu > 1:
79 model = torch.nn.DataParallel(model)
80
81 # Distributed training (should be after apex fp16 initialization)
82 if args.local_rank != -1:
83 model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
84 output_device=args.local_rank,
85 find_unused_parameters=True)
86
87 # Train!
88 logger.info("***** Running training *****")
89 logger.info(" Num examples = %d", len(train_dataset))
90 logger.info(" Num Epochs = %d", args.num_train_epochs)
91 logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
92 logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
93 args.train_batch_size * args.gradient_accumulation_steps * (
94 torch.distributed.get_world_size() if args.local_rank != -1 else 1))
95 logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
96 logger.info(" Total optimization steps = %d", t_total)
97
98 global_step = 0
99 tr_loss, logging_loss = 0.0, 0.0
100 model.zero_grad()
101 seed_everything(args.seed) # Added here for reproductibility (even between python 2 and 3)
102 for _ in range(int(args.num_train_epochs)):
103 pbar = ProgressBar(n_total=len(train_dataloader), desc='Training')
104 for step, batch in enumerate(train_dataloader):
105 model.train()

Callers 1

mainFunction · 0.85

Calls 10

stepMethod · 0.95
AdamWClass · 0.90
seed_everythingFunction · 0.90
ProgressBarClass · 0.90
trainMethod · 0.80
joinMethod · 0.80
evaluateFunction · 0.70
save_pretrainedMethod · 0.45
save_vocabularyMethod · 0.45

Tested by

no test coverage detected