MCPcopy
hub / github.com/ultralytics/yolov5 / train

Function train

train.py:69–430  ·  view source on GitHub ↗
(hyp, opt, device, callbacks)

Source from the content-addressed store, hash-verified

67
68
69def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictionary
70 save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze = \
71 Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
72 opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
73 callbacks.run('on_pretrain_routine_start')
74
75 # Directories
76 w = save_dir / 'weights' # weights dir
77 (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
78 last, best = w / 'last.pt', w / 'best.pt'
79
80 # Hyperparameters
81 if isinstance(hyp, str):
82 with open(hyp, errors='ignore') as f:
83 hyp = yaml.safe_load(f) # load hyps dict
84 LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
85 opt.hyp = hyp.copy() # for saving hyps to checkpoints
86
87 # Save run settings
88 if not evolve:
89 yaml_save(save_dir / 'hyp.yaml', hyp)
90 yaml_save(save_dir / 'opt.yaml', vars(opt))
91
92 # Loggers
93 data_dict = None
94 if RANK in {-1, 0}:
95 loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
96
97 # Register actions
98 for k in methods(loggers):
99 callbacks.register_action(k, callback=getattr(loggers, k))
100
101 # Process custom dataset artifact link
102 data_dict = loggers.remote_dataset
103 if resume: # If resuming runs from remote artifact
104 weights, epochs, hyp, batch_size = opt.weights, opt.epochs, opt.hyp, opt.batch_size
105
106 # Config
107 plots = not evolve and not opt.noplots # create plots
108 cuda = device.type != 'cpu'
109 init_seeds(opt.seed + 1 + RANK, deterministic=True)
110 with torch_distributed_zero_first(LOCAL_RANK):
111 data_dict = data_dict or check_dataset(data) # check if None
112 train_path, val_path = data_dict['train'], data_dict['val']
113 nc = 1 if single_cls else int(data_dict['nc']) # number of classes
114 names = {0: 'item'} if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
115 is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
116
117 # Model
118 check_suffix(weights, '.pt') # check weights
119 pretrained = weights.endswith('.pt')
120 if pretrained:
121 with torch_distributed_zero_first(LOCAL_RANK):
122 weights = attempt_download(weights) # download if not found locally
123 ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
124 model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
125 exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
126 csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32

Callers 3

sweepFunction · 0.90
runFunction · 0.90
mainFunction · 0.70

Calls 15

on_params_updateMethod · 0.95
colorstrFunction · 0.90
yaml_saveFunction · 0.90
LoggersClass · 0.90
methodsFunction · 0.90
init_seedsFunction · 0.90
check_datasetFunction · 0.90
check_suffixFunction · 0.90
attempt_downloadFunction · 0.90
intersect_dictsFunction · 0.90
check_ampFunction · 0.90

Tested by

no test coverage detected