(args)
| 34 | |
| 35 | |
| 36 | def main(args): |
| 37 | _min_gpu_count = 2 |
| 38 | if not verify_min_gpu_count(min_gpus=_min_gpu_count): |
| 39 | print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") |
| 40 | exit() |
| 41 | rank = int(os.environ["LOCAL_RANK"]) |
| 42 | if torch.accelerator.is_available(): |
| 43 | device_type = torch.accelerator.current_accelerator() |
| 44 | device = torch.device(f"{device_type}:{rank}") |
| 45 | torch.accelerator.device_index(rank) |
| 46 | print(f"Running on rank {rank} on device {device}") |
| 47 | else: |
| 48 | device = torch.device("cpu") |
| 49 | print(f"Running on device {device}") |
| 50 | |
| 51 | backend = torch.distributed.get_default_backend_for_device(device) |
| 52 | torch.distributed.init_process_group(backend=backend, device_id=device) |
| 53 | |
| 54 | torch.manual_seed(0) |
| 55 | vocab_size = 1024 |
| 56 | batch_size = 32 |
| 57 | seq_len = 64 |
| 58 | model_args = ModelArgs( |
| 59 | n_layers=10, |
| 60 | n_heads=4, |
| 61 | vocab_size=vocab_size, |
| 62 | max_seq_len=seq_len, |
| 63 | dropout_p=0, |
| 64 | ) |
| 65 | with torch.device("meta"): |
| 66 | model = Transformer(model_args) |
| 67 | fsdp_kwargs = {} |
| 68 | if args.mixed_precision: |
| 69 | fsdp_kwargs["mp_policy"] = MixedPrecisionPolicy( |
| 70 | param_dtype=torch.bfloat16, |
| 71 | reduce_dtype=torch.float32, |
| 72 | ) |
| 73 | for layer in model.layers: |
| 74 | fully_shard(layer, **fsdp_kwargs) |
| 75 | fully_shard(model, **fsdp_kwargs) |
| 76 | |
| 77 | inspect_model(model) |
| 78 | |
| 79 | if args.explicit_prefetching: |
| 80 | set_modules_to_forward_prefetch(model, num_to_forward_prefetch=2) |
| 81 | set_modules_to_backward_prefetch(model, num_to_backward_prefetch=2) |
| 82 | |
| 83 | checkpointer = Checkpointer("checkpoints", dcp_api=args.dcp_api) |
| 84 | if checkpointer.last_training_time is None: |
| 85 | model.to_empty(device=device) |
| 86 | model.reset_parameters() |
| 87 | else: |
| 88 | checkpointer.load_model(model) |
| 89 | |
| 90 | if args.mixed_precision: |
| 91 | inspect_mixed_precision(model) |
| 92 | |
| 93 | optim = torch.optim.Adam(model.parameters(), lr=1e-2) |
no test coverage detected