MCPcopy
hub / github.com/fundamentalvision/Deformable-DETR / main

Function main

main.py:129–318  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

127
128
129def main(args):
130 utils.init_distributed_mode(args)
131 print("git:\n {}\n".format(utils.get_sha()))
132
133 if args.frozen_weights is not None:
134 assert args.masks, "Frozen training is meant for segmentation only"
135 print(args)
136
137 device = torch.device(args.device)
138
139 # fix the seed for reproducibility
140 seed = args.seed + utils.get_rank()
141 torch.manual_seed(seed)
142 np.random.seed(seed)
143 random.seed(seed)
144
145 model, criterion, postprocessors = build_model(args)
146 model.to(device)
147
148 model_without_ddp = model
149 n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
150 print('number of params:', n_parameters)
151
152 dataset_train = build_dataset(image_set='train', args=args)
153 dataset_val = build_dataset(image_set='val', args=args)
154
155 if args.distributed:
156 if args.cache_mode:
157 sampler_train = samplers.NodeDistributedSampler(dataset_train)
158 sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False)
159 else:
160 sampler_train = samplers.DistributedSampler(dataset_train)
161 sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False)
162 else:
163 sampler_train = torch.utils.data.RandomSampler(dataset_train)
164 sampler_val = torch.utils.data.SequentialSampler(dataset_val)
165
166 batch_sampler_train = torch.utils.data.BatchSampler(
167 sampler_train, args.batch_size, drop_last=True)
168
169 data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
170 collate_fn=utils.collate_fn, num_workers=args.num_workers,
171 pin_memory=True)
172 data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
173 drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers,
174 pin_memory=True)
175
176 # lr_backbone_names = ["backbone.0", "backbone.neck", "input_proj", "transformer.encoder"]
177 def match_name_keywords(n, name_keywords):
178 out = False
179 for b in name_keywords:
180 if b in n:
181 out = True
182 break
183 return out
184
185 for n, p in model_without_ddp.named_parameters():
186 print(n)

Callers 1

main.pyFile · 0.70

Calls 9

build_modelFunction · 0.90
build_datasetFunction · 0.90
evaluateFunction · 0.90
train_one_epochFunction · 0.90
printFunction · 0.85
match_name_keywordsFunction · 0.85
toMethod · 0.80
set_epochMethod · 0.45

Tested by

no test coverage detected