(
self,
swa_start: str = '0.7dur',
swa_end: str = '0.97dur',
update_interval: str = '1ep',
schedule_swa_lr: bool = False,
anneal_strategy: str = 'linear',
anneal_steps: int = 10,
swa_lr: Optional[float] = None,
)
| 109 | """ |
| 110 | |
| 111 | def __init__( |
| 112 | self, |
| 113 | swa_start: str = '0.7dur', |
| 114 | swa_end: str = '0.97dur', |
| 115 | update_interval: str = '1ep', |
| 116 | schedule_swa_lr: bool = False, |
| 117 | anneal_strategy: str = 'linear', |
| 118 | anneal_steps: int = 10, |
| 119 | swa_lr: Optional[float] = None, |
| 120 | ): |
| 121 | |
| 122 | warnings.warn( |
| 123 | 'SWA has known issues when resuming from a checkpoint on multiple GPUs, which will cause an error when resuming without `load_weights_only=True`.', |
| 124 | ) |
| 125 | self.schedule_swa_lr = schedule_swa_lr |
| 126 | self.anneal_strategy = anneal_strategy |
| 127 | self.anneal_steps = anneal_steps |
| 128 | self.swa_lr = swa_lr |
| 129 | self.swa_model: Optional[torch.nn.Module] = None |
| 130 | self.swa_completed = False |
| 131 | self.swa_started = False |
| 132 | |
| 133 | # Check timestrings are parsable and convert into time objects |
| 134 | self.swa_start = Time.from_timestring(swa_start) |
| 135 | self.swa_end = Time.from_timestring(swa_end) |
| 136 | self.update_interval = Time.from_timestring(update_interval) |
| 137 | |
| 138 | self._validate_time() |
| 139 | |
| 140 | if anneal_steps <= 0: |
| 141 | raise ValueError('anneal_steps must be greater than 0') |
| 142 | |
| 143 | # Check annealing_strategy string |
| 144 | if self.anneal_strategy.lower() in ['linear', 'lin']: |
| 145 | self.anneal_strategy = 'linear' |
| 146 | elif self.anneal_strategy.lower() in ['cos', 'cosine']: |
| 147 | self.anneal_strategy = 'cos' |
| 148 | else: |
| 149 | raise ValueError("anneal_strategy must be one of {'linear', 'cos'}.") |
| 150 | |
| 151 | self.swa_scheduler = None |
| 152 | self.swa_model = None |
| 153 | |
| 154 | # Keeps track of # steps so that we can know when to update averaged model |
| 155 | self.step_counter = 0 |
| 156 | |
| 157 | # Check units for update_interval and set match event accordingly |
| 158 | if self.update_interval.unit == TimeUnit.BATCH: |
| 159 | self.match_event = Event.BATCH_END |
| 160 | elif self.update_interval.unit == TimeUnit.EPOCH: |
| 161 | self.match_event = Event.EPOCH_END |
| 162 | |
| 163 | def _validate_time(self): |
| 164 | # validate time units |
nothing calls this directly
no test coverage detected