MCPcopy
hub / github.com/microsoft/Cream / main

Function main

CDARTS/CDARTS_detection/test.py:130–201  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

128
129
130def main():
131 args = parse_args()
132
133 if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
134 raise ValueError('The output file must be a pkl file.')
135
136 cfg = mmcv.Config.fromfile(args.config)
137 # set cudnn_benchmark
138 if cfg.get('cudnn_benchmark', False):
139 torch.backends.cudnn.benchmark = True
140 cfg.model.pretrained = None
141 cfg.data.test.test_mode = True
142
143 # init distributed env first, since logger depends on the dist info.
144 if args.launcher == 'none':
145 distributed = False
146 else:
147 distributed = True
148 init_dist(args.launcher, **cfg.dist_params)
149
150 # build the dataloader
151 # TODO: support multiple images per gpu (only minor changes are needed)
152 dataset = build_dataset(cfg.data.test)
153 data_loader = build_dataloader(
154 dataset,
155 imgs_per_gpu=1,
156 workers_per_gpu=cfg.data.workers_per_gpu,
157 dist=distributed,
158 shuffle=False)
159
160 # build the model and load checkpoint
161 model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
162 fp16_cfg = cfg.get('fp16', None)
163 if fp16_cfg is not None:
164 wrap_fp16_model(model)
165 checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
166 # old versions did not save class info in checkpoint, this walkaround is
167 # for backward compatibility
168 if 'CLASSES' in checkpoint['meta']:
169 model.CLASSES = checkpoint['meta']['CLASSES']
170 else:
171 model.CLASSES = dataset.CLASSES
172
173 if not distributed:
174 model = MMDataParallel(model, device_ids=[0])
175 outputs = single_gpu_test(model, data_loader, args.show)
176 else:
177 model = MMDistributedDataParallel(model.cuda())
178 outputs = multi_gpu_test(model, data_loader, args.tmpdir)
179
180 rank, _ = get_dist_info()
181 if args.out and rank == 0:
182 print('\nwriting results to {}'.format(args.out))
183 mmcv.dump(outputs, args.out)
184 eval_types = args.eval
185 if eval_types:
186 print('Starting evaluate {}'.format(' and '.join(eval_types)))
187 if eval_types == ['proposal_fast']:

Callers 1

test.pyFile · 0.70

Calls 15

init_distFunction · 0.90
build_datasetFunction · 0.90
build_dataloaderFunction · 0.90
build_detectorFunction · 0.90
wrap_fp16_modelFunction · 0.90
load_checkpointFunction · 0.90
MMDataParallelClass · 0.90
get_dist_infoFunction · 0.90
coco_evalFunction · 0.90
results2jsonFunction · 0.90
fromfileMethod · 0.80

Tested by

no test coverage detected