MCPcopy Index your code
hub / github.com/geekcomputers/Python / train_epoch

Method train_epoch

ML/src/python/neuralforge/trainer.py:50–103  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

48 self.logger.info(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
49
50 def train_epoch(self) -> Dict[str, float]:
51 self.model.train()
52 epoch_loss = 0.0
53 correct = 0
54 total = 0
55
56 pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}/{self.config.epochs}")
57
58 for batch_idx, (inputs, targets) in enumerate(pbar):
59 inputs = inputs.to(self.device, non_blocking=True)
60 targets = targets.to(self.device, non_blocking=True)
61
62 self.optimizer.zero_grad(set_to_none=True)
63
64 if self.scaler is not None:
65 with amp.autocast('cuda'):
66 outputs = self.model(inputs)
67 loss = self.criterion(outputs, targets)
68
69 self.scaler.scale(loss).backward()
70
71 if self.config.grad_clip > 0:
72 self.scaler.unscale_(self.optimizer)
73 torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
74
75 self.scaler.step(self.optimizer)
76 self.scaler.update()
77 else:
78 outputs = self.model(inputs)
79 loss = self.criterion(outputs, targets)
80 loss.backward()
81
82 if self.config.grad_clip > 0:
83 torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
84
85 self.optimizer.step()
86
87 epoch_loss += loss.item()
88 _, predicted = outputs.max(1)
89 total += targets.size(0)
90 correct += predicted.eq(targets).sum().item()
91
92 self.global_step += 1
93
94 if batch_idx % 10 == 0:
95 pbar.set_postfix({
96 'loss': f'{loss.item():.4f}',
97 'acc': f'{100. * correct / total:.2f}%'
98 })
99
100 avg_loss = epoch_loss / len(self.train_loader)
101 accuracy = 100. * correct / total
102
103 return {'loss': avg_loss, 'accuracy': accuracy}
104
105 def validate(self) -> Dict[str, float]:
106 if self.val_loader is None:

Callers 1

trainMethod · 0.95

Calls 4

sizeMethod · 0.80
trainMethod · 0.45
stepMethod · 0.45
updateMethod · 0.45

Tested by

no test coverage detected