MCPcopy Index your code
hub / github.com/deepspeedai/DeepSpeed / get_data_loader

Function get_data_loader

tests/small_model_debugging/test_model.py:40–47  ·  view source on GitHub ↗
(model, total_samples, hidden_dim, device)

Source from the content-addressed store, hash-verified

38
39
40def get_data_loader(model, total_samples, hidden_dim, device):
41 batch_size = model.train_micro_batch_size_per_gpu()
42 train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.half)
43 train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim)
44 train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
45 sampler = DistributedSampler(train_dataset)
46 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
47 return train_loader
48
49
50def get_args(tmpdir, config_dict):

Callers 1

test_model.pyFile · 0.70

Calls 1

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…