MCPcopy
hub / github.com/mosaicml/composer / __init__

Method __init__

composer/algorithms/swa/swa.py:111–161  ·  view source on GitHub ↗
(
        self,
        swa_start: str = '0.7dur',
        swa_end: str = '0.97dur',
        update_interval: str = '1ep',
        schedule_swa_lr: bool = False,
        anneal_strategy: str = 'linear',
        anneal_steps: int = 10,
        swa_lr: Optional[float] = None,
    )

Source from the content-addressed store, hash-verified

109 """
110
111 def __init__(
112 self,
113 swa_start: str = '0.7dur',
114 swa_end: str = '0.97dur',
115 update_interval: str = '1ep',
116 schedule_swa_lr: bool = False,
117 anneal_strategy: str = 'linear',
118 anneal_steps: int = 10,
119 swa_lr: Optional[float] = None,
120 ):
121
122 warnings.warn(
123 'SWA has known issues when resuming from a checkpoint on multiple GPUs, which will cause an error when resuming without `load_weights_only=True`.',
124 )
125 self.schedule_swa_lr = schedule_swa_lr
126 self.anneal_strategy = anneal_strategy
127 self.anneal_steps = anneal_steps
128 self.swa_lr = swa_lr
129 self.swa_model: Optional[torch.nn.Module] = None
130 self.swa_completed = False
131 self.swa_started = False
132
133 # Check timestrings are parsable and convert into time objects
134 self.swa_start = Time.from_timestring(swa_start)
135 self.swa_end = Time.from_timestring(swa_end)
136 self.update_interval = Time.from_timestring(update_interval)
137
138 self._validate_time()
139
140 if anneal_steps <= 0:
141 raise ValueError('anneal_steps must be greater than 0')
142
143 # Check annealing_strategy string
144 if self.anneal_strategy.lower() in ['linear', 'lin']:
145 self.anneal_strategy = 'linear'
146 elif self.anneal_strategy.lower() in ['cos', 'cosine']:
147 self.anneal_strategy = 'cos'
148 else:
149 raise ValueError("anneal_strategy must be one of {'linear', 'cos'}.")
150
151 self.swa_scheduler = None
152 self.swa_model = None
153
154 # Keeps track of # steps so that we can know when to update averaged model
155 self.step_counter = 0
156
157 # Check units for update_interval and set match event accordingly
158 if self.update_interval.unit == TimeUnit.BATCH:
159 self.match_event = Event.BATCH_END
160 elif self.update_interval.unit == TimeUnit.EPOCH:
161 self.match_event = Event.EPOCH_END
162
163 def _validate_time(self):
164 # validate time units

Callers

nothing calls this directly

Calls 2

_validate_timeMethod · 0.95
from_timestringMethod · 0.80

Tested by

no test coverage detected