MCPcopy
hub / github.com/hpcaitech/ColossalAI / train

Function train

applications/Colossal-LLaMA/train.py:40–433  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

38
39
40def train(args) -> None:
41 # ==============================
42 # Initialize Distributed Training
43 # ==============================
44 colossalai.launch_from_torch()
45 accelerator = get_accelerator()
46 coordinator = DistCoordinator()
47
48 # ==============================
49 # Initialize Tensorboard and Save Config
50 # ==============================
51 if coordinator.is_master():
52 os.makedirs(args.tensorboard_dir, exist_ok=True)
53 writer = SummaryWriter(args.tensorboard_dir)
54
55 with open(args.config_file, "w") as f:
56 json.dump(args.__dict__, f, indent=4)
57
58 # ==============================
59 # Initialize Booster
60 # ==============================
61 if args.plugin == "ddp":
62 plugin = TorchDDPPlugin(find_unused_parameters=True if args.use_grad_checkpoint is False else False)
63 elif args.plugin == "gemini":
64 plugin = GeminiPlugin(
65 precision=args.mixed_precision,
66 initial_scale=2**16,
67 max_norm=args.grad_clip,
68 enable_gradient_accumulation=(args.accumulation_steps > 1),
69 enable_fused_normalization=get_accelerator().is_available(),
70 enable_flash_attention=args.use_flash_attn,
71 )
72 elif args.plugin == "gemini_auto":
73 plugin = GeminiPlugin(
74 precision=args.mixed_precision,
75 placement_policy="auto",
76 initial_scale=2**16,
77 max_norm=args.grad_clip,
78 enable_gradient_accumulation=(args.accumulation_steps > 1),
79 enable_fused_normalization=get_accelerator().is_available(),
80 enable_flash_attention=args.use_flash_attn,
81 )
82 elif args.plugin == "zero2":
83 plugin = LowLevelZeroPlugin(
84 stage=2,
85 precision=args.mixed_precision,
86 initial_scale=2**16,
87 max_norm=args.grad_clip,
88 )
89 elif args.plugin == "zero2_cpu":
90 plugin = LowLevelZeroPlugin(
91 stage=2,
92 precision=args.mixed_precision,
93 initial_scale=2**16,
94 cpu_offload=True,
95 max_norm=args.grad_clip,
96 )
97 elif args.plugin == "3d":

Callers 1

train.pyFile · 0.70

Calls 15

is_masterMethod · 0.95
print_on_masterMethod · 0.95
prepare_dataloaderMethod · 0.95
enable_loraMethod · 0.95
boostMethod · 0.95
load_modelMethod · 0.95
execute_pipelineMethod · 0.95
stepMethod · 0.95
backwardMethod · 0.95
save_modelMethod · 0.95
get_acceleratorFunction · 0.90
DistCoordinatorClass · 0.90

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…