MCPcopy
hub / github.com/HKUDS/AI-Researcher / main

Function main

examples/con_flowmatching/project/run_training_testing_v2.py:81–168  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

79 return metrics_history
80
81def main():
82 # Configuration
83 data_dir = os.path.join('data', 'cifar-10-batches-py')
84 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
85
86 # Create necessary directories
87 os.makedirs('experiments/logs', exist_ok=True)
88 os.makedirs('experiments/checkpoints', exist_ok=True)
89 os.makedirs('experiments/results', exist_ok=True)
90
91 # Training configurations
92 configs = {
93 'baseline': {
94 'model_type': 'simple',
95 'hidden_dims': [512, 512, 512],
96 'activation': 'relu',
97 'learning_rate': 2e-4,
98 'alpha': 0.1,
99 'epochs': 100,
100 'eval_interval': 5,
101 'batch_size': 512
102 },
103 'improved': {
104 'model_type': 'resnet',
105 'hidden_dims': [128, 256, 512, 256, 128],
106 'activation': 'relu',
107 'learning_rate': 2e-4,
108 'alpha': 0.1,
109 'epochs': 100,
110 'eval_interval': 5,
111 'batch_size': 512
112 }
113 }
114
115 # Set up data loaders
116 logger.info("Setting up data loaders...")
117 for config_name, config in configs.items():
118 train_loader, test_loader = get_data_loaders(
119 data_dir,
120 batch_size=config['batch_size'],
121 num_workers=4
122 )
123
124 logger.info(f"\nStarting experiments for {config_name} configuration")
125
126 try:
127 # Initialize model
128 if config['model_type'] == 'simple':
129 velocity_net = VelocityNetwork(
130 hidden_dims=config['hidden_dims'],
131 activation=config['activation']
132 )
133 model = CNF(velocity_net)
134 else:
135 velocity_net = ResNetVelocity(
136 hidden_dims=config['hidden_dims'],
137 activation=config['activation']
138 )

Callers 1

Calls 7

get_data_loadersFunction · 0.90
VelocityNetworkClass · 0.90
CNFClass · 0.90
ResNetVelocityClass · 0.90
ImprovedCNFClass · 0.90
infoMethod · 0.80
train_and_evaluateFunction · 0.70

Tested by

no test coverage detected