MCPcopy
hub / github.com/facebookresearch/mmf / train

Method train

pythia/trainers/base_trainer.py:194–248  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

192 torch.backends.cudnn.benchmark = False
193
194 def train(self):
195 self.writer.write("===== Model =====")
196 self.writer.write(self.model)
197
198 if "train" not in self.run_type:
199 self.inference()
200 return
201
202 should_break = False
203
204 if self.max_epochs is None:
205 self.max_epochs = math.inf
206 else:
207 self.max_iterations = math.inf
208
209 self.model.train()
210 self.train_timer = Timer()
211 self.snapshot_timer = Timer()
212
213 self.profile("Setup Time")
214
215 torch.autograd.set_detect_anomaly(True)
216
217 self.writer.write("Starting training...")
218 while self.current_iteration < self.max_iterations and not should_break:
219 self.current_epoch += 1
220 registry.register("current_epoch", self.current_epoch)
221
222 # Seed the sampler in case if it is distributed
223 self.task_loader.seed_sampler("train", self.current_epoch)
224
225 if self.current_epoch > self.max_epochs:
226 break
227
228 for batch in self.train_loader:
229 self.profile("Batch load time")
230 self.current_iteration += 1
231 self.writer.write(self.current_iteration, "debug")
232
233 registry.register("current_iteration", self.current_iteration)
234
235 if self.current_iteration > self.max_iterations:
236 break
237
238 self._run_scheduler()
239 report = self._forward_pass(batch)
240 self._update_meter(report, self.meter)
241 loss = self._extract_loss(report)
242 self._backward(loss)
243 should_break = self._logistics(report)
244
245 if should_break:
246 break
247
248 self.finalize()
249
250 def _run_scheduler(self):
251 if self.lr_scheduler is not None:

Callers 5

runFunction · 0.80
train.pyFile · 0.80
one_stage_run_modelFunction · 0.80
evaluateMethod · 0.80
predict_for_evalaiMethod · 0.80

Calls 13

inferenceMethod · 0.95
profileMethod · 0.95
_run_schedulerMethod · 0.95
_forward_passMethod · 0.95
_update_meterMethod · 0.95
_extract_lossMethod · 0.95
_backwardMethod · 0.95
_logisticsMethod · 0.95
finalizeMethod · 0.95
TimerClass · 0.90
writeMethod · 0.80
registerMethod · 0.80

Tested by

no test coverage detected