MCPcopy
hub / github.com/microsoft/Cream / divide_param_groups_by_lr_scale

Function divide_param_groups_by_lr_scale

TinyViT/tinyvit_utils.py:79–132  ·  view source on GitHub ↗

Divide parameters with different lr scale into different groups. Inputs ------ param_groups: a list of dict of torch.nn.Parameter ``` # example: param1.lr_scale = param2.lr_scale = param3.lr_scale = 0.6 param4.lr_scale = param5.lr_scale = param6.lr_scale = 0.3 p

(param_groups)

Source from the content-addressed store, hash-verified

77
78
79def divide_param_groups_by_lr_scale(param_groups):
80 """
81 Divide parameters with different lr scale into different groups.
82
83 Inputs
84 ------
85 param_groups: a list of dict of torch.nn.Parameter
86 ```
87 # example:
88 param1.lr_scale = param2.lr_scale = param3.lr_scale = 0.6
89 param4.lr_scale = param5.lr_scale = param6.lr_scale = 0.3
90 param_groups = [{'params': [param1, param2, param4]},
91 {'params': [param3, param5, param6], 'weight_decay': 0.}]
92
93 param_groups = divide_param_groups_by_lr_scale(param_groups)
94 ```
95
96 Outputs
97 -------
98 new_param_groups: a list of dict containing the key `lr_scale`
99 ```
100 param_groups = [
101 {'params': [param1, param2], 'lr_scale': 0.6},
102 {'params': [param3], 'weight_decay': 0., 'lr_scale': 0.6}
103 {'params': [param4], 'lr_scale': 0.3},
104 {'params': [param5, param6], 'weight_decay': 0., 'lr_scale': 0.3}
105 ]
106 ```
107 """
108 new_groups = []
109 for group in param_groups:
110 params = group.pop('params')
111
112 '''
113 divide parameters to different groups by lr_scale
114 '''
115 lr_scale_groups = dict()
116 for p in params:
117 lr_scale = getattr(p, 'lr_scale', 1.0)
118
119 # create a list if not existed
120 if lr_scale not in lr_scale_groups:
121 lr_scale_groups[lr_scale] = list()
122
123 # add the parameter with `lr_scale` into the specific group.
124 lr_scale_groups[lr_scale].append(p)
125
126 for lr_scale, params in lr_scale_groups.items():
127 # copy other parameter information like `weight_decay`
128 new_group = copy.copy(group)
129 new_group['params'] = params
130 new_group['lr_scale'] = lr_scale
131 new_groups.append(new_group)
132 return new_groups
133
134
135def set_weight_decay(model):

Callers 1

build_optimizerFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected