MCPcopy
hub / github.com/hustvl/Vim / main

Function main

vim/main.py:222–544  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

220
221
222def main(args):
223 utils.init_distributed_mode(args)
224
225 print(args)
226
227 if args.distillation_type != 'none' and args.finetune and not args.eval:
228 raise NotImplementedError("Finetuning with distillation not yet supported")
229
230 device = torch.device(args.device)
231
232 # fix the seed for reproducibility
233 seed = args.seed + utils.get_rank()
234 torch.manual_seed(seed)
235 np.random.seed(seed)
236 # random.seed(seed)
237
238 cudnn.benchmark = True
239
240 # log about
241 run_name = args.output_dir.split("/")[-1]
242 if args.local_rank == 0 and args.gpu == 0:
243 mlflow.start_run(run_name=run_name)
244 for key, value in vars(args).items():
245 mlflow.log_param(key, value)
246
247 dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
248 dataset_val, _ = build_dataset(is_train=False, args=args)
249
250 if args.distributed:
251 num_tasks = utils.get_world_size()
252 global_rank = utils.get_rank()
253 if args.repeated_aug:
254 sampler_train = RASampler(
255 dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
256 )
257 else:
258 sampler_train = torch.utils.data.DistributedSampler(
259 dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
260 )
261 if args.dist_eval:
262 if len(dataset_val) % num_tasks != 0:
263 print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
264 'This will slightly alter validation results as extra duplicate entries are added to achieve '
265 'equal num of samples per-process.')
266 sampler_val = torch.utils.data.DistributedSampler(
267 dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
268 else:
269 sampler_val = torch.utils.data.SequentialSampler(dataset_val)
270 else:
271 sampler_train = torch.utils.data.RandomSampler(dataset_train)
272 sampler_val = torch.utils.data.SequentialSampler(dataset_val)
273
274 data_loader_train = torch.utils.data.DataLoader(
275 dataset_train, sampler=sampler_train,
276 batch_size=args.batch_size,
277 num_workers=args.num_workers,
278 pin_memory=args.pin_mem,
279 drop_last=True,

Callers 1

main.pyFile · 0.70

Calls 15

build_datasetFunction · 0.90
RASamplerClass · 0.90
new_data_aug_generatorFunction · 0.90
DistillationLossClass · 0.90
evaluateFunction · 0.90
train_one_epochFunction · 0.90
printFunction · 0.85
get_state_dictFunction · 0.85
set_epochMethod · 0.80
deviceMethod · 0.45
loadMethod · 0.45
state_dictMethod · 0.45

Tested by

no test coverage detected