MCPcopy
hub / github.com/InternLM/InternLM / initialize_trainer

Function initialize_trainer

internlm/initialize/initialize_trainer.py:32–137  ·  view source on GitHub ↗

Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config. Args: model (:class:`torch.nn.Module` or Callbale): Your model instance or a function to build the model. optimizer (:class:`BaseOptimizer`.

(
    model: nn.Module,
    optimizer: Optimizer,
    criterion: Optional[_Loss] = None,
    train_dataloader: Optional[Iterable] = None,
    test_dataloader: Optional[Iterable] = None,
    lr_scheduler: Optional[_LRScheduler] = None,
    beta2_scheduler: Optional[Beta2Scheduler] = None,
    scheduler_hooks: Optional[List[SchedulerHook]] = None,
)

Source from the content-addressed store, hash-verified

30
31
32def initialize_trainer(
33 model: nn.Module,
34 optimizer: Optimizer,
35 criterion: Optional[_Loss] = None,
36 train_dataloader: Optional[Iterable] = None,
37 test_dataloader: Optional[Iterable] = None,
38 lr_scheduler: Optional[_LRScheduler] = None,
39 beta2_scheduler: Optional[Beta2Scheduler] = None,
40 scheduler_hooks: Optional[List[SchedulerHook]] = None,
41) -> Tuple[Trainer, DataLoader, DataLoader, _LRScheduler]:
42 """Core function to wrap the essential training components with our functionality based on the config which is
43 loaded into gpc.config.
44
45 Args:
46 model (:class:`torch.nn.Module` or Callbale): Your model instance or a function to build the model.
47 optimizer (:class:`BaseOptimizer`.
48 criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
49 train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
50 test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
51 lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
52
53 Returns:
54 Tuple (trainer, train_dataloader, test_dataloader, lr_scheduler):
55 A tuple of ``(trainer, train_dataloader, test_dataloader, lr_scheduler)``
56 where only ``trainer`` could not be None.
57 """
58
59 if isinstance(model, nn.Module):
60 # first sync model across dp ranks
61 model.to(get_current_device())
62 elif isinstance(model, Callable):
63 model = model().to(get_current_device())
64
65 # clip grad norm
66 clip_grad_norm = gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0)
67
68 assert isinstance(optimizer, BaseOptimizer), "optimizer must be instance of BaseOptimizer"
69
70 # gradient handler, only support PipelineSharedModuleGradientHandler now
71 if gpc.is_using_pp():
72 gpc.config.gradient_handler = [dict(type="PipelineSharedModuleGradientHandler")]
73 gradient_handler_cfg = gpc.config.get("gradient_handler", [])
74 gradient_handlers = []
75 assert isinstance(gradient_handler_cfg, list), f"gradient_handler must be list but got {type(gradient_handler_cfg)}"
76 for config in gradient_handler_cfg:
77 if isinstance(config, dict) and config.get("type") == "PipelineSharedModuleGradientHandler":
78 handler = PipelineSharedModuleGradientHandler(model=model, optimizer=optimizer)
79 gradient_handlers.append(handler)
80
81 # initialize scheduler for trainer
82 scheduler = None
83 if gpc.config.model.use_flash_attn:
84 data_fn = None
85 else:
86 data_fn = unpack_data
87 if gpc.is_using_pp():
88 gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
89 tensor_shape = get_tensor_shape()

Callers

nothing calls this directly

Calls 10

get_current_deviceFunction · 0.90
get_tensor_shapeFunction · 0.90
PipelineSchedulerClass · 0.90
EngineClass · 0.90
TrainerClass · 0.90
is_using_ppMethod · 0.80
is_initializedMethod · 0.80

Tested by

no test coverage detected