(
scheduler: ComposerScheduler,
ssr: float,
test_times: list[str],
expected_lrs: list[float],
dummy_schedulers_state: State,
)
| 294 | ], |
| 295 | ) |
| 296 | def test_scheduler_init( |
| 297 | scheduler: ComposerScheduler, |
| 298 | ssr: float, |
| 299 | test_times: list[str], |
| 300 | expected_lrs: list[float], |
| 301 | dummy_schedulers_state: State, |
| 302 | ): |
| 303 | |
| 304 | state = dummy_schedulers_state |
| 305 | assert state.dataloader_len is not None |
| 306 | assert state.max_duration is not None |
| 307 | state.max_duration = Time(value=int(state.max_duration.value * ssr), unit=state.max_duration.unit) |
| 308 | for test_time, expected_lr in zip(test_times, expected_lrs): |
| 309 | parsed_time = Time.from_timestring(test_time) |
| 310 | assert parsed_time.unit in [TimeUnit.EPOCH, TimeUnit.BATCH] |
| 311 | if parsed_time.unit == TimeUnit.EPOCH: |
| 312 | state.timestamp = state.timestamp.copy( |
| 313 | epoch=parsed_time, |
| 314 | batch=Time(int(state.dataloader_len) * int(parsed_time), TimeUnit.BATCH), |
| 315 | ) |
| 316 | else: |
| 317 | state.timestamp = state.timestamp.copy( |
| 318 | batch=parsed_time, |
| 319 | epoch=Time(int(parsed_time) // int(state.dataloader_len), TimeUnit.EPOCH), |
| 320 | ) |
| 321 | |
| 322 | lr = scheduler(state, ssr) |
| 323 | assert lr == pytest.approx(expected_lr, abs=1e-3) |
| 324 | |
| 325 | |
| 326 | @pytest.mark.parametrize( |
nothing calls this directly
no test coverage detected