MCPcopy
hub / github.com/geekcomputers/Python / main

Function main

ML/train.py:48–193  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

46 )
47
48def main():
49 parser = argparse.ArgumentParser(description='NeuralForge Training')
50 parser.add_argument('--config', type=str, default=None, help='Path to config file')
51 parser.add_argument('--model', type=str, default='simple', choices=['simple', 'resnet18', 'efficientnet', 'vit'])
52 parser.add_argument('--batch-size', type=int, default=32)
53 parser.add_argument('--epochs', type=int, default=50)
54 parser.add_argument('--lr', type=float, default=0.001)
55 parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
56 parser.add_argument('--num-samples', type=int, default=5000, help='Number of synthetic samples')
57 parser.add_argument('--num-classes', type=int, default=10)
58 parser.add_argument('--seed', type=int, default=42)
59 parser.add_argument('--dataset', type=str, default='synthetic',
60 choices=['synthetic', 'cifar10', 'cifar100', 'mnist', 'fashion_mnist', 'stl10',
61 'tiny_imagenet', 'imagenet', 'food101', 'caltech256', 'oxford_pets'],
62 help='Dataset to use')
63 args = parser.parse_args()
64
65 if args.config:
66 config = Config.load(args.config)
67 else:
68 config = Config()
69 config.batch_size = args.batch_size
70 config.epochs = args.epochs
71 config.learning_rate = args.lr
72 config.device = args.device
73 config.num_classes = args.num_classes
74 config.seed = args.seed
75
76 set_seed(config.seed)
77
78 logger = Logger(config.log_dir, "training")
79 logger.info("=" * 80)
80 logger.info("NeuralForge Training Framework")
81 logger.info("=" * 80)
82 logger.info(f"Configuration:\n{config}")
83
84 if args.dataset == 'synthetic':
85 logger.info("Creating synthetic dataset...")
86 train_dataset = SyntheticDataset(
87 num_samples=args.num_samples,
88 num_classes=config.num_classes,
89 image_size=config.image_size,
90 channels=3
91 )
92
93 val_dataset = SyntheticDataset(
94 num_samples=args.num_samples // 5,
95 num_classes=config.num_classes,
96 image_size=config.image_size,
97 channels=3
98 )
99 else:
100 logger.info(f"Downloading and loading {args.dataset} dataset...")
101 config.num_classes = get_num_classes(args.dataset)
102
103 train_dataset = get_dataset(args.dataset, root=config.data_path, train=True, download=True)
104 val_dataset = get_dataset(args.dataset, root=config.data_path, train=False, download=True)
105

Callers 1

train.pyFile · 0.70

Calls 15

infoMethod · 0.95
build_train_loaderMethod · 0.95
build_val_loaderMethod · 0.95
log_model_summaryMethod · 0.95
trainMethod · 0.95
saveMethod · 0.95
ConfigClass · 0.90
LoggerClass · 0.90
SyntheticDatasetClass · 0.90
get_num_classesFunction · 0.90
get_datasetFunction · 0.90
DataLoaderBuilderClass · 0.90

Tested by

no test coverage detected