(model, total_samples, hidden_dim, device)
| 38 | |
| 39 | |
| 40 | def get_data_loader(model, total_samples, hidden_dim, device): |
| 41 | batch_size = model.train_micro_batch_size_per_gpu() |
| 42 | train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.half) |
| 43 | train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim) |
| 44 | train_dataset = torch.utils.data.TensorDataset(train_data, train_label) |
| 45 | sampler = DistributedSampler(train_dataset) |
| 46 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=sampler) |
| 47 | return train_loader |
| 48 | |
| 49 | |
| 50 | def get_args(tmpdir, config_dict): |
no test coverage detected
searching dependent graphs…