(model, skip_list=(), skip_keywords=())
| 57 | |
| 58 | |
| 59 | def set_weight_decay(model, skip_list=(), skip_keywords=()): |
| 60 | has_decay = [] |
| 61 | no_decay = [] |
| 62 | |
| 63 | for name, param in model.named_parameters(): |
| 64 | if not param.requires_grad: |
| 65 | continue # frozen weights |
| 66 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ |
| 67 | check_keywords_in_name(name, skip_keywords): |
| 68 | no_decay.append(param) |
| 69 | # print(f"{name} has no weight decay") |
| 70 | else: |
| 71 | has_decay.append(param) |
| 72 | return [{'params': has_decay}, |
| 73 | {'params': no_decay, 'weight_decay': 0.}] |
| 74 | |
| 75 | |
| 76 | def check_keywords_in_name(name, keywords=()): |
no test coverage detected