MCPcopy Index your code
hub / github.com/geekcomputers/Python / main

Function main

ML/examples/train_custom.py:12–44  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

10from src.python.neuralforge.optim.schedulers import CosineAnnealingWarmRestarts
11
12def main():
13 config = Config()
14 config.batch_size = 64
15 config.epochs = 100
16 config.learning_rate = 0.001
17 config.num_classes = 100
18 config.model_name = "resnet18_custom"
19
20 train_dataset = SyntheticDataset(num_samples=10000, num_classes=100)
21 val_dataset = SyntheticDataset(num_samples=2000, num_classes=100)
22
23 loader_builder = DataLoaderBuilder(config)
24 train_loader = loader_builder.build_train_loader(train_dataset)
25 val_loader = loader_builder.build_val_loader(val_dataset)
26
27 model = ResNet18(num_classes=100)
28 criterion = nn.CrossEntropyLoss()
29 optimizer = AdamW(model.parameters(), lr=config.learning_rate, weight_decay=0.01)
30 scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
31
32 trainer = Trainer(
33 model=model,
34 train_loader=train_loader,
35 val_loader=val_loader,
36 optimizer=optimizer,
37 criterion=criterion,
38 config=config,
39 scheduler=scheduler
40 )
41
42 trainer.train()
43
44 print(f"Best validation loss: {trainer.best_val_loss:.4f}")
45
46if __name__ == '__main__':
47 main()

Callers 1

train_custom.pyFile · 0.70

Calls 10

build_train_loaderMethod · 0.95
build_val_loaderMethod · 0.95
trainMethod · 0.95
ConfigClass · 0.90
SyntheticDatasetClass · 0.90
DataLoaderBuilderClass · 0.90
ResNet18Function · 0.90
AdamWClass · 0.90
TrainerClass · 0.90

Tested by

no test coverage detected