(model, optim_cfg)
| 9 | |
| 10 | |
| 11 | def build_optimizer(model, optim_cfg): |
| 12 | if optim_cfg.OPTIMIZER == 'adam': |
| 13 | optimizer = optim.Adam(model.parameters(), lr=optim_cfg.LR, weight_decay=optim_cfg.WEIGHT_DECAY) |
| 14 | elif optim_cfg.OPTIMIZER == 'sgd': |
| 15 | optimizer = optim.SGD( |
| 16 | model.parameters(), lr=optim_cfg.LR, weight_decay=optim_cfg.WEIGHT_DECAY, |
| 17 | momentum=optim_cfg.MOMENTUM |
| 18 | ) |
| 19 | elif optim_cfg.OPTIMIZER == 'adam_onecycle': |
| 20 | def children(m: nn.Module): |
| 21 | return list(m.children()) |
| 22 | |
| 23 | def num_children(m: nn.Module) -> int: |
| 24 | return len(children(m)) |
| 25 | |
| 26 | flatten_model = lambda m: sum(map(flatten_model, m.children()), []) if num_children(m) else [m] |
| 27 | get_layer_groups = lambda m: [nn.Sequential(*flatten_model(m))] |
| 28 | |
| 29 | optimizer_func = partial(optim.Adam, betas=(0.9, 0.99)) |
| 30 | optimizer = OptimWrapper.create( |
| 31 | optimizer_func, 3e-3, get_layer_groups(model), wd=optim_cfg.WEIGHT_DECAY, true_wd=True, bn_wd=True |
| 32 | ) |
| 33 | else: |
| 34 | raise NotImplementedError |
| 35 | |
| 36 | return optimizer |
| 37 | |
| 38 | |
| 39 | def build_scheduler(optimizer, total_iters_each_epoch, total_epochs, last_epoch, optim_cfg): |
no test coverage detected