MCPcopy
hub / github.com/deepspeedai/DeepSpeed / initialize

Function initialize

deepspeed/__init__.py:80–252  ·  view source on GitHub ↗

Initialize the DeepSpeed Engine. Arguments: args: an object containing local_rank and deepspeed_config fields. This is optional if `config` is passed. model: Required: nn.module class before apply any wrappers optimizer: Optional: a user defined Optimizer o

(args=None,
               model: torch.nn.Module = None,
               optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None,
               model_parameters: Optional[torch.nn.Module] = None,
               training_data: Optional[torch.utils.data.Dataset] = None,
               lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None,
               distributed_port: int = TORCH_DISTRIBUTED_DEFAULT_PORT,
               mpu=None,
               dist_init_required: Optional[bool] = None,
               collate_fn=None,
               config=None,
               mesh_param=None,
               config_params=None)

Source from the content-addressed store, hash-verified

78
79
80def initialize(args=None,
81 model: torch.nn.Module = None,
82 optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None,
83 model_parameters: Optional[torch.nn.Module] = None,
84 training_data: Optional[torch.utils.data.Dataset] = None,
85 lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None,
86 distributed_port: int = TORCH_DISTRIBUTED_DEFAULT_PORT,
87 mpu=None,
88 dist_init_required: Optional[bool] = None,
89 collate_fn=None,
90 config=None,
91 mesh_param=None,
92 config_params=None):
93 """Initialize the DeepSpeed Engine.
94
95 Arguments:
96 args: an object containing local_rank and deepspeed_config fields.
97 This is optional if `config` is passed.
98
99 model: Required: nn.module class before apply any wrappers
100
101 optimizer: Optional: a user defined Optimizer or Callable that returns an Optimizer object.
102 This overrides any optimizer definition in the DeepSpeed json config.
103
104 model_parameters: Optional: An iterable of torch.Tensors or dicts.
105 Specifies what Tensors should be optimized.
106
107 training_data: Optional: Dataset of type torch.utils.data.Dataset
108
109 lr_scheduler: Optional: Learning Rate Scheduler Object or a Callable that takes an Optimizer and returns a Scheduler object.
110 The scheduler object should define a get_lr(), step(), state_dict(), and load_state_dict() methods
111
112 distributed_port: Optional: Master node (rank 0)'s free port that needs to be used for communication during distributed training
113
114 mpu: Optional: A model parallelism unit object that implements
115 get_{model,data}_parallel_{rank,group,world_size}()
116
117 dist_init_required: Optional: None will auto-initialize torch distributed if needed,
118 otherwise the user can force it to be initialized or not via boolean.
119
120 collate_fn: Optional: Merges a list of samples to form a
121 mini-batch of Tensor(s). Used when using batched loading from a
122 map-style dataset.
123
124 config: Optional: Instead of requiring args.deepspeed_config you can pass your deepspeed config
125 as an argument instead, as a path or a dictionary.
126
127 config_params: Optional: Same as `config`, kept for backwards compatibility.
128
129 Returns:
130 A tuple of ``engine``, ``optimizer``, ``training_dataloader``, ``lr_scheduler``
131
132 * ``engine``: DeepSpeed runtime engine which wraps the client model for distributed training.
133
134 * ``optimizer``: Wrapped optimizer if a user defined ``optimizer`` is supplied, or if
135 optimizer is specified in json config else ``None``.
136
137 * ``training_dataloader``: DeepSpeed dataloader if ``training_data`` was supplied,

Calls 14

log_distFunction · 0.85
get_acceleratorFunction · 0.85
load_ds_configFunction · 0.85
set_autotp_modeFunction · 0.85
DeepSpeedConfigClass · 0.85
set_optimizer_flagsFunction · 0.85
DeepSpeedEngineClass · 0.85
PipelineEngineClass · 0.85
warningMethod · 0.80
mpuMethod · 0.80

Used in the wild real call sites across dependent graphs

searching dependent graphs…