(args)
| 127 | |
| 128 | |
| 129 | def main(args): |
| 130 | utils.init_distributed_mode(args) |
| 131 | print("git:\n {}\n".format(utils.get_sha())) |
| 132 | |
| 133 | if args.frozen_weights is not None: |
| 134 | assert args.masks, "Frozen training is meant for segmentation only" |
| 135 | print(args) |
| 136 | |
| 137 | device = torch.device(args.device) |
| 138 | |
| 139 | # fix the seed for reproducibility |
| 140 | seed = args.seed + utils.get_rank() |
| 141 | torch.manual_seed(seed) |
| 142 | np.random.seed(seed) |
| 143 | random.seed(seed) |
| 144 | |
| 145 | model, criterion, postprocessors = build_model(args) |
| 146 | model.to(device) |
| 147 | |
| 148 | model_without_ddp = model |
| 149 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| 150 | print('number of params:', n_parameters) |
| 151 | |
| 152 | dataset_train = build_dataset(image_set='train', args=args) |
| 153 | dataset_val = build_dataset(image_set='val', args=args) |
| 154 | |
| 155 | if args.distributed: |
| 156 | if args.cache_mode: |
| 157 | sampler_train = samplers.NodeDistributedSampler(dataset_train) |
| 158 | sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False) |
| 159 | else: |
| 160 | sampler_train = samplers.DistributedSampler(dataset_train) |
| 161 | sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False) |
| 162 | else: |
| 163 | sampler_train = torch.utils.data.RandomSampler(dataset_train) |
| 164 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) |
| 165 | |
| 166 | batch_sampler_train = torch.utils.data.BatchSampler( |
| 167 | sampler_train, args.batch_size, drop_last=True) |
| 168 | |
| 169 | data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, |
| 170 | collate_fn=utils.collate_fn, num_workers=args.num_workers, |
| 171 | pin_memory=True) |
| 172 | data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, |
| 173 | drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers, |
| 174 | pin_memory=True) |
| 175 | |
| 176 | # lr_backbone_names = ["backbone.0", "backbone.neck", "input_proj", "transformer.encoder"] |
| 177 | def match_name_keywords(n, name_keywords): |
| 178 | out = False |
| 179 | for b in name_keywords: |
| 180 | if b in n: |
| 181 | out = True |
| 182 | break |
| 183 | return out |
| 184 | |
| 185 | for n, p in model_without_ddp.named_parameters(): |
| 186 | print(n) |
no test coverage detected