Tests that the trainer evaluates the model at the correct intervals when using token-based intervals.
(
tiny_bert_tokenizer,
eval_interval: str,
batch_size: int,
sequence_length: int,
tmp_path: pathlib.Path,
)
| 229 | @pytest.mark.parametrize('batch_size', [1, 4, 5]) |
| 230 | @pytest.mark.parametrize('sequence_length', [1, 16]) |
| 231 | def test_eval_token_interval( |
| 232 | tiny_bert_tokenizer, |
| 233 | eval_interval: str, |
| 234 | batch_size: int, |
| 235 | sequence_length: int, |
| 236 | tmp_path: pathlib.Path, |
| 237 | ): |
| 238 | """Tests that the trainer evaluates the model at the correct intervals when using token-based intervals.""" |
| 239 | tokens_per_batch = batch_size * sequence_length |
| 240 | max_duration_time = Time.from_timestring('5ba') |
| 241 | eval_interval_time = Time.from_timestring(eval_interval) |
| 242 | max_duration_tokens = max_duration_time.value * tokens_per_batch |
| 243 | |
| 244 | # calculate the expected number of evals |
| 245 | last_token_iter = 0 |
| 246 | next_multiple = eval_interval_time.value |
| 247 | expected_num_evals = 0 |
| 248 | last_multiple_added = -1 |
| 249 | for token_iter in range(0, max_duration_tokens + tokens_per_batch, tokens_per_batch): |
| 250 | if last_token_iter < next_multiple <= token_iter: |
| 251 | last_multiple_added = next_multiple |
| 252 | expected_num_evals += 1 |
| 253 | last_token_iter = token_iter |
| 254 | while next_multiple <= last_token_iter: |
| 255 | next_multiple += eval_interval_time.value |
| 256 | |
| 257 | if last_multiple_added + tokens_per_batch <= max_duration_tokens: |
| 258 | expected_num_evals += 1 |
| 259 | |
| 260 | num_eval_batches = 2 |
| 261 | expected_batch_evals = expected_num_evals * num_eval_batches |
| 262 | |
| 263 | transformers = pytest.importorskip('transformers') |
| 264 | model = SimpleTransformerMaskedLM(vocab_size=tiny_bert_tokenizer.vocab_size) |
| 265 | pretraining_train_dataset = RandomTextLMDataset( |
| 266 | size=100, |
| 267 | vocab_size=tiny_bert_tokenizer.vocab_size, |
| 268 | sequence_length=sequence_length, |
| 269 | use_keys=True, |
| 270 | ) |
| 271 | |
| 272 | collator = transformers.DataCollatorForLanguageModeling(tokenizer=tiny_bert_tokenizer, mlm_probability=0.15) |
| 273 | dataloader = DataLoader( |
| 274 | pretraining_train_dataset, |
| 275 | batch_size=batch_size, |
| 276 | sampler=dist.get_sampler(pretraining_train_dataset), |
| 277 | collate_fn=collator, |
| 278 | ) |
| 279 | eval_dataloader = DataLoader( |
| 280 | pretraining_train_dataset, |
| 281 | batch_size=batch_size, |
| 282 | sampler=dist.get_sampler(pretraining_train_dataset), |
| 283 | collate_fn=collator, |
| 284 | ) |
| 285 | |
| 286 | event_counter_callback = EventCounterCallback() |
| 287 | trainer = Trainer( |
| 288 | model=model, |
nothing calls this directly
no test coverage detected