(self)
| 249 | |
| 250 | @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| 251 | def init_model(self): |
| 252 | from verl.workers.actor import DataParallelPPOActor |
| 253 | # This is used to import external_lib into the huggingface systems |
| 254 | import_external_libs(self.config.model.get('external_lib', None)) |
| 255 | |
| 256 | from omegaconf import OmegaConf |
| 257 | override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create())) |
| 258 | |
| 259 | if self._is_actor or self._is_rollout: |
| 260 | # we need the model for actor and rollout |
| 261 | if self._is_actor: |
| 262 | optim_config = self.config.actor.optim |
| 263 | fsdp_config = self.config.actor.fsdp_config |
| 264 | else: |
| 265 | optim_config = None |
| 266 | fsdp_config = OmegaConf.create() |
| 267 | self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = self._build_model_optimizer( |
| 268 | model_path=self.config.model.path, |
| 269 | fsdp_config=fsdp_config, |
| 270 | optim_config=optim_config, |
| 271 | override_model_config=override_model_config, |
| 272 | enable_gradient_checkpointing=self.config.model.get('enable_gradient_checkpointing', False), |
| 273 | trust_remote_code=self.config.model.get('trust_remote_code', False)) |
| 274 | |
| 275 | # get the original unwrapped module |
| 276 | self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module |
| 277 | |
| 278 | if self._is_offload_param: |
| 279 | # param is require during state_dict in sharding manager |
| 280 | offload_fsdp_grad(module=self.actor_module_fsdp) |
| 281 | log_gpu_memory_usage('After offload actor grad during init', logger=logger) |
| 282 | if self._is_offload_optimizer: |
| 283 | offload_fsdp_optimizer(optimizer=self.actor_optimizer) |
| 284 | log_gpu_memory_usage('After offload actor optimizer during init', logger=logger) |
| 285 | # load from checkpoint |
| 286 | if self._is_actor: |
| 287 | OmegaConf.set_struct(self.config.actor, True) |
| 288 | self.actor = DataParallelPPOActor(config=self.config.actor, |
| 289 | actor_module=self.actor_module_fsdp, |
| 290 | actor_optimizer=self.actor_optimizer) |
| 291 | |
| 292 | if self._is_rollout: |
| 293 | self.rollout, self.sharding_manager = self._build_rollout() |
| 294 | |
| 295 | if self._is_ref: |
| 296 | self.ref_module_fsdp = self._build_model_optimizer(model_path=self.config.model.path, |
| 297 | fsdp_config=self.config.ref.fsdp_config, |
| 298 | optim_config=None, |
| 299 | override_model_config=override_model_config, |
| 300 | trust_remote_code=self.config.model.get( |
| 301 | 'trust_remote_code', False))[0] |
| 302 | if self._is_offload_param: |
| 303 | offload_fsdp_param_and_grad(module=self.ref_module_fsdp, offload_grad=self._is_offload_grad) |
| 304 | |
| 305 | OmegaConf.set_struct(self.config.ref, True) |
| 306 | self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp) |
| 307 | |
| 308 | torch.cuda.synchronize() |
nothing calls this directly
no test coverage detected