| 893 | self.plot() |
| 894 | |
| 895 | def _train_batch(self, train_iter): |
| 896 | self.model.train() |
| 897 | total_loss = None # for late initialization |
| 898 | |
| 899 | self.optimizer.zero_grad() |
| 900 | for i in range(self.grad_accumulation_steps): |
| 901 | batch = next(train_iter) |
| 902 | |
| 903 | batch = tuple(t.to(self.device) for t in batch) |
| 904 | inputs = { |
| 905 | "input_ids": batch[0], |
| 906 | "attention_mask": batch[1], |
| 907 | "labels": batch[3], |
| 908 | } |
| 909 | |
| 910 | if self.model_type in ["bert", "xlnet"]: |
| 911 | inputs["token_type_ids"] = batch[2] |
| 912 | |
| 913 | if self.is_fp16: |
| 914 | with autocast(): |
| 915 | outputs = self.model(**inputs) |
| 916 | else: |
| 917 | outputs = self.model(**inputs) |
| 918 | |
| 919 | loss = outputs[0] |
| 920 | |
| 921 | if self.n_gpu > 1: |
| 922 | loss = loss.mean() # mean() to average on multi-gpu parallel training |
| 923 | |
| 924 | loss /= self.grad_accumulation_steps |
| 925 | |
| 926 | if self.is_fp16: |
| 927 | self.scaler.scale(loss).backward() |
| 928 | else: |
| 929 | loss.backward() |
| 930 | |
| 931 | if total_loss is None: |
| 932 | total_loss = loss |
| 933 | else: |
| 934 | total_loss += loss |
| 935 | |
| 936 | self.optimizer.step() |
| 937 | |
| 938 | return total_loss.item() |
| 939 | |
| 940 | def _validate(self, val_iter): |
| 941 | # Set model to evaluation mode and disable gradient computation |