Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config. Args: model (:class:`torch.nn.Module` or Callable): Your model instance or a function to build the model. optimizer (:class:`torch.optim.optimi
(
model: nn.Module,
optimizer: Optimizer,
criterion: Optional[_Loss] = None,
train_dataloader: Optional[Iterable] = None,
test_dataloader: Optional[Iterable] = None,
lr_scheduler: Optional[_LRScheduler] = None,
ophooks: Optional[List[BaseOpHook]] = None,
verbose: bool = True,
)
| 240 | |
| 241 | |
| 242 | def initialize( |
| 243 | model: nn.Module, |
| 244 | optimizer: Optimizer, |
| 245 | criterion: Optional[_Loss] = None, |
| 246 | train_dataloader: Optional[Iterable] = None, |
| 247 | test_dataloader: Optional[Iterable] = None, |
| 248 | lr_scheduler: Optional[_LRScheduler] = None, |
| 249 | ophooks: Optional[List[BaseOpHook]] = None, |
| 250 | verbose: bool = True, |
| 251 | ) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]: |
| 252 | """Core function to wrap the essential training components with our functionality based on the config which is |
| 253 | loaded into gpc.config. |
| 254 | |
| 255 | Args: |
| 256 | model (:class:`torch.nn.Module` or Callable): Your model instance or a function to build the model. |
| 257 | optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`): |
| 258 | Your optimizer instance. |
| 259 | criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance. |
| 260 | train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training. |
| 261 | test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing. |
| 262 | lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional. |
| 263 | verbose (bool, optional): Whether to print logs. |
| 264 | |
| 265 | Returns: |
| 266 | Tuple (engine, train_dataloader, test_dataloader, lr_scheduler): |
| 267 | A tuple of ``(engine, train_dataloader, test_dataloader, lr_scheduler)`` |
| 268 | where only ``engine`` could not be None. |
| 269 | """ |
| 270 | # get logger |
| 271 | logger = get_dist_logger() |
| 272 | gpc.verbose = verbose |
| 273 | |
| 274 | # get config from gpc |
| 275 | config = gpc.config |
| 276 | |
| 277 | # print config |
| 278 | if verbose: |
| 279 | logger.info( |
| 280 | f"\n========== Your Config ========\n" |
| 281 | f"{pprint.pformat(gpc.config)}\n" |
| 282 | f"================================\n", |
| 283 | ranks=[0], |
| 284 | ) |
| 285 | |
| 286 | # cudnn |
| 287 | cudnn_benchmark = config.get("cudnn_benchmark", False) |
| 288 | cudnn_deterministic = config.get("cudnn_deterministic", False) |
| 289 | torch.backends.cudnn.benchmark = cudnn_benchmark |
| 290 | torch.backends.cudnn.deterministic = cudnn_deterministic |
| 291 | if verbose: |
| 292 | logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0]) |
| 293 | |
| 294 | # zero |
| 295 | use_zero = hasattr(gpc.config, "zero") |
| 296 | if use_zero: |
| 297 | zero_cfg = gpc.config.get("zero", None) |
| 298 | if zero_cfg is not None: |
| 299 | cfg_ = zero_cfg.copy() |
nothing calls this directly
no test coverage detected
searching dependent graphs…