MCPcopy Index your code
hub / github.com/pytorch/examples / main

Function main

distributed/FSDP2/example.py:36–108  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

34
35
36def main(args):
37 _min_gpu_count = 2
38 if not verify_min_gpu_count(min_gpus=_min_gpu_count):
39 print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
40 exit()
41 rank = int(os.environ["LOCAL_RANK"])
42 if torch.accelerator.is_available():
43 device_type = torch.accelerator.current_accelerator()
44 device = torch.device(f"{device_type}:{rank}")
45 torch.accelerator.device_index(rank)
46 print(f"Running on rank {rank} on device {device}")
47 else:
48 device = torch.device("cpu")
49 print(f"Running on device {device}")
50
51 backend = torch.distributed.get_default_backend_for_device(device)
52 torch.distributed.init_process_group(backend=backend, device_id=device)
53
54 torch.manual_seed(0)
55 vocab_size = 1024
56 batch_size = 32
57 seq_len = 64
58 model_args = ModelArgs(
59 n_layers=10,
60 n_heads=4,
61 vocab_size=vocab_size,
62 max_seq_len=seq_len,
63 dropout_p=0,
64 )
65 with torch.device("meta"):
66 model = Transformer(model_args)
67 fsdp_kwargs = {}
68 if args.mixed_precision:
69 fsdp_kwargs["mp_policy"] = MixedPrecisionPolicy(
70 param_dtype=torch.bfloat16,
71 reduce_dtype=torch.float32,
72 )
73 for layer in model.layers:
74 fully_shard(layer, **fsdp_kwargs)
75 fully_shard(model, **fsdp_kwargs)
76
77 inspect_model(model)
78
79 if args.explicit_prefetching:
80 set_modules_to_forward_prefetch(model, num_to_forward_prefetch=2)
81 set_modules_to_backward_prefetch(model, num_to_backward_prefetch=2)
82
83 checkpointer = Checkpointer("checkpoints", dcp_api=args.dcp_api)
84 if checkpointer.last_training_time is None:
85 model.to_empty(device=device)
86 model.reset_parameters()
87 else:
88 checkpointer.load_model(model)
89
90 if args.mixed_precision:
91 inspect_mixed_precision(model)
92
93 optim = torch.optim.Adam(model.parameters(), lr=1e-2)

Callers 1

example.pyFile · 0.70

Calls 12

reset_parametersMethod · 0.95
load_modelMethod · 0.95
load_optimMethod · 0.95
saveMethod · 0.95
ModelArgsClass · 0.90
TransformerClass · 0.90
inspect_modelFunction · 0.90
CheckpointerClass · 0.90
inspect_mixed_precisionFunction · 0.90
verify_min_gpu_countFunction · 0.70

Tested by

no test coverage detected