(self, event: Event, state: State, logger: Optional[Logger])
| 165 | return event == Event.AFTER_LOAD |
| 166 | |
| 167 | def apply(self, event: Event, state: State, logger: Optional[Logger]) -> Optional[int]: |
| 168 | assert state.optimizers is not None |
| 169 | |
| 170 | state.optimizers = tuple( |
| 171 | SAMOptimizer( |
| 172 | base_optimizer=optimizer, |
| 173 | rho=self.rho, |
| 174 | epsilon=self.epsilon, |
| 175 | interval=self.interval, |
| 176 | ) for optimizer in ensure_tuple(state.optimizers) |
| 177 | ) |
| 178 | |
| 179 | # Switch to ClosureGradScaler as SAM supports and requires it |
| 180 | state.scaler = ClosureGradScaler() |
nothing calls this directly
no test coverage detected