(self)
| 46 | self.update_lr() |
| 47 | |
| 48 | def update_lr(self): |
| 49 | param_groups = self.optimizer.param_groups |
| 50 | for group in param_groups: |
| 51 | if 'lr_scale' not in group: |
| 52 | continue |
| 53 | params = group['params'] |
| 54 | # update lr scale |
| 55 | lr_scale = None |
| 56 | for p in params: |
| 57 | if hasattr(p, 'lr_scale'): |
| 58 | if lr_scale is None: |
| 59 | lr_scale = p.lr_scale |
| 60 | else: |
| 61 | assert lr_scale == p.lr_scale, (lr_scale, p.lr_scale) |
| 62 | if lr_scale != group['lr_scale']: |
| 63 | if is_main_process(): |
| 64 | print('=' * 30) |
| 65 | print("params:", [e.param_name for e in params]) |
| 66 | print( |
| 67 | f"change lr scale: {group['lr_scale']} to {lr_scale}") |
| 68 | group['lr_scale'] = lr_scale |
| 69 | if lr_scale is not None: |
| 70 | group['lr'] *= lr_scale |
| 71 | |
| 72 | def state_dict(self): |
| 73 | return self.lr_scheduler.state_dict() |
no test coverage detected