MCPcopy
hub / github.com/hpcaitech/ColossalAI / initialize

Function initialize

colossalai/legacy/initialize.py:242–490  ·  view source on GitHub ↗

Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config. Args: model (:class:`torch.nn.Module` or Callable): Your model instance or a function to build the model. optimizer (:class:`torch.optim.optimi

(
    model: nn.Module,
    optimizer: Optimizer,
    criterion: Optional[_Loss] = None,
    train_dataloader: Optional[Iterable] = None,
    test_dataloader: Optional[Iterable] = None,
    lr_scheduler: Optional[_LRScheduler] = None,
    ophooks: Optional[List[BaseOpHook]] = None,
    verbose: bool = True,
)

Source from the content-addressed store, hash-verified

240
241
242def initialize(
243 model: nn.Module,
244 optimizer: Optimizer,
245 criterion: Optional[_Loss] = None,
246 train_dataloader: Optional[Iterable] = None,
247 test_dataloader: Optional[Iterable] = None,
248 lr_scheduler: Optional[_LRScheduler] = None,
249 ophooks: Optional[List[BaseOpHook]] = None,
250 verbose: bool = True,
251) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
252 """Core function to wrap the essential training components with our functionality based on the config which is
253 loaded into gpc.config.
254
255 Args:
256 model (:class:`torch.nn.Module` or Callable): Your model instance or a function to build the model.
257 optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`):
258 Your optimizer instance.
259 criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
260 train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
261 test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
262 lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
263 verbose (bool, optional): Whether to print logs.
264
265 Returns:
266 Tuple (engine, train_dataloader, test_dataloader, lr_scheduler):
267 A tuple of ``(engine, train_dataloader, test_dataloader, lr_scheduler)``
268 where only ``engine`` could not be None.
269 """
270 # get logger
271 logger = get_dist_logger()
272 gpc.verbose = verbose
273
274 # get config from gpc
275 config = gpc.config
276
277 # print config
278 if verbose:
279 logger.info(
280 f"\n========== Your Config ========\n"
281 f"{pprint.pformat(gpc.config)}\n"
282 f"================================\n",
283 ranks=[0],
284 )
285
286 # cudnn
287 cudnn_benchmark = config.get("cudnn_benchmark", False)
288 cudnn_deterministic = config.get("cudnn_deterministic", False)
289 torch.backends.cudnn.benchmark = cudnn_benchmark
290 torch.backends.cudnn.deterministic = cudnn_deterministic
291 if verbose:
292 logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
293
294 # zero
295 use_zero = hasattr(gpc.config, "zero")
296 if use_zero:
297 zero_cfg = gpc.config.get("zero", None)
298 if zero_cfg is not None:
299 cfg_ = zero_cfg.copy()

Callers

nothing calls this directly

Calls 15

get_dist_loggerFunction · 0.90
convert_to_zero_v2Function · 0.90
get_acceleratorFunction · 0.90
is_using_sequenceFunction · 0.90
sync_model_paramFunction · 0.90
is_using_ddpFunction · 0.90
ConfigExceptionClass · 0.90
is_using_ppFunction · 0.90
convert_to_ampFunction · 0.90
get_tensor_shapeFunction · 0.90
PipelineScheduleClass · 0.90

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…