The main training loops. by iterating over training data (i.e. `train_iter_fct`) and running validation (i.e. iterating over `valid_iter_fct` Args: train_iter_fct(function): a function that returns the train iterator. e.g. something like
(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1)
| 103 | self.model.train() |
| 104 | |
| 105 | def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1): |
| 106 | """ |
| 107 | The main training loops. |
| 108 | by iterating over training data (i.e. `train_iter_fct`) |
| 109 | and running validation (i.e. iterating over `valid_iter_fct` |
| 110 | |
| 111 | Args: |
| 112 | train_iter_fct(function): a function that returns the train |
| 113 | iterator. e.g. something like |
| 114 | train_iter_fct = lambda: generator(*args, **kwargs) |
| 115 | valid_iter_fct(function): same as train_iter_fct, for valid data |
| 116 | train_steps(int): |
| 117 | valid_steps(int): |
| 118 | save_checkpoint_steps(int): |
| 119 | |
| 120 | Return: |
| 121 | None |
| 122 | """ |
| 123 | logger.info('Start training...') |
| 124 | |
| 125 | # step = self.optim._step + 1 |
| 126 | step = self.optimizer._step + 1 |
| 127 | true_batchs = [] |
| 128 | accum = 0 |
| 129 | normalization = 0 |
| 130 | train_iter = train_iter_fct() |
| 131 | |
| 132 | total_stats = Statistics() |
| 133 | report_stats = Statistics() |
| 134 | self._start_report_manager(start_time=total_stats.start_time) |
| 135 | |
| 136 | while step <= train_steps: |
| 137 | |
| 138 | reduce_counter = 0 |
| 139 | for i, batch in enumerate(train_iter): |
| 140 | # print(batch.src) |
| 141 | # print(len(batch)) |
| 142 | if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): |
| 143 | |
| 144 | true_batchs.append(batch) |
| 145 | normalization += batch.batch_size |
| 146 | accum += 1 |
| 147 | if accum == self.grad_accum_count: |
| 148 | reduce_counter += 1 |
| 149 | if self.n_gpu > 1: |
| 150 | normalization = sum(distributed.all_gather_list(normalization)) |
| 151 | |
| 152 | self._gradient_accumulation(true_batchs, normalization, total_stats, report_stats) |
| 153 | |
| 154 | report_stats = self._maybe_report_training(step, train_steps, self.optimizer.learning_rate, |
| 155 | report_stats) |
| 156 | true_batchs = [] |
| 157 | accum = 0 |
| 158 | normalization = 0 |
| 159 | if step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0: |
| 160 | self._save(step) |
| 161 | |
| 162 | step += 1 |
no test coverage detected