MCPcopy
hub / github.com/deepspeedai/DeepSpeed / test

Method test

tests/unit/v1/zero/test_zero.py:1288–1376  ·  view source on GitHub ↗
(self, dtype)

Source from the content-addressed store, hash-verified

1286 world_size = 2
1287
1288 def test(self, dtype):
1289
1290 if not dtype in get_accelerator().supported_dtypes():
1291 pytest.skip("{dtype} is not supported")
1292
1293 config_dict = {
1294 "train_batch_size": 4,
1295 "steps_per_print": 1,
1296 "optimizer": {
1297 "type": "Adam",
1298 "params": {
1299 "lr": 1e-4
1300 }
1301 },
1302 "zero_optimization": {
1303 "stage": 3
1304 },
1305 }
1306
1307 if dtype == torch.bfloat16:
1308 if get_accelerator().is_bf16_supported():
1309 config_dict["bf16"] = {"enabled": True}
1310 else:
1311 pytest.skip("bfloat16 is not supported on this accelerator")
1312 elif dtype == torch.float16:
1313 if get_accelerator().is_fp16_supported():
1314 config_dict["fp16"] = {"enabled": True}
1315 else:
1316 pytest.skip("fp16 is not supported on this accelerator")
1317 hidden_dim = 10
1318
1319 class SubModel(torch.nn.Module):
1320
1321 def __init__(self, input_size, output_size, dropout_prob=0.5, device=None):
1322 super(SubModel, self).__init__()
1323 self.linear = torch.nn.Linear(input_size, output_size, device=device)
1324 self.dropout = torch.nn.Dropout(dropout_prob)
1325 self.module_list = torch.nn.ModuleList([torch.nn.Linear(input_size, output_size, device=device)])
1326
1327 def forward(self, x):
1328 x = self.linear(x)
1329 x = self.dropout(x)
1330 x = self.module_list[0](x)
1331 return x
1332
1333 class MyModel(torch.nn.Module):
1334
1335 def __init__(self, hidden_dim):
1336 super(MyModel, self).__init__()
1337 self.l1 = skip_init(Linear, hidden_dim, hidden_dim)
1338 self.l2 = skip_init(SubModel, hidden_dim, hidden_dim)
1339 self.l3 = torch.nn.Linear(hidden_dim, hidden_dim)
1340 self.cel = torch.nn.CrossEntropyLoss()
1341 self.l4 = skip_init(SubModel, hidden_dim, hidden_dim)
1342
1343 def forward(self, x, y):
1344 x = self.l1(x)
1345 x = self.l2(x)

Callers

nothing calls this directly

Calls 14

get_acceleratorFunction · 0.90
random_dataloaderFunction · 0.90
get_world_sizeMethod · 0.80
numelMethod · 0.80
initializeMethod · 0.80
MyModelClass · 0.70
supported_dtypesMethod · 0.45
is_bf16_supportedMethod · 0.45
is_fp16_supportedMethod · 0.45
parametersMethod · 0.45
barrierMethod · 0.45
backwardMethod · 0.45

Tested by

no test coverage detected