MCPcopy
hub / github.com/hustvl/Vim / train_segmentor

Function train_segmentor

seg/mmcv_custom/train_api.py:37–129  ·  view source on GitHub ↗

Launch segmentor training.

(model,
                    dataset,
                    cfg,
                    distributed=False,
                    validate=False,
                    timestamp=None,
                    meta=None)

Source from the content-addressed store, hash-verified

35
36
37def train_segmentor(model,
38 dataset,
39 cfg,
40 distributed=False,
41 validate=False,
42 timestamp=None,
43 meta=None):
44 """Launch segmentor training."""
45 logger = get_root_logger(cfg.log_level)
46
47 # prepare data loaders
48 dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
49 data_loaders = [
50 build_dataloader(
51 ds,
52 cfg.data.samples_per_gpu,
53 cfg.data.workers_per_gpu,
54 # cfg.gpus will be ignored if distributed
55 len(cfg.gpu_ids),
56 dist=distributed,
57 seed=cfg.seed,
58 drop_last=True) for ds in dataset
59 ]
60
61 # build optimizer
62 optimizer = build_optimizer(model, cfg.optimizer)
63
64 # use apex fp16 optimizer
65 if cfg.optimizer_config.get("type", None) and cfg.optimizer_config["type"] == "DistOptimizerHook":
66 if cfg.optimizer_config.get("use_fp16", False):
67 model, optimizer = apex.amp.initialize(
68 model.cuda(), optimizer, opt_level="O1")
69 for m in model.modules():
70 if hasattr(m, "fp16_enabled"):
71 m.fp16_enabled = True
72
73 # put model on gpus
74 if distributed:
75 find_unused_parameters = cfg.get('find_unused_parameters', False)
76 # Sets the `find_unused_parameters` parameter in
77 # torch.nn.parallel.DistributedDataParallel
78 model = MMDistributedDataParallel(
79 model.cuda(),
80 device_ids=[torch.cuda.current_device()],
81 broadcast_buffers=False,
82 find_unused_parameters=find_unused_parameters)
83 else:
84 model = MMDataParallel(
85 model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
86
87 if cfg.get('runner') is None:
88 cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters}
89 warnings.warn(
90 'config is now expected to have a `runner` section, '
91 'please set `runner` in your config.', UserWarning)
92
93 runner = build_runner(
94 cfg.runner,

Callers 1

mainFunction · 0.90

Calls 5

build_datasetFunction · 0.90
build_optimizerFunction · 0.85
resumeMethod · 0.80
getMethod · 0.45
runMethod · 0.45

Tested by

no test coverage detected