MCPcopy
hub / github.com/deepspeedai/DeepSpeed / test_2_param_groups

Method test_2_param_groups

tests/unit/v1/zero/test_zero.py:314–419  ·  view source on GitHub ↗
(self, tmpdir, zero_stage, freeze_params)

Source from the content-addressed store, hash-verified

312 model.destroy()
313
314 def test_2_param_groups(self, tmpdir, zero_stage, freeze_params):
315 # TODO:
316 # - need to test with multiple param groups
317 # force all params to be partitioned by forcing threshold=0
318 config_dict = {
319 "train_micro_batch_size_per_gpu": 2,
320 "gradient_accumulation_steps": 2,
321 "steps_per_print": 1,
322 "zero_allow_untested_optimizer": 1,
323 "zero_optimization": {
324 "stage": zero_stage,
325 "stage3_param_persistence_threshold": 0,
326 },
327 "optimizer": {
328 "type": "Adam",
329 "params": {
330 "lr": 1e-3
331 }
332 },
333 }
334 if get_accelerator().is_bf16_supported():
335 config_dict["bf16"] = {"enabled": True}
336 elif get_accelerator().is_fp16_supported():
337 config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
338
339 class MyModel(torch.nn.Module):
340
341 def __init__(self, hidden_dim, n_layers, freeze_params):
342 super().__init__()
343 self.ll = torch.nn.ModuleList(torch.nn.Linear(hidden_dim, hidden_dim) for i in range(n_layers))
344 self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
345 if freeze_params:
346 self.ll[0].weight.requires_grad = False
347 self.ll[0].bias.requires_grad = False
348
349 def forward(self, x, y):
350 hidden = x
351 for l in self.ll:
352 hidden = l(hidden)
353 return self.cross_entropy_loss(hidden, y)
354
355 hidden_dim = 3
356
357 world_size = dist.get_world_size()
358 n_layers = world_size * 2
359 model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers, freeze_params=freeze_params)
360
361 optim_groups = [
362 {
363 "params": [l.weight for l in model.ll],
364 "weight_decay": 0.01,
365 },
366 {
367 "params": [l.bias for l in model.ll],
368 "weight_decay": 0.0
369 },
370 ]
371 optim = torch.optim.SGD(optim_groups, lr=0.1)

Callers

nothing calls this directly

Calls 15

get_acceleratorFunction · 0.90
random_dataloaderFunction · 0.90
get_world_sizeMethod · 0.80
initializeMethod · 0.80
save_checkpointMethod · 0.80
MyModelClass · 0.70
is_bf16_supportedMethod · 0.45
is_fp16_supportedMethod · 0.45
parametersMethod · 0.45
empty_partition_cacheMethod · 0.45
backwardMethod · 0.45

Tested by

no test coverage detected