Initialize the DeepSpeed Engine. Arguments: args: an object containing local_rank and deepspeed_config fields. This is optional if `config` is passed. model: Required: nn.module class before apply any wrappers optimizer: Optional: a user defined Optimizer o
(args=None,
model: torch.nn.Module = None,
optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None,
model_parameters: Optional[torch.nn.Module] = None,
training_data: Optional[torch.utils.data.Dataset] = None,
lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None,
distributed_port: int = TORCH_DISTRIBUTED_DEFAULT_PORT,
mpu=None,
dist_init_required: Optional[bool] = None,
collate_fn=None,
config=None,
mesh_param=None,
config_params=None)
| 78 | |
| 79 | |
| 80 | def initialize(args=None, |
| 81 | model: torch.nn.Module = None, |
| 82 | optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None, |
| 83 | model_parameters: Optional[torch.nn.Module] = None, |
| 84 | training_data: Optional[torch.utils.data.Dataset] = None, |
| 85 | lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None, |
| 86 | distributed_port: int = TORCH_DISTRIBUTED_DEFAULT_PORT, |
| 87 | mpu=None, |
| 88 | dist_init_required: Optional[bool] = None, |
| 89 | collate_fn=None, |
| 90 | config=None, |
| 91 | mesh_param=None, |
| 92 | config_params=None): |
| 93 | """Initialize the DeepSpeed Engine. |
| 94 | |
| 95 | Arguments: |
| 96 | args: an object containing local_rank and deepspeed_config fields. |
| 97 | This is optional if `config` is passed. |
| 98 | |
| 99 | model: Required: nn.module class before apply any wrappers |
| 100 | |
| 101 | optimizer: Optional: a user defined Optimizer or Callable that returns an Optimizer object. |
| 102 | This overrides any optimizer definition in the DeepSpeed json config. |
| 103 | |
| 104 | model_parameters: Optional: An iterable of torch.Tensors or dicts. |
| 105 | Specifies what Tensors should be optimized. |
| 106 | |
| 107 | training_data: Optional: Dataset of type torch.utils.data.Dataset |
| 108 | |
| 109 | lr_scheduler: Optional: Learning Rate Scheduler Object or a Callable that takes an Optimizer and returns a Scheduler object. |
| 110 | The scheduler object should define a get_lr(), step(), state_dict(), and load_state_dict() methods |
| 111 | |
| 112 | distributed_port: Optional: Master node (rank 0)'s free port that needs to be used for communication during distributed training |
| 113 | |
| 114 | mpu: Optional: A model parallelism unit object that implements |
| 115 | get_{model,data}_parallel_{rank,group,world_size}() |
| 116 | |
| 117 | dist_init_required: Optional: None will auto-initialize torch distributed if needed, |
| 118 | otherwise the user can force it to be initialized or not via boolean. |
| 119 | |
| 120 | collate_fn: Optional: Merges a list of samples to form a |
| 121 | mini-batch of Tensor(s). Used when using batched loading from a |
| 122 | map-style dataset. |
| 123 | |
| 124 | config: Optional: Instead of requiring args.deepspeed_config you can pass your deepspeed config |
| 125 | as an argument instead, as a path or a dictionary. |
| 126 | |
| 127 | config_params: Optional: Same as `config`, kept for backwards compatibility. |
| 128 | |
| 129 | Returns: |
| 130 | A tuple of ``engine``, ``optimizer``, ``training_dataloader``, ``lr_scheduler`` |
| 131 | |
| 132 | * ``engine``: DeepSpeed runtime engine which wraps the client model for distributed training. |
| 133 | |
| 134 | * ``optimizer``: Wrapped optimizer if a user defined ``optimizer`` is supplied, or if |
| 135 | optimizer is specified in json config else ``None``. |
| 136 | |
| 137 | * ``training_dataloader``: DeepSpeed dataloader if ``training_data`` was supplied, |
searching dependent graphs…