MCPcopy
hub / github.com/mosaicml/composer / test_eval_token_interval

Function test_eval_token_interval

tests/trainer/test_trainer_eval.py:231–300  ·  view source on GitHub ↗

Tests that the trainer evaluates the model at the correct intervals when using token-based intervals.

(
    tiny_bert_tokenizer,
    eval_interval: str,
    batch_size: int,
    sequence_length: int,
    tmp_path: pathlib.Path,
)

Source from the content-addressed store, hash-verified

229@pytest.mark.parametrize('batch_size', [1, 4, 5])
230@pytest.mark.parametrize('sequence_length', [1, 16])
231def test_eval_token_interval(
232 tiny_bert_tokenizer,
233 eval_interval: str,
234 batch_size: int,
235 sequence_length: int,
236 tmp_path: pathlib.Path,
237):
238 """Tests that the trainer evaluates the model at the correct intervals when using token-based intervals."""
239 tokens_per_batch = batch_size * sequence_length
240 max_duration_time = Time.from_timestring('5ba')
241 eval_interval_time = Time.from_timestring(eval_interval)
242 max_duration_tokens = max_duration_time.value * tokens_per_batch
243
244 # calculate the expected number of evals
245 last_token_iter = 0
246 next_multiple = eval_interval_time.value
247 expected_num_evals = 0
248 last_multiple_added = -1
249 for token_iter in range(0, max_duration_tokens + tokens_per_batch, tokens_per_batch):
250 if last_token_iter < next_multiple <= token_iter:
251 last_multiple_added = next_multiple
252 expected_num_evals += 1
253 last_token_iter = token_iter
254 while next_multiple <= last_token_iter:
255 next_multiple += eval_interval_time.value
256
257 if last_multiple_added + tokens_per_batch <= max_duration_tokens:
258 expected_num_evals += 1
259
260 num_eval_batches = 2
261 expected_batch_evals = expected_num_evals * num_eval_batches
262
263 transformers = pytest.importorskip('transformers')
264 model = SimpleTransformerMaskedLM(vocab_size=tiny_bert_tokenizer.vocab_size)
265 pretraining_train_dataset = RandomTextLMDataset(
266 size=100,
267 vocab_size=tiny_bert_tokenizer.vocab_size,
268 sequence_length=sequence_length,
269 use_keys=True,
270 )
271
272 collator = transformers.DataCollatorForLanguageModeling(tokenizer=tiny_bert_tokenizer, mlm_probability=0.15)
273 dataloader = DataLoader(
274 pretraining_train_dataset,
275 batch_size=batch_size,
276 sampler=dist.get_sampler(pretraining_train_dataset),
277 collate_fn=collator,
278 )
279 eval_dataloader = DataLoader(
280 pretraining_train_dataset,
281 batch_size=batch_size,
282 sampler=dist.get_sampler(pretraining_train_dataset),
283 collate_fn=collator,
284 )
285
286 event_counter_callback = EventCounterCallback()
287 trainer = Trainer(
288 model=model,

Callers

nothing calls this directly

Calls 6

fitMethod · 0.95
RandomTextLMDatasetClass · 0.90
TrainerClass · 0.90
from_timestringMethod · 0.80

Tested by

no test coverage detected