Tests that the trainer evaluates the model at the correct intervals when using sample-based intervals.
(
tiny_bert_tokenizer,
eval_interval: str,
batch_size: int,
sequence_length: int,
tmp_path: pathlib.Path,
)
| 304 | @pytest.mark.parametrize('batch_size', [1, 4, 5]) |
| 305 | @pytest.mark.parametrize('sequence_length', [1, 16]) |
| 306 | def test_eval_sample_interval( |
| 307 | tiny_bert_tokenizer, |
| 308 | eval_interval: str, |
| 309 | batch_size: int, |
| 310 | sequence_length: int, |
| 311 | tmp_path: pathlib.Path, |
| 312 | ): |
| 313 | """Tests that the trainer evaluates the model at the correct intervals when using sample-based intervals.""" |
| 314 | max_duration_time = Time.from_timestring('5ba') |
| 315 | eval_interval_time = Time.from_timestring(eval_interval) |
| 316 | max_duration_samples = max_duration_time.value * batch_size |
| 317 | |
| 318 | # calculate the expected number of evals |
| 319 | last_sample_iter = 0 |
| 320 | next_multiple = eval_interval_time.value |
| 321 | expected_num_evals = 0 |
| 322 | last_multiple_added = -1 |
| 323 | for sample_iter in range(0, max_duration_samples + batch_size, batch_size): |
| 324 | if last_sample_iter < next_multiple <= sample_iter: |
| 325 | last_multiple_added = next_multiple |
| 326 | expected_num_evals += 1 |
| 327 | last_token_iter = sample_iter |
| 328 | while next_multiple <= last_token_iter: |
| 329 | next_multiple += eval_interval_time.value |
| 330 | |
| 331 | if last_multiple_added + batch_size <= max_duration_samples: |
| 332 | expected_num_evals += 1 |
| 333 | |
| 334 | num_eval_batches = 2 |
| 335 | expected_batch_evals = expected_num_evals * num_eval_batches |
| 336 | |
| 337 | transformers = pytest.importorskip('transformers') |
| 338 | model = SimpleTransformerMaskedLM(vocab_size=tiny_bert_tokenizer.vocab_size) |
| 339 | pretraining_train_dataset = RandomTextLMDataset( |
| 340 | size=100, |
| 341 | vocab_size=tiny_bert_tokenizer.vocab_size, |
| 342 | sequence_length=sequence_length, |
| 343 | use_keys=True, |
| 344 | ) |
| 345 | |
| 346 | collator = transformers.DataCollatorForLanguageModeling(tokenizer=tiny_bert_tokenizer, mlm_probability=0.15) |
| 347 | dataloader = DataLoader( |
| 348 | pretraining_train_dataset, |
| 349 | batch_size=batch_size, |
| 350 | sampler=dist.get_sampler(pretraining_train_dataset), |
| 351 | collate_fn=collator, |
| 352 | ) |
| 353 | eval_dataloader = DataLoader( |
| 354 | pretraining_train_dataset, |
| 355 | batch_size=batch_size, |
| 356 | sampler=dist.get_sampler(pretraining_train_dataset), |
| 357 | collate_fn=collator, |
| 358 | ) |
| 359 | |
| 360 | event_counter_callback = EventCounterCallback() |
| 361 | trainer = Trainer( |
| 362 | model=model, |
| 363 | train_dataloader=dataloader, |
nothing calls this directly
no test coverage detected