MCPcopy
hub / github.com/appvision-ai/fast-bert / _train_batch

Method _train_batch

fast_bert/learner_cls.py:895–938  ·  view source on GitHub ↗
(self, train_iter)

Source from the content-addressed store, hash-verified

893 self.plot()
894
895 def _train_batch(self, train_iter):
896 self.model.train()
897 total_loss = None # for late initialization
898
899 self.optimizer.zero_grad()
900 for i in range(self.grad_accumulation_steps):
901 batch = next(train_iter)
902
903 batch = tuple(t.to(self.device) for t in batch)
904 inputs = {
905 "input_ids": batch[0],
906 "attention_mask": batch[1],
907 "labels": batch[3],
908 }
909
910 if self.model_type in ["bert", "xlnet"]:
911 inputs["token_type_ids"] = batch[2]
912
913 if self.is_fp16:
914 with autocast():
915 outputs = self.model(**inputs)
916 else:
917 outputs = self.model(**inputs)
918
919 loss = outputs[0]
920
921 if self.n_gpu > 1:
922 loss = loss.mean() # mean() to average on multi-gpu parallel training
923
924 loss /= self.grad_accumulation_steps
925
926 if self.is_fp16:
927 self.scaler.scale(loss).backward()
928 else:
929 loss.backward()
930
931 if total_loss is None:
932 total_loss = loss
933 else:
934 total_loss += loss
935
936 self.optimizer.step()
937
938 return total_loss.item()
939
940 def _validate(self, val_iter):
941 # Set model to evaluation mode and disable gradient computation

Callers 1

lr_findMethod · 0.95

Calls 2

zero_gradMethod · 0.80
stepMethod · 0.45

Tested by

no test coverage detected