(self, dtype)
| 1286 | world_size = 2 |
| 1287 | |
| 1288 | def test(self, dtype): |
| 1289 | |
| 1290 | if not dtype in get_accelerator().supported_dtypes(): |
| 1291 | pytest.skip("{dtype} is not supported") |
| 1292 | |
| 1293 | config_dict = { |
| 1294 | "train_batch_size": 4, |
| 1295 | "steps_per_print": 1, |
| 1296 | "optimizer": { |
| 1297 | "type": "Adam", |
| 1298 | "params": { |
| 1299 | "lr": 1e-4 |
| 1300 | } |
| 1301 | }, |
| 1302 | "zero_optimization": { |
| 1303 | "stage": 3 |
| 1304 | }, |
| 1305 | } |
| 1306 | |
| 1307 | if dtype == torch.bfloat16: |
| 1308 | if get_accelerator().is_bf16_supported(): |
| 1309 | config_dict["bf16"] = {"enabled": True} |
| 1310 | else: |
| 1311 | pytest.skip("bfloat16 is not supported on this accelerator") |
| 1312 | elif dtype == torch.float16: |
| 1313 | if get_accelerator().is_fp16_supported(): |
| 1314 | config_dict["fp16"] = {"enabled": True} |
| 1315 | else: |
| 1316 | pytest.skip("fp16 is not supported on this accelerator") |
| 1317 | hidden_dim = 10 |
| 1318 | |
| 1319 | class SubModel(torch.nn.Module): |
| 1320 | |
| 1321 | def __init__(self, input_size, output_size, dropout_prob=0.5, device=None): |
| 1322 | super(SubModel, self).__init__() |
| 1323 | self.linear = torch.nn.Linear(input_size, output_size, device=device) |
| 1324 | self.dropout = torch.nn.Dropout(dropout_prob) |
| 1325 | self.module_list = torch.nn.ModuleList([torch.nn.Linear(input_size, output_size, device=device)]) |
| 1326 | |
| 1327 | def forward(self, x): |
| 1328 | x = self.linear(x) |
| 1329 | x = self.dropout(x) |
| 1330 | x = self.module_list[0](x) |
| 1331 | return x |
| 1332 | |
| 1333 | class MyModel(torch.nn.Module): |
| 1334 | |
| 1335 | def __init__(self, hidden_dim): |
| 1336 | super(MyModel, self).__init__() |
| 1337 | self.l1 = skip_init(Linear, hidden_dim, hidden_dim) |
| 1338 | self.l2 = skip_init(SubModel, hidden_dim, hidden_dim) |
| 1339 | self.l3 = torch.nn.Linear(hidden_dim, hidden_dim) |
| 1340 | self.cel = torch.nn.CrossEntropyLoss() |
| 1341 | self.l4 = skip_init(SubModel, hidden_dim, hidden_dim) |
| 1342 | |
| 1343 | def forward(self, x, y): |
| 1344 | x = self.l1(x) |
| 1345 | x = self.l2(x) |
nothing calls this directly
no test coverage detected