MCPcopy
hub / github.com/mosaicml/composer / test_eval_sample_interval

Function test_eval_sample_interval

tests/trainer/test_trainer_eval.py:306–374  ·  view source on GitHub ↗

Tests that the trainer evaluates the model at the correct intervals when using sample-based intervals.

(
    tiny_bert_tokenizer,
    eval_interval: str,
    batch_size: int,
    sequence_length: int,
    tmp_path: pathlib.Path,
)

Source from the content-addressed store, hash-verified

304@pytest.mark.parametrize('batch_size', [1, 4, 5])
305@pytest.mark.parametrize('sequence_length', [1, 16])
306def test_eval_sample_interval(
307 tiny_bert_tokenizer,
308 eval_interval: str,
309 batch_size: int,
310 sequence_length: int,
311 tmp_path: pathlib.Path,
312):
313 """Tests that the trainer evaluates the model at the correct intervals when using sample-based intervals."""
314 max_duration_time = Time.from_timestring('5ba')
315 eval_interval_time = Time.from_timestring(eval_interval)
316 max_duration_samples = max_duration_time.value * batch_size
317
318 # calculate the expected number of evals
319 last_sample_iter = 0
320 next_multiple = eval_interval_time.value
321 expected_num_evals = 0
322 last_multiple_added = -1
323 for sample_iter in range(0, max_duration_samples + batch_size, batch_size):
324 if last_sample_iter < next_multiple <= sample_iter:
325 last_multiple_added = next_multiple
326 expected_num_evals += 1
327 last_token_iter = sample_iter
328 while next_multiple <= last_token_iter:
329 next_multiple += eval_interval_time.value
330
331 if last_multiple_added + batch_size <= max_duration_samples:
332 expected_num_evals += 1
333
334 num_eval_batches = 2
335 expected_batch_evals = expected_num_evals * num_eval_batches
336
337 transformers = pytest.importorskip('transformers')
338 model = SimpleTransformerMaskedLM(vocab_size=tiny_bert_tokenizer.vocab_size)
339 pretraining_train_dataset = RandomTextLMDataset(
340 size=100,
341 vocab_size=tiny_bert_tokenizer.vocab_size,
342 sequence_length=sequence_length,
343 use_keys=True,
344 )
345
346 collator = transformers.DataCollatorForLanguageModeling(tokenizer=tiny_bert_tokenizer, mlm_probability=0.15)
347 dataloader = DataLoader(
348 pretraining_train_dataset,
349 batch_size=batch_size,
350 sampler=dist.get_sampler(pretraining_train_dataset),
351 collate_fn=collator,
352 )
353 eval_dataloader = DataLoader(
354 pretraining_train_dataset,
355 batch_size=batch_size,
356 sampler=dist.get_sampler(pretraining_train_dataset),
357 collate_fn=collator,
358 )
359
360 event_counter_callback = EventCounterCallback()
361 trainer = Trainer(
362 model=model,
363 train_dataloader=dataloader,

Callers

nothing calls this directly

Calls 6

fitMethod · 0.95
RandomTextLMDatasetClass · 0.90
TrainerClass · 0.90
from_timestringMethod · 0.80

Tested by

no test coverage detected