(config, epochs, step_each_epoch, model)
| 69 | |
| 70 | |
| 71 | def build_optimizer(config, epochs, step_each_epoch, model): |
| 72 | from . import regularizer, optimizer |
| 73 | |
| 74 | config = copy.deepcopy(config) |
| 75 | # step1 build lr |
| 76 | lr = build_lr_scheduler(config.pop("lr"), epochs, step_each_epoch) |
| 77 | |
| 78 | # step2 build regularization |
| 79 | wd_scheduler = None |
| 80 | if "regularizer" in config and config["regularizer"] is not None: |
| 81 | reg_config = config.pop("regularizer") |
| 82 | reg_name = reg_config.pop("name") |
| 83 | if not hasattr(regularizer, reg_name): |
| 84 | reg_name += "Decay" |
| 85 | reg_obj = getattr(regularizer, reg_name)(**reg_config) |
| 86 | reg = reg_obj() |
| 87 | |
| 88 | # Build weight decay scheduler for CosineL2Decay |
| 89 | if isinstance(reg_obj, regularizer.CosineL2Decay): |
| 90 | warmup_epoch = reg_obj.warmup_epoch |
| 91 | warmup_steps = round(warmup_epoch * step_each_epoch) |
| 92 | total_steps = step_each_epoch * epochs |
| 93 | wd_scheduler = { |
| 94 | "start_factor": reg_obj.start_factor, |
| 95 | "end_factor": reg_obj.end_factor, |
| 96 | "total_steps": total_steps, |
| 97 | "warmup_steps": warmup_steps, |
| 98 | } |
| 99 | elif "weight_decay" in config: |
| 100 | reg = config.pop("weight_decay") |
| 101 | else: |
| 102 | reg = None |
| 103 | |
| 104 | # step3 build optimizer |
| 105 | optim_name = config.pop("name") |
| 106 | if "clip_norm" in config: |
| 107 | clip_norm = config.pop("clip_norm") |
| 108 | grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm) |
| 109 | elif "clip_norm_global" in config: |
| 110 | clip_norm = config.pop("clip_norm_global") |
| 111 | grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=clip_norm) |
| 112 | else: |
| 113 | grad_clip = None |
| 114 | optim = getattr(optimizer, optim_name)( |
| 115 | learning_rate=lr, weight_decay=reg, grad_clip=grad_clip, **config |
| 116 | ) |
| 117 | built_optim = optim(model) |
| 118 | |
| 119 | # Instantiate the scheduler now that we have the real optimizer |
| 120 | if wd_scheduler is not None: |
| 121 | wd_scheduler = CosineWeightDecayScheduler(built_optim, **wd_scheduler) |
| 122 | |
| 123 | return built_optim, lr, wd_scheduler |
no test coverage detected
searching dependent graphs…