MCPcopy Index your code
hub / github.com/deepspeedai/DeepSpeed / SimpleModel

Class SimpleModel

tests/small_model_debugging/partial_offload_test.py:15–33  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

13
14
15class SimpleModel(torch.nn.Module):
16
17 def __init__(self, hidden_dim, empty_grad=False):
18 super(SimpleModel, self).__init__()
19 self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
20 self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
21 self.linear3 = torch.nn.Linear(hidden_dim, hidden_dim)
22 self.linear4 = torch.nn.Linear(hidden_dim, hidden_dim)
23 if empty_grad:
24 self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)])
25 self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
26
27 def forward(self, x, y):
28 hidden = x
29 hidden = self.linear(hidden)
30 hidden = self.linear2(hidden)
31 hidden = self.linear3(hidden)
32 hidden = self.linear4(hidden)
33 return self.cross_entropy_loss(hidden, y)
34
35
36def create_config_from_dict(tmpdir, config_dict):

Callers 1

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…