(self, tmpdir, zero_stage, freeze_params)
| 312 | model.destroy() |
| 313 | |
| 314 | def test_2_param_groups(self, tmpdir, zero_stage, freeze_params): |
| 315 | # TODO: |
| 316 | # - need to test with multiple param groups |
| 317 | # force all params to be partitioned by forcing threshold=0 |
| 318 | config_dict = { |
| 319 | "train_micro_batch_size_per_gpu": 2, |
| 320 | "gradient_accumulation_steps": 2, |
| 321 | "steps_per_print": 1, |
| 322 | "zero_allow_untested_optimizer": 1, |
| 323 | "zero_optimization": { |
| 324 | "stage": zero_stage, |
| 325 | "stage3_param_persistence_threshold": 0, |
| 326 | }, |
| 327 | "optimizer": { |
| 328 | "type": "Adam", |
| 329 | "params": { |
| 330 | "lr": 1e-3 |
| 331 | } |
| 332 | }, |
| 333 | } |
| 334 | if get_accelerator().is_bf16_supported(): |
| 335 | config_dict["bf16"] = {"enabled": True} |
| 336 | elif get_accelerator().is_fp16_supported(): |
| 337 | config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} |
| 338 | |
| 339 | class MyModel(torch.nn.Module): |
| 340 | |
| 341 | def __init__(self, hidden_dim, n_layers, freeze_params): |
| 342 | super().__init__() |
| 343 | self.ll = torch.nn.ModuleList(torch.nn.Linear(hidden_dim, hidden_dim) for i in range(n_layers)) |
| 344 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss() |
| 345 | if freeze_params: |
| 346 | self.ll[0].weight.requires_grad = False |
| 347 | self.ll[0].bias.requires_grad = False |
| 348 | |
| 349 | def forward(self, x, y): |
| 350 | hidden = x |
| 351 | for l in self.ll: |
| 352 | hidden = l(hidden) |
| 353 | return self.cross_entropy_loss(hidden, y) |
| 354 | |
| 355 | hidden_dim = 3 |
| 356 | |
| 357 | world_size = dist.get_world_size() |
| 358 | n_layers = world_size * 2 |
| 359 | model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers, freeze_params=freeze_params) |
| 360 | |
| 361 | optim_groups = [ |
| 362 | { |
| 363 | "params": [l.weight for l in model.ll], |
| 364 | "weight_decay": 0.01, |
| 365 | }, |
| 366 | { |
| 367 | "params": [l.bias for l in model.ll], |
| 368 | "weight_decay": 0.0 |
| 369 | }, |
| 370 | ] |
| 371 | optim = torch.optim.SGD(optim_groups, lr=0.1) |
nothing calls this directly
no test coverage detected