(model, total_samples, hidden_dim, device, dtype=preferred_dtype())
| 273 | |
| 274 | |
| 275 | def random_dataloader(model, total_samples, hidden_dim, device, dtype=preferred_dtype()): |
| 276 | batch_size = model.train_micro_batch_size_per_gpu() |
| 277 | train_dataset = random_dataset(total_samples, hidden_dim, device, dtype=dtype) |
| 278 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size) |
| 279 | return train_loader |
| 280 | |
| 281 | |
| 282 | def sequence_dataloader(model, total_samples, hidden_dim, device, seq_len: int = 32, dtype=preferred_dtype()): |
searching dependent graphs…