MCPcopy Index your code
hub / github.com/FoundationVision/ByteTrack / main

Function main

tutorials/transtrack/main_track.py:150–367  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

148
149
150def main(args):
151 utils.init_distributed_mode(args)
152 print("git:\n {}\n".format(utils.get_sha()))
153
154 if args.frozen_weights is not None:
155 assert args.masks, "Frozen training is meant for segmentation only"
156 print(args)
157
158 device = torch.device(args.device)
159
160 # fix the seed for reproducibility
161 seed = args.seed + utils.get_rank()
162 torch.manual_seed(seed)
163 np.random.seed(seed)
164 random.seed(seed)
165
166 if args.det_val:
167 assert args.eval, 'only support eval mode of detector for track'
168 model, criterion, postprocessors = build_model(args)
169 elif args.eval:
170 model, criterion, postprocessors = build_tracktest_model(args)
171 else:
172 model, criterion, postprocessors = build_tracktrain_model(args)
173
174 model.to(device)
175
176 model_without_ddp = model
177 n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
178 print('number of params:', n_parameters)
179
180 dataset_train = build_dataset(image_set=args.track_train_split, args=args)
181 dataset_val = build_dataset(image_set=args.track_eval_split, args=args)
182
183 if args.distributed:
184 if args.cache_mode:
185 sampler_train = samplers.NodeDistributedSampler(dataset_train)
186 sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False)
187 else:
188 sampler_train = samplers.DistributedSampler(dataset_train)
189 sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False)
190 else:
191 sampler_train = torch.utils.data.RandomSampler(dataset_train)
192 sampler_val = torch.utils.data.SequentialSampler(dataset_val)
193
194 batch_sampler_train = torch.utils.data.BatchSampler(
195 sampler_train, args.batch_size, drop_last=True)
196
197 data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
198 collate_fn=utils.collate_fn, num_workers=args.num_workers,
199 pin_memory=True)
200 data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
201 drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers,
202 pin_memory=True)
203
204 # lr_backbone_names = ["backbone.0", "backbone.neck", "input_proj", "transformer.encoder"]
205 def match_name_keywords(n, name_keywords):
206 out = False
207 for b in name_keywords:

Callers 1

main_track.pyFile · 0.70

Calls 10

DataLoaderClass · 0.90
BYTETrackerClass · 0.90
evaluate_trackFunction · 0.90
train_one_epochFunction · 0.90
evaluateFunction · 0.90
match_name_keywordsFunction · 0.85
save_trackFunction · 0.85
set_epochMethod · 0.80
writeMethod · 0.80
stepMethod · 0.45

Tested by

no test coverage detected