Launch segmentor training.
(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None)
| 35 | |
| 36 | |
| 37 | def train_segmentor(model, |
| 38 | dataset, |
| 39 | cfg, |
| 40 | distributed=False, |
| 41 | validate=False, |
| 42 | timestamp=None, |
| 43 | meta=None): |
| 44 | """Launch segmentor training.""" |
| 45 | logger = get_root_logger(cfg.log_level) |
| 46 | |
| 47 | # prepare data loaders |
| 48 | dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] |
| 49 | data_loaders = [ |
| 50 | build_dataloader( |
| 51 | ds, |
| 52 | cfg.data.samples_per_gpu, |
| 53 | cfg.data.workers_per_gpu, |
| 54 | # cfg.gpus will be ignored if distributed |
| 55 | len(cfg.gpu_ids), |
| 56 | dist=distributed, |
| 57 | seed=cfg.seed, |
| 58 | drop_last=True) for ds in dataset |
| 59 | ] |
| 60 | |
| 61 | # build optimizer |
| 62 | optimizer = build_optimizer(model, cfg.optimizer) |
| 63 | |
| 64 | # use apex fp16 optimizer |
| 65 | if cfg.optimizer_config.get("type", None) and cfg.optimizer_config["type"] == "DistOptimizerHook": |
| 66 | if cfg.optimizer_config.get("use_fp16", False): |
| 67 | model, optimizer = apex.amp.initialize( |
| 68 | model.cuda(), optimizer, opt_level="O1") |
| 69 | for m in model.modules(): |
| 70 | if hasattr(m, "fp16_enabled"): |
| 71 | m.fp16_enabled = True |
| 72 | |
| 73 | # put model on gpus |
| 74 | if distributed: |
| 75 | find_unused_parameters = cfg.get('find_unused_parameters', False) |
| 76 | # Sets the `find_unused_parameters` parameter in |
| 77 | # torch.nn.parallel.DistributedDataParallel |
| 78 | model = MMDistributedDataParallel( |
| 79 | model.cuda(), |
| 80 | device_ids=[torch.cuda.current_device()], |
| 81 | broadcast_buffers=False, |
| 82 | find_unused_parameters=find_unused_parameters) |
| 83 | else: |
| 84 | model = MMDataParallel( |
| 85 | model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) |
| 86 | |
| 87 | if cfg.get('runner') is None: |
| 88 | cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters} |
| 89 | warnings.warn( |
| 90 | 'config is now expected to have a `runner` section, ' |
| 91 | 'please set `runner` in your config.', UserWarning) |
| 92 | |
| 93 | runner = build_runner( |
| 94 | cfg.runner, |
no test coverage detected