MCPcopy
hub / github.com/microsoft/Cream / main

Function main

EfficientViT/classification/main.py:194–442  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

192
193
194def main(args):
195 utils.init_distributed_mode(args)
196
197 if args.distillation_type != 'none' and args.finetune and not args.eval:
198 raise NotImplementedError(
199 "Finetuning with distillation not yet supported")
200
201 device = torch.device(args.device)
202
203 # fix the seed for reproducibility
204 seed = args.seed + utils.get_rank()
205 torch.manual_seed(seed)
206 np.random.seed(seed)
207 # random.seed(seed)
208
209 cudnn.benchmark = True
210
211 dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
212 dataset_val, _ = build_dataset(is_train=False, args=args)
213
214 if True: # args.distributed:
215 num_tasks = utils.get_world_size()
216 global_rank = utils.get_rank()
217 if args.repeated_aug:
218 sampler_train = RASampler(
219 dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
220 )
221 else:
222 sampler_train = torch.utils.data.DistributedSampler(
223 dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
224 )
225 if args.dist_eval:
226 if len(dataset_val) % num_tasks != 0:
227 print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
228 'This will slightly alter validation results as extra duplicate entries are added to achieve '
229 'equal num of samples per-process.')
230 sampler_val = torch.utils.data.DistributedSampler(
231 dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
232 else:
233 sampler_val = torch.utils.data.SequentialSampler(dataset_val)
234 else:
235 sampler_train = torch.utils.data.RandomSampler(dataset_train)
236 sampler_val = torch.utils.data.SequentialSampler(dataset_val)
237
238 data_loader_train = torch.utils.data.DataLoader(
239 dataset_train, sampler=sampler_train,
240 batch_size=args.batch_size,
241 num_workers=args.num_workers,
242 pin_memory=args.pin_mem,
243 drop_last=True,
244 )
245
246 if args.ThreeAugment:
247 data_loader_train.dataset.transform = new_data_aug_generator(args)
248
249 data_loader_val = torch.utils.data.DataLoader(
250 dataset_val, sampler=sampler_val,
251 batch_size=int(1.5 * args.batch_size),

Callers 1

main.pyFile · 0.70

Calls 15

build_datasetFunction · 0.90
RASamplerClass · 0.90
new_data_aug_generatorFunction · 0.90
MixupClass · 0.90
ModelEmaClass · 0.90
DistillationLossClass · 0.90
evaluateFunction · 0.90
train_one_epochFunction · 0.90
create_modelFunction · 0.85
get_state_dictFunction · 0.85
toMethod · 0.80
formatMethod · 0.80

Tested by

no test coverage detected