(hyp, opt, device, callbacks)
| 67 | |
| 68 | |
| 69 | def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictionary |
| 70 | save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze = \ |
| 71 | Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \ |
| 72 | opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze |
| 73 | callbacks.run('on_pretrain_routine_start') |
| 74 | |
| 75 | # Directories |
| 76 | w = save_dir / 'weights' # weights dir |
| 77 | (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir |
| 78 | last, best = w / 'last.pt', w / 'best.pt' |
| 79 | |
| 80 | # Hyperparameters |
| 81 | if isinstance(hyp, str): |
| 82 | with open(hyp, errors='ignore') as f: |
| 83 | hyp = yaml.safe_load(f) # load hyps dict |
| 84 | LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) |
| 85 | opt.hyp = hyp.copy() # for saving hyps to checkpoints |
| 86 | |
| 87 | # Save run settings |
| 88 | if not evolve: |
| 89 | yaml_save(save_dir / 'hyp.yaml', hyp) |
| 90 | yaml_save(save_dir / 'opt.yaml', vars(opt)) |
| 91 | |
| 92 | # Loggers |
| 93 | data_dict = None |
| 94 | if RANK in {-1, 0}: |
| 95 | loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance |
| 96 | |
| 97 | # Register actions |
| 98 | for k in methods(loggers): |
| 99 | callbacks.register_action(k, callback=getattr(loggers, k)) |
| 100 | |
| 101 | # Process custom dataset artifact link |
| 102 | data_dict = loggers.remote_dataset |
| 103 | if resume: # If resuming runs from remote artifact |
| 104 | weights, epochs, hyp, batch_size = opt.weights, opt.epochs, opt.hyp, opt.batch_size |
| 105 | |
| 106 | # Config |
| 107 | plots = not evolve and not opt.noplots # create plots |
| 108 | cuda = device.type != 'cpu' |
| 109 | init_seeds(opt.seed + 1 + RANK, deterministic=True) |
| 110 | with torch_distributed_zero_first(LOCAL_RANK): |
| 111 | data_dict = data_dict or check_dataset(data) # check if None |
| 112 | train_path, val_path = data_dict['train'], data_dict['val'] |
| 113 | nc = 1 if single_cls else int(data_dict['nc']) # number of classes |
| 114 | names = {0: 'item'} if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names |
| 115 | is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset |
| 116 | |
| 117 | # Model |
| 118 | check_suffix(weights, '.pt') # check weights |
| 119 | pretrained = weights.endswith('.pt') |
| 120 | if pretrained: |
| 121 | with torch_distributed_zero_first(LOCAL_RANK): |
| 122 | weights = attempt_download(weights) # download if not found locally |
| 123 | ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak |
| 124 | model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create |
| 125 | exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys |
| 126 | csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 |
no test coverage detected