(args)
| 38 | |
| 39 | |
| 40 | def train(args) -> None: |
| 41 | # ============================== |
| 42 | # Initialize Distributed Training |
| 43 | # ============================== |
| 44 | colossalai.launch_from_torch() |
| 45 | accelerator = get_accelerator() |
| 46 | coordinator = DistCoordinator() |
| 47 | |
| 48 | # ============================== |
| 49 | # Initialize Tensorboard and Save Config |
| 50 | # ============================== |
| 51 | if coordinator.is_master(): |
| 52 | os.makedirs(args.tensorboard_dir, exist_ok=True) |
| 53 | writer = SummaryWriter(args.tensorboard_dir) |
| 54 | |
| 55 | with open(args.config_file, "w") as f: |
| 56 | json.dump(args.__dict__, f, indent=4) |
| 57 | |
| 58 | # ============================== |
| 59 | # Initialize Booster |
| 60 | # ============================== |
| 61 | if args.plugin == "ddp": |
| 62 | plugin = TorchDDPPlugin(find_unused_parameters=True if args.use_grad_checkpoint is False else False) |
| 63 | elif args.plugin == "gemini": |
| 64 | plugin = GeminiPlugin( |
| 65 | precision=args.mixed_precision, |
| 66 | initial_scale=2**16, |
| 67 | max_norm=args.grad_clip, |
| 68 | enable_gradient_accumulation=(args.accumulation_steps > 1), |
| 69 | enable_fused_normalization=get_accelerator().is_available(), |
| 70 | enable_flash_attention=args.use_flash_attn, |
| 71 | ) |
| 72 | elif args.plugin == "gemini_auto": |
| 73 | plugin = GeminiPlugin( |
| 74 | precision=args.mixed_precision, |
| 75 | placement_policy="auto", |
| 76 | initial_scale=2**16, |
| 77 | max_norm=args.grad_clip, |
| 78 | enable_gradient_accumulation=(args.accumulation_steps > 1), |
| 79 | enable_fused_normalization=get_accelerator().is_available(), |
| 80 | enable_flash_attention=args.use_flash_attn, |
| 81 | ) |
| 82 | elif args.plugin == "zero2": |
| 83 | plugin = LowLevelZeroPlugin( |
| 84 | stage=2, |
| 85 | precision=args.mixed_precision, |
| 86 | initial_scale=2**16, |
| 87 | max_norm=args.grad_clip, |
| 88 | ) |
| 89 | elif args.plugin == "zero2_cpu": |
| 90 | plugin = LowLevelZeroPlugin( |
| 91 | stage=2, |
| 92 | precision=args.mixed_precision, |
| 93 | initial_scale=2**16, |
| 94 | cpu_offload=True, |
| 95 | max_norm=args.grad_clip, |
| 96 | ) |
| 97 | elif args.plugin == "3d": |
no test coverage detected
searching dependent graphs…