Divide parameters with different lr scale into different groups. Inputs ------ param_groups: a list of dict of torch.nn.Parameter ``` # example: param1.lr_scale = param2.lr_scale = param3.lr_scale = 0.6 param4.lr_scale = param5.lr_scale = param6.lr_scale = 0.3 p
(param_groups)
| 77 | |
| 78 | |
| 79 | def divide_param_groups_by_lr_scale(param_groups): |
| 80 | """ |
| 81 | Divide parameters with different lr scale into different groups. |
| 82 | |
| 83 | Inputs |
| 84 | ------ |
| 85 | param_groups: a list of dict of torch.nn.Parameter |
| 86 | ``` |
| 87 | # example: |
| 88 | param1.lr_scale = param2.lr_scale = param3.lr_scale = 0.6 |
| 89 | param4.lr_scale = param5.lr_scale = param6.lr_scale = 0.3 |
| 90 | param_groups = [{'params': [param1, param2, param4]}, |
| 91 | {'params': [param3, param5, param6], 'weight_decay': 0.}] |
| 92 | |
| 93 | param_groups = divide_param_groups_by_lr_scale(param_groups) |
| 94 | ``` |
| 95 | |
| 96 | Outputs |
| 97 | ------- |
| 98 | new_param_groups: a list of dict containing the key `lr_scale` |
| 99 | ``` |
| 100 | param_groups = [ |
| 101 | {'params': [param1, param2], 'lr_scale': 0.6}, |
| 102 | {'params': [param3], 'weight_decay': 0., 'lr_scale': 0.6} |
| 103 | {'params': [param4], 'lr_scale': 0.3}, |
| 104 | {'params': [param5, param6], 'weight_decay': 0., 'lr_scale': 0.3} |
| 105 | ] |
| 106 | ``` |
| 107 | """ |
| 108 | new_groups = [] |
| 109 | for group in param_groups: |
| 110 | params = group.pop('params') |
| 111 | |
| 112 | ''' |
| 113 | divide parameters to different groups by lr_scale |
| 114 | ''' |
| 115 | lr_scale_groups = dict() |
| 116 | for p in params: |
| 117 | lr_scale = getattr(p, 'lr_scale', 1.0) |
| 118 | |
| 119 | # create a list if not existed |
| 120 | if lr_scale not in lr_scale_groups: |
| 121 | lr_scale_groups[lr_scale] = list() |
| 122 | |
| 123 | # add the parameter with `lr_scale` into the specific group. |
| 124 | lr_scale_groups[lr_scale].append(p) |
| 125 | |
| 126 | for lr_scale, params in lr_scale_groups.items(): |
| 127 | # copy other parameter information like `weight_decay` |
| 128 | new_group = copy.copy(group) |
| 129 | new_group['params'] = params |
| 130 | new_group['lr_scale'] = lr_scale |
| 131 | new_groups.append(new_group) |
| 132 | return new_groups |
| 133 | |
| 134 | |
| 135 | def set_weight_decay(model): |